Unverified Commit 19e1a5cf authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[shardformer] update colo attention to support custom mask (#5510)

* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests
parent 9a3321e9
...@@ -117,7 +117,7 @@ jobs: ...@@ -117,7 +117,7 @@ jobs:
cd TensorNVMe cd TensorNVMe
conda install cmake conda install cmake
pip install -r requirements.txt pip install -r requirements.txt
pip install -v . DISABLE_URING=1 pip install -v .
- name: Store TensorNVMe Cache - name: Store TensorNVMe Cache
run: | run: |
...@@ -201,4 +201,4 @@ jobs: ...@@ -201,4 +201,4 @@ jobs:
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
with: with:
name: report name: report
path: report/ path: report/
\ No newline at end of file
...@@ -44,7 +44,7 @@ jobs: ...@@ -44,7 +44,7 @@ jobs:
cd TensorNVMe cd TensorNVMe
conda install cmake conda install cmake
pip install -r requirements.txt pip install -r requirements.txt
pip install -v . DISABLE_URING=1 pip install -v .
- uses: actions/checkout@v2 - uses: actions/checkout@v2
if: steps.check-avai.outputs.avai == 'true' if: steps.check-avai.outputs.avai == 'true'
......
...@@ -66,7 +66,7 @@ jobs: ...@@ -66,7 +66,7 @@ jobs:
cd TensorNVMe cd TensorNVMe
apt update && apt install -y cmake apt update && apt install -y cmake
pip install -r requirements.txt pip install -r requirements.txt
pip install -v . DISABLE_URING=1 pip install -v .
- uses: actions/checkout@v2 - uses: actions/checkout@v2
with: with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
......
...@@ -60,7 +60,7 @@ jobs: ...@@ -60,7 +60,7 @@ jobs:
cd TensorNVMe cd TensorNVMe
apt update && apt install -y cmake apt update && apt install -y cmake
pip install -r requirements.txt pip install -r requirements.txt
pip install -v . DISABLE_URING=1 pip install -v .
- uses: actions/checkout@v2 - uses: actions/checkout@v2
with: with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
......
...@@ -56,7 +56,7 @@ jobs: ...@@ -56,7 +56,7 @@ jobs:
cd TensorNVMe cd TensorNVMe
apt update && apt install -y cmake apt update && apt install -y cmake
pip install -r requirements.txt pip install -r requirements.txt
pip install -v . DISABLE_URING=1 pip install -v .
- uses: actions/checkout@v2 - uses: actions/checkout@v2
with: with:
ssh-key: ${{ secrets.SSH_KEY_FOR_CI }} ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
......
...@@ -6,7 +6,7 @@ from .extensions import ( ...@@ -6,7 +6,7 @@ from .extensions import (
CpuAdamX86Extension, CpuAdamX86Extension,
FlashAttentionDaoCudaExtension, FlashAttentionDaoCudaExtension,
FlashAttentionNpuExtension, FlashAttentionNpuExtension,
FlashAttentionXformersCudaExtension, FlashAttentionSdpaCudaExtension,
FusedOptimizerCudaExtension, FusedOptimizerCudaExtension,
LayerNormCudaExtension, LayerNormCudaExtension,
MoeCudaExtension, MoeCudaExtension,
...@@ -65,9 +65,9 @@ class KernelLoader: ...@@ -65,9 +65,9 @@ class KernelLoader:
else: else:
usable_exts = [] usable_exts = []
for ext in exts: for ext in exts:
if ext.is_hardware_available(): if ext.is_available():
# make sure the machine is compatible during kernel loading # make sure the machine is compatible during kernel loading
ext.assert_hardware_compatible() ext.assert_compatible()
usable_exts.append(ext) usable_exts.append(ext)
assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine." assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine."
...@@ -106,4 +106,20 @@ class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader): ...@@ -106,4 +106,20 @@ class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader):
class FlashAttentionLoader(KernelLoader): class FlashAttentionLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension] REGISTRY = [
FlashAttentionNpuExtension,
FlashAttentionDaoCudaExtension,
FlashAttentionSdpaCudaExtension,
]
class FlashAttentionWithPaddingMaskLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension]
class FlashAttentionWithCustomMaskLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
class FlashAttentionForFloatAndCustomMaskLoader(KernelLoader):
REGISTRY = [FlashAttentionSdpaCudaExtension]
import enum
import math
import warnings
from dataclasses import dataclass
from typing import Iterable, Optional, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from colossalai.accelerator import get_accelerator
from colossalai.kernel.kernel_loader import FlashAttentionLoader
@dataclass
class SeqLenInfo:
seqlens: Iterable[int] = None
indices: torch.Tensor = None
max_seqlen: int = None
cu_seqlens: torch.Tensor = None
@staticmethod
def materialize(
attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device()
):
if attn_mask is not None:
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
else:
batch_size, tgt_len = size[0], size[1]
indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
max_seqlen = max(seqlens)
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3
class Unpad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
ctx.save_for_backward(indices)
# [b, s, ...]
assert tensor.ndim >= 3
ctx.bsz = tensor.shape[0]
out = rearrange(tensor, "b s ... -> (b s) ...")
ctx.shape = out.shape
# [ntokens, ...]
return out[indices]
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [ntokens, ...]
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
grad[indices] = grad_output
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
# [b, s, ...]
return grad, None
class Repad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
ctx.save_for_backward(indices)
# [ntokens, ...]
tensor = tensor
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
# [b*s, ...]
out[indices] = tensor
return out
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [b*s, ...]
grad = grad_output[indices]
# [ntokens, ...]
return grad, None, None, None
class ColoAttention(torch.nn.Module):
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
super().__init__()
assert (
embed_dim % num_heads == 0
), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
if scale is not None:
self.scale = scale
else:
self.scale = 1 / math.sqrt(embed_dim // num_heads)
self.dropout = dropout
self.attn = FlashAttentionLoader().load()
@staticmethod
def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
return Unpad.apply(tensor, indices)
@staticmethod
def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
return Repad.apply(tensor, indices, batch_size, seq_len)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
origin_attn_mask: Optional[torch.Tensor] = None,
attn_mask_type: Optional[AttnMaskType] = None,
bias: Optional[torch.Tensor] = None,
):
"""
ColoAttention
Args:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
origin_attn_mask: (nheads, q_seqlen, kv_seqlen)
bias: will not be used
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
# if flash attention is not applicable, switch to memory effcient attention
if self.attn.__name__ == "flash_attention" and (
query.dtype not in [torch.float16, torch.bfloat16] or bias != None
):
warnings.warn(
f"flash-attn expects fp16 or bf16 but got {query.dtype}, switching to xformers' implementation."
)
self.attn = FlashAttentionLoader().load(ext_name="flash_attention_xformers_cuda")
padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
causal = attn_mask_type is not None and attn_mask_type.value > 1
batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
# unpad
seq_len_info_q = None
seq_len_info_kv = None
if padded:
# bert style, unpad process
assert (
attn_mask is not None
), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
assert attn_mask.dim() == 2, (
"attention mask is supposed to have shape (batch_size, seq_len), "
+ f"but got {attn_mask.dim()} dimensions."
)
# bert style
if tgt_len == src_len:
seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
if batch_size > 1:
query, key, value = self.unpad(
torch.stack([query, key, value], dim=2), seq_len_info_q.indices
).unbind(dim=1)
else:
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
seq_len_info_kv = seq_len_info_q
else:
seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device)
seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
if batch_size > 1:
query = rearrange(query, "b s ... -> c (b s) ...", c=1)
key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind(
dim=1
)
else:
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
out = self.attn(
query,
key,
value,
seq_len_info_q=seq_len_info_q,
seq_len_info_kv=seq_len_info_kv,
origin_attn_mask=origin_attn_mask,
dropout_p=self.dropout,
scale=self.scale,
causal=causal,
padded=padded,
)
# repad
if padded:
if batch_size > 1:
out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len)
out = rearrange(out, "(b s) h d -> b s h d", b=batch_size)
if len(out.shape) == 4:
out = rearrange(out, "b s h d -> b s (h d)")
return out
from .attn import AttnMaskType, ColoAttention
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row from .linear import Linear1D_Col, Linear1D_Row
...@@ -23,4 +24,6 @@ __all__ = [ ...@@ -23,4 +24,6 @@ __all__ = [
"FusedRMSNorm", "FusedRMSNorm",
"FusedLinear1D_Col", "FusedLinear1D_Col",
"ParallelModule", "ParallelModule",
"AttnMaskType",
"ColoAttention",
] ]
from enum import Enum
from typing import Callable, Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from colossalai.kernel.kernel_loader import (
FlashAttentionForFloatAndCustomMaskLoader,
FlashAttentionLoader,
FlashAttentionWithCustomMaskLoader,
FlashAttentionWithPaddingMaskLoader,
KernelLoader,
)
__all__ = [
"AttnMaskType",
"ColoAttention",
]
class AttnMaskType(Enum):
CUSTOM = 0
PADDED = 1
CAUSAL = 2
PADDED_CAUSAL = 3
def invert_mask(mask: torch.Tensor) -> torch.Tensor:
"""Invert the mask tensor.
Args:
mask (torch.Tensor): Mask tensor. Shape should be [B, 1, Sq, Skv]
Returns:
torch.Tensor: Inverted mask tensor.
"""
inverted_mask = 1.0 - mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(mask.dtype).min)
# adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor]:
"""Get padding information from padding mask.
Args:
padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, S]
Returns:
Tuple[int, torch.Tensor, torch.Tensor]: Tuple of (max_seq_len, cu_seqlens, indices)
"""
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return max_seqlen_in_batch, cu_seqlens, indices
class ColoAttention:
_kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
@staticmethod
def _init_kernels_dispatch():
if ColoAttention._kernel_dispatch_map is None:
# fp16/bf16
half_dispatch_map = {
None: FlashAttentionLoader(),
AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(),
AttnMaskType.PADDED: FlashAttentionWithPaddingMaskLoader(),
AttnMaskType.CAUSAL: FlashAttentionLoader(),
AttnMaskType.PADDED_CAUSAL: FlashAttentionWithPaddingMaskLoader(),
}
# fp32
float_dispatch_map = {
None: FlashAttentionForFloatAndCustomMaskLoader(),
AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(),
AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
}
ColoAttention._kernel_dispatch_map = {
torch.float16: half_dispatch_map,
torch.bfloat16: half_dispatch_map,
torch.float32: float_dispatch_map,
}
@staticmethod
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> Callable:
ColoAttention._init_kernels_dispatch()
if (
dtype not in ColoAttention._kernel_dispatch_map
or mask_type not in ColoAttention._kernel_dispatch_map[dtype]
):
raise ValueError(
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
)
# lazy load
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
mask_type
].load()
return ColoAttention._kernel_dispatch_map[dtype][mask_type]
@staticmethod
def prepare_attn_kwargs(
shape_4d: Tuple[int],
dtype: torch.dtype,
device: torch.device,
q_padding_mask: Optional[torch.Tensor] = None,
kv_padding_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
) -> Dict[str, torch.Tensor]:
"""Return a dictionary of keyword arguments for attention function. It supports 4 mask type.
1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves.
2. padded mask: recv padding mask and is_causal=False, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}.
3. causal mask: no padding mask and is_causal=True, return {attention_mask, attention_mask_type}.
4. padded causal mask: recv padding mask and is_causal=True, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}.
Args:
shape_4d (Tuple[int]): Should be (B, 1, Sq, Skv)
dtype (torch.dtype): Dtype of attention mask, generally should be ``hidden_states.dtype``
device (torch.device): Device of attention mask, generally should be ``hidden_states.device``
q_padding_mask (Optional[torch.Tensor], optional): Padding mask of query. It should be a long tensor or int tensor.
The shape should be [B, Sq]. ``1`` means valid token, and ``0`` means padding token. Defaults to None.
kv_padding_mask (Optional[torch.Tensor], optional): Padding mask of key and value. It should be a long tensor or int tensor.
The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token.
If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None.
is_causal (bool, optional): Whether to use causal attention mask. Defaults to False.
Returns:
Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function.
"""
if q_padding_mask is None and not is_causal:
return {}
assert len(shape_4d) == 4 and shape_4d[1] == 1
b, _, s_q, s_kv = shape_4d
outputs = {}
if (q_padding_mask is None or q_padding_mask.bool().all()) and (
kv_padding_mask is None or kv_padding_mask.bool().all()
):
# no padding
assert is_causal
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
else:
if kv_padding_mask is None:
# self attention
kv_padding_mask = q_padding_mask
assert q_padding_mask.shape == (b, s_q) and kv_padding_mask.shape == (
b,
s_kv,
), f"q_padding_mask shape {q_padding_mask.shape} and kv_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
attention_mask = torch.einsum("bi,bj->bij", q_padding_mask, kv_padding_mask).to(dtype=dtype, device=device)
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
outputs.update(
{
"cu_seqlens_q": cu_seqlens_q,
"cu_seqlens_kv": cu_seqlens_kv,
"max_seqlen_q": max_seqlen_q,
"max_seqlen_kv": max_seqlen_kv,
"q_indices": q_indices,
"kv_indices": kv_indices,
}
)
if is_causal:
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
else:
outputs["attention_mask_type"] = AttnMaskType.PADDED
attention_mask = invert_mask(attention_mask).unsqueeze(1)
outputs["attention_mask"] = attention_mask
return outputs
@staticmethod
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attention_mask_type: AttnMaskType = AttnMaskType.CUSTOM,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
q_indices: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
) -> torch.Tensor:
"""Flash Attention function. It supports 4 mask type.
1. custom mask: recv attention_mask
2. padded mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
3. causal mask: recv attention_mask, attention_mask_type
4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
Args:
q (torch.Tensor): Query tensor. Shape should be [B, N, Sq, D]
k (torch.Tensor): Key tensor. Shape should be [B, N, Skv, D]
v (torch.Tensor): Value tensor. Shape should be [B, N, Skv, D]
attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None.
attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM.
cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths
of the sequences in the batch, used to index into q.
Shape should be [B+1]. Defaults to None.
cu_seqlens_kv (Optional[torch.Tensor], optional): The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
Shape should be [B+1]. Defaults to None.
max_seqlen_q (Optional[int], optional): Maximum query sequence length in the batch. Defaults to None.
max_seqlen_kv (Optional[int], optional): Maximum key/value sequence length in the batch. Defaults to None.
indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from the flattened input sequence.
Shape should be [NUM_TOKENS]. Defaults to None.
dropout_p (float, optional): Dropout probability. Defaults to 0.0.
scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None.
Returns:
torch.Tensor: Output tensor. Shape should be [B, N, Sq, D]
"""
# known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan
# this case is usaul when padding mask is used and self attention is performed
# thus, we don't use sdpa when padding mask is used
# sanity check
if attention_mask is not None:
assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor."
if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL):
assert (
cu_seqlens_q is None
and cu_seqlens_kv is None
and max_seqlen_q is None
and max_seqlen_kv is None
and q_indices is None
and kv_indices is None
)
if attention_mask_type == AttnMaskType.CUSTOM:
assert not torch.all(attention_mask != 0, dim=-1).any()
elif attention_mask_type in (
AttnMaskType.PADDED,
AttnMaskType.PADDED_CAUSAL,
):
assert (
cu_seqlens_q is not None
and cu_seqlens_kv is not None
and max_seqlen_q is not None
and max_seqlen_kv is not None
and q_indices is not None
and kv_indices is not None
)
else:
# if attention_mask is None, attention_mask_type should be the default value
assert attention_mask_type == AttnMaskType.CUSTOM
# kernel dispatch
mask_type = attention_mask_type if attention_mask is not None else None
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type)
is_causal = attention_mask is not None and attention_mask_type in (
AttnMaskType.CAUSAL,
AttnMaskType.PADDED_CAUSAL,
)
return attn_func(
q,
k,
v,
dropout_p=dropout_p,
scale=scale,
attention_mask=attention_mask,
is_causal=is_causal,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
q_indices=q_indices,
kv_indices=kv_indices,
)
...@@ -3,6 +3,8 @@ from typing import Optional, Tuple ...@@ -3,6 +3,8 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.shardformer.layer import ColoAttention
def forward_fn(): def forward_fn():
def forward( def forward(
...@@ -62,8 +64,6 @@ def forward_fn(): ...@@ -62,8 +64,6 @@ def forward_fn():
def get_blip2_flash_attention_forward(): def get_blip2_flash_attention_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
from colossalai.nn.layer.colo_attention import ColoAttention
def forward( def forward(
self: Blip2Attention, self: Blip2Attention,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -71,16 +71,25 @@ def get_blip2_flash_attention_forward(): ...@@ -71,16 +71,25 @@ def get_blip2_flash_attention_forward():
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
assert head_mask is None, "head_mask is not supported in FlashAttention"
bsz, tgt_len, embed_dim = hidden_states.size() bsz, tgt_len, embed_dim = hidden_states.size()
mixed_qkv = self.qkv(hidden_states) mixed_qkv = self.qkv(hidden_states)
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2] query_states, key_states, value_states = (
mixed_qkv[0],
mixed_qkv[1],
mixed_qkv[2],
)
attention = ColoAttention( dropout_p = self.dropout.p if self.training else 0.0
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout.p, scale=self.scale context_layer = ColoAttention.attention(
query_states,
key_states,
value_states,
dropout_p=dropout_p,
scale=self.scale,
) )
context_layer = attention(query_states, key_states, value_states) context_layer = context_layer.permute(0, 2, 1, 3).reshape(bsz, tgt_len, self.embed_dim)
output = self.projection(context_layer) output = self.projection(context_layer)
outputs = (output, None) outputs = (output, None)
...@@ -93,7 +102,11 @@ def get_blip2_flash_attention_forward(): ...@@ -93,7 +102,11 @@ def get_blip2_flash_attention_forward():
def get_jit_fused_blip2_QFormer_self_output_forward(): def get_jit_fused_blip2_QFormer_self_output_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput
def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: def forward(
self: Blip2QFormerSelfOutput,
hidden_states: torch.Tensor,
input_tensor: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
...@@ -105,7 +118,11 @@ def get_jit_fused_blip2_QFormer_self_output_forward(): ...@@ -105,7 +118,11 @@ def get_jit_fused_blip2_QFormer_self_output_forward():
def get_jit_fused_blip2_QFormer_output_forward(): def get_jit_fused_blip2_QFormer_output_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput
def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: def forward(
self: Blip2QFormerOutput,
hidden_states: torch.Tensor,
input_tensor: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training) hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states) hidden_states = self.LayerNorm(hidden_states)
......
""" PyTorch ChatGLM model. """ """ PyTorch ChatGLM model. """
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
...@@ -9,63 +10,49 @@ from transformers.utils import logging ...@@ -9,63 +10,49 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
def get_flash_core_attention_forward(): def get_flash_core_attention_forward():
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
from .chatglm2_6b.modeling_chatglm import CoreAttention from .chatglm2_6b.modeling_chatglm import CoreAttention
def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask): def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask):
pytorch_major_version = int(torch.__version__.split(".")[0]) query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
if pytorch_major_version >= 2: if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] attention_mask_type = AttnMaskType.CAUSAL
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: attn_bias = torch.zeros(
context_layer = torch.nn.functional.scaled_dot_product_attention( query_layer.shape[0],
query_layer, key_layer, value_layer, is_causal=True 1,
) query_layer.shape[2],
else: key_layer.shape[2],
if attention_mask is not None: dtype=query_layer.dtype,
attention_mask = ~attention_mask device=query_layer.device,
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer, attention_mask
)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
else:
# Raw attention scores
query_layer = query_layer.permute(1, 0, 2, 3).contiguous()
key_layer = key_layer.permute(1, 0, 2, 3).contiguous()
value_layer = value_layer.permute(1, 0, 2, 3).contiguous()
scale = 1.0 / self.norm_factor
if self.coeff is not None:
scale = scale * self.coeff
flash_attention_mask = None
attn_mask_type = None
if attention_mask is None:
attn_mask_type = AttnMaskType.causal
else:
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
if not torch.all(flash_attention_mask):
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(
embed_dim=self.hidden_size_per_partition,
num_heads=self.num_attention_heads_per_partition,
dropout=self.attention_dropout.p,
scale=scale,
) )
context_layer = attention( temp_mask = (
query_layer, key_layer, value_layer, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device)
.tril(diagonal=0)
.expand(query_layer.shape[0], 1, -1, -1)
) )
attn_bias.masked_fill_(temp_mask.logical_not(), torch.finfo(query_layer.dtype).min)
context_layer = context_layer.permute(1, 0, -1).contiguous() else:
attention_mask_type = AttnMaskType.CUSTOM
if attention_mask is not None:
attn_bias = torch.zeros_like(attention_mask, dtype=query_layer.dtype)
attn_bias.masked_fill_(attention_mask, torch.finfo(query_layer.dtype).min)
dropout_p = self.attention_dropout.p if self.training else 0.0
context_layer = ColoAttention.attention(
query_layer,
key_layer,
value_layer,
attention_mask=attn_bias,
attention_mask_type=attention_mask_type,
dropout_p=dropout_p,
)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
return context_layer return context_layer
return forward return forward
...@@ -169,11 +156,17 @@ class ChatGLMPipelineForwards: ...@@ -169,11 +156,17 @@ class ChatGLMPipelineForwards:
if self.pre_seq_len is not None: if self.pre_seq_len is not None:
if past_key_values is None: if past_key_values is None:
past_key_values = self.get_prompt( past_key_values = self.get_prompt(
batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype batch_size=batch_size,
device=input_ids.device,
dtype=inputs_embeds.dtype,
) )
if attention_mask is not None: if attention_mask is not None:
attention_mask = torch.cat( attention_mask = torch.cat(
[attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1 [
attention_mask.new_ones((batch_size, self.pre_seq_len)),
attention_mask,
],
dim=-1,
) )
if full_attention_mask is None: if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
...@@ -200,7 +193,9 @@ class ChatGLMPipelineForwards: ...@@ -200,7 +193,9 @@ class ChatGLMPipelineForwards:
if shard_config.enable_sequence_parallelism: if shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
) )
for idx in range(start_idx, end_idx): for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx) layer = self.encoder._get_layer(idx)
...@@ -208,7 +203,12 @@ class ChatGLMPipelineForwards: ...@@ -208,7 +203,12 @@ class ChatGLMPipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.encoder.gradient_checkpointing and self.encoder.training: if self.encoder.gradient_checkpointing and self.encoder.training:
layer_ret = torch.utils.checkpoint.checkpoint( layer_ret = torch.utils.checkpoint.checkpoint(
layer, hidden_states, attention_mask, rotary_pos_emb, past_key_values[idx], use_cache layer,
hidden_states,
attention_mask,
rotary_pos_emb,
past_key_values[idx],
use_cache,
) )
else: else:
layer_ret = layer( layer_ret = layer(
...@@ -224,7 +224,9 @@ class ChatGLMPipelineForwards: ...@@ -224,7 +224,9 @@ class ChatGLMPipelineForwards:
if shard_config.enable_sequence_parallelism: if shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward( hidden_states = gather_forward_split_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
) )
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
...@@ -234,7 +236,14 @@ class ChatGLMPipelineForwards: ...@@ -234,7 +236,14 @@ class ChatGLMPipelineForwards:
hidden_states = self.encoder.final_layernorm(hidden_states) hidden_states = self.encoder.final_layernorm(hidden_states)
if not return_dict: if not return_dict:
return tuple( return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None v
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
]
if v is not None
) )
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
...@@ -368,7 +377,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -368,7 +377,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
# Run encoder. # Run encoder.
# [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size] # [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size]
inputs_embeds = split_forward_gather_backward( inputs_embeds = split_forward_gather_backward(
inputs_embeds, dim=0, process_group=shard_config.tensor_parallel_process_group inputs_embeds,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
) )
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds, inputs_embeds,
...@@ -380,7 +391,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -380,7 +391,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
) )
hidden_states = gather_forward_split_backward( hidden_states = gather_forward_split_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
) )
if not return_dict: if not return_dict:
......
...@@ -21,12 +21,82 @@ from transformers.models.gpt2.modeling_gpt2 import ( ...@@ -21,12 +21,82 @@ from transformers.models.gpt2.modeling_gpt2 import (
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d from ..layer import cross_entropy_1d
from ..layer._operation import gather_forward_split_backward from ..layer._operation import gather_forward_split_backward
logger = logging.get_logger(__name__)
def _get_attention_mask(
self: GPT2Model,
shard_config: ShardConfig,
hidden_states: torch.Tensor,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]],
attention_mask: Optional[torch.FloatTensor],
encoder_hidden_states: Optional[torch.Tensor],
encoder_attention_mask: Optional[torch.FloatTensor],
) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]:
batch_size, seq_len = hidden_states.shape[:2]
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
if shard_config.enable_flash_attention:
encoder_attention_mask = ColoAttention.prepare_attn_kwargs(
(encoder_batch_size, 1, seq_len, encoder_sequence_length),
dtype=hidden_states.dtype,
dtype2=encoder_hidden_states.dtype,
q_padding_mask=attention_mask,
kv_padding_mask=encoder_attention_mask,
)
else:
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=encoder_hidden_states.device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
if shard_config.enable_flash_attention:
encoder_attention_mask = {"attention_mask": None}
else:
encoder_attention_mask = None
# GPT2Attention mask.
past_key_values_length = 0
if past_key_values is not None and past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[2]
if shard_config.enable_flash_attention:
if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = ColoAttention.prepare_attn_kwargs(
(batch_size, 1, seq_len, seq_len + past_key_values_length),
hidden_states.dtype,
hidden_states.device,
attention_mask,
is_causal=True,
)
elif attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
return attention_mask, encoder_attention_mask
class GPT2PipelineForwards: class GPT2PipelineForwards:
""" """
...@@ -83,10 +153,10 @@ class GPT2PipelineForwards: ...@@ -83,10 +153,10 @@ class GPT2PipelineForwards:
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0] input_ids.shape[0]
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0] inputs_embeds.shape[0]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
...@@ -99,38 +169,7 @@ class GPT2PipelineForwards: ...@@ -99,38 +169,7 @@ class GPT2PipelineForwards:
input_shape = hidden_states.size()[:-1] input_shape = hidden_states.size()[:-1]
device = hidden_states.device device = hidden_states.device
hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:]) hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:])
batch_size = hidden_states.shape[0] hidden_states.shape[0]
# GPT2Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
...@@ -156,6 +195,16 @@ class GPT2PipelineForwards: ...@@ -156,6 +195,16 @@ class GPT2PipelineForwards:
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
attention_mask, encoder_attention_mask = _get_attention_mask(
self,
shard_config,
hidden_states,
past_key_values,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning_once( logger.warning_once(
...@@ -171,7 +220,9 @@ class GPT2PipelineForwards: ...@@ -171,7 +220,9 @@ class GPT2PipelineForwards:
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config.enable_sequence_parallelism: if shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
) )
# Going through held blocks. # Going through held blocks.
...@@ -180,7 +231,7 @@ class GPT2PipelineForwards: ...@@ -180,7 +231,7 @@ class GPT2PipelineForwards:
block = self.h[i] block = self.h[i]
torch.cuda.set_device(hidden_states.device) torch.cuda.set_device(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states # Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None: if torch.is_tensor(attention_mask):
attention_mask = attention_mask.to(hidden_states.device) attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor): if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device) head_mask = head_mask.to(hidden_states.device)
...@@ -229,7 +280,9 @@ class GPT2PipelineForwards: ...@@ -229,7 +280,9 @@ class GPT2PipelineForwards:
# When sequence parallelism done, gather the output tensor in forward and split it in backward # When sequence parallelism done, gather the output tensor in forward and split it in backward
if shard_config.enable_sequence_parallelism: if shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward( hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
) )
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
...@@ -245,7 +298,13 @@ class GPT2PipelineForwards: ...@@ -245,7 +298,13 @@ class GPT2PipelineForwards:
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None if v is not None
) )
...@@ -333,7 +392,9 @@ class GPT2PipelineForwards: ...@@ -333,7 +392,9 @@ class GPT2PipelineForwards:
shift_labels = shift_labels.view(-1) shift_labels = shift_labels.view(-1)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output: if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
loss = cross_entropy_1d( loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
) )
else: else:
loss = loss_fct(shift_logits, shift_labels) loss = loss_fct(shift_logits, shift_labels)
...@@ -733,27 +794,18 @@ class GPT2PipelineForwards: ...@@ -733,27 +794,18 @@ class GPT2PipelineForwards:
def get_gpt2_flash_attention_forward(): def get_gpt2_flash_attention_forward():
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
def split_heads(tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor
def forward( def forward(
self: GPT2Attention, self: GPT2Attention,
hidden_states: Optional[Tuple[torch.FloatTensor]], hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None, layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[dict] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[dict] = None,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
assert head_mask is None, "FlashAttention does not support head_mask"
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"): if not hasattr(self, "q_attn"):
raise ValueError( raise ValueError(
...@@ -766,10 +818,9 @@ def get_gpt2_flash_attention_forward(): ...@@ -766,10 +818,9 @@ def get_gpt2_flash_attention_forward():
attention_mask = encoder_attention_mask attention_mask = encoder_attention_mask
else: else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim)
query = split_heads(query, self.num_heads, self.head_dim) key = self._split_heads(key, self.num_heads, self.head_dim)
key = split_heads(key, self.num_heads, self.head_dim) value = self._split_heads(value, self.num_heads, self.head_dim)
value = split_heads(value, self.num_heads, self.head_dim)
if layer_past is not None: if layer_past is not None:
past_key, past_value = layer_past past_key, past_value = layer_past
...@@ -781,29 +832,14 @@ def get_gpt2_flash_attention_forward(): ...@@ -781,29 +832,14 @@ def get_gpt2_flash_attention_forward():
else: else:
present = None present = None
if not self.is_cross_attention: scale = 1.0
attn_mask_type = AttnMaskType.causal if self.scale_attn_weights:
flash_attention_mask = None scale /= value.size(-1) ** 0.5
if attention_mask != None:
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
if not torch.all(flash_attention_mask):
if attn_mask_type == AttnMaskType.causal:
attn_mask_type == AttnMaskType.paddedcausal
else:
attn_mask_type = AttnMaskType.padding
scale = value.size(-1) ** -0.5
if self.scale_attn_by_inverse_layer_idx: if self.scale_attn_by_inverse_layer_idx:
scale = scale * (1 / float(self.layer_idx + 1)) scale /= float(self.layer_idx + 1)
dropout_p = self.attn_dropout.p if self.training else 0.0
# use coloattention attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
if not hasattr(self, "attention"): attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
self.attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
)
attn_output = self.attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
attn_output = self.c_proj(attn_output) attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output) attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present, None) outputs = (attn_output, present, None)
...@@ -813,9 +849,9 @@ def get_gpt2_flash_attention_forward(): ...@@ -813,9 +849,9 @@ def get_gpt2_flash_attention_forward():
return forward return forward
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig):
def forward( def forward(
self, self: GPT2Model,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
...@@ -840,12 +876,13 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -840,12 +876,13 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size() input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0] input_ids.shape[0]
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0] inputs_embeds.shape[0]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
...@@ -862,39 +899,201 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -862,39 +899,201 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
else: else:
past_length = past_key_values[0][0].size(-2) past_length = past_key_values[0][0].size(-2)
if position_ids is None: if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = torch.arange(
past_length,
input_shape[-1] + past_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# GPT2Attention mask. # Prepare head mask if needed
if attention_mask is not None: # 1.0 in head_mask indicate we keep the head
if batch_size <= 0: # attention_probs has shape bsz x n_heads x N x N
raise ValueError("batch_size has to be defined and > 0") # head_mask has shape n_layer x batch x n_heads x N x N
attention_mask = attention_mask.view(batch_size, -1) head_mask = self.get_head_mask(head_mask, self.config.n_layer)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length] if inputs_embeds is None:
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] inputs_embeds = self.wte(input_ids)
# this attention mask is more simple than the triangular masking of causal attention position_embeds = self.wpe(position_ids)
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. hidden_states = inputs_embeds + position_embeds
attention_mask = attention_mask[:, None, None, :]
if token_type_ids is not None:
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for token_type_embeds = self.wte(token_type_ids)
# masked positions, this operation will create a tensor which is 0.0 for hidden_states = hidden_states + token_type_embeds
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is hidden_states = self.drop(hidden_states)
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
attention_mask, encoder_attention_mask = _get_attention_mask(
# If a 2D or 3D attention mask is provided for the cross-attention self,
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] shard_config,
if self.config.add_cross_attention and encoder_hidden_states is not None: hidden_states,
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() past_key_values,
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) attention_mask,
if encoder_attention_mask is None: encoder_hidden_states,
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_attention_mask,
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) )
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if torch.is_tensor(attention_mask):
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
return forward
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
inputs_embeds.shape[0]
else: else:
encoder_attention_mask = None raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(
past_length,
input_shape[-1] + past_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
...@@ -914,6 +1113,15 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -914,6 +1113,15 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
hidden_states = self.drop(hidden_states) hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
attention_mask, encoder_attention_mask = _get_attention_mask(
self,
shard_config,
hidden_states,
past_key_values,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
...@@ -931,7 +1139,9 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -931,7 +1139,9 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
# split the input tensor along sequence dimension # split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
) )
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
...@@ -942,7 +1152,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -942,7 +1152,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
if layer_past is not None: if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states # Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None: if torch.is_tensor(attention_mask):
attention_mask = attention_mask.to(hidden_states.device) attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor): if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device) head_mask = head_mask.to(hidden_states.device)
...@@ -996,7 +1206,9 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -996,7 +1206,9 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
# When sequence parallelism done, gather the output tensor in forward and split it in backward # When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states = gather_forward_split_backward( hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
) )
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
...@@ -1008,7 +1220,13 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -1008,7 +1220,13 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None if v is not None
) )
......
...@@ -19,9 +19,54 @@ from transformers.models.gptj.modeling_gptj import ( ...@@ -19,9 +19,54 @@ from transformers.models.gptj.modeling_gptj import (
from transformers.utils import is_torch_fx_proxy, logging from transformers.utils import is_torch_fx_proxy, logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
logger = logging.get_logger(__name__)
def _get_attention_mask(
self: GPTJModel,
shard_config: ShardConfig,
hidden_states: torch.Tensor,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]],
attention_mask: Optional[torch.FloatTensor],
) -> Optional[Union[torch.Tensor, dict]]:
batch_size, seq_len = hidden_states.shape[:2]
past_key_values_length = 0
if past_key_values is not None and past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[2]
if shard_config.enable_flash_attention:
if attention_mask is not None:
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = ColoAttention.prepare_attn_kwargs(
(batch_size, 1, seq_len, seq_len + past_key_values_length),
hidden_states.dtype,
hidden_states.device,
attention_mask,
is_causal=True,
)
elif attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
return attention_mask
class GPTJPipelineForwards: class GPTJPipelineForwards:
""" """
...@@ -96,26 +141,6 @@ class GPTJPipelineForwards: ...@@ -96,26 +141,6 @@ class GPTJPipelineForwards:
batch_size, seq_length = input_shape[0], input_shape[1] batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device device = hidden_states.device
# Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x num_attention_heads x N x N # attention_probs has shape bsz x num_attention_heads x N x N
...@@ -139,6 +164,8 @@ class GPTJPipelineForwards: ...@@ -139,6 +164,8 @@ class GPTJPipelineForwards:
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning_once( logger.warning_once(
...@@ -154,7 +181,9 @@ class GPTJPipelineForwards: ...@@ -154,7 +181,9 @@ class GPTJPipelineForwards:
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config.enable_sequence_parallelism: if shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
) )
# Going through held blocks. # Going through held blocks.
...@@ -209,7 +238,9 @@ class GPTJPipelineForwards: ...@@ -209,7 +238,9 @@ class GPTJPipelineForwards:
# When sequence parallelism done, gather the output tensor in forward and split it in backward # When sequence parallelism done, gather the output tensor in forward and split it in backward
if shard_config.enable_sequence_parallelism: if shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward( hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
) )
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
...@@ -223,7 +254,14 @@ class GPTJPipelineForwards: ...@@ -223,7 +254,14 @@ class GPTJPipelineForwards:
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
if not return_dict: if not return_dict:
return tuple( return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None v
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
]
if v is not None
) )
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
...@@ -530,24 +568,11 @@ class GPTJPipelineForwards: ...@@ -530,24 +568,11 @@ class GPTJPipelineForwards:
def get_gptj_flash_attention_forward(): def get_gptj_flash_attention_forward():
from transformers.models.gptj.modeling_gptj import GPTJAttention from transformers.models.gptj.modeling_gptj import GPTJAttention
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
def split_heads(tensor, num_attention_heads, attn_head_size, rotary):
"""
Splits hidden dim into attn_head_size and num_attention_heads
"""
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
tensor = tensor.view(new_shape)
if rotary or len(tensor.shape) in [4, 5]:
return tensor
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
def forward( def forward(
self: GPTJAttention, self: GPTJAttention,
hidden_states: torch.FloatTensor, hidden_states: torch.FloatTensor,
layer_past: Optional[Tuple[torch.Tensor]] = None, layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[dict] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
...@@ -556,13 +581,14 @@ def get_gptj_flash_attention_forward(): ...@@ -556,13 +581,14 @@ def get_gptj_flash_attention_forward():
Tuple[torch.Tensor, Tuple[torch.Tensor]], Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]: ]:
assert head_mask is None, "head_mask is not supported for FlashAttention"
query = self.q_proj(hidden_states) query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states) key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states) value = self.v_proj(hidden_states)
query = split_heads(query, self.num_attention_heads, self.head_dim, True) query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
key = split_heads(key, self.num_attention_heads, self.head_dim, True) key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
value = split_heads(value, self.num_attention_heads, self.head_dim, False) value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)
if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing(): if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing():
# The logic to conditionally copy to GPU could not be traced, so we do this # The logic to conditionally copy to GPU could not be traced, so we do this
...@@ -591,46 +617,202 @@ def get_gptj_flash_attention_forward(): ...@@ -591,46 +617,202 @@ def get_gptj_flash_attention_forward():
key = apply_rotary_pos_emb(key, sin, cos) key = apply_rotary_pos_emb(key, sin, cos)
query = apply_rotary_pos_emb(query, sin, cos) query = apply_rotary_pos_emb(query, sin, cos)
# key = key.permute(0, 2, 1, 3) key = key.permute(0, 2, 1, 3)
# query = query.permute(0, 2, 1, 3) query = query.permute(0, 2, 1, 3)
key = key.to(dtype=value.dtype) # fp16 compatibility
query = query.to(dtype=value.dtype)
if layer_past is not None: if layer_past is not None:
past_key = layer_past[0] past_key = layer_past[0]
past_value = layer_past[1] past_value = layer_past[1]
key = torch.cat((past_key, key), dim=1) key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=1) value = torch.cat((past_value, value), dim=-2)
if use_cache is True: if use_cache is True:
present = (key, value) present = (key, value)
else: else:
present = None present = None
# use AttnMaskType and ColoAttention dropout_p = self.attn_dropout.p if self.training else 0.0
attn_mask_type = AttnMaskType.causal attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p)
flash_attention_mask = None attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
if attention_mask != None: attn_output = self.out_proj(attn_output)
if attn_mask_type == AttnMaskType.causal: attn_output = self.resid_dropout(attn_output)
attn_mask_type == AttnMaskType.paddedcausal outputs = (attn_output, present, None)
else:
attn_mask_type = AttnMaskType.padding
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
# use coloattention return outputs # a, present, (attentions)
scale = value.size(-1) ** -0.5
return forward
attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_attention_heads, dropout=self.attn_dropout.p, scale=scale def gptj_model_forward_for_flash_attention(shard_config: ShardConfig):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type) device = input_ids.device if input_ids is not None else inputs_embeds.device
attn_output = self.out_proj(attn_output) if token_type_ids is not None:
attn_output = self.resid_dropout(attn_output) token_type_ids = token_type_ids.view(-1, input_shape[-1])
outputs = (attn_output, present, None)
return outputs # a, present, (attentions) if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1]).long()
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(
past_length,
input_shape[-1] + past_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x num_attention_heads x N x N
# head_mask has shape n_layer x batch x num_attention_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
hidden_states = inputs_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
position_ids,
head_mask[i],
)
else:
outputs = block(
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
return forward return forward
...@@ -662,10 +844,10 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -662,10 +844,10 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0] input_ids.shape[0]
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0] inputs_embeds.shape[0]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
...@@ -684,29 +866,14 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -684,29 +866,14 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
past_length = past_key_values[0][0].size(-2) past_length = past_key_values[0][0].size(-2)
if position_ids is None: if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = torch.arange(
past_length,
input_shape[-1] + past_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x num_attention_heads x N x N # attention_probs has shape bsz x num_attention_heads x N x N
...@@ -725,6 +892,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -725,6 +892,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
hidden_states = self.drop(hidden_states) hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
...@@ -740,7 +908,9 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -740,7 +908,9 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
# split the input tensor along sequence dimension # split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
) )
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
...@@ -801,7 +971,9 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -801,7 +971,9 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
# When sequence parallelism done, gather the output tensor in forward and split it in backward # When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states = gather_forward_split_backward( hidden_states = gather_forward_split_backward(
hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group,
) )
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
...@@ -812,7 +984,16 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig): ...@@ -812,7 +984,16 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) return tuple(
v
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
]
if v is not None
)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
......
...@@ -15,7 +15,9 @@ from transformers.utils import logging ...@@ -15,7 +15,9 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
from ..layer import ColoAttention, cross_entropy_1d
try: try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
...@@ -105,18 +107,25 @@ class LlamaPipelineForwards: ...@@ -105,18 +107,25 @@ class LlamaPipelineForwards:
# embed positions, for the first stage, hidden_states is the input embeddings, # embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage # for the other stages, hidden_states is the output of the previous stage
if attention_mask is None: if shard_config.enable_flash_attention:
attention_mask = torch.ones( # in this case, attention_mask is a dict rather than a tensor
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
) attention_mask = ColoAttention.prepare_attn_kwargs(
if LATEST_VERSION: mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
) )
else: else:
attention_mask = self._prepare_decoder_attention_mask( if attention_mask is None:
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length attention_mask = torch.ones(
) (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
if LATEST_VERSION:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
...@@ -262,6 +271,7 @@ class LlamaPipelineForwards: ...@@ -262,6 +271,7 @@ class LlamaPipelineForwards:
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config,
) )
past_key_values = None past_key_values = None
...@@ -352,6 +362,7 @@ class LlamaPipelineForwards: ...@@ -352,6 +362,7 @@ class LlamaPipelineForwards:
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config,
) )
if input_ids is not None: if input_ids is not None:
...@@ -420,8 +431,6 @@ class LlamaPipelineForwards: ...@@ -420,8 +431,6 @@ class LlamaPipelineForwards:
def get_llama_flash_attention_forward(shard_config: ShardConfig): def get_llama_flash_attention_forward(shard_config: ShardConfig):
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
llama_version = 2 llama_version = 2
try: try:
from transformers.models.llama.modeling_llama import repeat_kv from transformers.models.llama.modeling_llama import repeat_kv
...@@ -432,7 +441,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): ...@@ -432,7 +441,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
def forward( def forward(
self: LlamaAttention, self: LlamaAttention,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[dict] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
...@@ -466,31 +475,10 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): ...@@ -466,31 +475,10 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
me_input_shape = (bsz, q_len, self.num_heads, self.head_dim) assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape) attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape) attn_output = attn_output.transpose(1, 2).contiguous()
value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
flash_attention_mask = None
attn_mask_type = AttnMaskType.causal
if not getattr(shard_config, "causal_lm", False) and attention_mask != None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal
if not hasattr(self, "attention"):
self.attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = self.attention(
query_states,
key_states,
value_states,
attn_mask=flash_attention_mask,
attn_mask_type=attn_mask_type,
origin_attn_mask=attention_mask,
)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
...@@ -499,6 +487,137 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig): ...@@ -499,6 +487,137 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
return forward return forward
def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
logger = logging.get_logger(__name__)
assert shard_config.enable_flash_attention, "Flash Attention is not enabled."
def forward(
self: LlamaModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
hidden_states = inputs_embeds
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return forward
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
......
...@@ -18,6 +18,37 @@ from transformers.models.opt.modeling_opt import ( ...@@ -18,6 +18,37 @@ from transformers.models.opt.modeling_opt import (
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.shard import ShardConfig
logger = logging.get_logger(__name__)
def _get_attention_mask(
self: OPTModel,
shard_config: ShardConfig,
hidden_states: torch.Tensor,
past_key_values_length: int,
attention_mask: Optional[torch.FloatTensor],
):
batch_size, seq_length = hidden_states.shape[:2]
mask_seq_length = past_key_values_length + seq_length
if shard_config.enable_flash_attention:
attention_mask = ColoAttention.prepare_attn_kwargs(
(batch_size, 1, seq_length, mask_seq_length),
hidden_states.dtype,
hidden_states.device,
attention_mask,
is_causal=True,
)
else:
attention_mask = self.decoder._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
)
return attention_mask
class OPTPipelineForwards: class OPTPipelineForwards:
...@@ -26,46 +57,6 @@ class OPTPipelineForwards: ...@@ -26,46 +57,6 @@ class OPTPipelineForwards:
under pipeline setting. under pipeline setting.
""" """
@staticmethod
def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
from transformers.models.opt.modeling_opt import _make_causal_mask
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
_dtype,
device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, tgt_len=input_shape[-1]).to(
device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
@staticmethod @staticmethod
def opt_model_forward( def opt_model_forward(
self: OPTModel, self: OPTModel,
...@@ -81,6 +72,7 @@ class OPTPipelineForwards: ...@@ -81,6 +72,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: Optional[ShardConfig] = None,
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
""" """
This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
...@@ -119,7 +111,7 @@ class OPTPipelineForwards: ...@@ -119,7 +111,7 @@ class OPTPipelineForwards:
if decoder.project_in is not None: if decoder.project_in is not None:
inputs_embeds = decoder.project_in(inputs_embeds) inputs_embeds = decoder.project_in(inputs_embeds)
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
_dtype = inputs_embeds.dtype inputs_embeds.dtype
else: else:
if hidden_states is None: if hidden_states is None:
...@@ -127,7 +119,7 @@ class OPTPipelineForwards: ...@@ -127,7 +119,7 @@ class OPTPipelineForwards:
input_shape = hidden_states.size()[:-1] input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape[0], input_shape[1] batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device device = hidden_states.device
_dtype = hidden_states.dtype hidden_states.dtype
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
# required mask seq length can be calculated via length of past # required mask seq length can be calculated via length of past
...@@ -141,13 +133,24 @@ class OPTPipelineForwards: ...@@ -141,13 +133,24 @@ class OPTPipelineForwards:
f"{mask_seq_length} (sum of the lengths of current and past inputs)" f"{mask_seq_length} (sum of the lengths of current and past inputs)"
) )
causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(
attention_mask, input_shape, _dtype, device, past_key_values_length
)
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
causal_attention_mask = _get_attention_mask(
self,
shard_config,
inputs_embeds,
past_key_values_length,
attention_mask,
)
pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length) pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length)
hidden_states = inputs_embeds + pos_embeds hidden_states = inputs_embeds + pos_embeds
else:
causal_attention_mask = _get_attention_mask(
self,
shard_config,
hidden_states,
past_key_values_length,
attention_mask,
)
if decoder.gradient_checkpointing and decoder.training: if decoder.gradient_checkpointing and decoder.training:
if use_cache: if use_cache:
...@@ -249,7 +252,16 @@ class OPTPipelineForwards: ...@@ -249,7 +252,16 @@ class OPTPipelineForwards:
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
...@@ -276,6 +288,7 @@ class OPTPipelineForwards: ...@@ -276,6 +288,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: Optional[ShardConfig] = None,
) -> Union[Tuple, CausalLMOutputWithPast]: ) -> Union[Tuple, CausalLMOutputWithPast]:
r""" r"""
This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForCausalLM.forward. This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForCausalLM.forward.
...@@ -303,6 +316,7 @@ class OPTPipelineForwards: ...@@ -303,6 +316,7 @@ class OPTPipelineForwards:
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config,
) )
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
logits = self.lm_head(outputs[0]).contiguous() logits = self.lm_head(outputs[0]).contiguous()
...@@ -347,6 +361,7 @@ class OPTPipelineForwards: ...@@ -347,6 +361,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: Optional[ShardConfig] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]: ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r""" r"""
This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForSequenceClassification.forward. This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForSequenceClassification.forward.
...@@ -371,6 +386,7 @@ class OPTPipelineForwards: ...@@ -371,6 +386,7 @@ class OPTPipelineForwards:
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config,
) )
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
...@@ -448,6 +464,7 @@ class OPTPipelineForwards: ...@@ -448,6 +464,7 @@ class OPTPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
shard_config: Optional[ShardConfig] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]: ) -> Union[Tuple, QuestionAnsweringModelOutput]:
r""" r"""
This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForQuestionAnswering.forward. This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForQuestionAnswering.forward.
...@@ -469,6 +486,7 @@ class OPTPipelineForwards: ...@@ -469,6 +486,7 @@ class OPTPipelineForwards:
stage_manager=stage_manager, stage_manager=stage_manager,
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index, stage_index=stage_index,
shard_config=shard_config,
) )
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
...@@ -511,49 +529,47 @@ class OPTPipelineForwards: ...@@ -511,49 +529,47 @@ class OPTPipelineForwards:
return {"hidden_states": hidden_states} return {"hidden_states": hidden_states}
def get_opt_flash_attention_forward(): def get_opt_flash_attention_forward(shard_config: ShardConfig):
from transformers.models.opt.modeling_opt import OPTAttention from transformers.models.opt.modeling_opt import OPTAttention
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
def forward( def forward(
self: OPTAttention, self: OPTAttention,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None, key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[dict] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention"
# if key_value_states are provided this layer is used as a cross-attention layer # if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder # for the decoder
is_cross_attention = key_value_states is not None is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size() bsz, tgt_len, _ = hidden_states.size()
attention_input_shape = (bsz, -1, self.num_heads, self.head_dim)
# get query proj # get query proj
query_states = self.q_proj(hidden_states).view(*attention_input_shape) query_states = self.q_proj(hidden_states)
# get key, value proj # get key, value proj
if is_cross_attention and past_key_value is not None: if is_cross_attention and past_key_value is not None:
# reuse k, v, cross_attentions # reuse k,v, cross_attentions
key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape) key_states = past_key_value[0]
value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape) value_states = past_key_value[1]
elif is_cross_attention: elif is_cross_attention:
# cross_attentions # cross_attentions
key_states = self.k_proj(key_value_states).view(*attention_input_shape) key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self.v_proj(key_value_states).view(*attention_input_shape) value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None: elif past_key_value is not None:
# reuse k, v, self_attention # reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*attention_input_shape) key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self.v_proj(hidden_states).view(*attention_input_shape) value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=1) key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=1) value_states = torch.cat([past_key_value[1], value_states], dim=2)
else: else:
# self_attention # self_attention
key_states = self.k_proj(hidden_states).view(*attention_input_shape) key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self.v_proj(hidden_states).view(*attention_input_shape) value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder: if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
...@@ -565,38 +581,181 @@ def get_opt_flash_attention_forward(): ...@@ -565,38 +581,181 @@ def get_opt_flash_attention_forward():
# if encoder bi-directional self-attention `past_key_value` is always `None` # if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states) past_key_value = (key_states, value_states)
src_len = key_states.size(1) query_states = self._shape(query_states, tgt_len, bsz)
if layer_head_mask != None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
flash_attention_mask = None
attn_mask_type = AttnMaskType.causal
if attention_mask != None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
if not torch.all(flash_attention_mask):
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention( dropout_p = self.dropout if self.training else 0.0
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling attn_output = ColoAttention.attention(
) query_states,
attn_output = attention( key_states,
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type value_states,
**attention_mask,
dropout_p=dropout_p,
scale=self.scaling,
) )
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value return attn_output, None, past_key_value
return forward return forward
def get_opt_decoder_forward_for_flash_attention(shard_config: ShardConfig):
from transformers.models.opt.modeling_opt import OPTDecoder
def forward(
self: OPTDecoder,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length
# embed positions
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
elif attention_mask.shape[1] != mask_seq_length:
raise ValueError(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)
causal_attention_mask = _get_attention_mask(
self, shard_config, inputs_embeds, past_key_values_length, attention_mask
)
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
causal_attention_mask,
head_mask[idx] if head_mask is not None else None,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return forward
def get_jit_fused_opt_decoder_layer_forward(): def get_jit_fused_opt_decoder_layer_forward():
from transformers.models.opt.modeling_opt import OPTDecoderLayer from transformers.models.opt.modeling_opt import OPTDecoderLayer
......
import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
...@@ -6,6 +5,7 @@ from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder ...@@ -6,6 +5,7 @@ from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention
def _encoder_forward( def _encoder_forward(
...@@ -98,7 +98,9 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: ...@@ -98,7 +98,9 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index:
pixel_values = pixel_values.to(expected_dtype) pixel_values = pixel_values.to(expected_dtype)
embedding_output = self.embeddings( embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding pixel_values,
bool_masked_pos=bool_masked_pos,
interpolate_pos_encoding=interpolate_pos_encoding,
) )
hidden_states = embedding_output hidden_states = embedding_output
else: else:
...@@ -336,34 +338,27 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag ...@@ -336,34 +338,27 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag
def get_vit_flash_self_attention_forward(): def get_vit_flash_self_attention_forward():
from transformers.models.vit.modeling_vit import ViTSelfAttention from transformers.models.vit.modeling_vit import ViTSelfAttention
from colossalai.nn.layer.colo_attention import ColoAttention
def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
x = x.view(new_x_shape)
return x
def forward( def forward(
self: ViTSelfAttention, self: ViTSelfAttention,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
assert head_mask is None, "head_mask is not supported for FlashAttention"
mixed_query_layer = self.query(hidden_states) mixed_query_layer = self.query(hidden_states)
key_layer = transpose_for_scores(self.key(hidden_states), self.num_attention_heads, self.attention_head_size) key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = transpose_for_scores( value_layer = self.transpose_for_scores(self.value(hidden_states))
self.value(hidden_states), self.num_attention_heads, self.attention_head_size query_layer = self.transpose_for_scores(mixed_query_layer)
)
query_layer = transpose_for_scores(mixed_query_layer, self.num_attention_heads, self.attention_head_size)
scale = 1.0 / math.sqrt(self.attention_head_size) dropout_p = self.dropout.p if self.training else 0.0
attention = ColoAttention( context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, dropout_p=dropout_p)
embed_dim=self.all_head_size, num_heads=self.num_attention_heads, dropout=self.dropout.p, scale=scale
) context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
context_layer = attention(query_layer, key_layer, value_layer) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer,) outputs = (context_layer, None) if output_attentions else (context_layer,)
return outputs return outputs
......
...@@ -13,41 +13,74 @@ from transformers.modeling_outputs import ( ...@@ -13,41 +13,74 @@ from transformers.modeling_outputs import (
SequenceClassifierOutput, SequenceClassifierOutput,
) )
from transformers.models.whisper.modeling_whisper import ( from transformers.models.whisper.modeling_whisper import (
WhisperDecoder,
WhisperEncoder, WhisperEncoder,
WhisperForAudioClassification, WhisperForAudioClassification,
WhisperForConditionalGeneration, WhisperForConditionalGeneration,
WhisperModel, WhisperModel,
shift_tokens_right,
) )
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.shard import ShardConfig
logger = logging.get_logger(__name__)
def _get_attention_mask(
self: WhisperDecoder,
shard_config: ShardConfig,
hidden_states: torch.Tensor,
past_key_values_length: int,
attention_mask: Optional[torch.FloatTensor],
):
batch_size, seq_length = hidden_states.shape[:2]
mask_seq_length = past_key_values_length + seq_length
if shard_config.enable_flash_attention:
attention_mask = ColoAttention.prepare_attn_kwargs(
(batch_size, 1, seq_length, mask_seq_length),
hidden_states.dtype,
hidden_states.device,
attention_mask,
is_causal=True,
)
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
)
return attention_mask
def get_whisper_flash_attention_forward(): def get_whisper_flash_attention_forward():
from transformers.models.whisper.modeling_whisper import WhisperAttention from transformers.models.whisper.modeling_whisper import WhisperAttention
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()
def forward( def forward(
self: WhisperAttention, self: WhisperAttention,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None, key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[dict] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention"
# for encoder, attention_mask is None
if attention_mask is None:
attention_mask = {}
# if key_value_states are provided this layer is used as a cross-attention layer # if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder # for the decoder
is_cross_attention = key_value_states is not None is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size() bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj # get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]` # `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as # is checking that the `sequence_length` of the `past_key_value` is the same as
...@@ -55,25 +88,25 @@ def get_whisper_flash_attention_forward(): ...@@ -55,25 +88,25 @@ def get_whisper_flash_attention_forward():
if ( if (
is_cross_attention is_cross_attention
and past_key_value is not None and past_key_value is not None
and past_key_value[0].shape[1] == key_value_states.shape[1] and past_key_value[0].shape[2] == key_value_states.shape[1]
): ):
# reuse k,v, cross_attentions # reuse k,v, cross_attentions
key_states = past_key_value[0] key_states = past_key_value[0]
value_states = past_key_value[1] value_states = past_key_value[1]
elif is_cross_attention: elif is_cross_attention:
# cross_attentions # cross_attentions
key_states = shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim) value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None: elif past_key_value is not None:
# reuse k, v, self_attention # reuse k, v, self_attention
key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=1) key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=1) value_states = torch.cat([past_key_value[1], value_states], dim=2)
else: else:
# self_attention # self_attention
key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim) value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder: if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
...@@ -85,42 +118,178 @@ def get_whisper_flash_attention_forward(): ...@@ -85,42 +118,178 @@ def get_whisper_flash_attention_forward():
# if encoder bi-directional self-attention `past_key_value` is always `None` # if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states) past_key_value = (key_states, value_states)
# get query proj query_states = self._shape(query_states, tgt_len, bsz)
query_states = shape(self.q_proj(hidden_states), tgt_len, bsz, self.num_heads, self.head_dim)
src_len = key_states.size(1) dropout_p = self.dropout if self.training else 0.0
if layer_head_mask is not None: attn_output = ColoAttention.attention(
if layer_head_mask.size() != (self.num_heads,): query_states,
raise ValueError( key_states,
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" value_states,
f" {layer_head_mask.size()}" **attention_mask,
) dropout_p=dropout_p,
scale=self.scaling,
)
attn_output = attn_output.transpose(1, 2)
attn_type = None # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
flash_attention_mask = None # partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
if self.is_decoder: attn_output = self.out_proj(attn_output)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous())
if not torch.all(flash_attention_mask):
attn_type = AttnMaskType.paddedcausal
else:
attn_type = AttnMaskType.causal
attention = ColoAttention( return attn_output, None, past_key_value
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling
) return forward
attn_output = attention(
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_type
def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig):
def forward(
self: WhisperDecoder,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
attn_output = self.out_proj(attn_output) # retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
return attn_output, None, past_key_value # past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
attention_mask = _get_attention_mask(self, shard_config, inputs_embeds, past_key_values_length, attention_mask)
# embed positions
if input_ids is not None:
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
else:
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
assert attn_mask.size()[0] == (len(self.layers)), (
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
encoder_hidden_states,
None, # encoder attention mask
head_mask[idx] if head_mask is not None else None,
(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
None, # past_key_value
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
hidden_states = self.layer_norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
return forward return forward
...@@ -292,6 +461,7 @@ class WhisperPipelineForwards: ...@@ -292,6 +461,7 @@ class WhisperPipelineForwards:
all_attentions=None, all_attentions=None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None, decoder_starting_stage: Optional[int] = None,
shard_config: Optional[ShardConfig] = None,
): ):
r""" r"""
Args: Args:
...@@ -403,7 +573,9 @@ class WhisperPipelineForwards: ...@@ -403,7 +573,9 @@ class WhisperPipelineForwards:
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput( return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
) )
else: else:
...@@ -411,7 +583,7 @@ class WhisperPipelineForwards: ...@@ -411,7 +583,7 @@ class WhisperPipelineForwards:
@staticmethod @staticmethod
def whisper_decoder_forward( def whisper_decoder_forward(
self, self: WhisperDecoder,
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
...@@ -427,6 +599,7 @@ class WhisperPipelineForwards: ...@@ -427,6 +599,7 @@ class WhisperPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None, decoder_starting_stage: Optional[int] = None,
shard_config: Optional[ShardConfig] = None,
): ):
r""" r"""
Args: Args:
...@@ -535,8 +708,12 @@ class WhisperPipelineForwards: ...@@ -535,8 +708,12 @@ class WhisperPipelineForwards:
else: else:
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _get_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length self,
shard_config,
inputs_embeds,
past_key_values_length,
attention_mask,
) )
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
...@@ -556,8 +733,12 @@ class WhisperPipelineForwards: ...@@ -556,8 +733,12 @@ class WhisperPipelineForwards:
) )
input_shape = hidden_states.size()[:-1] input_shape = hidden_states.size()[:-1]
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _get_attention_mask(
attention_mask, input_shape, hidden_states, past_key_values_length self,
shard_config,
hidden_states,
past_key_values_length,
attention_mask,
) )
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
...@@ -590,7 +771,7 @@ class WhisperPipelineForwards: ...@@ -590,7 +771,7 @@ class WhisperPipelineForwards:
encoder_hidden_states, encoder_hidden_states,
None, # encoder attention mask None, # encoder attention mask
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
None, # past_key_value None, # past_key_value
) )
else: else:
...@@ -626,7 +807,13 @@ class WhisperPipelineForwards: ...@@ -626,7 +807,13 @@ class WhisperPipelineForwards:
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
all_cross_attentions,
]
if v is not None if v is not None
) )
return BaseModelOutputWithPastAndCrossAttentions( return BaseModelOutputWithPastAndCrossAttentions(
...@@ -666,6 +853,7 @@ class WhisperPipelineForwards: ...@@ -666,6 +853,7 @@ class WhisperPipelineForwards:
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None, decoder_starting_stage: Optional[int] = None,
shard_config: Optional[ShardConfig] = None,
): ):
r""" r"""
Returns: Returns:
...@@ -735,7 +923,7 @@ class WhisperPipelineForwards: ...@@ -735,7 +923,7 @@ class WhisperPipelineForwards:
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput( encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0], last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, hidden_states=(encoder_outputs[1] if len(encoder_outputs) > 1 else None),
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
) )
...@@ -767,6 +955,7 @@ class WhisperPipelineForwards: ...@@ -767,6 +955,7 @@ class WhisperPipelineForwards:
hidden_states=hidden_states, hidden_states=hidden_states,
stage_index=stage_index, stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage, decoder_starting_stage=decoder_starting_stage,
shard_config=shard_config,
) )
# Directly return outputs of overloaded Whisper forward if not at last stage. # Directly return outputs of overloaded Whisper forward if not at last stage.
...@@ -810,6 +999,7 @@ class WhisperPipelineForwards: ...@@ -810,6 +999,7 @@ class WhisperPipelineForwards:
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None, decoder_starting_stage: Optional[int] = None,
shard_config: Optional[ShardConfig] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
...@@ -870,6 +1060,7 @@ class WhisperPipelineForwards: ...@@ -870,6 +1060,7 @@ class WhisperPipelineForwards:
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
stage_index=stage_index, stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage, decoder_starting_stage=decoder_starting_stage,
shard_config=shard_config,
) )
if not in_decoder: if not in_decoder:
return outputs return outputs
...@@ -920,6 +1111,7 @@ class WhisperPipelineForwards: ...@@ -920,6 +1111,7 @@ class WhisperPipelineForwards:
all_attentions=None, all_attentions=None,
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None, decoder_starting_stage: Optional[int] = None,
shard_config: Optional[ShardConfig] = None,
): ):
r""" r"""
This function is modified on the basis of transformers.models.whisper.modeling_whisper.WhisperForAudioClassification.forward. This function is modified on the basis of transformers.models.whisper.modeling_whisper.WhisperForAudioClassification.forward.
......
...@@ -8,6 +8,7 @@ import colossalai.shardformer.layer as col_nn ...@@ -8,6 +8,7 @@ import colossalai.shardformer.layer as col_nn
from ..modeling.gpt2 import ( from ..modeling.gpt2 import (
GPT2PipelineForwards, GPT2PipelineForwards,
get_gpt2_flash_attention_forward, get_gpt2_flash_attention_forward,
get_gpt_model_forward_for_flash_attn,
get_lm_forward_with_dist_cross_entropy, get_lm_forward_with_dist_cross_entropy,
gpt2_sequence_parallel_forward_fn, gpt2_sequence_parallel_forward_fn,
) )
...@@ -75,7 +76,11 @@ class GPT2Policy(Policy): ...@@ -75,7 +76,11 @@ class GPT2Policy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.c_attn", suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col, target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap}, kwargs={
"n_fused": 3,
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.c_proj", suffix="attn.c_proj",
...@@ -87,7 +92,11 @@ class GPT2Policy(Policy): ...@@ -87,7 +92,11 @@ class GPT2Policy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.c_fc", suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col, target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={"n_fused": 1, "seq_parallel": use_sequence_parallel, "overlap": overlap}, kwargs={
"n_fused": 1,
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.c_proj", suffix="mlp.c_proj",
...@@ -150,6 +159,10 @@ class GPT2Policy(Policy): ...@@ -150,6 +159,10 @@ class GPT2Policy(Policy):
policy=policy, policy=policy,
target_key=GPT2Attention, target_key=GPT2Attention,
) )
if not self.shard_config.pipeline_stage_manager:
policy[GPT2Model].method_replacement = {
"forward": get_gpt_model_forward_for_flash_attn(self.shard_config)
}
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)} policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
...@@ -223,14 +236,21 @@ class GPT2Policy(Policy): ...@@ -223,14 +236,21 @@ class GPT2Policy(Policy):
num_stages=stage_manager.num_stages, num_stages=stage_manager.num_stages,
) )
method_replacement = { method_replacement = {
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) "forward": partial(
new_forward,
stage_manager=stage_manager,
shard_config=self.shard_config,
)
} }
else: else:
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config,
) )
} }
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
...@@ -245,7 +265,9 @@ class GPT2ModelPolicy(GPT2Policy): ...@@ -245,7 +265,9 @@ class GPT2ModelPolicy(GPT2Policy):
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=GPT2Model, new_forward=GPT2PipelineForwards.gpt2_model_forward, policy=policy model_cls=GPT2Model,
new_forward=GPT2PipelineForwards.gpt2_model_forward,
policy=policy,
) )
return policy return policy
...@@ -299,7 +321,12 @@ class GPT2LMHeadModelPolicy(GPT2Policy): ...@@ -299,7 +321,12 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
if stage_manager is not None: if stage_manager is not None:
if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
first_stage, last_stage = 0, stage_manager.num_stages - 1 first_stage, last_stage = 0, stage_manager.num_stages - 1
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] return [
{
first_stage: module.transformer.wte.weight,
last_stage: module.lm_head.weight,
}
]
return [] return []
...@@ -315,7 +342,9 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): ...@@ -315,7 +342,9 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
GPT2DoubleHeadsModel: ModulePolicyDescription( GPT2DoubleHeadsModel: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True},
) )
] ]
) )
...@@ -350,7 +379,12 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): ...@@ -350,7 +379,12 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
if stage_manager is not None: if stage_manager is not None:
if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
first_stage, last_stage = 0, stage_manager.num_stages - 1 first_stage, last_stage = 0, stage_manager.num_stages - 1
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] return [
{
first_stage: module.transformer.wte.weight,
last_stage: module.lm_head.weight,
}
]
return [] return []
...@@ -392,7 +426,10 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy): ...@@ -392,7 +426,10 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy):
addon_module = { addon_module = {
GPT2ForTokenClassification: ModulePolicyDescription( GPT2ForTokenClassification: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription(suffix="dropout", target_module=col_nn.DropoutForParallelInput) SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
)
] ]
) )
} }
......
...@@ -6,7 +6,11 @@ from torch import Tensor, nn ...@@ -6,7 +6,11 @@ from torch import Tensor, nn
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from ..modeling.gptj import GPTJPipelineForwards, get_gptj_flash_attention_forward from ..modeling.gptj import (
GPTJPipelineForwards,
get_gptj_flash_attention_forward,
gptj_model_forward_for_flash_attention,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [ __all__ = [
...@@ -71,17 +75,26 @@ class GPTJPolicy(Policy): ...@@ -71,17 +75,26 @@ class GPTJPolicy(Policy):
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.k_proj", suffix="attn.k_proj",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.q_proj", suffix="attn.q_proj",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.v_proj", suffix="attn.v_proj",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap}, kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.out_proj", suffix="attn.out_proj",
...@@ -143,6 +156,12 @@ class GPTJPolicy(Policy): ...@@ -143,6 +156,12 @@ class GPTJPolicy(Policy):
policy=policy, policy=policy,
target_key=GPTJAttention, target_key=GPTJAttention,
) )
if not self.shard_config.pipeline_stage_manager:
self.append_or_create_method_replacement(
description={"forward": gptj_model_forward_for_flash_attention(self.shard_config)},
policy=policy,
target_key=GPTJModel,
)
return policy return policy
...@@ -185,7 +204,10 @@ class GPTJPolicy(Policy): ...@@ -185,7 +204,10 @@ class GPTJPolicy(Policy):
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config,
) )
} }
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
...@@ -203,7 +225,9 @@ class GPTJModelPolicy(GPTJPolicy): ...@@ -203,7 +225,9 @@ class GPTJModelPolicy(GPTJPolicy):
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=GPTJModel, new_forward=GPTJPipelineForwards.gptj_model_forward, policy=policy model_cls=GPTJModel,
new_forward=GPTJPipelineForwards.gptj_model_forward,
policy=policy,
) )
return policy return policy
...@@ -230,7 +254,9 @@ class GPTJForCausalLMPolicy(GPTJPolicy): ...@@ -230,7 +254,9 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
GPTJForCausalLM: ModulePolicyDescription( GPTJForCausalLM: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True},
) )
] ]
) )
...@@ -239,7 +265,9 @@ class GPTJForCausalLMPolicy(GPTJPolicy): ...@@ -239,7 +265,9 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=GPTJForCausalLM, new_forward=GPTJPipelineForwards.gptj_causallm_model_forward, policy=policy model_cls=GPTJForCausalLM,
new_forward=GPTJPipelineForwards.gptj_causallm_model_forward,
policy=policy,
) )
return policy return policy
...@@ -256,7 +284,12 @@ class GPTJForCausalLMPolicy(GPTJPolicy): ...@@ -256,7 +284,12 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
if stage_manager is not None: if stage_manager is not None:
if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
first_stage, last_stage = 0, stage_manager.num_stages - 1 first_stage, last_stage = 0, stage_manager.num_stages - 1
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] return [
{
first_stage: module.transformer.wte.weight,
last_stage: module.lm_head.weight,
}
]
return [] return []
......
...@@ -11,6 +11,7 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Ro ...@@ -11,6 +11,7 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Ro
from ..modeling.llama import ( from ..modeling.llama import (
LlamaPipelineForwards, LlamaPipelineForwards,
get_llama_flash_attention_forward, get_llama_flash_attention_forward,
get_llama_model_forward_for_flash_attn,
get_lm_forward_with_dist_cross_entropy, get_lm_forward_with_dist_cross_entropy,
) )
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
...@@ -135,6 +136,15 @@ class LlamaPolicy(Policy): ...@@ -135,6 +136,15 @@ class LlamaPolicy(Policy):
policy=policy, policy=policy,
target_key=LlamaAttention, target_key=LlamaAttention,
) )
if self.pipeline_stage_manager is None:
# replace llama model forward method
self.append_or_create_method_replacement(
description={
"forward": get_llama_model_forward_for_flash_attn(self.shard_config),
},
policy=policy,
target_key=LlamaModel,
)
return policy return policy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment