"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "1b0dd059408108c3fadc46a5a9193d5061a23e5d"
Commit 906426cb authored by flybird1111's avatar flybird1111 Committed by Hongxin Liu
Browse files

[Shardformer] Merge flash attention branch to pipeline branch (#4362)



* [shardformer] supported flash attention test dependency (#4158)

* [shardformer] fix flash attention utils test (#4180)

* [shardformer] opt support flash attention (#4163)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] add performance benchmark of shardformer (#4175)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] benchmark fix

* [shardformer] benchmark fix

* [shardformer] llama support flash attention (#4185)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] llama support flash attention

* [shardformer] llama support flash attention

* [shardformer] Move the import statement for xformer outside the forward function.

* [shardformer] gpt2 support flash attention. (#4191)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] gpt2 support flash attention

* [shardformer] gpt2 support flash attention

* [shardformer] gpt2 support flash attention

* [shardformer] bloom support flash attention (#4188)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] bloom suport flash attention

* [shardformer] add assert to sequence length

* [shardformer] fix

* [shardformer] fix

* [shardformer] fix

* [shardformer] bert support flash attention. (#4206)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] bert support flash attention

* [shardformer] t5 support flash attention. (#4216)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] t5 support flash attention

* [shardformer] t5 support flash attention

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* [shardformer] support 'paddedcausal'  type of attention mask in Coloattention. (#4215)

* added padded causal attn mask type for ColoAttention

* [shardformer]t5 flash attention fix (#4239)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] t5 flash attention fix

* [shardformer] update gpt2 to use coloattention. (#4234)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2

* [shardformer] update opt and llama to use coloattention. (#4226)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt

* [shardformer] shardformer support jit fused operator. (#4236)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] bloom support jit fused operator

* [shardformer] bloom support jit fused operator

* [shardformer] bloom support jit fused operator

* [shardformer] t5 support jit fused operator

* [shardformer] t5 support jit fused operator

* [shardformer] t5 support jit fused operator

* [shardformer] add roadmap of flash attention

* [shardformer] add roadmap of flash attention

* [shardformer] add roadmap of flash attention

* [shardformer] add type hint to 'self' param of forward

* [shardformer] merge feature/shardformer-models branch to feature/flash-attention-shardformer branch. (#4290)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>

* [shardformer] whisper support flash attention (#4301)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] whisper support flash attention

* [shardformer] whisper support flash attention

* [shardformer]whisper support jit operator

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>

* [shardformer] sam support flash attention (#4316)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] sam support flash attention

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>

* [shardformer] merge blip2/chatglm  (#4321)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: default avatarklhhhhh <1412841649@qq.com>

* [shardformer] blip2 support flash attention and jit operator (#4325)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

* [shardformer] blip2 support flash attention and jit operator

* [shardformer] blip2 support flash attention and jit operator

* [shardformer] blip2 support flash attention and jit operator

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: default avatarklhhhhh <1412841649@qq.com>

* [shardformer] chatglm support flash attention and jit operator (#4330)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

* [shardformer] chatglm support flash attention and jit operator

* [shardformer] chatglm support flash attention and jit operator

* [shardformer] chatglm support flash attention and jit operator

* [shardformer] chatglm support flash attention and jit operator

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: default avatarklhhhhh <1412841649@qq.com>

* [shardformer] vit support flash attention and jit operator (#4334)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

* [shardformer] vit support flash attention and jit operator

* [shardformer] vit support flash attention and jit operator

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: default avatarklhhhhh <1412841649@qq.com>

* [pipeline] merge flash attention branch

* [pipeline] merge flash attention branch

* [pipeline] merge flash attention branch

* [pipeline] fix conflict

* [pipeline] fix conflict

* Merge branch 'feature/pipeline' into feature/pipeline

* Merge branch 'feature/pipeline' into feature/pipeline

* Merge branch 'feature/pipeline' into feature/pipeline

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* fix flash attention tests

* gemini ignore whisper

* fix vit

* fix xformers import handle

---------
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: default avatarklhhhhh <1412841649@qq.com>
parent a88e9225
......@@ -30,7 +30,7 @@
### Quick Start
The sample API usage is given below:
The sample API usage is given below(If you enable the use of flash attention, please install xformers.):
```python
from colossalai.shardformer import ShardConfig, Shard
......@@ -106,6 +106,20 @@ We will follow this roadmap to develop Shardformer:
- [ ] Multi-modal
- [x] SAM
- [x] BLIP-2
- [ ] Flash Attention Support
- [ ] NLP
- [x] BERT
- [x] T5
- [x] LlaMa
- [x] GPT2
- [x] OPT
- [x] BLOOM
- [ ] GLM
- [ ] RoBERTa
- [ ] ALBERT
- [ ] ERNIE
- [ ] GPT Neo
- [ ] GPT-J
## 💡 API Design
......@@ -373,11 +387,49 @@ pytest tests/test_shardformer
### System Performance
To be added.
We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate the performance improvement of Shardformer. We compared the training time between the original model and the shard model.
We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.
In the case of using 2 GPUs, the training times are as follows.
| N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: |
| 256 | 11.2ms | 17.2ms |
| 512 | 9.8ms | 19.5ms |
| 1024 | 19.6ms | 18.9ms |
| 2048 | 46.6ms | 30.8ms |
| 4096 | 160.5ms | 90.4ms |
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/shardformer/performance_benchmark_gpus2.png" width="600" />
<br/>
</p>
In the case of using 4 GPUs, the training times are as follows.
| N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: |
| 256 | 10.0ms | 21.1ms |
| 512 | 11.5ms | 20.2ms |
| 1024 | 22.1ms | 20.6ms |
| 2048 | 46.9ms | 24.8ms |
| 4096 | 160.4ms | 68.0ms |
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/shardformer/performance_benchmark_gpus4.png" width="600" />
<br/>
</p>
As shown in the figures above, when the sequence length is around 1000 or greater, the parallel optimization of Shardformer for long sequences starts to become evident.
### Convergence
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/shardformer_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.
| accuracy | f1 | loss | GPU number | model shard |
| :------: | :-----: | :-----: | :--------: | :---------: |
......
torchrun --standalone --nproc_per_node=4 shardformer_benchmark.py \
torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \
--model "bert" \
--pretrain "bert-base-uncased" \
--max_epochs 1 \
......
"""
Shardformer Benchmark
"""
import torch
import torch.distributed as dist
import transformers
import triton
import colossalai
from colossalai.shardformer import ShardConfig, ShardFormer
def data_gen(batch_size, seq_length):
input_ids = torch.randint(0, seq_length, (batch_size, seq_length), dtype=torch.long)
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def data_gen_for_sequence_classification(batch_size, seq_length):
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen(batch_size, seq_length)
data['labels'] = torch.ones((batch_size), dtype=torch.long)
return data
MODEL_CONFIG = transformers.LlamaConfig(num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128,
num_labels=16)
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64
model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG)
# vary seq length for fixed head and batch=4
configs = [
triton.testing.Benchmark(x_names=['N_CTX'],
x_vals=[2**i for i in range(8, 13)],
line_arg='provider',
line_vals=['org_model', 'shard_model'],
line_names=['org_model', 'shard_model'],
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'lama_for_sequence_classification-batch-{BATCH}',
args={
'BATCH': BATCH,
'dtype': torch.float16,
'model_func': model_func
})
]
def train(model, data):
output = model(**data)
loss = output.logits.mean()
loss.backward()
@triton.testing.perf_report(configs)
def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, device="cuda"):
warmup = 10
rep = 100
# prepare data
data = data_gen_for_sequence_classification(BATCH, N_CTX)
data = {k: v.cuda() for k, v in data.items()}
model = model_func().to(device)
model.train()
if provider == "org_model":
fn = lambda: train(model, data)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == "shard_model":
shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model = shard_former.optimize(model).cuda()
fn = lambda: train(sharded_model, data)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
# start benchmark, command:
# torchrun --standalone --nproc_per_node=2 performance_benchmark.py
if __name__ == "__main__":
colossalai.launch_from_torch({})
bench_shardformer.run(save_path='.', print_data=dist.get_rank() == 0)
import math
import warnings
from typing import Any, Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
......@@ -962,3 +963,138 @@ class BertPipelineForwards:
else:
hidden_states = outputs.get('hidden_states')
return {'hidden_states': hidden_states}
def get_bert_flash_attention_forward():
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.bert.modeling_bert import BertAttention
def forward(
self: BertAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
final_attention_mask = None
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
final_attention_mask = relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
final_attention_mask = relative_position_scores_query + relative_position_scores_key
scale = 1 / math.sqrt(self.attention_head_size)
if attention_mask is not None:
if final_attention_mask != None:
final_attention_mask = final_attention_mask * scale + attention_mask
else:
final_attention_mask = attention_mask
batch_size, src_len = query_layer.size()[0], query_layer.size()[2]
tgt_len = key_layer.size()[2]
final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len, tgt_len)
query_layer = query_layer.permute(0, 2, 1, 3).contiguous()
key_layer = key_layer.permute(0, 2, 1, 3).contiguous()
value_layer = value_layer.permute(0, 2, 1, 3).contiguous()
context_layer = me_attention(query_layer,
key_layer,
value_layer,
attn_bias=final_attention_mask,
p=self.dropout.p,
scale=scale)
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, None)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
return forward
def get_jit_fused_bert_self_output_forward():
from transformers.models.bert.modeling_bert import BertSelfOutput
def forward(self: BertSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
return forward
def get_jit_fused_bert_output_forward():
from transformers.models.bert.modeling_bert import BertOutput
def forward(self: BertOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
return forward
import math
from typing import Optional, Tuple, Union
import torch
......@@ -58,3 +59,62 @@ def forward_fn():
return outputs
return forward
def get_blip2_flash_attention_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
def forward(
self: Blip2Attention,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size()
mixed_qkv = self.qkv(hidden_states)
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4)
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
attention = ColoAttention(embed_dim=self.embed_dim,
num_heads=self.num_heads,
dropout=self.dropout.p,
scale=self.scale)
context_layer = attention(query_states, key_states, value_states)
output = self.projection(context_layer)
outputs = (output, None)
return outputs
return forward
def get_jit_fused_blip2_QFormer_self_output_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput
def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
return forward
def get_jit_fused_blip2_QFormer_output_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput
def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
return forward
......@@ -5,6 +5,7 @@ import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
......@@ -675,3 +676,223 @@ class BloomPipelineForwards:
else:
hidden_states = outputs.get('hidden_states')
return {'hidden_states': hidden_states}
def get_bloom_flash_attention_forward(enabel_jit_fused=False):
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.bloom.modeling_bloom import BloomAttention
def forward(
self: BloomAttention,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(hidden_states)
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, tgt_len, _ = hidden_states.size()
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
_, kv_length, _, _ = key_layer.size()
proj_shape = (batch_size, tgt_len, self.num_heads, self.head_dim)
query_layer = query_layer.contiguous().view(*proj_shape)
key_layer = key_layer.contiguous().view(*proj_shape)
value_layer = value_layer.contiguous().view(*proj_shape)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, head_dim, kv_length]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
if use_cache is True:
present = (key_layer, value_layer)
else:
present = None
tgt_len = key_layer.size()[1]
attention_numerical_mask = torch.zeros((batch_size, self.num_heads, tgt_len, kv_length),
dtype=torch.float32,
device=query_layer.device,
requires_grad=True)
attention_numerical_mask = attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1,
kv_length) * self.beta
attention_numerical_mask = torch.masked_fill(attention_numerical_mask, attention_mask,
torch.finfo(torch.float32).min)
context_layer = me_attention(query_layer,
key_layer,
value_layer,
attn_bias=attention_numerical_mask,
scale=self.inv_norm_factor,
p=self.attention_dropout.p)
context_layer = context_layer.reshape(-1, kv_length, self.hidden_size)
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices):int((i + 1) * slices)],
self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
# TODO to replace with the bias_dropout_add function in jit
output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
outputs = (output_tensor, present, None)
return outputs
return forward
def get_jit_fused_bloom_attention_forward():
from transformers.models.bloom.modeling_bloom import BloomAttention
def forward(
self: BloomAttention,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, q_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, head_dim, kv_length]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=2)
value_layer = torch.cat((past_value, value_layer), dim=1)
_, _, kv_length = key_layer.shape
if use_cache is True:
present = (key_layer, value_layer)
else:
present = None
# [batch_size * num_heads, q_length, kv_length]
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
matmul_result = alibi.baddbmm(
batch1=query_layer,
batch2=key_layer,
beta=self.beta,
alpha=self.inv_norm_factor,
)
# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16:
attention_scores = attention_scores.to(torch.float)
attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# change view [batch_size x num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
# change view [batch_size, num_heads, q_length, head_dim]
context_layer = self._merge_heads(context_layer)
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices):int((i + 1) * slices)],
self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
outputs = (output_tensor, present)
if output_attentions:
outputs += (attention_probs,)
return outputs
return forward
def get_jit_fused_bloom_mlp_forward():
from transformers.models.bloom.modeling_bloom import BloomMLP
def forward(self: BloomMLP, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
if self.pretraining_tp > 1 and self.slow_but_exact:
intermediate_output = torch.zeros_like(residual)
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
for i in range(self.pretraining_tp):
intermediate_output = intermediate_output + F.linear(
hidden_states[:, :, int(i * slices):int((i + 1) * slices)],
self.dense_4h_to_h.weight[:, int(i * slices):int((i + 1) * slices)],
)
else:
intermediate_output = self.dense_4h_to_h(hidden_states)
output = self.dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
return output
return forward
def get_jit_fused_bloom_gelu_forward():
from transformers.models.bloom.modeling_bloom import BloomGelu
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
def forward(self: BloomGelu, x: torch.Tensor) -> torch.Tensor:
bias = torch.zeros_like(x)
if self.training:
return JitGeLUFunction.apply(x, bias)
else:
return self.bloom_gelu_forward(x, bias)
return forward
......@@ -17,6 +17,116 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
)
def get_flash_core_attention_forward():
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
from .chatglm2_6b.modeling_chatglm import CoreAttention
def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask):
pytorch_major_version = int(torch.__version__.split(".")[0])
if pytorch_major_version >= 2:
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
key_layer,
value_layer,
is_causal=True)
else:
if attention_mask is not None:
attention_mask = ~attention_mask
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
attention_mask)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
else:
# Raw attention scores
query_layer = query_layer.permute(1, 0, 2, 3).contiguous()
key_layer = key_layer.permute(1, 0, 2, 3).contiguous()
value_layer = value_layer.permute(1, 0, 2, 3).contiguous()
scale = 1.0 / self.norm_factor
if self.coeff is not None:
scale = scale * self.coeff
flash_attention_mask = None
attn_mask_type = None
if attention_mask is None:
attn_mask_type = AttnMaskType.causal
else:
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(embed_dim=self.hidden_size_per_partition,
num_heads=self.num_attention_heads_per_partition,
dropout=self.attention_dropout.p,
scale=scale)
context_layer = attention(query_layer,
key_layer,
value_layer,
attn_mask=flash_attention_mask,
attn_mask_type=attn_mask_type)
context_layer = context_layer.permute(1, 0, -1).contiguous()
return context_layer
return forward
def get_jit_fused_glm_block_forward():
from .chatglm2_6b.modeling_chatglm import GLMBlock
def forward(
self: GLMBlock,
hidden_states,
attention_mask,
rotary_pos_emb,
kv_cache=None,
use_cache=True,
):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, kv_cache = self.self_attention(
layernorm_output,
attention_mask,
rotary_pos_emb,
kv_cache=kv_cache,
use_cache=use_cache,
)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
layernorm_input = self.dropout_add(attention_output, residual, self.hidden_dropout, self.training)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
output = self.dropout_add(mlp_output, residual, self.hidden_dropout, self.training)
return output, kv_cache
return forward
class ChatGLMPipelineForwards:
'''
This class serves as a micro library for ChatGLM model forwards under pipeline parallelism.
......
......@@ -668,3 +668,88 @@ class GPT2PipelineForwards:
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def get_gpt2_flash_attention_forward():
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
def split_heads(tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor
def forward(
self: GPT2Attention,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
_, tgt_len, _ = hidden_states.size()
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`.")
query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = split_heads(query, self.num_heads, self.head_dim)
key = split_heads(key, self.num_heads, self.head_dim)
value = split_heads(value, self.num_heads, self.head_dim)
if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=1)
value = torch.cat((past_value, value), dim=1)
if use_cache is True:
present = (key, value)
else:
present = None
if not self.is_cross_attention:
attn_mask_type = AttnMaskType.causal
flash_attention_mask = None
if attention_mask != None:
if attn_mask_type == AttnMaskType.causal:
attn_mask_type == AttnMaskType.paddedcausal
else:
attn_mask_type = AttnMaskType.padding
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
scale = value.size(-1)**-0.5
if self.scale_attn_by_inverse_layer_idx:
scale = scale * (1 / float(self.layer_idx + 1))
# use coloattention
attention = ColoAttention(embed_dim=self.embed_dim,
num_heads=self.num_heads,
dropout=self.attn_dropout.p,
scale=scale)
attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present, None)
return outputs
return forward
import torch
def get_dropout_add_func():
from transformers.models.bloom.modeling_bloom import dropout_add
def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
return dropout_add(x, residual, prob, training)
return self_dropout_add
def get_jit_fused_dropout_add_func():
from colossalai.kernel.jit import bias_dropout_add_fused_inference, bias_dropout_add_fused_train
def self_dropout_add(self, x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
bias = torch.zeros_like(x)
if training:
return bias_dropout_add_fused_train(x, bias, residual, prob)
return bias_dropout_add_fused_inference(x, bias, residual, prob)
return self_dropout_add
def get_jit_fused_gelu_forward_func():
from colossalai.kernel.jit.bias_gelu import bias_gelu
def bloom_gelu_forward(x: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
return bias_gelu(bias, x)
return bloom_gelu_forward
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Tuple
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
......@@ -386,3 +386,67 @@ class LlamaPipelineForwards:
else:
hidden_states = transformer_outputs.get('hidden_states')
return {'hidden_states': hidden_states}
def get_llama_flash_attention_forward():
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
def forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)
key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape)
value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape)
flash_attention_mask = None
attn_mask_type = AttnMaskType.causal
if attention_mask != None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention(query_states,
key_states,
value_states,
attn_mask=flash_attention_mask,
attn_mask_type=attn_mask_type)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
return forward
from typing import Optional, Tuple
import torch
from torch import nn
def get_opt_flash_attention_forward():
from transformers.models.opt.modeling_opt import OPTAttention
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
def forward(
self: OPTAttention,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
attention_input_shape = (bsz, -1, self.num_heads, self.head_dim)
# get query proj
query_states = self.q_proj(hidden_states).view(*attention_input_shape)
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k, v, cross_attentions
key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape)
value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape)
elif is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states).view(*attention_input_shape)
value_states = self.v_proj(key_value_states).view(*attention_input_shape)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self.k_proj(hidden_states).view(*attention_input_shape)
value_states = self.v_proj(hidden_states).view(*attention_input_shape)
key_states = torch.cat([past_key_value[0], key_states], dim=1)
value_states = torch.cat([past_key_value[1], value_states], dim=1)
else:
# self_attention
key_states = self.k_proj(hidden_states).view(*attention_input_shape)
value_states = self.v_proj(hidden_states).view(*attention_input_shape)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
src_len = key_states.size(1)
if layer_head_mask != None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}")
flash_attention_mask = None
attn_mask_type = AttnMaskType.causal
if attention_mask != None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}")
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(embed_dim=self.embed_dim,
num_heads=self.num_heads,
dropout=self.dropout,
scale=self.scaling)
attn_output = attention(query_states,
key_states,
value_states,
attn_mask=flash_attention_mask,
attn_mask_type=attn_mask_type)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
return forward
def get_jit_fused_opt_decoder_layer_forward():
from transformers.models.opt.modeling_opt import OPTDecoderLayer
def forward(
self: OPTDecoderLayer,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Fully Connected
hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training).view(hidden_states_shape)
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
return forward
import math
from typing import Tuple
import torch
import torch.nn.functional as F
from torch import Tensor
def forward_fn():
......@@ -37,3 +42,162 @@ def forward_fn():
return outputs
return forward
def get_sam_flash_attention_forward():
from transformers.models.sam.modeling_sam import SamAttention
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
def _separate_heads(hidden_states: Tensor, num_attention_heads: int) -> Tensor:
batch, point_batch_size, n_tokens, channel = hidden_states.shape
c_per_head = channel // num_attention_heads
hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
return hidden_states
def _recombine_heads(hidden_states: Tensor, point_batch_size: int) -> Tensor:
batch, n_tokens, n_heads, c_per_head = hidden_states.shape
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
def forward(self: SamAttention,
query: Tensor,
key: Tensor,
value: Tensor,
attention_similarity: Tensor = None) -> Tensor:
# Input projections
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)
point_batch_size = query.shape[1]
# Separate into heads
query = _separate_heads(query, self.num_attention_heads)
key = _separate_heads(key, self.num_attention_heads)
value = _separate_heads(value, self.num_attention_heads)
# SamAttention
_, _, _, c_per_head = query.shape
bias = None
if attention_similarity is not None:
bias = attention_similarity
scale = 1.0 / math.sqrt(c_per_head)
out = me_attention(query, key, value, attn_bias=bias, scale=scale)
out = _recombine_heads(out, point_batch_size)
out = self.out_proj(out)
return out
return forward
def get_sam_vision_flash_attention_forward():
from transformers.models.sam.modeling_sam import SamVisionAttention
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
def add_decomposed_rel_pos(
query: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Args:
attn (`torch.Tensor`):
attention map.
query (`torch.Tensor`):
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
rel_pos_h (`torch.Tensor`):
relative position embeddings (Lh, channel) for height axis.
rel_pos_w (`torch.Tensor`):
relative position embeddings (Lw, channel) for width axis.
q_size (tuple):
spatial sequence size of query q with (query_height, query_width).
k_size (tuple):
spatial sequence size of key k with (key_height, key_width).
Returns:
attn (`torch.Tensor`):
attention map with added relative positional embeddings.
"""
query_height, query_width = q_size
key_height, key_width = k_size
relative_position_height = get_rel_pos(query_height, key_height, rel_pos_h)
relative_position_width = get_rel_pos(query_width, key_width, rel_pos_w)
batch_size, _, nHead, dim = query.shape
reshaped_query = query.transpose(1, 2).reshape(batch_size * nHead, query_height, query_width, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
rel_pos = rel_pos.reshape(batch_size, nHead, query_height * query_width, key_height * key_width)
return rel_pos
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int):
size of the query.
k_size (int):
size of key k.
rel_pos (`torch.Tensor`):
relative position embeddings (L, channel).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def forward(self: SamVisionAttention, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
# qkv with shape (3, batch_size, nHead, height * width, channel)
qkv = (self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads,
-1).permute(2, 0, 1, 3, 4))
query, key, value = qkv.reshape(3, batch_size, height * width, self.num_attention_heads, -1).unbind(0)
rel_pos = None
if self.use_rel_pos:
rel_pos = add_decomposed_rel_pos(query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width))
attn_output = me_attention(query, key, value, attn_bias=rel_pos, p=self.dropout, scale=self.scale)
attn_output = attn_output.reshape(batch_size, height, width, -1)
attn_output = self.proj(attn_output)
outputs = (attn_output, None)
return outputs
return forward
......@@ -587,3 +587,209 @@ class T5PipelineForwards:
decoder_starting_stage=decoder_starting_stage)
return outputs
def get_t5_flash_attention_forward():
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.t5.modeling_t5 import T5Attention
def forward(
self: T5Attention,
hidden_states: torch.Tensor,
mask: Optional[torch.Tensor] = None,
key_value_states: Optional[torch.Tensor] = None,
position_bias: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
layer_head_mask: Optional[torch.Tensor] = None,
query_length: Optional[int] = None,
use_cache: bool = False,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2]
real_seq_length = seq_length
if past_key_value is not None:
if len(past_key_value) != 2:
raise ValueError(
f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
)
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
def unshape(states):
"""reshape"""
return states.view(batch_size, -1, self.inner_dim)
def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = torch.cat([past_key_value, hidden_states], dim=1)
elif past_key_value.shape[1] != key_value_states.shape[1]:
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
else:
# cross-attn
hidden_states = past_key_value
return hidden_states
# get query states
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
# get key/value states
key_states = project(hidden_states, self.k, key_value_states,
past_key_value[0] if past_key_value is not None else None)
value_states = project(hidden_states, self.v, key_value_states,
past_key_value[1] if past_key_value is not None else None)
if position_bias is None:
if not self.has_relative_attention_bias:
position_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length),
device=query_states.device,
dtype=query_states.dtype)
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(real_seq_length, key_length, device=query_states.device)
# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1):, :]
if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
if self.pruned_heads:
mask = torch.ones(position_bias.shape[1])
mask[list(self.pruned_heads)] = 0
position_bias_masked = position_bias[:, mask.bool()]
else:
position_bias_masked = position_bias
position_bias_masked = position_bias_masked.contiguous()
attn_output = me_attention(query_states,
key_states,
value_states,
attn_bias=position_bias_masked,
p=self.dropout,
scale=1.0)
attn_output = unshape(attn_output)
attn_output = self.o(attn_output)
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
return outputs
return forward
def get_jit_fused_T5_layer_ff_forward():
from transformers.models.t5.modeling_t5 import T5LayerFF
def forward(self: T5LayerFF, hidden_states: torch.Tensor) -> torch.Tensor:
forwarded_states = self.layer_norm(hidden_states)
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = self.dropout_add(forwarded_states, hidden_states, self.dropout.p, self.dropout.training)
return hidden_states
return forward
def get_T5_layer_self_attention_forward():
from transformers.models.t5.modeling_t5 import T5LayerSelfAttention
def forward(
self: T5LayerSelfAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_bias: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(
normed_hidden_states,
mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training)
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
return outputs
return forward
def get_T5_layer_cross_attention_forward():
from transformers.models.t5.modeling_t5 import T5LayerCrossAttention
def forward(
self: T5LayerCrossAttention,
hidden_states: torch.Tensor,
key_value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_bias: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
use_cache: bool = False,
query_length: Optional[int] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention(
normed_hidden_states,
mask=attention_mask,
key_value_states=key_value_states,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
query_length=query_length,
output_attentions=output_attentions,
)
layer_output = self.dropout_add(attention_output[0], hidden_states, self.dropout.p, self.dropout.training)
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
return outputs
return forward
import logging
import math
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
......@@ -335,3 +336,51 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag
)
return pp_forward
def get_vit_flash_self_attention_forward():
from transformers.models.vit.modeling_vit import ViTSelfAttention
from colossalai.kernel.cuda_native.flash_attention import ColoAttention
def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
x = x.view(new_x_shape)
return x
def forward(self: ViTSelfAttention,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = transpose_for_scores(self.key(hidden_states), self.num_attention_heads, self.attention_head_size)
value_layer = transpose_for_scores(self.value(hidden_states), self.num_attention_heads,
self.attention_head_size)
query_layer = transpose_for_scores(mixed_query_layer, self.num_attention_heads, self.attention_head_size)
scale = 1.0 / math.sqrt(self.attention_head_size)
attention = ColoAttention(embed_dim=self.all_head_size,
num_heads=self.num_attention_heads,
dropout=self.dropout.p,
scale=scale)
context_layer = attention(query_layer, key_layer, value_layer)
outputs = (context_layer,)
return outputs
return forward
def get_jit_fused_vit_output_forward():
from transformers.models.vit.modeling_vit import ViTOutput
def forward(self: ViTOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
return hidden_states
return forward
from typing import Optional, Tuple
import torch
from torch import nn
def get_whisper_flash_attention_forward():
from transformers.models.whisper.modeling_whisper import WhisperAttention
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()
def forward(
self: WhisperAttention,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (is_cross_attention and past_key_value is not None
and past_key_value[0].shape[1] == key_value_states.shape[1]):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
value_states = shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
key_states = torch.cat([past_key_value[0], key_states], dim=1)
value_states = torch.cat([past_key_value[1], value_states], dim=1)
else:
# self_attention
key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
# get query proj
query_states = shape(self.q_proj(hidden_states), tgt_len, bsz, self.num_heads, self.head_dim)
src_len = key_states.size(1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}")
attn_type = None
flash_attention_mask = None
if self.is_decoder:
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous())
attn_type = AttnMaskType.paddedcausal
attention = ColoAttention(embed_dim=self.embed_dim,
num_heads=self.num_heads,
dropout=self.dropout,
scale=self.scaling)
attn_output = attention(query_states,
key_states,
value_states,
attn_mask=flash_attention_mask,
attn_mask_type=attn_type)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
return forward
def get_jit_fused_whisper_encoder_layer_forward():
from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer
def forward(
self: WhisperEncoderLayer,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
layer_head_mask: torch.Tensor,
output_attentions: bool = False,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
if hidden_states.dtype == torch.float16 and (torch.isinf(hidden_states).any()
or torch.isnan(hidden_states).any()):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
return forward
def get_jit_fused_whisper_decoder_layer_forward():
from transformers.models.whisper.modeling_whisper import WhisperDecoderLayer
def forward(
self: WhisperDecoderLayer,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
encoder_hidden_states (`torch.FloatTensor`):
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(decoder_attention_heads,)`.
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
# Cross-Attention Block
cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
# add cross-attn to positions 3,4 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value
# Fully Connected
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_add(hidden_states, residual, self.dropout, self.training)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
if use_cache:
outputs += (present_key_value,)
return outputs
return forward
......@@ -7,7 +7,14 @@ from torch.nn import Module
import colossalai.shardformer.layer as col_nn
from ..modeling.bert import BertPipelineForwards
from .._utils import getattr_, setattr_
from ..modeling.bert import (
BertPipelineForwards,
get_bert_flash_attention_forward,
get_jit_fused_bert_output_forward,
get_jit_fused_bert_self_output_forward,
)
from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
......@@ -37,7 +44,13 @@ class BertPolicy(Policy):
return self.model
def module_policy(self):
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer
from transformers.models.bert.modeling_bert import (
BertEmbeddings,
BertLayer,
BertOutput,
BertSelfAttention,
BertSelfOutput,
)
policy = {}
......@@ -126,6 +139,23 @@ class BertPolicy(Policy):
policy=policy,
target_key=BertEmbeddings)
# use flash attention
if self.shard_config.enable_flash_attention:
policy[BertSelfAttention] = ModulePolicyDescription(method_replacement={
'forward': get_bert_flash_attention_forward(),
})
# use jit operator
if self.shard_config.enable_jit_fused:
policy[BertSelfOutput] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_bert_self_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[BertOutput] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_bert_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy
def add_lm_head_policy(self, base_policy):
......
......@@ -3,7 +3,13 @@ import torch.nn as nn
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from ..modeling.blip2 import forward_fn
from ..modeling.blip2 import (
forward_fn,
get_blip2_flash_attention_forward,
get_jit_fused_blip2_QFormer_output_forward,
get_jit_fused_blip2_QFormer_self_output_forward,
)
from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['BlipPolicy', 'BlipModelPolicy']
......@@ -33,6 +39,8 @@ class BlipPolicy(Policy):
Blip2EncoderLayer,
Blip2QFormerLayer,
Blip2QFormerModel,
Blip2QFormerOutput,
Blip2QFormerSelfOutput,
Blip2VisionModel,
)
from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTForCausalLM
......@@ -275,6 +283,24 @@ class BlipPolicy(Policy):
policy=policy,
target_key=OPTDecoderLayer)
# use flash attention
if self.shard_config.enable_flash_attention:
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={
'forward': get_blip2_flash_attention_forward(),
})
# use jit operator
if self.shard_config.enable_jit_fused:
policy[Blip2QFormerSelfOutput] = ModulePolicyDescription(
method_replacement={
'forward': get_jit_fused_blip2_QFormer_self_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[Blip2QFormerOutput] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_blip2_QFormer_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy
def postprocess(self):
......
......@@ -7,7 +7,16 @@ from torch.nn import Module
import colossalai.shardformer.layer as col_nn
from ..modeling.bloom import BloomPipelineForwards, build_bloom_alibi_tensor_fn
from .._utils import getattr_, setattr_
from ..modeling.bloom import (
BloomPipelineForwards,
build_bloom_alibi_tensor_fn,
get_bloom_flash_attention_forward,
get_jit_fused_bloom_attention_forward,
get_jit_fused_bloom_gelu_forward,
get_jit_fused_bloom_mlp_forward,
)
from ..modeling.jit import get_dropout_add_func, get_jit_fused_dropout_add_func, get_jit_fused_gelu_forward_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
......@@ -30,7 +39,7 @@ class BloomPolicy(Policy):
return self.model
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomGelu, BloomMLP, BloomModel
policy = {}
......@@ -107,6 +116,27 @@ class BloomPolicy(Policy):
policy=policy,
target_key=BloomBlock)
if self.shard_config.enable_flash_attention:
policy[BloomAttention] = ModulePolicyDescription(method_replacement={
'forward': get_bloom_flash_attention_forward(),
'dropout_add': get_dropout_add_func()
})
# enable jit fused operator
if self.shard_config.enable_jit_fused:
policy[BloomAttention] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_bloom_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[BloomMLP] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_bloom_mlp_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[BloomGelu] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_bloom_gelu_forward(),
'bloom_gelu_forward': get_jit_fused_gelu_forward_func(),
})
return policy
def postprocess(self):
......
......@@ -15,6 +15,8 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
GLMBlock,
)
from ..modeling.chatglm import get_flash_core_attention_forward, get_jit_fused_glm_block_forward
from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['ChatGLMPolicy', 'ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy']
......@@ -35,12 +37,11 @@ class ChatGLMPolicy(Policy):
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock
policy = {}
......@@ -121,6 +122,19 @@ class ChatGLMPolicy(Policy):
policy=policy,
target_key=ChatGLMModel)
# use flash attention
if self.shard_config.enable_flash_attention:
policy[CoreAttention] = ModulePolicyDescription(method_replacement={
'forward': get_flash_core_attention_forward(),
})
# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[GLMBlock] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_glm_block_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy
def postprocess(self):
......@@ -192,7 +206,6 @@ class ChatGLMModelPolicy(ChatGLMPolicy):
return []
class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
def module_policy(self):
......@@ -213,4 +226,3 @@ class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in ChatGLMForConditionalGenerationModel."""
return []
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment