"...git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "1ff119c7b7e2bc0ce0fdf06abaa2e9930421a750"
Unverified Commit 19e1a5cf authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[shardformer] update colo attention to support custom mask (#5510)

* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests
parent 9a3321e9
...@@ -117,7 +117,7 @@ jobs: ...@@ -117,7 +117,7 @@ jobs:
cd TensorNVMe cd TensorNVMe
conda install cmake conda install cmake
pip install -r requirements.txt pip install -r requirements.txt
pip install -v . DISABLE_URING=1 pip install -v .
- name: Store TensorNVMe Cache - name: Store TensorNVMe Cache
run: | run: |
...@@ -201,4 +201,4 @@ jobs: ...@@ -201,4 +201,4 @@ jobs:
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
with: with:
name: report name: report
path: report/ path: report/
\ No newline at end of file
...@@ -44,7 +44,7 @@ jobs: ...@@ -44,7 +44,7 @@ jobs:
cd TensorNVMe cd TensorNVMe
conda install cmake conda install cmake
pip install -r requirements.txt pip install -r requirements.txt
pip install -v . DISABLE_URING=1 pip install -v .
- uses: actions/checkout@v2 - uses: actions/checkout@v2
if: steps.check-avai.outputs.avai == 'true' if: steps.check-avai.outputs.avai == 'true'
......
...@@ -66,7 +66,7 @@ jobs: ...@@ -66,7 +66,7 @@ jobs:
cd TensorNVMe cd TensorNVMe
apt update && apt install -y cmake apt update && apt install -y cmake
pip install -r requirements.txt pip install -r requirements.txt
pip install -v . DISABLE_URING=1 pip install -v .
- uses: actions/checkout@v2 - uses: actions/checkout@v2
with: with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
......
...@@ -60,7 +60,7 @@ jobs: ...@@ -60,7 +60,7 @@ jobs:
cd TensorNVMe cd TensorNVMe
apt update && apt install -y cmake apt update && apt install -y cmake
pip install -r requirements.txt pip install -r requirements.txt
pip install -v . DISABLE_URING=1 pip install -v .
- uses: actions/checkout@v2 - uses: actions/checkout@v2
with: with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
......
...@@ -56,7 +56,7 @@ jobs: ...@@ -56,7 +56,7 @@ jobs:
cd TensorNVMe cd TensorNVMe
apt update && apt install -y cmake apt update && apt install -y cmake
pip install -r requirements.txt pip install -r requirements.txt
pip install -v . DISABLE_URING=1 pip install -v .
- uses: actions/checkout@v2 - uses: actions/checkout@v2
with: with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
......
...@@ -6,7 +6,7 @@ from .extensions import ( ...@@ -6,7 +6,7 @@ from .extensions import (
CpuAdamX86Extension, CpuAdamX86Extension,
FlashAttentionDaoCudaExtension, FlashAttentionDaoCudaExtension,
FlashAttentionNpuExtension, FlashAttentionNpuExtension,
FlashAttentionXformersCudaExtension, FlashAttentionSdpaCudaExtension,
FusedOptimizerCudaExtension, FusedOptimizerCudaExtension,
LayerNormCudaExtension, LayerNormCudaExtension,
MoeCudaExtension, MoeCudaExtension,
...@@ -65,9 +65,9 @@ class KernelLoader: ...@@ -65,9 +65,9 @@ class KernelLoader:
else: else:
usable_exts = [] usable_exts = []
for ext in exts: for ext in exts:
if ext.is_hardware_available(): if ext.is_available():
# make sure the machine is compatible during kernel loading # make sure the machine is compatible during kernel loading
ext.assert_hardware_compatible() ext.assert_compatible()
usable_exts.append(ext) usable_exts.append(ext)
assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine." assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine."
...@@ -106,4 +106,20 @@ class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader): ...@@ -106,4 +106,20 @@ class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader):
class FlashAttentionLoader(KernelLoader): class FlashAttentionLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension] REGISTRY = [
FlashAttentionNpuExtension,
FlashAttentionDaoCudaExtension,
FlashAttentionSdpaCudaExtension,
]
class FlashAttentionWithPaddingMaskLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension]
class FlashAttentionWithCustomMaskLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
class FlashAttentionForFloatAndCustomMaskLoader(KernelLoader):
REGISTRY = [FlashAttentionSdpaCudaExtension]
import enum
import math
import warnings
from dataclasses import dataclass
from typing import Iterable, Optional, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from colossalai.accelerator import get_accelerator
from colossalai.kernel.kernel_loader import FlashAttentionLoader
@dataclass
class SeqLenInfo:
seqlens: Iterable[int] = None
indices: torch.Tensor = None
max_seqlen: int = None
cu_seqlens: torch.Tensor = None
@staticmethod
def materialize(
attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device()
):
if attn_mask is not None:
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
else:
batch_size, tgt_len = size[0], size[1]
indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
max_seqlen = max(seqlens)
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3
class Unpad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
ctx.save_for_backward(indices)
# [b, s, ...]
assert tensor.ndim >= 3
ctx.bsz = tensor.shape[0]
out = rearrange(tensor, "b s ... -> (b s) ...")
ctx.shape = out.shape
# [ntokens, ...]
return out[indices]
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [ntokens, ...]
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
grad[indices] = grad_output
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
# [b, s, ...]
return grad, None
class Repad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
ctx.save_for_backward(indices)
# [ntokens, ...]
tensor = tensor
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
# [b*s, ...]
out[indices] = tensor
return out
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [b*s, ...]
grad = grad_output[indices]
# [ntokens, ...]
return grad, None, None, None
class ColoAttention(torch.nn.Module):
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
super().__init__()
assert (
embed_dim % num_heads == 0
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
if scale is not None:
self.scale = scale
else:
self.scale = 1 / math.sqrt(embed_dim // num_heads)
self.dropout = dropout
self.attn = FlashAttentionLoader().load()
@staticmethod
def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
return Unpad.apply(tensor, indices)
@staticmethod
def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
return Repad.apply(tensor, indices, batch_size, seq_len)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
origin_attn_mask: Optional[torch.Tensor] = None,
attn_mask_type: Optional[AttnMaskType] = None,
bias: Optional[torch.Tensor] = None,
):
"""
ColoAttention
Args:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
origin_attn_mask: (nheads, q_seqlen, kv_seqlen)
bias: will not be used
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
# if flash attention is not applicable, switch to memory effcient attention
if self.attn.__name__ == "flash_attention" and (
query.dtype not in [torch.float16, torch.bfloat16] or bias != None
):
warnings.warn(
f"flash-attn expects fp16 or bf16 but got {query.dtype}, switching to xformers' implementation."
)
self.attn = FlashAttentionLoader().load(ext_name="flash_attention_xformers_cuda")
padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
causal = attn_mask_type is not None and attn_mask_type.value > 1
batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
# unpad
seq_len_info_q = None
seq_len_info_kv = None
if padded:
# bert style, unpad process
assert (
attn_mask is not None
), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
assert attn_mask.dim() == 2, (
"attention mask is supposed to have shape (batch_size, seq_len), "
+ f"but got {attn_mask.dim()} dimensions."
)
# bert style
if tgt_len == src_len:
seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
if batch_size > 1:
query, key, value = self.unpad(
torch.stack([query, key, value], dim=2), seq_len_info_q.indices
).unbind(dim=1)
else:
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
seq_len_info_kv = seq_len_info_q
else:
seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device)
seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
if batch_size > 1:
query = rearrange(query, "b s ... -> c (b s) ...", c=1)
key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind(
dim=1
)
else:
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
out = self.attn(
query,
key,
value,
seq_len_info_q=seq_len_info_q,
seq_len_info_kv=seq_len_info_kv,
origin_attn_mask=origin_attn_mask,
dropout_p=self.dropout,
scale=self.scale,
causal=causal,
padded=padded,
)
# repad
if padded:
if batch_size > 1:
out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len)
out = rearrange(out, "(b s) h d -> b s h d", b=batch_size)
if len(out.shape) == 4:
out = rearrange(out, "b s h d -> b s (h d)")
return out
from .attn import AttnMaskType, ColoAttention
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row from .linear import Linear1D_Col, Linear1D_Row
...@@ -23,4 +24,6 @@ __all__ = [ ...@@ -23,4 +24,6 @@ __all__ = [
"FusedRMSNorm", "FusedRMSNorm",
"FusedLinear1D_Col", "FusedLinear1D_Col",
"ParallelModule", "ParallelModule",
"AttnMaskType",
"ColoAttention",
] ]
from enum import Enum
from typing import Callable, Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from colossalai.kernel.kernel_loader import (
FlashAttentionForFloatAndCustomMaskLoader,
FlashAttentionLoader,
FlashAttentionWithCustomMaskLoader,
FlashAttentionWithPaddingMaskLoader,
KernelLoader,
)
__all__ = [
"AttnMaskType",
"ColoAttention",
]
class AttnMaskType(Enum):
CUSTOM = 0
PADDED = 1
CAUSAL = 2
PADDED_CAUSAL = 3
def invert_mask(mask: torch.Tensor) -> torch.Tensor:
"""Invert the mask tensor.
Args:
mask (torch.Tensor): Mask tensor. Shape should be [B, 1, Sq, Skv]
Returns:
torch.Tensor: Inverted mask tensor.
"""
inverted_mask = 1.0 - mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(mask.dtype).min)
# adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor]:
"""Get padding information from padding mask.
Args:
padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, S]
Returns:
Tuple[int, torch.Tensor, torch.Tensor]: Tuple of (max_seq_len, cu_seqlens, indices)
"""
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return max_seqlen_in_batch, cu_seqlens, indices
class ColoAttention:
_kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
@staticmethod
def _init_kernels_dispatch():
if ColoAttention._kernel_dispatch_map is None:
# fp16/bf16
half_dispatch_map = {
None: FlashAttentionLoader(),
AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(),
AttnMaskType.PADDED: FlashAttentionWithPaddingMaskLoader(),
AttnMaskType.CAUSAL: FlashAttentionLoader(),
AttnMaskType.PADDED_CAUSAL: FlashAttentionWithPaddingMaskLoader(),
}
# fp32
float_dispatch_map = {
None: FlashAttentionForFloatAndCustomMaskLoader(),
AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(),
AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
}
ColoAttention._kernel_dispatch_map = {
torch.float16: half_dispatch_map,
torch.bfloat16: half_dispatch_map,
torch.float32: float_dispatch_map,
}
@staticmethod
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> Callable:
ColoAttention._init_kernels_dispatch()
if (
dtype not in ColoAttention._kernel_dispatch_map
or mask_type not in ColoAttention._kernel_dispatch_map[dtype]
):
raise ValueError(
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
)
# lazy load
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
mask_type
].load()
return ColoAttention._kernel_dispatch_map[dtype][mask_type]
@staticmethod
def prepare_attn_kwargs(
shape_4d: Tuple[int],
dtype: torch.dtype,
device: torch.device,
q_padding_mask: Optional[torch.Tensor] = None,
kv_padding_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
) -> Dict[str, torch.Tensor]:
"""Return a dictionary of keyword arguments for attention function. It supports 4 mask type.
1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves.
2. padded mask: recv padding mask and is_causal=False, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}.
3. causal mask: no padding mask and is_causal=True, return {attention_mask, attention_mask_type}.
4. padded causal mask: recv padding mask and is_causal=True, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}.
Args:
shape_4d (Tuple[int]): Should be (B, 1, Sq, Skv)
dtype (torch.dtype): Dtype of attention mask, generally should be ``hidden_states.dtype``
device (torch.device): Device of attention mask, generally should be ``hidden_states.device``
q_padding_mask (Optional[torch.Tensor], optional): Padding mask of query. It should be a long tensor or int tensor.
The shape should be [B, Sq]. ``1`` means valid token, and ``0`` means padding token. Defaults to None.
kv_padding_mask (Optional[torch.Tensor], optional): Padding mask of key and value. It should be a long tensor or int tensor.
The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token.
If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None.
is_causal (bool, optional): Whether to use causal attention mask. Defaults to False.
Returns:
Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function.
"""
if q_padding_mask is None and not is_causal:
return {}
assert len(shape_4d) == 4 and shape_4d[1] == 1
b, _, s_q, s_kv = shape_4d
outputs = {}
if (q_padding_mask is None or q_padding_mask.bool().all()) and (
kv_padding_mask is None or kv_padding_mask.bool().all()
):
# no padding
assert is_causal
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
else:
if kv_padding_mask is None:
# self attention
kv_padding_mask = q_padding_mask
assert q_padding_mask.shape == (b, s_q) and kv_padding_mask.shape == (
b,
s_kv,
), f"q_padding_mask shape {q_padding_mask.shape} and kv_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
attention_mask = torch.einsum("bi,bj->bij", q_padding_mask, kv_padding_mask).to(dtype=dtype, device=device)
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
outputs.update(
{
"cu_seqlens_q": cu_seqlens_q,
"cu_seqlens_kv": cu_seqlens_kv,
"max_seqlen_q": max_seqlen_q,
"max_seqlen_kv": max_seqlen_kv,
"q_indices": q_indices,
"kv_indices": kv_indices,
}
)
if is_causal:
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
else:
outputs["attention_mask_type"] = AttnMaskType.PADDED
attention_mask = invert_mask(attention_mask).unsqueeze(1)
outputs["attention_mask"] = attention_mask
return outputs
@staticmethod
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attention_mask_type: AttnMaskType = AttnMaskType.CUSTOM,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
q_indices: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
) -> torch.Tensor:
"""Flash Attention function. It supports 4 mask type.
1. custom mask: recv attention_mask
2. padded mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
3. causal mask: recv attention_mask, attention_mask_type
4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
Args:
q (torch.Tensor): Query tensor. Shape should be [B, N, Sq, D]
k (torch.Tensor): Key tensor. Shape should be [B, N, Skv, D]
v (torch.Tensor): Value tensor. Shape should be [B, N, Skv, D]
attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None.
attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM.
cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths
of the sequences in the batch, used to index into q.
Shape should be [B+1]. Defaults to None.
cu_seqlens_kv (Optional[torch.Tensor], optional): The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
Shape should be [B+1]. Defaults to None.
max_seqlen_q (Optional[int], optional): Maximum query sequence length in the batch. Defaults to None.
max_seqlen_kv (Optional[int], optional): Maximum key/value sequence length in the batch. Defaults to None.
indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from the flattened input sequence.
Shape should be [NUM_TOKENS]. Defaults to None.
dropout_p (float, optional): Dropout probability. Defaults to 0.0.
scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None.
Returns:
torch.Tensor: Output tensor. Shape should be [B, N, Sq, D]
"""
# known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan
# this case is usaul when padding mask is used and self attention is performed
# thus, we don't use sdpa when padding mask is used
# sanity check
if attention_mask is not None:
assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor."
if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL):
assert (
cu_seqlens_q is None
and cu_seqlens_kv is None
and max_seqlen_q is None
and max_seqlen_kv is None
and q_indices is None
and kv_indices is None
)
if attention_mask_type == AttnMaskType.CUSTOM:
assert not torch.all(attention_mask != 0, dim=-1).any()
elif attention_mask_type in (
AttnMaskType.PADDED,
AttnMaskType.PADDED_CAUSAL,
):
assert (
cu_seqlens_q is not None
and cu_seqlens_kv is not None
and max_seqlen_q is not None
and max_seqlen_kv is not None
and q_indices is not None
and kv_indices is not None
)
else:
# if attention_mask is None, attention_mask_type should be the default value
assert attention_mask_type == AttnMaskType.CUSTOM
# kernel dispatch
mask_type = attention_mask_type if attention_mask is not None else None
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type)
is_causal = attention_mask is not None and attention_mask_type in (
AttnMaskType.CAUSAL,
AttnMaskType.PADDED_CAUSAL,
)
return attn_func(
q,
k,
v,
dropout_p=dropout_p,
scale=scale,
attention_mask=attention_mask,
is_causal=is_causal,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
q_indices=q_indices,
kv_indices=kv_indices,
)
...@@ -3,6 +3,8 @@ from typing import Optional, Tuple ...@@ -3,6 +3,8 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.shardformer.layer import ColoAttention
def forward_fn(): def forward_fn():
def forward( def forward(
...@@ -62,8 +64,6 @@ def forward_fn(): ...@@ -62,8 +64,6 @@ def forward_fn():
def get_blip2_flash_attention_forward(): def get_blip2_flash_attention_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
from colossalai.nn.layer.colo_attention import ColoAttention
def forward( def forward(
self: Blip2Attention, self: Blip2Attention,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -71,16 +71,25 @@ def get_blip2_flash_attention_forward(): ...@@ -71,16 +71,25 @@ def get_blip2_flash_attention_forward():
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
assert head_mask is None, "head_mask is not supported in FlashAttention"
bsz, tgt_len, embed_dim = hidden_states.size() bsz, tgt_len, embed_dim = hidden_states.size()
mixed_qkv = self.qkv(hidden_states) mixed_qkv = self.qkv(hidden_states)
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] query_states, key_states, value_states = (
mixed_qkv[0],
mixed_qkv[1],
mixed_qkv[2],
)
attention = ColoAttention( dropout_p = self.dropout.p if self.training else 0.0
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout.p, scale=self.scale context_layer = ColoAttention.attention(
query_states,
key_states,
value_states,
dropout_p=dropout_p,
scale=self.scale,
) )
context_layer = attention(query_states, key_states, value_states) context_layer = context_layer.permute(0, 2, 1, 3).reshape(bsz, tgt_len, self.embed_dim)
output = self.projection(context_layer) output = self.projection(context_layer)
outputs = (output, None) outputs = (output, None)
...@@ -93,7 +102,11 @@ def get_blip2_flash_attention_forward(): ...@@ -93,7 +102,11 @@ def get_blip2_flash_attention_forward():
def get_jit_fused_blip2_QFormer_self_output_forward(): def get_jit_fused_blip2_QFormer_self_output_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput
def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: def forward(
self: Blip2QFormerSelfOutput,
hidden_states: torch.Tensor,
input_tensor: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
...@@ -105,7 +118,11 @@ def get_jit_fused_blip2_QFormer_self_output_forward(): ...@@ -105,7 +118,11 @@ def get_jit_fused_blip2_QFormer_self_output_forward():
def get_jit_fused_blip2_QFormer_output_forward(): def get_jit_fused_blip2_QFormer_output_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput
def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: def forward(
self: Blip2QFormerOutput,
hidden_states: torch.Tensor,
input_tensor: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
......
""" PyTorch ChatGLM model. """ """ PyTorch ChatGLM model. """
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
...@@ -9,63 +10,49 @@ from transformers.utils import logging ...@@ -9,63 +10,49 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
def get_flash_core_attention_forward(): def get_flash_core_attention_forward():
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
from .chatglm2_6b.modeling_chatglm import CoreAttention from .chatglm2_6b.modeling_chatglm import CoreAttention
def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask): def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask):
pytorch_major_version = int(torch.__version__.split(".")[0]) query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
if pytorch_major_version >= 2: if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] attention_mask_type = AttnMaskType.CAUSAL
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: attn_bias = torch.zeros(
context_layer = torch.nn.functional.scaled_dot_product_attention( query_layer.shape[0],
query_layer, key_layer, value_layer, is_causal=True 1,
) query_layer.shape[2],
else: key_layer.shape[2],
if attention_mask is not None: dtype=query_layer.dtype,
attention_mask = ~attention_mask device=query_layer.device,
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer, attention_mask
)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
else:
# Raw attention scores
query_layer = query_layer.permute(1, 0, 2, 3).contiguous()
key_layer = key_layer.permute(1, 0, 2, 3).contiguous()
value_layer = value_layer.permute(1, 0, 2, 3).contiguous()
scale = 1.0 / self.norm_factor
if self.coeff is not None:
scale = scale * self.coeff
flash_attention_mask = None
attn_mask_type = None
if attention_mask is None:
attn_mask_type = AttnMaskType.causal
else:
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
if not torch.all(flash_attention_mask):
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(
embed_dim=self.hidden_size_per_partition,
num_heads=self.num_attention_heads_per_partition,
dropout=self.attention_dropout.p,
scale=scale,
) )
context_layer = attention( temp_mask = (
query_layer, key_layer, value_layer, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device)
.tril(diagonal=0)
.expand(query_layer.shape[0], 1, -1, -1)
) )
attn_bias.masked_fill_(temp_mask.logical_not(), torch.finfo(query_layer.dtype).min)
context_layer = context_layer.permute(1, 0, -1).contiguous() else:
attention_mask_type = AttnMaskType.CUSTOM
if attention_mask is not None:
attn_bias = torch.zeros_like(attention_mask, dtype=query_layer.dtype)
attn_bias.masked_fill_(attention_mask, torch.finfo(query_layer.dtype).min)
dropout_p = self.attention_dropout.p if self.training else 0.0
context_layer = ColoAttention.attention(
query_layer,
key_layer,
value_layer,
attention_mask=attn_bias,
attention_mask_type=attention_mask_type,
dropout_p=dropout_p,
)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
return context_layer return context_layer
return forward return forward
...@@ -169,11 +156,17 @@ class ChatGLMPipelineForwards: ...@@ -169,11 +156,17 @@ class ChatGLMPipelineForwards:
if self.pre_seq_len is not None: if self.pre_seq_len is not None:
if past_key_values is None: if past_key_values is None:
past_key_values = self.get_prompt( past_key_values = self.get_prompt(
batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype batch_size=batch_size,
device=input_ids.device,
dtype=inputs_embeds.dtype,
) )
if attention_mask is not None: if attention_mask is not None:
attention_mask = torch.cat( attention_mask = torch.cat(
[attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1 [
attention_mask.new_ones((batch_size, self.pre_seq_len)),
attention_mask,
],
dim=-1,
) )
if full_attention_mask is None: if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
...@@ -200,7 +193,9 @@ class ChatGLMPipelineForwards: ...@@ -200,7 +193,9 @@ class ChatGLMPipelineForwards:
if shard_config.enable_sequence_parallelism: if shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
) )
for idx in range(start_idx, end_idx): for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx) layer = self.encoder._get_layer(idx)
...@@ -208,7 +203,12 @@ class ChatGLMPipelineForwards: ...@@ -208,7 +203,12 @@ class ChatGLMPipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.encoder.gradient_checkpointing and self.encoder.training: if self.encoder.gradient_checkpointing and self.encoder.training:
layer_ret = torch.utils.checkpoint.checkpoint( layer_ret = torch.utils.checkpoint.checkpoint(
layer, hidden_states, attention_mask, rotary_pos_emb, past_key_values[idx], use_cache layer,
hidden_states,
attention_mask,
rotary_pos_emb,
past_key_values[idx],
use_cache,
) )
else: else:
layer_ret = layer( layer_ret = layer(
...@@ -224,7 +224,9 @@ class ChatGLMPipelineForwards: ...@@ -224,7 +224,9 @@ class ChatGLMPipelineForwards:
if shard_config.enable_sequence_parallelism: if shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward( hidden_states = gather_forward_split_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
) )
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
...@@ -234,7 +236,14 @@ class ChatGLMPipelineForwards: ...@@ -234,7 +236,14 @@ class ChatGLMPipelineForwards:
hidden_states = self.encoder.final_layernorm(hidden_states) hidden_states = self.encoder.final_layernorm(hidden_states)
if not return_dict: if not return_dict:
return tuple( return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None v
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
]
if v is not None
) )
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
...@@ -368,7 +377,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -368,7 +377,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
# Run encoder. # Run encoder.
# [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size] # [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size]
inputs_embeds = split_forward_gather_backward( inputs_embeds = split_forward_gather_backward(
inputs_embeds, dim=0, process_group=shard_config.tensor_parallel_process_group inputs_embeds,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
) )
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds, inputs_embeds,
...@@ -380,7 +391,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -380,7 +391,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
) )
hidden_states = gather_forward_split_backward( hidden_states = gather_forward_split_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
) )
if not return_dict: if not return_dict:
......
This diff is collapsed.
This diff is collapsed.
...@@ -15,7 +15,9 @@ from transformers.utils import logging ...@@ -15,7 +15,9 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
from ..layer import ColoAttention, cross_entropy_1d
try: try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
...@@ -105,18 +107,25 @@ class LlamaPipelineForwards: ...@@ -105,18 +107,25 @@ class LlamaPipelineForwards:
# embed positions, for the first stage, hidden_states is the input embeddings, # embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage # for the other stages, hidden_states is the output of the previous stage
if attention_mask is None: if shard_config.enable_flash_attention:
attention_mask = torch.ones( # in this case, attention_mask is a dict rather than a tensor
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
) attention_mask = ColoAttention.prepare_attn_kwargs(
if LATEST_VERSION: mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
) )
else: else:
attention_mask = self._prepare_decoder_attention_mask( if attention_mask is None:
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length attention_mask = torch.ones(
) (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
if LATEST_VERSION:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
...@@ -262,6 +271,7 @@ class LlamaPipelineForwards: ...@@ -262,6 +271,7 @@ class LlamaPipelineForwards:
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config,
) )
past_key_values = None past_key_values = None
...@@ -352,6 +362,7 @@ class LlamaPipelineForwards: ...@@ -352,6 +362,7 @@ class LlamaPipelineForwards:
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config,
) )
if input_ids is not None: if input_ids is not None:
...@@ -420,8 +431,6 @@ class LlamaPipelineForwards: ...@@ -420,8 +431,6 @@ class LlamaPipelineForwards:
def get_llama_flash_attention_forward(shard_config: ShardConfig): def get_llama_flash_attention_forward(shard_config: ShardConfig):
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
llama_version = 2 llama_version = 2
try: try:
from transformers.models.llama.modeling_llama import repeat_kv from transformers.models.llama.modeling_llama import repeat_kv
...@@ -432,7 +441,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): ...@@ -432,7 +441,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
def forward( def forward(
self: LlamaAttention, self: LlamaAttention,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[dict] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
...@@ -466,31 +475,10 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): ...@@ -466,31 +475,10 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) attn_output = attn_output.transpose(1, 2).contiguous()
value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
flash_attention_mask = None
attn_mask_type = AttnMaskType.causal
if not getattr(shard_config, "causal_lm", False) and attention_mask != None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal
if not hasattr(self, "attention"):
self.attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = self.attention(
query_states,
key_states,
value_states,
attn_mask=flash_attention_mask,
attn_mask_type=attn_mask_type,
origin_attn_mask=attention_mask,
)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
...@@ -499,6 +487,137 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): ...@@ -499,6 +487,137 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
return forward return forward
def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
logger = logging.get_logger(__name__)
assert shard_config.enable_flash_attention, "Flash Attention is not enabled."
def forward(
self: LlamaModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
hidden_states = inputs_embeds
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return forward
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
......
This diff is collapsed.
import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
...@@ -6,6 +5,7 @@ from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder ...@@ -6,6 +5,7 @@ from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention
def _encoder_forward( def _encoder_forward(
...@@ -98,7 +98,9 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: ...@@ -98,7 +98,9 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index:
pixel_values = pixel_values.to(expected_dtype) pixel_values = pixel_values.to(expected_dtype)
embedding_output = self.embeddings( embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding pixel_values,
bool_masked_pos=bool_masked_pos,
interpolate_pos_encoding=interpolate_pos_encoding,
) )
hidden_states = embedding_output hidden_states = embedding_output
else: else:
...@@ -336,34 +338,27 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag ...@@ -336,34 +338,27 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag
def get_vit_flash_self_attention_forward(): def get_vit_flash_self_attention_forward():
from transformers.models.vit.modeling_vit import ViTSelfAttention from transformers.models.vit.modeling_vit import ViTSelfAttention
from colossalai.nn.layer.colo_attention import ColoAttention
def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
x = x.view(new_x_shape)
return x
def forward( def forward(
self: ViTSelfAttention, self: ViTSelfAttention,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
assert head_mask is None, "head_mask is not supported for FlashAttention"
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
key_layer = transpose_for_scores(self.key(hidden_states), self.num_attention_heads, self.attention_head_size) key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = transpose_for_scores( value_layer = self.transpose_for_scores(self.value(hidden_states))
self.value(hidden_states), self.num_attention_heads, self.attention_head_size query_layer = self.transpose_for_scores(mixed_query_layer)
)
query_layer = transpose_for_scores(mixed_query_layer, self.num_attention_heads, self.attention_head_size)
scale = 1.0 / math.sqrt(self.attention_head_size) dropout_p = self.dropout.p if self.training else 0.0
attention = ColoAttention( context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, dropout_p=dropout_p)
embed_dim=self.all_head_size, num_heads=self.num_attention_heads, dropout=self.dropout.p, scale=scale
) context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
context_layer = attention(query_layer, key_layer, value_layer) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer,) outputs = (context_layer, None) if output_attentions else (context_layer,)
return outputs return outputs
......
This diff is collapsed.
...@@ -8,6 +8,7 @@ import colossalai.shardformer.layer as col_nn ...@@ -8,6 +8,7 @@ import colossalai.shardformer.layer as col_nn
from ..modeling.gpt2 import ( from ..modeling.gpt2 import (
GPT2PipelineForwards, GPT2PipelineForwards,
get_gpt2_flash_attention_forward, get_gpt2_flash_attention_forward,
get_gpt_model_forward_for_flash_attn,
get_lm_forward_with_dist_cross_entropy, get_lm_forward_with_dist_cross_entropy,
gpt2_sequence_parallel_forward_fn, gpt2_sequence_parallel_forward_fn,
) )
...@@ -75,7 +76,11 @@ class GPT2Policy(Policy): ...@@ -75,7 +76,11 @@ class GPT2Policy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.c_attn", suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col, target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, kwargs={
"n_fused": 3,
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.c_proj", suffix="attn.c_proj",
...@@ -87,7 +92,11 @@ class GPT2Policy(Policy): ...@@ -87,7 +92,11 @@ class GPT2Policy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.c_fc", suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col, target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={"n_fused": 1, "seq_parallel": use_sequence_parallel, "overlap": overlap}, kwargs={
"n_fused": 1,
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.c_proj", suffix="mlp.c_proj",
...@@ -150,6 +159,10 @@ class GPT2Policy(Policy): ...@@ -150,6 +159,10 @@ class GPT2Policy(Policy):
policy=policy, policy=policy,
target_key=GPT2Attention, target_key=GPT2Attention,
) )
if not self.shard_config.pipeline_stage_manager:
policy[GPT2Model].method_replacement = {
"forward": get_gpt_model_forward_for_flash_attn(self.shard_config)
}
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
...@@ -223,14 +236,21 @@ class GPT2Policy(Policy): ...@@ -223,14 +236,21 @@ class GPT2Policy(Policy):
num_stages=stage_manager.num_stages, num_stages=stage_manager.num_stages,
) )
method_replacement = { method_replacement = {
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) "forward": partial(
new_forward,
stage_manager=stage_manager,
shard_config=self.shard_config,
)
} }
else: else:
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config,
) )
} }
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
...@@ -245,7 +265,9 @@ class GPT2ModelPolicy(GPT2Policy): ...@@ -245,7 +265,9 @@ class GPT2ModelPolicy(GPT2Policy):
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=GPT2Model, new_forward=GPT2PipelineForwards.gpt2_model_forward, policy=policy model_cls=GPT2Model,
new_forward=GPT2PipelineForwards.gpt2_model_forward,
policy=policy,
) )
return policy return policy
...@@ -299,7 +321,12 @@ class GPT2LMHeadModelPolicy(GPT2Policy): ...@@ -299,7 +321,12 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
if stage_manager is not None: if stage_manager is not None:
if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
first_stage, last_stage = 0, stage_manager.num_stages - 1 first_stage, last_stage = 0, stage_manager.num_stages - 1
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] return [
{
first_stage: module.transformer.wte.weight,
last_stage: module.lm_head.weight,
}
]
return [] return []
...@@ -315,7 +342,9 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): ...@@ -315,7 +342,9 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
GPT2DoubleHeadsModel: ModulePolicyDescription( GPT2DoubleHeadsModel: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True},
) )
] ]
) )
...@@ -350,7 +379,12 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): ...@@ -350,7 +379,12 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
if stage_manager is not None: if stage_manager is not None:
if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
first_stage, last_stage = 0, stage_manager.num_stages - 1 first_stage, last_stage = 0, stage_manager.num_stages - 1
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] return [
{
first_stage: module.transformer.wte.weight,
last_stage: module.lm_head.weight,
}
]
return [] return []
...@@ -392,7 +426,10 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy): ...@@ -392,7 +426,10 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy):
addon_module = { addon_module = {
GPT2ForTokenClassification: ModulePolicyDescription( GPT2ForTokenClassification: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription(suffix="dropout", target_module=col_nn.DropoutForParallelInput) SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
] ]
) )
} }
......
...@@ -6,7 +6,11 @@ from torch import Tensor, nn ...@@ -6,7 +6,11 @@ from torch import Tensor, nn
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from ..modeling.gptj import GPTJPipelineForwards, get_gptj_flash_attention_forward from ..modeling.gptj import (
GPTJPipelineForwards,
get_gptj_flash_attention_forward,
gptj_model_forward_for_flash_attention,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [ __all__ = [
...@@ -71,17 +75,26 @@ class GPTJPolicy(Policy): ...@@ -71,17 +75,26 @@ class GPTJPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.k_proj", suffix="attn.k_proj",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.q_proj", suffix="attn.q_proj",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.v_proj", suffix="attn.v_proj",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.out_proj", suffix="attn.out_proj",
...@@ -143,6 +156,12 @@ class GPTJPolicy(Policy): ...@@ -143,6 +156,12 @@ class GPTJPolicy(Policy):
policy=policy, policy=policy,
target_key=GPTJAttention, target_key=GPTJAttention,
) )
if not self.shard_config.pipeline_stage_manager:
self.append_or_create_method_replacement(
description={"forward": gptj_model_forward_for_flash_attention(self.shard_config)},
policy=policy,
target_key=GPTJModel,
)
return policy return policy
...@@ -185,7 +204,10 @@ class GPTJPolicy(Policy): ...@@ -185,7 +204,10 @@ class GPTJPolicy(Policy):
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config,
) )
} }
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
...@@ -203,7 +225,9 @@ class GPTJModelPolicy(GPTJPolicy): ...@@ -203,7 +225,9 @@ class GPTJModelPolicy(GPTJPolicy):
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=GPTJModel, new_forward=GPTJPipelineForwards.gptj_model_forward, policy=policy model_cls=GPTJModel,
new_forward=GPTJPipelineForwards.gptj_model_forward,
policy=policy,
) )
return policy return policy
...@@ -230,7 +254,9 @@ class GPTJForCausalLMPolicy(GPTJPolicy): ...@@ -230,7 +254,9 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
GPTJForCausalLM: ModulePolicyDescription( GPTJForCausalLM: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True},
) )
] ]
) )
...@@ -239,7 +265,9 @@ class GPTJForCausalLMPolicy(GPTJPolicy): ...@@ -239,7 +265,9 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=GPTJForCausalLM, new_forward=GPTJPipelineForwards.gptj_causallm_model_forward, policy=policy model_cls=GPTJForCausalLM,
new_forward=GPTJPipelineForwards.gptj_causallm_model_forward,
policy=policy,
) )
return policy return policy
...@@ -256,7 +284,12 @@ class GPTJForCausalLMPolicy(GPTJPolicy): ...@@ -256,7 +284,12 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
if stage_manager is not None: if stage_manager is not None:
if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
first_stage, last_stage = 0, stage_manager.num_stages - 1 first_stage, last_stage = 0, stage_manager.num_stages - 1
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] return [
{
first_stage: module.transformer.wte.weight,
last_stage: module.lm_head.weight,
}
]
return [] return []
......
...@@ -11,6 +11,7 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Ro ...@@ -11,6 +11,7 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Ro
from ..modeling.llama import ( from ..modeling.llama import (
LlamaPipelineForwards, LlamaPipelineForwards,
get_llama_flash_attention_forward, get_llama_flash_attention_forward,
get_llama_model_forward_for_flash_attn,
get_lm_forward_with_dist_cross_entropy, get_lm_forward_with_dist_cross_entropy,
) )
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
...@@ -135,6 +136,15 @@ class LlamaPolicy(Policy): ...@@ -135,6 +136,15 @@ class LlamaPolicy(Policy):
policy=policy, policy=policy,
target_key=LlamaAttention, target_key=LlamaAttention,
) )
if self.pipeline_stage_manager is None:
# replace llama model forward method
self.append_or_create_method_replacement(
description={
"forward": get_llama_model_forward_for_flash_attn(self.shard_config),
},
policy=policy,
target_key=LlamaModel,
)
return policy return policy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment