"tests/vscode:/vscode.git/clone" did not exist on "1f5d2e80621fbf3d71ade0e93d7ebbb0899e4805"
Commit 906426cb authored by flybird1111's avatar flybird1111 Committed by Hongxin Liu
Browse files

[Shardformer] Merge flash attention branch to pipeline branch (#4362)



* [shardformer] supported flash attention test dependency (#4158)

* [shardformer] fix flash attention utils test (#4180)

* [shardformer] opt support flash attention (#4163)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] add performance benchmark of shardformer (#4175)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] benchmark fix

* [shardformer] benchmark fix

* [shardformer] llama support flash attention (#4185)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] llama support flash attention

* [shardformer] llama support flash attention

* [shardformer] Move the import statement for xformer outside the forward function.

* [shardformer] gpt2 support flash attention. (#4191)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] gpt2 support flash attention

* [shardformer] gpt2 support flash attention

* [shardformer] gpt2 support flash attention

* [shardformer] bloom support flash attention (#4188)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] bloom suport flash attention

* [shardformer] add assert to sequence length

* [shardformer] fix

* [shardformer] fix

* [shardformer] fix

* [shardformer] bert support flash attention. (#4206)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] bert support flash attention

* [shardformer] t5 support flash attention. (#4216)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] t5 support flash attention

* [shardformer] t5 support flash attention

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* [shardformer] support 'paddedcausal'  type of attention mask in Coloattention. (#4215)

* added padded causal attn mask type for ColoAttention

* [shardformer]t5 flash attention fix (#4239)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] t5 flash attention fix

* [shardformer] update gpt2 to use coloattention. (#4234)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2

* [shardformer] update opt and llama to use coloattention. (#4226)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt

* [shardformer] shardformer support jit fused operator. (#4236)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] bloom support jit fused operator

* [shardformer] bloom support jit fused operator

* [shardformer] bloom support jit fused operator

* [shardformer] t5 support jit fused operator

* [shardformer] t5 support jit fused operator

* [shardformer] t5 support jit fused operator

* [shardformer] add roadmap of flash attention

* [shardformer] add roadmap of flash attention

* [shardformer] add roadmap of flash attention

* [shardformer] add type hint to 'self' param of forward

* [shardformer] merge feature/shardformer-models branch to feature/flash-attention-shardformer branch. (#4290)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>

* [shardformer] whisper support flash attention (#4301)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] whisper support flash attention

* [shardformer] whisper support flash attention

* [shardformer]whisper support jit operator

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>

* [shardformer] sam support flash attention (#4316)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] sam support flash attention

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>

* [shardformer] merge blip2/chatglm  (#4321)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: default avatarklhhhhh <1412841649@qq.com>

* [shardformer] blip2 support flash attention and jit operator (#4325)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

* [shardformer] blip2 support flash attention and jit operator

* [shardformer] blip2 support flash attention and jit operator

* [shardformer] blip2 support flash attention and jit operator

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: default avatarklhhhhh <1412841649@qq.com>

* [shardformer] chatglm support flash attention and jit operator (#4330)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

* [shardformer] chatglm support flash attention and jit operator

* [shardformer] chatglm support flash attention and jit operator

* [shardformer] chatglm support flash attention and jit operator

* [shardformer] chatglm support flash attention and jit operator

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: default avatarklhhhhh <1412841649@qq.com>

* [shardformer] vit support flash attention and jit operator (#4334)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

* [shardformer] vit support flash attention and jit operator

* [shardformer] vit support flash attention and jit operator

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: default avatarklhhhhh <1412841649@qq.com>

* [pipeline] merge flash attention branch

* [pipeline] merge flash attention branch

* [pipeline] merge flash attention branch

* [pipeline] fix conflict

* [pipeline] fix conflict

* Merge branch 'feature/pipeline' into feature/pipeline

* Merge branch 'feature/pipeline' into feature/pipeline

* Merge branch 'feature/pipeline' into feature/pipeline

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* fix flash attention tests

* gemini ignore whisper

* fix vit

* fix xformers import handle

---------
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: default avatarklhhhhh <1412841649@qq.com>
parent a88e9225
...@@ -5,7 +5,8 @@ from torch import Tensor, nn ...@@ -5,7 +5,8 @@ from torch import Tensor, nn
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from ..modeling.gpt2 import GPT2PipelineForwards from .._utils import getattr_, setattr_
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [ __all__ = [
...@@ -33,7 +34,7 @@ class GPT2Policy(Policy): ...@@ -33,7 +34,7 @@ class GPT2Policy(Policy):
return self.model return self.model
def module_policy(self): def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
policy = {} policy = {}
...@@ -53,42 +54,42 @@ class GPT2Policy(Policy): ...@@ -53,42 +54,42 @@ class GPT2Policy(Policy):
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
}, },
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.c_attn", suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col, target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={ kwargs={
"n_fused": 3, "n_fused": 3,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.c_proj", suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row, target_module=col_nn.GPT2FusedLinearConv1D_Row,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.c_fc", suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col, target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={ kwargs={
"n_fused": 1, "n_fused": 1,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.c_proj", suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row, target_module=col_nn.GPT2FusedLinearConv1D_Row,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.attn_dropout", suffix="attn.attn_dropout",
target_module=col_nn.DropoutForParallelInput, target_module=col_nn.DropoutForParallelInput,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.resid_dropout", suffix="attn.resid_dropout",
target_module=col_nn.DropoutForParallelInput, target_module=col_nn.DropoutForParallelInput,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.dropout", suffix="mlp.dropout",
target_module=col_nn.DropoutForParallelInput, target_module=col_nn.DropoutForParallelInput,
), ),
]) ])
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
...@@ -96,8 +97,8 @@ class GPT2Policy(Policy): ...@@ -96,8 +97,8 @@ class GPT2Policy(Policy):
suffix="ln_f", suffix="ln_f",
target_module=col_nn.FusedLayerNorm, target_module=col_nn.FusedLayerNorm,
), ),
policy=policy, policy=policy,
target_key=GPT2Model) target_key=GPT2Model)
self.append_or_create_submodule_replacement(description=[ self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
...@@ -112,8 +113,13 @@ class GPT2Policy(Policy): ...@@ -112,8 +113,13 @@ class GPT2Policy(Policy):
target_module=col_nn.FusedLayerNorm, target_module=col_nn.FusedLayerNorm,
ignore_if_not_exist=True) ignore_if_not_exist=True)
], ],
policy=policy, policy=policy,
target_key=GPT2Block) target_key=GPT2Block)
if self.shard_config.enable_flash_attention:
policy[GPT2Attention] = ModulePolicyDescription(method_replacement={
'forward': get_gpt2_flash_attention_forward(),
})
return policy return policy
def postprocess(self): def postprocess(self):
......
...@@ -7,7 +7,7 @@ from torch.nn import Module ...@@ -7,7 +7,7 @@ from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from ..modeling.llama import LlamaPipelineForwards from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] __all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
...@@ -31,7 +31,7 @@ class LlamaPolicy(Policy): ...@@ -31,7 +31,7 @@ class LlamaPolicy(Policy):
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
policy = {} policy = {}
...@@ -104,6 +104,11 @@ class LlamaPolicy(Policy): ...@@ -104,6 +104,11 @@ class LlamaPolicy(Policy):
policy=policy, policy=policy,
target_key=LlamaModel) target_key=LlamaModel)
if self.shard_config.enable_flash_attention:
policy[LlamaAttention] = ModulePolicyDescription(method_replacement={
'forward': get_llama_flash_attention_forward(),
})
return policy return policy
def postprocess(self): def postprocess(self):
......
...@@ -25,6 +25,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager ...@@ -25,6 +25,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .._utils import getattr_, setattr_ from .._utils import getattr_, setattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.opt import get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [ __all__ = [
...@@ -114,6 +116,19 @@ class OPTPolicy(Policy): ...@@ -114,6 +116,19 @@ class OPTPolicy(Policy):
policy=policy, policy=policy,
target_key=OPTDecoderLayer) target_key=OPTDecoderLayer)
# use flash attention
if self.shard_config.enable_flash_attention:
policy[OPTAttention] = ModulePolicyDescription(method_replacement={
'forward': get_opt_flash_attention_forward(),
})
# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_opt_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy return policy
def postprocess(self): def postprocess(self):
...@@ -189,13 +204,11 @@ class OPTForCausalLMPolicy(OPTPolicy): ...@@ -189,13 +204,11 @@ class OPTForCausalLMPolicy(OPTPolicy):
from transformers.models.opt.modeling_opt import OPTForCausalLM from transformers.models.opt.modeling_opt import OPTForCausalLM
policy = super().module_policy() policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
policy=policy, policy=policy,
target_key=OPTForCausalLM) target_key=OPTForCausalLM)
if self.pipeline_stage_manager: if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=OPTForCausalLM, self.set_pipeline_forward(model_cls=OPTForCausalLM,
new_forward=OPTPipelineForwards.opt_for_causal_lm_forward, new_forward=OPTPipelineForwards.opt_for_causal_lm_forward,
......
...@@ -3,7 +3,7 @@ import torch.nn as nn ...@@ -3,7 +3,7 @@ import torch.nn as nn
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_ from .._utils import getattr_, setattr_
from ..modeling.sam import forward_fn from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['SamPolicy', 'SamModelPolicy'] __all__ = ['SamPolicy', 'SamModelPolicy']
...@@ -19,6 +19,7 @@ class SamPolicy(Policy): ...@@ -19,6 +19,7 @@ class SamPolicy(Policy):
def module_policy(self): def module_policy(self):
from transformers.models.sam.modeling_sam import ( from transformers.models.sam.modeling_sam import (
SamAttention,
SamFeedForward, SamFeedForward,
SamTwoWayAttentionBlock, SamTwoWayAttentionBlock,
SamTwoWayTransformer, SamTwoWayTransformer,
...@@ -196,6 +197,15 @@ class SamPolicy(Policy): ...@@ -196,6 +197,15 @@ class SamPolicy(Policy):
policy=policy, policy=policy,
target_key=SamTwoWayTransformer) target_key=SamTwoWayTransformer)
# use flash attention
if self.shard_config.enable_flash_attention:
policy[SamAttention] = ModulePolicyDescription(method_replacement={
'forward': get_sam_flash_attention_forward(),
})
policy[SamVisionAttention] = ModulePolicyDescription(method_replacement={
'forward': get_sam_vision_flash_attention_forward(),
})
return policy return policy
def postprocess(self): def postprocess(self):
......
...@@ -14,7 +14,14 @@ from colossalai.shardformer.layer import ( ...@@ -14,7 +14,14 @@ from colossalai.shardformer.layer import (
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
from .._utils import getattr_, setattr_ from .._utils import getattr_, setattr_
from ..modeling.t5 import T5PipelineForwards from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.t5 import (
T5PipelineForwards,
get_jit_fused_T5_layer_ff_forward,
get_t5_flash_attention_forward,
get_T5_layer_cross_attention_forward,
get_T5_layer_self_attention_forward,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] __all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
...@@ -168,6 +175,27 @@ class T5BasePolicy(Policy): ...@@ -168,6 +175,27 @@ class T5BasePolicy(Policy):
suffix="final_layer_norm", target_module=FusedRMSNorm), suffix="final_layer_norm", target_module=FusedRMSNorm),
policy=policy, policy=policy,
target_key=T5Stack) target_key=T5Stack)
# use flash attention
if self.shard_config.enable_flash_attention:
policy[T5Attention] = ModulePolicyDescription(method_replacement={
'forward': get_t5_flash_attention_forward(),
})
# use jit operator
if self.shard_config.enable_jit_fused:
policy[T5LayerFF] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_T5_layer_ff_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[T5LayerSelfAttention] = ModulePolicyDescription(method_replacement={
'forward': get_T5_layer_self_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[T5LayerCrossAttention] = ModulePolicyDescription(method_replacement={
'forward': get_T5_layer_cross_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy return policy
def postprocess(self): def postprocess(self):
......
...@@ -3,11 +3,15 @@ from typing import Callable, Dict, List, Union ...@@ -3,11 +3,15 @@ from typing import Callable, Dict, List, Union
import torch.nn as nn import torch.nn as nn
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.layer import DropoutForReplicatedInput, Linear1D_Col
from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.vit import ( from ..modeling.vit import (
ViTForImageClassification_pipeline_forward, ViTForImageClassification_pipeline_forward,
ViTForMaskedImageModeling_pipeline_forward, ViTForMaskedImageModeling_pipeline_forward,
ViTModel_pipeline_forward, ViTModel_pipeline_forward,
get_jit_fused_vit_output_forward,
get_vit_flash_self_attention_forward,
) )
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
...@@ -23,7 +27,8 @@ class ViTPolicy(Policy): ...@@ -23,7 +27,8 @@ class ViTPolicy(Policy):
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel, ViTOutput, ViTSelfAttention
policy = {} policy = {}
...@@ -33,7 +38,7 @@ class ViTPolicy(Policy): ...@@ -33,7 +38,7 @@ class ViTPolicy(Policy):
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="dropout", suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput, target_module=DropoutForReplicatedInput,
) )
]) ])
...@@ -83,8 +88,18 @@ class ViTPolicy(Policy): ...@@ -83,8 +88,18 @@ class ViTPolicy(Policy):
), ),
]) ])
return policy # use flash attention
if self.shard_config.enable_flash_attention:
policy[ViTSelfAttention] = ModulePolicyDescription(method_replacement={
'forward': get_vit_flash_self_attention_forward(),
})
# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[ViTOutput] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_vit_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy return policy
def new_model_class(self): def new_model_class(self):
...@@ -167,7 +182,7 @@ class ViTForImageClassificationPolicy(ViTPolicy): ...@@ -167,7 +182,7 @@ class ViTForImageClassificationPolicy(ViTPolicy):
ViTForImageClassification: ViTForImageClassification:
ModulePolicyDescription(sub_module_replacement=[ ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)) suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
]) ])
} }
policy.update(new_item) policy.update(new_item)
......
...@@ -3,6 +3,12 @@ import torch.nn as nn ...@@ -3,6 +3,12 @@ import torch.nn as nn
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_ from .._utils import getattr_, setattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.whisper import (
get_jit_fused_whisper_decoder_layer_forward,
get_jit_fused_whisper_encoder_layer_forward,
get_whisper_flash_attention_forward,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [ __all__ = [
...@@ -30,6 +36,7 @@ class WhisperPolicy(Policy): ...@@ -30,6 +36,7 @@ class WhisperPolicy(Policy):
def module_policy(self): def module_policy(self):
from transformers.models.whisper.modeling_whisper import ( from transformers.models.whisper.modeling_whisper import (
WhisperAttention,
WhisperDecoder, WhisperDecoder,
WhisperDecoderLayer, WhisperDecoderLayer,
WhisperEncoder, WhisperEncoder,
...@@ -181,6 +188,24 @@ class WhisperPolicy(Policy): ...@@ -181,6 +188,24 @@ class WhisperPolicy(Policy):
], ],
policy=policy, policy=policy,
target_key=WhisperDecoder) target_key=WhisperDecoder)
# enable flash attention
if self.shard_config.enable_flash_attention:
policy[WhisperAttention] = ModulePolicyDescription(method_replacement={
'forward': get_whisper_flash_attention_forward(),
})
# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[WhisperEncoderLayer] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_whisper_encoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[WhisperDecoderLayer] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_whisper_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy return policy
def add_lm_head_policy(self, base_policy): def add_lm_head_policy(self, base_policy):
......
...@@ -26,6 +26,8 @@ class ShardConfig: ...@@ -26,6 +26,8 @@ class ShardConfig:
enable_tensor_parallelism: bool = True enable_tensor_parallelism: bool = True
enable_fused_normalization: bool = False enable_fused_normalization: bool = False
enable_all_optimization: bool = False enable_all_optimization: bool = False
enable_flash_attention: bool = False
enable_jit_fused: bool = False
# TODO: add support for tensor parallel # TODO: add support for tensor parallel
# pipeline_parallel_size: int # pipeline_parallel_size: int
...@@ -44,7 +46,6 @@ class ShardConfig: ...@@ -44,7 +46,6 @@ class ShardConfig:
else: else:
# get the parallel size # get the parallel size
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
# turn on all optimization if all_optimization is set to True # turn on all optimization if all_optimization is set to True
if self.enable_all_optimization: if self.enable_all_optimization:
self._turn_on_all_optimization() self._turn_on_all_optimization()
...@@ -55,3 +56,5 @@ class ShardConfig: ...@@ -55,3 +56,5 @@ class ShardConfig:
""" """
# you can add all the optimization flag here # you can add all the optimization flag here
self.enable_fused_normalization = True self.enable_fused_normalization = True
self.enable_flash_attention = True
self.enable_jit_fused = True
...@@ -18,3 +18,5 @@ SentencePiece ...@@ -18,3 +18,5 @@ SentencePiece
ninja ninja
flash_attn>=2.0 flash_attn>=2.0
datasets datasets
ninja
flash-attn
...@@ -20,7 +20,7 @@ def data_gen(): ...@@ -20,7 +20,7 @@ def data_gen():
# token_type_ids = tokenized_input['token_type_ids'] # token_type_ids = tokenized_input['token_type_ids']
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64) input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64)
token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 0]], dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
...@@ -69,19 +69,21 @@ def data_gen_for_mcq(): ...@@ -69,19 +69,21 @@ def data_gen_for_mcq():
# data['labels'] = torch.tensor([0], dtype=torch.int64) # data['labels'] = torch.tensor([0], dtype=torch.int64)
input_ids = torch.tensor([[[ input_ids = torch.tensor([[[
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591, 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591,
4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102
], ],
[ [
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037,
4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096, 4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096,
2218, 1999, 1996, 2192, 1012, 102, 0 2218, 1999, 1996, 2192, 1012, 102, 0, 0
]]]) ]]])
token_type_ids = torch.tensor( token_type_ids = torch.tensor(
[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]]) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0]]])
attention_mask = torch.tensor( attention_mask = torch.tensor(
[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]]) [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0]]])
labels = torch.tensor([0], dtype=torch.int64) labels = torch.tensor([0], dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels) return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
......
...@@ -38,6 +38,7 @@ output_transform_fn = lambda x: x ...@@ -38,6 +38,7 @@ output_transform_fn = lambda x: x
loss_fn_blip2_model = lambda x: x.loss loss_fn_blip2_model = lambda x: x.loss
config = transformers.Blip2Config() config = transformers.Blip2Config()
config.vision_config.patch_size = 14
config.text_config.num_hidden_layers = 1 config.text_config.num_hidden_layers = 1
config.qformer_config.num_hidden_layers = 1 config.qformer_config.num_hidden_layers = 1
config.vision_config.num_hidden_layers = 1 config.vision_config.num_hidden_layers = 1
......
...@@ -16,8 +16,8 @@ def data_gen(): ...@@ -16,8 +16,8 @@ def data_gen():
# tokenized_input = tokenizer(input, return_tensors='pt') # tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids'] # input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask'] # attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595]], dtype=torch.int64) input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595, 632, 207595]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64) attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, attention_mask=attention_mask)
...@@ -33,7 +33,7 @@ def data_gen_for_token_classification(): ...@@ -33,7 +33,7 @@ def data_gen_for_token_classification():
# token classification data gen # token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1 # `labels` is the type not the token id for token classification, 0 or 1
data = data_gen() data = data_gen()
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64) data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
return data return data
...@@ -53,8 +53,8 @@ def data_gen_for_question_answering(): ...@@ -53,8 +53,8 @@ def data_gen_for_question_answering():
# inputs = tokenizer(question, text, return_tensors="pt") # inputs = tokenizer(question, text, return_tensors="pt")
input_ids = torch.tensor( input_ids = torch.tensor(
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161]], dtype=torch.int64) [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
start_positions = torch.tensor([1], dtype=torch.int64) start_positions = torch.tensor([1], dtype=torch.int64)
end_positions = torch.tensor([10], dtype=torch.int64) end_positions = torch.tensor([10], dtype=torch.int64)
return dict(input_ids=input_ids, return dict(input_ids=input_ids,
......
...@@ -6,7 +6,6 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM ...@@ -6,7 +6,6 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM
from ..registry import ModelAttribute, model_zoo from ..registry import ModelAttribute, model_zoo
# ================================ # ================================
# Register single-sentence ChatGLM # Register single-sentence ChatGLM
# ================================ # ================================
......
from transformers import PretrainedConfig
class ChatGLMConfig(PretrainedConfig):
model_type = "chatglm"
def __init__(self,
num_layers=28,
padded_vocab_size=65024,
hidden_size=4096,
ffn_hidden_size=13696,
kv_channels=128,
num_attention_heads=32,
seq_length=2048,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
rmsnorm=True,
apply_residual_connection_post_layernorm=False,
post_layer_norm=True,
add_bias_linear=False,
add_qkv_bias=False,
bias_dropout_fusion=True,
multi_query_attention=False,
multi_query_group_num=1,
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=True,
fp32_residual_connection=False,
quantization_bit=0,
pre_seq_len=None,
prefix_projection=False,
**kwargs):
self.num_layers = num_layers
self.vocab_size = padded_vocab_size
self.padded_vocab_size = padded_vocab_size
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads
self.seq_length = seq_length
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon
self.rmsnorm = rmsnorm
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.post_layer_norm = post_layer_norm
self.add_bias_linear = add_bias_linear
self.add_qkv_bias = add_qkv_bias
self.bias_dropout_fusion = bias_dropout_fusion
self.multi_query_attention = multi_query_attention
self.multi_query_group_num = multi_query_group_num
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.fp32_residual_connection = fp32_residual_connection
self.quantization_bit = quantization_bit
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
super().__init__(**kwargs)
...@@ -18,8 +18,8 @@ def data_gen(): ...@@ -18,8 +18,8 @@ def data_gen():
# tokenized_input = tokenizer(input, return_tensors='pt') # tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids'] # input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask'] # attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779]], dtype=torch.int64) input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64) attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, attention_mask=attention_mask)
...@@ -46,7 +46,7 @@ def data_gen_for_token_classification(): ...@@ -46,7 +46,7 @@ def data_gen_for_token_classification():
# token classification data gen # token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1 # `labels` is the type not the token id for token classification, 0 or 1
data = data_gen() data = data_gen()
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 1]], dtype=torch.int64) data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64)
return data return data
......
...@@ -16,8 +16,9 @@ def data_gen_for_encoder_only(): ...@@ -16,8 +16,9 @@ def data_gen_for_encoder_only():
# config = T5Config(decoder_start_token_id=0) # config = T5Config(decoder_start_token_id=0)
# tokenizer = T5Tokenizer.from_pretrained("t5-small") # tokenizer = T5Tokenizer.from_pretrained("t5-small")
# input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids # input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1]]).long() input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12]]).long()
return dict(input_ids=input_ids) attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long()
return dict(input_ids=input_ids, attention_mask=attention_mask)
def data_gen_for_conditional_generation(): def data_gen_for_conditional_generation():
...@@ -25,17 +26,16 @@ def data_gen_for_conditional_generation(): ...@@ -25,17 +26,16 @@ def data_gen_for_conditional_generation():
# #
# labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids # labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
data = data_gen_for_encoder_only() data = data_gen_for_encoder_only()
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1]]).long() labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1]]).long()
data['labels'] = labels data['labels'] = labels
return data return data
def data_gen_for_t5_model(): def data_gen_for_t5_model():
# decoder_inputs_ids is obtained with the following code # decoder_inputs_ids is obtained with the following code
#
# decoder_input_ids = model._shift_right(input_ids) # decoder_input_ids = model._shift_right(input_ids)
data = data_gen_for_encoder_only() data = data_gen_for_encoder_only()
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5]]).long() decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5]]).long()
data['decoder_input_ids'] = decoder_input_ids data['decoder_input_ids'] = decoder_input_ids
return data return data
......
...@@ -76,14 +76,14 @@ model_zoo.register(name='transformers_whisper', ...@@ -76,14 +76,14 @@ model_zoo.register(name='transformers_whisper',
loss_fn=loss_fn, loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True)) model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_whisperForConditionalGeneration', model_zoo.register(name='transformers_whisper_for_conditional_generation',
model_fn=lambda: transformers.WhisperForConditionalGeneration(config), model_fn=lambda: transformers.WhisperForConditionalGeneration(config),
data_gen_fn=data_gen_for_conditional_generation, data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
loss_fn=loss_fn_attr, loss_fn=loss_fn_attr,
model_attribute=ModelAttribute(has_control_flow=True)) model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_whisperWhisperForAudioClassification', model_zoo.register(name='transformers_whisper_for_audio_classification',
model_fn=lambda: transformers.WhisperForAudioClassification(config), model_fn=lambda: transformers.WhisperForAudioClassification(config),
data_gen_fn=data_gen_for_audio_classification, data_gen_fn=data_gen_for_audio_classification,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
......
...@@ -93,7 +93,7 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): ...@@ -93,7 +93,7 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
'transformers_vit_for_image_classification', 'transformers_chatglm', 'transformers_vit_for_image_classification', 'transformers_chatglm',
'transformers_chatglm_for_conditional_generation', 'transformers_blip2', 'transformers_chatglm_for_conditional_generation', 'transformers_blip2',
'transformers_blip2_conditional_gerneration', 'transformers_sam', 'transformers_whisper', 'transformers_blip2_conditional_gerneration', 'transformers_sam', 'transformers_whisper',
'transformers_whisperForConditionalGeneration', 'transformers_whisperWhisperForAudioClassification' 'transformers_whisper_for_conditional_generation', 'transformers_whisper_for_audio_classification'
]: ]:
continue continue
......
...@@ -21,7 +21,13 @@ from colossalai.shardformer._utils import getattr_ ...@@ -21,7 +21,13 @@ from colossalai.shardformer._utils import getattr_
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False): def build_model(model_fn,
enable_fused_normalization=True,
enable_tensor_parallelism=True,
enable_flash_attention=False,
enable_jit_fused=False,
use_lazy_init: bool = False):
# create new model
ctx = LazyInitContext() if use_lazy_init else nullcontext() ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx: with ctx:
# create new model # create new model
...@@ -31,7 +37,10 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle ...@@ -31,7 +37,10 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle
ctx.materialize(org_model) ctx.materialize(org_model)
# shard model # shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism) enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config) shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy) sharded_model, shared_params = shard_former.optimize(model_copy)
return org_model.cuda(), sharded_model.cuda() return org_model.cuda(), sharded_model.cuda()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment