Commit 906426cb authored by flybird1111's avatar flybird1111 Committed by Hongxin Liu
Browse files

[Shardformer] Merge flash attention branch to pipeline branch (#4362)



* [shardformer] supported flash attention test dependency (#4158)

* [shardformer] fix flash attention utils test (#4180)

* [shardformer] opt support flash attention (#4163)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] add performance benchmark of shardformer (#4175)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] benchmark fix

* [shardformer] benchmark fix

* [shardformer] llama support flash attention (#4185)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] llama support flash attention

* [shardformer] llama support flash attention

* [shardformer] Move the import statement for xformer outside the forward function.

* [shardformer] gpt2 support flash attention. (#4191)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] gpt2 support flash attention

* [shardformer] gpt2 support flash attention

* [shardformer] gpt2 support flash attention

* [shardformer] bloom support flash attention (#4188)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] bloom suport flash attention

* [shardformer] add assert to sequence length

* [shardformer] fix

* [shardformer] fix

* [shardformer] fix

* [shardformer] bert support flash attention. (#4206)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] bert support flash attention

* [shardformer] t5 support flash attention. (#4216)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] t5 support flash attention

* [shardformer] t5 support flash attention

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* [shardformer] support 'paddedcausal'  type of attention mask in Coloattention. (#4215)

* added padded causal attn mask type for ColoAttention

* [shardformer]t5 flash attention fix (#4239)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] t5 flash attention fix

* [shardformer] update gpt2 to use coloattention. (#4234)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2 to use coloattention

* [shardformer] update gpt2

* [shardformer] update opt and llama to use coloattention. (#4226)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt to use coloattention

* [shardformer]update opt

* [shardformer] shardformer support jit fused operator. (#4236)

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] opt support flash attention

* [shardformer] move to modeling

* [shardformer] move to modeling

* [shardformer] bloom support jit fused operator

* [shardformer] bloom support jit fused operator

* [shardformer] bloom support jit fused operator

* [shardformer] t5 support jit fused operator

* [shardformer] t5 support jit fused operator

* [shardformer] t5 support jit fused operator

* [shardformer] add roadmap of flash attention

* [shardformer] add roadmap of flash attention

* [shardformer] add roadmap of flash attention

* [shardformer] add type hint to 'self' param of forward

* [shardformer] merge feature/shardformer-models branch to feature/flash-attention-shardformer branch. (#4290)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>

* [shardformer] whisper support flash attention (#4301)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] whisper support flash attention

* [shardformer] whisper support flash attention

* [shardformer]whisper support jit operator

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>

* [shardformer] sam support flash attention (#4316)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] sam support flash attention

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>

* [shardformer] merge blip2/chatglm  (#4321)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: default avatarklhhhhh <1412841649@qq.com>

* [shardformer] blip2 support flash attention and jit operator (#4325)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

* [shardformer] blip2 support flash attention and jit operator

* [shardformer] blip2 support flash attention and jit operator

* [shardformer] blip2 support flash attention and jit operator

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: default avatarklhhhhh <1412841649@qq.com>

* [shardformer] chatglm support flash attention and jit operator (#4330)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

* [shardformer] chatglm support flash attention and jit operator

* [shardformer] chatglm support flash attention and jit operator

* [shardformer] chatglm support flash attention and jit operator

* [shardformer] chatglm support flash attention and jit operator

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: default avatarklhhhhh <1412841649@qq.com>

* [shardformer] vit support flash attention and jit operator (#4334)

* Feature/vit support (#4182)

* [shardformer] added tests

* [shardformer] vit test finish and support

* fix attention dropout

* [shardformer] support SAM (#4231)

* 1.support sam 2.add fused qkv for nn.Linear

* update utils support set element in list

* overtwrite SamVisionAttention foward to use DropoutForParallelInput

* remove unused code

* [shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme

* Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit

* [shardformer] added tests

* [shardformer] vit test finish and support

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

* [shardformer] support Blip2 (#4243)

* support base blip2

* add support for downstream blip2 model

* update readme

* add forward injection

* skip not compatible models test

* fix test for gemini and low_level_zero_pugin

* [shardformer] vit support flash attention and jit operator

* [shardformer] vit support flash attention and jit operator

---------
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: default avatarklhhhhh <1412841649@qq.com>

* [pipeline] merge flash attention branch

* [pipeline] merge flash attention branch

* [pipeline] merge flash attention branch

* [pipeline] fix conflict

* [pipeline] fix conflict

* Merge branch 'feature/pipeline' into feature/pipeline

* Merge branch 'feature/pipeline' into feature/pipeline

* Merge branch 'feature/pipeline' into feature/pipeline

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* activate checks

* fix flash attention tests

* gemini ignore whisper

* fix vit

* fix xformers import handle

---------
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
Co-authored-by: default avatarKun Lin <81014421+klhhhhh@users.noreply.github.com>
Co-authored-by: default avatarFoolPlayer <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: default avatarklhhhhh <1412841649@qq.com>
parent a88e9225
...@@ -5,7 +5,8 @@ from torch import Tensor, nn ...@@ -5,7 +5,8 @@ from torch import Tensor, nn
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from ..modeling.gpt2 import GPT2PipelineForwards from .._utils import getattr_, setattr_
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [ __all__ = [
...@@ -33,7 +34,7 @@ class GPT2Policy(Policy): ...@@ -33,7 +34,7 @@ class GPT2Policy(Policy):
return self.model return self.model
def module_policy(self): def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
policy = {} policy = {}
...@@ -53,42 +54,42 @@ class GPT2Policy(Policy): ...@@ -53,42 +54,42 @@ class GPT2Policy(Policy):
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
}, },
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.c_attn", suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col, target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={ kwargs={
"n_fused": 3, "n_fused": 3,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.c_proj", suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row, target_module=col_nn.GPT2FusedLinearConv1D_Row,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.c_fc", suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col, target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={ kwargs={
"n_fused": 1, "n_fused": 1,
}, },
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.c_proj", suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row, target_module=col_nn.GPT2FusedLinearConv1D_Row,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.attn_dropout", suffix="attn.attn_dropout",
target_module=col_nn.DropoutForParallelInput, target_module=col_nn.DropoutForParallelInput,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attn.resid_dropout", suffix="attn.resid_dropout",
target_module=col_nn.DropoutForParallelInput, target_module=col_nn.DropoutForParallelInput,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="mlp.dropout", suffix="mlp.dropout",
target_module=col_nn.DropoutForParallelInput, target_module=col_nn.DropoutForParallelInput,
), ),
]) ])
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
...@@ -96,8 +97,8 @@ class GPT2Policy(Policy): ...@@ -96,8 +97,8 @@ class GPT2Policy(Policy):
suffix="ln_f", suffix="ln_f",
target_module=col_nn.FusedLayerNorm, target_module=col_nn.FusedLayerNorm,
), ),
policy=policy, policy=policy,
target_key=GPT2Model) target_key=GPT2Model)
self.append_or_create_submodule_replacement(description=[ self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
...@@ -112,8 +113,13 @@ class GPT2Policy(Policy): ...@@ -112,8 +113,13 @@ class GPT2Policy(Policy):
target_module=col_nn.FusedLayerNorm, target_module=col_nn.FusedLayerNorm,
ignore_if_not_exist=True) ignore_if_not_exist=True)
], ],
policy=policy, policy=policy,
target_key=GPT2Block) target_key=GPT2Block)
if self.shard_config.enable_flash_attention:
policy[GPT2Attention] = ModulePolicyDescription(method_replacement={
'forward': get_gpt2_flash_attention_forward(),
})
return policy return policy
def postprocess(self): def postprocess(self):
......
...@@ -7,7 +7,7 @@ from torch.nn import Module ...@@ -7,7 +7,7 @@ from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from ..modeling.llama import LlamaPipelineForwards from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] __all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']
...@@ -31,7 +31,7 @@ class LlamaPolicy(Policy): ...@@ -31,7 +31,7 @@ class LlamaPolicy(Policy):
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
policy = {} policy = {}
...@@ -104,6 +104,11 @@ class LlamaPolicy(Policy): ...@@ -104,6 +104,11 @@ class LlamaPolicy(Policy):
policy=policy, policy=policy,
target_key=LlamaModel) target_key=LlamaModel)
if self.shard_config.enable_flash_attention:
policy[LlamaAttention] = ModulePolicyDescription(method_replacement={
'forward': get_llama_flash_attention_forward(),
})
return policy return policy
def postprocess(self): def postprocess(self):
......
...@@ -25,6 +25,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager ...@@ -25,6 +25,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .._utils import getattr_, setattr_ from .._utils import getattr_, setattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.opt import get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [ __all__ = [
...@@ -114,6 +116,19 @@ class OPTPolicy(Policy): ...@@ -114,6 +116,19 @@ class OPTPolicy(Policy):
policy=policy, policy=policy,
target_key=OPTDecoderLayer) target_key=OPTDecoderLayer)
# use flash attention
if self.shard_config.enable_flash_attention:
policy[OPTAttention] = ModulePolicyDescription(method_replacement={
'forward': get_opt_flash_attention_forward(),
})
# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_opt_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy return policy
def postprocess(self): def postprocess(self):
...@@ -189,13 +204,11 @@ class OPTForCausalLMPolicy(OPTPolicy): ...@@ -189,13 +204,11 @@ class OPTForCausalLMPolicy(OPTPolicy):
from transformers.models.opt.modeling_opt import OPTForCausalLM from transformers.models.opt.modeling_opt import OPTForCausalLM
policy = super().module_policy() policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
policy=policy, policy=policy,
target_key=OPTForCausalLM) target_key=OPTForCausalLM)
if self.pipeline_stage_manager: if self.pipeline_stage_manager:
self.set_pipeline_forward(model_cls=OPTForCausalLM, self.set_pipeline_forward(model_cls=OPTForCausalLM,
new_forward=OPTPipelineForwards.opt_for_causal_lm_forward, new_forward=OPTPipelineForwards.opt_for_causal_lm_forward,
......
...@@ -3,7 +3,7 @@ import torch.nn as nn ...@@ -3,7 +3,7 @@ import torch.nn as nn
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_ from .._utils import getattr_, setattr_
from ..modeling.sam import forward_fn from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['SamPolicy', 'SamModelPolicy'] __all__ = ['SamPolicy', 'SamModelPolicy']
...@@ -19,6 +19,7 @@ class SamPolicy(Policy): ...@@ -19,6 +19,7 @@ class SamPolicy(Policy):
def module_policy(self): def module_policy(self):
from transformers.models.sam.modeling_sam import ( from transformers.models.sam.modeling_sam import (
SamAttention,
SamFeedForward, SamFeedForward,
SamTwoWayAttentionBlock, SamTwoWayAttentionBlock,
SamTwoWayTransformer, SamTwoWayTransformer,
...@@ -196,6 +197,15 @@ class SamPolicy(Policy): ...@@ -196,6 +197,15 @@ class SamPolicy(Policy):
policy=policy, policy=policy,
target_key=SamTwoWayTransformer) target_key=SamTwoWayTransformer)
# use flash attention
if self.shard_config.enable_flash_attention:
policy[SamAttention] = ModulePolicyDescription(method_replacement={
'forward': get_sam_flash_attention_forward(),
})
policy[SamVisionAttention] = ModulePolicyDescription(method_replacement={
'forward': get_sam_vision_flash_attention_forward(),
})
return policy return policy
def postprocess(self): def postprocess(self):
......
...@@ -14,7 +14,14 @@ from colossalai.shardformer.layer import ( ...@@ -14,7 +14,14 @@ from colossalai.shardformer.layer import (
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
from .._utils import getattr_, setattr_ from .._utils import getattr_, setattr_
from ..modeling.t5 import T5PipelineForwards from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.t5 import (
T5PipelineForwards,
get_jit_fused_T5_layer_ff_forward,
get_t5_flash_attention_forward,
get_T5_layer_cross_attention_forward,
get_T5_layer_self_attention_forward,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] __all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
...@@ -168,6 +175,27 @@ class T5BasePolicy(Policy): ...@@ -168,6 +175,27 @@ class T5BasePolicy(Policy):
suffix="final_layer_norm", target_module=FusedRMSNorm), suffix="final_layer_norm", target_module=FusedRMSNorm),
policy=policy, policy=policy,
target_key=T5Stack) target_key=T5Stack)
# use flash attention
if self.shard_config.enable_flash_attention:
policy[T5Attention] = ModulePolicyDescription(method_replacement={
'forward': get_t5_flash_attention_forward(),
})
# use jit operator
if self.shard_config.enable_jit_fused:
policy[T5LayerFF] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_T5_layer_ff_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[T5LayerSelfAttention] = ModulePolicyDescription(method_replacement={
'forward': get_T5_layer_self_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[T5LayerCrossAttention] = ModulePolicyDescription(method_replacement={
'forward': get_T5_layer_cross_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy return policy
def postprocess(self): def postprocess(self):
......
...@@ -3,11 +3,15 @@ from typing import Callable, Dict, List, Union ...@@ -3,11 +3,15 @@ from typing import Callable, Dict, List, Union
import torch.nn as nn import torch.nn as nn
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.layer import DropoutForReplicatedInput, Linear1D_Col
from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.vit import ( from ..modeling.vit import (
ViTForImageClassification_pipeline_forward, ViTForImageClassification_pipeline_forward,
ViTForMaskedImageModeling_pipeline_forward, ViTForMaskedImageModeling_pipeline_forward,
ViTModel_pipeline_forward, ViTModel_pipeline_forward,
get_jit_fused_vit_output_forward,
get_vit_flash_self_attention_forward,
) )
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
...@@ -23,7 +27,8 @@ class ViTPolicy(Policy): ...@@ -23,7 +27,8 @@ class ViTPolicy(Policy):
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel, ViTOutput, ViTSelfAttention
policy = {} policy = {}
...@@ -33,7 +38,7 @@ class ViTPolicy(Policy): ...@@ -33,7 +38,7 @@ class ViTPolicy(Policy):
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="dropout", suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput, target_module=DropoutForReplicatedInput,
) )
]) ])
...@@ -83,8 +88,18 @@ class ViTPolicy(Policy): ...@@ -83,8 +88,18 @@ class ViTPolicy(Policy):
), ),
]) ])
return policy # use flash attention
if self.shard_config.enable_flash_attention:
policy[ViTSelfAttention] = ModulePolicyDescription(method_replacement={
'forward': get_vit_flash_self_attention_forward(),
})
# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[ViTOutput] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_vit_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy return policy
def new_model_class(self): def new_model_class(self):
...@@ -167,7 +182,7 @@ class ViTForImageClassificationPolicy(ViTPolicy): ...@@ -167,7 +182,7 @@ class ViTForImageClassificationPolicy(ViTPolicy):
ViTForImageClassification: ViTForImageClassification:
ModulePolicyDescription(sub_module_replacement=[ ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)) suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
]) ])
} }
policy.update(new_item) policy.update(new_item)
......
...@@ -3,6 +3,12 @@ import torch.nn as nn ...@@ -3,6 +3,12 @@ import torch.nn as nn
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_ from .._utils import getattr_, setattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.whisper import (
get_jit_fused_whisper_decoder_layer_forward,
get_jit_fused_whisper_encoder_layer_forward,
get_whisper_flash_attention_forward,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [ __all__ = [
...@@ -30,6 +36,7 @@ class WhisperPolicy(Policy): ...@@ -30,6 +36,7 @@ class WhisperPolicy(Policy):
def module_policy(self): def module_policy(self):
from transformers.models.whisper.modeling_whisper import ( from transformers.models.whisper.modeling_whisper import (
WhisperAttention,
WhisperDecoder, WhisperDecoder,
WhisperDecoderLayer, WhisperDecoderLayer,
WhisperEncoder, WhisperEncoder,
...@@ -181,6 +188,24 @@ class WhisperPolicy(Policy): ...@@ -181,6 +188,24 @@ class WhisperPolicy(Policy):
], ],
policy=policy, policy=policy,
target_key=WhisperDecoder) target_key=WhisperDecoder)
# enable flash attention
if self.shard_config.enable_flash_attention:
policy[WhisperAttention] = ModulePolicyDescription(method_replacement={
'forward': get_whisper_flash_attention_forward(),
})
# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[WhisperEncoderLayer] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_whisper_encoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[WhisperDecoderLayer] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_whisper_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
return policy return policy
def add_lm_head_policy(self, base_policy): def add_lm_head_policy(self, base_policy):
......
...@@ -26,6 +26,8 @@ class ShardConfig: ...@@ -26,6 +26,8 @@ class ShardConfig:
enable_tensor_parallelism: bool = True enable_tensor_parallelism: bool = True
enable_fused_normalization: bool = False enable_fused_normalization: bool = False
enable_all_optimization: bool = False enable_all_optimization: bool = False
enable_flash_attention: bool = False
enable_jit_fused: bool = False
# TODO: add support for tensor parallel # TODO: add support for tensor parallel
# pipeline_parallel_size: int # pipeline_parallel_size: int
...@@ -44,7 +46,6 @@ class ShardConfig: ...@@ -44,7 +46,6 @@ class ShardConfig:
else: else:
# get the parallel size # get the parallel size
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
# turn on all optimization if all_optimization is set to True # turn on all optimization if all_optimization is set to True
if self.enable_all_optimization: if self.enable_all_optimization:
self._turn_on_all_optimization() self._turn_on_all_optimization()
...@@ -55,3 +56,5 @@ class ShardConfig: ...@@ -55,3 +56,5 @@ class ShardConfig:
""" """
# you can add all the optimization flag here # you can add all the optimization flag here
self.enable_fused_normalization = True self.enable_fused_normalization = True
self.enable_flash_attention = True
self.enable_jit_fused = True
...@@ -18,3 +18,5 @@ SentencePiece ...@@ -18,3 +18,5 @@ SentencePiece
ninja ninja
flash_attn>=2.0 flash_attn>=2.0
datasets datasets
ninja
flash-attn
...@@ -20,7 +20,7 @@ def data_gen(): ...@@ -20,7 +20,7 @@ def data_gen():
# token_type_ids = tokenized_input['token_type_ids'] # token_type_ids = tokenized_input['token_type_ids']
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64) input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64)
token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 0]], dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
...@@ -69,19 +69,21 @@ def data_gen_for_mcq(): ...@@ -69,19 +69,21 @@ def data_gen_for_mcq():
# data['labels'] = torch.tensor([0], dtype=torch.int64) # data['labels'] = torch.tensor([0], dtype=torch.int64)
input_ids = torch.tensor([[[ input_ids = torch.tensor([[[
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591, 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591,
4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102
], ],
[ [
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037,
4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096, 4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096,
2218, 1999, 1996, 2192, 1012, 102, 0 2218, 1999, 1996, 2192, 1012, 102, 0, 0
]]]) ]]])
token_type_ids = torch.tensor( token_type_ids = torch.tensor(
[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]]) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0]]])
attention_mask = torch.tensor( attention_mask = torch.tensor(
[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]]) [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
0]]])
labels = torch.tensor([0], dtype=torch.int64) labels = torch.tensor([0], dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels) return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
......
...@@ -38,6 +38,7 @@ output_transform_fn = lambda x: x ...@@ -38,6 +38,7 @@ output_transform_fn = lambda x: x
loss_fn_blip2_model = lambda x: x.loss loss_fn_blip2_model = lambda x: x.loss
config = transformers.Blip2Config() config = transformers.Blip2Config()
config.vision_config.patch_size = 14
config.text_config.num_hidden_layers = 1 config.text_config.num_hidden_layers = 1
config.qformer_config.num_hidden_layers = 1 config.qformer_config.num_hidden_layers = 1
config.vision_config.num_hidden_layers = 1 config.vision_config.num_hidden_layers = 1
......
...@@ -16,8 +16,8 @@ def data_gen(): ...@@ -16,8 +16,8 @@ def data_gen():
# tokenized_input = tokenizer(input, return_tensors='pt') # tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids'] # input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask'] # attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595]], dtype=torch.int64) input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595, 632, 207595]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64) attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, attention_mask=attention_mask)
...@@ -33,7 +33,7 @@ def data_gen_for_token_classification(): ...@@ -33,7 +33,7 @@ def data_gen_for_token_classification():
# token classification data gen # token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1 # `labels` is the type not the token id for token classification, 0 or 1
data = data_gen() data = data_gen()
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64) data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
return data return data
...@@ -53,8 +53,8 @@ def data_gen_for_question_answering(): ...@@ -53,8 +53,8 @@ def data_gen_for_question_answering():
# inputs = tokenizer(question, text, return_tensors="pt") # inputs = tokenizer(question, text, return_tensors="pt")
input_ids = torch.tensor( input_ids = torch.tensor(
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161]], dtype=torch.int64) [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
start_positions = torch.tensor([1], dtype=torch.int64) start_positions = torch.tensor([1], dtype=torch.int64)
end_positions = torch.tensor([10], dtype=torch.int64) end_positions = torch.tensor([10], dtype=torch.int64)
return dict(input_ids=input_ids, return dict(input_ids=input_ids,
......
...@@ -6,7 +6,6 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM ...@@ -6,7 +6,6 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM
from ..registry import ModelAttribute, model_zoo from ..registry import ModelAttribute, model_zoo
# ================================ # ================================
# Register single-sentence ChatGLM # Register single-sentence ChatGLM
# ================================ # ================================
......
from transformers import PretrainedConfig
class ChatGLMConfig(PretrainedConfig):
model_type = "chatglm"
def __init__(self,
num_layers=28,
padded_vocab_size=65024,
hidden_size=4096,
ffn_hidden_size=13696,
kv_channels=128,
num_attention_heads=32,
seq_length=2048,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
rmsnorm=True,
apply_residual_connection_post_layernorm=False,
post_layer_norm=True,
add_bias_linear=False,
add_qkv_bias=False,
bias_dropout_fusion=True,
multi_query_attention=False,
multi_query_group_num=1,
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=True,
fp32_residual_connection=False,
quantization_bit=0,
pre_seq_len=None,
prefix_projection=False,
**kwargs):
self.num_layers = num_layers
self.vocab_size = padded_vocab_size
self.padded_vocab_size = padded_vocab_size
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads
self.seq_length = seq_length
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon
self.rmsnorm = rmsnorm
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.post_layer_norm = post_layer_norm
self.add_bias_linear = add_bias_linear
self.add_qkv_bias = add_qkv_bias
self.bias_dropout_fusion = bias_dropout_fusion
self.multi_query_attention = multi_query_attention
self.multi_query_group_num = multi_query_group_num
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.fp32_residual_connection = fp32_residual_connection
self.quantization_bit = quantization_bit
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
super().__init__(**kwargs)
"""
The ChatGLM2-6B License
1. Definitions
“Licensor” means the ChatGLM2-6B Model Team that distributes its Software.
“Software” means the ChatGLM2-6B model parameters made available under this license.
2. License Grant
Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes.
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
3. Restriction
You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes.
You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
4. Disclaimer
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
5. Limitation of Liability
EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
6. Dispute Resolution
This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com.
"""
""" PyTorch ChatGLM model. """
import copy
import math
import re
import sys
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from torch.nn.utils import skip_init
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from .configuration_chatglm import ChatGLMConfig
# flags required to enable jit fusion kernels
if sys.platform != "darwin":
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B"
_CONFIG_FOR_DOC = "ChatGLM6BConfig"
CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
"THUDM/chatglm2-6b",
# See all ChatGLM models at https://huggingface.co/models?filter=chatglm
]
def default_init(cls, *args, **kwargs):
return cls(*args, **kwargs)
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores
class PrefixEncoder(torch.nn.Module):
"""
The torch.nn model to encode the prefix
Input shape: (batch-size, prefix-length)
Output shape: (batch-size, prefix-length, 2*layers*hidden)
"""
def __init__(self, config: ChatGLMConfig):
super().__init__()
self.prefix_projection = config.prefix_projection
if self.prefix_projection:
# Use a two-layer MLP to encode the prefix
kv_size = (config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
self.trans = torch.nn.Sequential(
torch.nn.Linear(kv_size, config.hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(config.hidden_size, kv_size),
)
else:
self.embedding = torch.nn.Embedding(
config.pre_seq_len,
config.num_layers * config.kv_channels * config.multi_query_group_num * 2,
)
def forward(self, prefix: torch.Tensor):
if self.prefix_projection:
prefix_tokens = self.embedding(prefix)
past_key_values = self.trans(prefix_tokens)
else:
past_key_values = self.embedding(prefix)
return past_key_values
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = tensor.size()[last_dim] // num_partitions
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
class RotaryEmbedding(nn.Module):
def __init__(self, dim, original_impl=False, device=None, dtype=None):
super().__init__()
inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
self.register_buffer("inv_freq", inv_freq)
self.dim = dim
self.original_impl = original_impl
def forward_impl(
self,
seq_len: int,
n_elem: int,
dtype: torch.dtype,
device: torch.device,
base: int = 10000,
):
"""Enhanced Transformer with Rotary Position Embedding.
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
transformers/rope/__init__.py. MIT License:
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
"""
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1.0 / (base**(torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
# Calculate the product of position index and $\theta_i$
idx_theta = torch.outer(seq_idx, theta).float()
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
# this is to mimic the behaviour of complex32, else we will get different results
if dtype in (torch.float16, torch.bfloat16, torch.int8):
cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
return cache
def forward(self, max_seq_len, offset=0):
return self.forward_impl(
max_seq_len,
self.dim,
dtype=self.inv_freq.dtype,
device=self.inv_freq.device,
)
@torch.jit.script
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
# x: [sq, b, np, hn]
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
rot_dim = rope_cache.shape[-2] * 2
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
# truncate to support variable sizes
rope_cache = rope_cache[:sq]
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
],
-1,
)
x_out2 = x_out2.flatten(3)
return torch.cat((x_out2, x_pass), dim=-1)
class RMSNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
super().__init__()
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
self.eps = eps
def forward(self, hidden_states: torch.Tensor):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
return (self.weight * hidden_states).to(input_dtype)
class CoreAttention(torch.nn.Module):
def __init__(self, config: ChatGLMConfig, layer_number):
super(CoreAttention, self).__init__()
self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
projection_size = config.kv_channels * config.num_attention_heads
# Per attention head and per partition values.
self.hidden_size_per_partition = projection_size
self.hidden_size_per_attention_head = (projection_size // config.num_attention_heads)
self.num_attention_heads_per_partition = config.num_attention_heads
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.coeff = coeff
self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
def forward(self, query_layer, key_layer, value_layer, attention_mask):
pytorch_major_version = int(torch.__version__.split(".")[0])
if pytorch_major_version >= 2:
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
key_layer,
value_layer,
is_causal=True)
else:
if attention_mask is not None:
attention_mask = ~attention_mask
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
attention_mask)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
else:
# Raw attention scores
# [b, np, sq, sk]
output_size = (
query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0),
)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = torch.empty(
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=query_layer.device,
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
if self.attention_softmax_in_fp32:
attention_scores = attention_scores.float()
if self.coeff is not None:
attention_scores = attention_scores * self.coeff
if (attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]):
attention_mask = torch.ones(
output_size[0],
1,
output_size[2],
output_size[3],
device=attention_scores.device,
dtype=torch.bool,
)
attention_mask.tril_()
attention_mask = ~attention_mask
if attention_mask is not None:
attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = attention_probs.type_as(value_layer)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (
value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3),
)
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class SelfAttention(torch.nn.Module):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
super(SelfAttention, self).__init__()
self.layer_number = max(1, layer_number)
self.projection_size = config.kv_channels * config.num_attention_heads
# Per attention head and per partition values.
self.hidden_size_per_attention_head = (self.projection_size // config.num_attention_heads)
self.num_attention_heads_per_partition = config.num_attention_heads
self.multi_query_attention = config.multi_query_attention
self.qkv_hidden_size = 3 * self.projection_size
if self.multi_query_attention:
self.num_multi_query_groups_per_partition = config.multi_query_group_num
self.qkv_hidden_size = (self.projection_size +
2 * self.hidden_size_per_attention_head * config.multi_query_group_num)
self.query_key_value = nn.Linear(
config.hidden_size,
self.qkv_hidden_size,
bias=config.add_bias_linear or config.add_qkv_bias,
device=device,
**_config_to_kwargs(config),
)
self.core_attention = CoreAttention(config, self.layer_number)
# Output.
self.dense = nn.Linear(
self.projection_size,
config.hidden_size,
bias=config.add_bias_linear,
device=device,
**_config_to_kwargs(config),
)
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
if self.multi_query_attention:
num_attention_heads = self.num_multi_query_groups_per_partition
else:
num_attention_heads = self.num_attention_heads_per_partition
return torch.empty(
inference_max_sequence_len,
batch_size,
num_attention_heads,
self.hidden_size_per_attention_head,
dtype=dtype,
device=device,
)
def forward(
self,
hidden_states,
attention_mask,
rotary_pos_emb,
kv_cache=None,
use_cache=True,
):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
# =====================
# Query, Key, and Value
# =====================
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer = self.query_key_value(hidden_states)
if self.multi_query_attention:
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
[
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
],
dim=-1,
)
query_layer = query_layer.view(query_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
))
key_layer = key_layer.view(key_layer.size()[:-1] + (
self.num_multi_query_groups_per_partition,
self.hidden_size_per_attention_head,
))
value_layer = value_layer.view(value_layer.size()[:-1] + (
self.num_multi_query_groups_per_partition,
self.hidden_size_per_attention_head,
))
else:
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
# adjust key and value for inference
if kv_cache is not None:
cache_k, cache_v = kv_cache
key_layer = torch.cat((cache_k, key_layer), dim=0)
value_layer = torch.cat((cache_v, value_layer), dim=0)
if use_cache:
kv_cache = (key_layer, value_layer)
else:
kv_cache = None
if self.multi_query_attention:
key_layer = key_layer.unsqueeze(-2)
key_layer = key_layer.expand(
-1,
-1,
-1,
self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
-1,
)
key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
))
value_layer = value_layer.unsqueeze(-2)
value_layer = value_layer.expand(
-1,
-1,
-1,
self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
-1,
)
value_layer = value_layer.contiguous().view(value_layer.size()[:2] + (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
))
# ==================================
# core attention computation
# ==================================
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
# =================
# Output. [sq, b, h]
# =================
output = self.dense(context_layer)
return output, kv_cache
def _config_to_kwargs(args):
common_kwargs = {
"dtype": args.torch_dtype,
}
return common_kwargs
class MLP(torch.nn.Module):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self, config: ChatGLMConfig, device=None):
super(MLP, self).__init__()
self.add_bias = config.add_bias_linear
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self.dense_h_to_4h = nn.Linear(
config.hidden_size,
config.ffn_hidden_size * 2,
bias=self.add_bias,
device=device,
**_config_to_kwargs(config),
)
def swiglu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]
self.activation_func = swiglu
# Project back to h.
self.dense_4h_to_h = nn.Linear(
config.ffn_hidden_size,
config.hidden_size,
bias=self.add_bias,
device=device,
**_config_to_kwargs(config),
)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output = self.dense_4h_to_h(intermediate_parallel)
return output
class GLMBlock(torch.nn.Module):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
super(GLMBlock, self).__init__()
self.layer_number = layer_number
self.apply_residual_connection_post_layernorm = (config.apply_residual_connection_post_layernorm)
self.fp32_residual_connection = config.fp32_residual_connection
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data.
self.input_layernorm = LayerNormFunc(
config.hidden_size,
eps=config.layernorm_epsilon,
device=device,
dtype=config.torch_dtype,
)
# Self attention.
self.self_attention = SelfAttention(config, layer_number, device=device)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
self.post_attention_layernorm = LayerNormFunc(
config.hidden_size,
eps=config.layernorm_epsilon,
device=device,
dtype=config.torch_dtype,
)
# MLP
self.mlp = MLP(config, device=device)
def forward(
self,
hidden_states,
attention_mask,
rotary_pos_emb,
kv_cache=None,
use_cache=True,
):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, kv_cache = self.self_attention(
layernorm_output,
attention_mask,
rotary_pos_emb,
kv_cache=kv_cache,
use_cache=use_cache,
)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
layernorm_input = residual + layernorm_input
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
output = residual + output
return output, kv_cache
class GLMTransformer(torch.nn.Module):
"""Transformer class."""
def __init__(self, config: ChatGLMConfig, device=None):
super(GLMTransformer, self).__init__()
self.fp32_residual_connection = config.fp32_residual_connection
self.post_layer_norm = config.post_layer_norm
# Number of layers.
self.num_layers = config.num_layers
# Transformer layers.
def build_layer(layer_number):
return GLMBlock(config, layer_number, device=device)
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
if self.post_layer_norm:
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
# Final layer norm before output.
self.final_layernorm = LayerNormFunc(
config.hidden_size,
eps=config.layernorm_epsilon,
device=device,
dtype=config.torch_dtype,
)
self.gradient_checkpointing = False
def _get_layer(self, layer_number):
return self.layers[layer_number]
def forward(
self,
hidden_states,
attention_mask,
rotary_pos_emb,
kv_caches=None,
use_cache: Optional[bool] = True,
output_hidden_states: Optional[bool] = False,
):
if not kv_caches:
kv_caches = [None for _ in range(self.num_layers)]
presents = () if use_cache else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
all_self_attentions = None
all_hidden_states = () if output_hidden_states else None
for index in range(self.num_layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer = self._get_layer(index)
if self.gradient_checkpointing and self.training:
layer_ret = torch.utils.checkpoint.checkpoint(
layer,
hidden_states,
attention_mask,
rotary_pos_emb,
kv_caches[index],
use_cache,
)
else:
layer_ret = layer(
hidden_states,
attention_mask,
rotary_pos_emb,
kv_cache=kv_caches[index],
use_cache=use_cache,
)
hidden_states, kv_cache = layer_ret
if use_cache:
presents = presents + (kv_cache,)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# Final layer norm.
if self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
return hidden_states, presents, all_hidden_states, all_self_attentions
class ChatGLMPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
is_parallelizable = False
supports_gradient_checkpointing = True
config_class = ChatGLMConfig
base_model_prefix = "transformer"
_no_split_modules = ["GLMBlock"]
def _init_weights(self, module: nn.Module):
"""Initialize the weights."""
return
def get_masks(self, input_ids, past_key_values, padding_mask=None):
batch_size, seq_length = input_ids.shape
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
full_attention_mask.tril_()
past_length = 0
if past_key_values:
past_length = past_key_values[0][0].shape[0]
if past_length:
full_attention_mask = torch.cat(
(
torch.ones(batch_size, seq_length, past_length, device=input_ids.device),
full_attention_mask,
),
dim=-1,
)
if padding_mask is not None:
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
if not past_length and padding_mask is not None:
full_attention_mask -= padding_mask.unsqueeze(-1) - 1
full_attention_mask = (full_attention_mask < 0.5).bool()
full_attention_mask.unsqueeze_(1)
return full_attention_mask
def get_position_ids(self, input_ids, device):
batch_size, seq_length = input_ids.shape
position_ids = (torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1))
return position_ids
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, GLMTransformer):
module.gradient_checkpointing = value
class Embedding(torch.nn.Module):
"""Language model embeddings."""
def __init__(self, config: ChatGLMConfig, device=None):
super(Embedding, self).__init__()
self.hidden_size = config.hidden_size
# Word embeddings (parallel).
self.word_embeddings = nn.Embedding(
config.padded_vocab_size,
self.hidden_size,
dtype=config.torch_dtype,
device=device,
)
self.fp32_residual_connection = config.fp32_residual_connection
def forward(self, input_ids):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
embeddings = words_embeddings
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
embeddings = embeddings.float()
return embeddings
class ChatGLMModel(ChatGLMPreTrainedModel):
def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
super().__init__(config)
if empty_init:
init_method = skip_init
else:
init_method = default_init
init_kwargs = {}
if device is not None:
init_kwargs["device"] = device
self.embedding = init_method(Embedding, config, **init_kwargs)
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
# Rotary positional embeddings
self.seq_length = config.seq_length
rotary_dim = (config.hidden_size //
config.num_attention_heads if config.kv_channels is None else config.kv_channels)
self.rotary_pos_emb = RotaryEmbedding(
rotary_dim // 2,
original_impl=config.original_rope,
device=device,
dtype=config.torch_dtype,
)
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
self.output_layer = init_method(
nn.Linear,
config.hidden_size,
config.padded_vocab_size,
bias=False,
dtype=config.torch_dtype,
**init_kwargs,
)
self.pre_seq_len = config.pre_seq_len
self.prefix_projection = config.prefix_projection
if self.pre_seq_len is not None:
for param in self.parameters():
param.requires_grad = False
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
self.prefix_encoder = PrefixEncoder(config)
self.dropout = torch.nn.Dropout(0.1)
def get_input_embeddings(self):
return self.embedding.word_embeddings
def get_prompt(self, batch_size, device, dtype=torch.half):
prefix_tokens = (self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device))
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
past_key_values = past_key_values.view(
batch_size,
self.pre_seq_len,
self.num_layers * 2,
self.multi_query_group_num,
self.kv_channels,
)
# seq_len, b, nh, hidden_size
past_key_values = self.dropout(past_key_values)
past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
return past_key_values
def forward(
self,
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
full_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
batch_size, seq_length = input_ids.shape
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
if self.pre_seq_len is not None:
if past_key_values is None:
past_key_values = self.get_prompt(
batch_size=batch_size,
device=input_ids.device,
dtype=inputs_embeds.dtype,
)
if attention_mask is not None:
attention_mask = torch.cat(
[
attention_mask.new_ones((batch_size, self.pre_seq_len)),
attention_mask,
],
dim=-1,
)
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
if position_ids is not None:
rotary_pos_emb = rotary_pos_emb[position_ids]
else:
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
# Run encoder.
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds,
full_attention_mask,
rotary_pos_emb=rotary_pos_emb,
kv_caches=past_key_values,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
)
if not return_dict:
return tuple(v for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
def quantize(self, weight_bit_width: int):
from .quantization import quantize
quantize(self.encoder, weight_bit_width)
return self
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
super().__init__(config)
self.max_sequence_length = config.max_length
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
self.config = config
self.quantized = False
if self.config.quantization_bit:
self.quantize(self.config.quantization_bit, empty_init=True)
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
# update past_key_values
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format)
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
dim=-1,
)
# update position ids
if "position_ids" in model_kwargs:
position_ids = model_kwargs["position_ids"]
new_position_id = position_ids[..., -1:].clone()
new_position_id += 1
model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)
model_kwargs["is_first_forward"] = False
return model_kwargs
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
is_first_forward: bool = True,
**kwargs,
) -> dict:
# only last token for input_ids if past is not None
if position_ids is None:
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
if not is_first_forward:
position_ids = position_ids[..., -1:]
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"position_ids": position_ids,
"attention_mask": attention_mask,
"return_last_logit": True,
}
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_last_logit: Optional[bool] = False,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
transformer_outputs = self.transformer(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
if return_last_logit:
hidden_states = hidden_states[-1:]
lm_logits = self.transformer.output_layer(hidden_states)
lm_logits = lm_logits.transpose(0, 1).contiguous()
loss = None
if labels is not None:
lm_logits = lm_logits.to(torch.float32)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
lm_logits = lm_logits.to(hidden_states.dtype)
loss = loss.to(hidden_states.dtype)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...],
beam_idx: torch.LongTensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
Output shares the same memory storage as `past`.
"""
return tuple((
layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
) for layer_past in past)
def process_response(self, response):
response = response.strip()
response = response.replace("[[训练时间]]", "2023年")
return response
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
prompt = tokenizer.build_prompt(query, history=history)
inputs = tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.device)
return inputs
def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
if history:
prompt = "\n\n[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
input_ids = input_ids[1:]
inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False)
else:
prompt = "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
inputs = tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.device)
return inputs
@torch.no_grad()
def chat(
self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = None,
max_length: int = 8192,
num_beams=1,
do_sample=True,
top_p=0.8,
temperature=0.8,
logits_processor=None,
**kwargs,
):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {
"max_length": max_length,
"num_beams": num_beams,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
**kwargs,
}
inputs = self.build_inputs(tokenizer, query, history=history)
outputs = self.generate(**inputs, **gen_kwargs)
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs)
response = self.process_response(response)
history = history + [(query, response)]
return response, history
@torch.no_grad()
def stream_chat(
self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = None,
past_key_values=None,
max_length: int = 8192,
do_sample=True,
top_p=0.8,
temperature=0.8,
logits_processor=None,
return_past_key_values=False,
**kwargs,
):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {
"max_length": max_length,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
**kwargs,
}
if past_key_values is None and not return_past_key_values:
inputs = self.build_inputs(tokenizer, query, history=history)
else:
inputs = self.build_stream_inputs(tokenizer, query, history=history)
if past_key_values is not None:
past_length = past_key_values[0][0].shape[0]
if self.transformer.pre_seq_len is not None:
past_length -= self.transformer.pre_seq_len
inputs.position_ids += past_length
attention_mask = inputs.attention_mask
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
inputs["attention_mask"] = attention_mask
for outputs in self.stream_generate(
**inputs,
past_key_values=past_key_values,
return_past_key_values=return_past_key_values,
**gen_kwargs,
):
if return_past_key_values:
outputs, past_key_values = outputs
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs)
if response and response[-1] != "�":
response = self.process_response(response)
new_history = history + [(query, response)]
if return_past_key_values:
yield response, new_history, past_key_values
else:
yield response, new_history
@torch.no_grad()
def stream_generate(
self,
input_ids,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
return_past_key_values=False,
**kwargs,
):
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
if generation_config is None:
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
bos_token_id, eos_token_id = (
generation_config.bos_token_id,
generation_config.eos_token_id,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
has_default_max_length = (kwargs.get("max_length") is None and generation_config.max_length is not None)
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = (generation_config.max_new_tokens + input_ids_seq_length)
if not has_default_max_length:
logger.warn(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
UserWarning,
)
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = ("decoder_input_ids" if self.config.is_encoder_decoder else "input_ids")
logger.warning(f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`.")
# 2. Set generation parameters if not already defined
logits_processor = (logits_processor if logits_processor is not None else LogitsProcessorList())
stopping_criteria = (stopping_criteria if stopping_criteria is not None else StoppingCriteriaList())
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
stopping_criteria = self._get_stopping_criteria(generation_config=generation_config,
stopping_criteria=stopping_criteria)
logits_warper = self._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None
while True:
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
if generation_config.do_sample:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder)
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
if return_past_key_values:
yield input_ids, outputs.past_key_values
else:
yield input_ids
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
break
def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
if bits == 0:
return
from .quantization import quantize
if self.quantized:
logger.info("Already quantized.")
return self
self.quantized = True
self.config.quantization_bit = bits
self.transformer.encoder = quantize(
self.transformer.encoder,
bits,
empty_init=empty_init,
device=device,
**kwargs,
)
return self
...@@ -18,8 +18,8 @@ def data_gen(): ...@@ -18,8 +18,8 @@ def data_gen():
# tokenized_input = tokenizer(input, return_tensors='pt') # tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids'] # input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask'] # attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779]], dtype=torch.int64) input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64) attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, attention_mask=attention_mask)
...@@ -46,7 +46,7 @@ def data_gen_for_token_classification(): ...@@ -46,7 +46,7 @@ def data_gen_for_token_classification():
# token classification data gen # token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1 # `labels` is the type not the token id for token classification, 0 or 1
data = data_gen() data = data_gen()
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 1]], dtype=torch.int64) data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64)
return data return data
......
...@@ -16,8 +16,9 @@ def data_gen_for_encoder_only(): ...@@ -16,8 +16,9 @@ def data_gen_for_encoder_only():
# config = T5Config(decoder_start_token_id=0) # config = T5Config(decoder_start_token_id=0)
# tokenizer = T5Tokenizer.from_pretrained("t5-small") # tokenizer = T5Tokenizer.from_pretrained("t5-small")
# input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids # input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1]]).long() input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1, 12]]).long()
return dict(input_ids=input_ids) attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]).long()
return dict(input_ids=input_ids, attention_mask=attention_mask)
def data_gen_for_conditional_generation(): def data_gen_for_conditional_generation():
...@@ -25,17 +26,16 @@ def data_gen_for_conditional_generation(): ...@@ -25,17 +26,16 @@ def data_gen_for_conditional_generation():
# #
# labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids # labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
data = data_gen_for_encoder_only() data = data_gen_for_encoder_only()
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1]]).long() labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1]]).long()
data['labels'] = labels data['labels'] = labels
return data return data
def data_gen_for_t5_model(): def data_gen_for_t5_model():
# decoder_inputs_ids is obtained with the following code # decoder_inputs_ids is obtained with the following code
#
# decoder_input_ids = model._shift_right(input_ids) # decoder_input_ids = model._shift_right(input_ids)
data = data_gen_for_encoder_only() data = data_gen_for_encoder_only()
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5]]).long() decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5]]).long()
data['decoder_input_ids'] = decoder_input_ids data['decoder_input_ids'] = decoder_input_ids
return data return data
......
...@@ -76,14 +76,14 @@ model_zoo.register(name='transformers_whisper', ...@@ -76,14 +76,14 @@ model_zoo.register(name='transformers_whisper',
loss_fn=loss_fn, loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True)) model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_whisperForConditionalGeneration', model_zoo.register(name='transformers_whisper_for_conditional_generation',
model_fn=lambda: transformers.WhisperForConditionalGeneration(config), model_fn=lambda: transformers.WhisperForConditionalGeneration(config),
data_gen_fn=data_gen_for_conditional_generation, data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
loss_fn=loss_fn_attr, loss_fn=loss_fn_attr,
model_attribute=ModelAttribute(has_control_flow=True)) model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_whisperWhisperForAudioClassification', model_zoo.register(name='transformers_whisper_for_audio_classification',
model_fn=lambda: transformers.WhisperForAudioClassification(config), model_fn=lambda: transformers.WhisperForAudioClassification(config),
data_gen_fn=data_gen_for_audio_classification, data_gen_fn=data_gen_for_audio_classification,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
......
...@@ -93,7 +93,7 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): ...@@ -93,7 +93,7 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
'transformers_vit_for_image_classification', 'transformers_chatglm', 'transformers_vit_for_image_classification', 'transformers_chatglm',
'transformers_chatglm_for_conditional_generation', 'transformers_blip2', 'transformers_chatglm_for_conditional_generation', 'transformers_blip2',
'transformers_blip2_conditional_gerneration', 'transformers_sam', 'transformers_whisper', 'transformers_blip2_conditional_gerneration', 'transformers_sam', 'transformers_whisper',
'transformers_whisperForConditionalGeneration', 'transformers_whisperWhisperForAudioClassification' 'transformers_whisper_for_conditional_generation', 'transformers_whisper_for_audio_classification'
]: ]:
continue continue
......
...@@ -21,7 +21,13 @@ from colossalai.shardformer._utils import getattr_ ...@@ -21,7 +21,13 @@ from colossalai.shardformer._utils import getattr_
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False): def build_model(model_fn,
enable_fused_normalization=True,
enable_tensor_parallelism=True,
enable_flash_attention=False,
enable_jit_fused=False,
use_lazy_init: bool = False):
# create new model
ctx = LazyInitContext() if use_lazy_init else nullcontext() ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx: with ctx:
# create new model # create new model
...@@ -31,7 +37,10 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle ...@@ -31,7 +37,10 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle
ctx.materialize(org_model) ctx.materialize(org_model)
# shard model # shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism) enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config) shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy) sharded_model, shared_params = shard_former.optimize(model_copy)
return org_model.cuda(), sharded_model.cuda() return org_model.cuda(), sharded_model.cuda()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment