Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8120eca0
Commit
8120eca0
authored
Jul 20, 2023
by
klhhhhh
Committed by
Hongxin Liu
Aug 15, 2023
Browse files
[shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit
parent
4da05052
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
4 deletions
+37
-4
colossalai/shardformer/policies/chatglm.py
colossalai/shardformer/policies/chatglm.py
+24
-0
colossalai/shardformer/policies/vit.py
colossalai/shardformer/policies/vit.py
+1
-1
tests/kit/model_zoo/transformers/chatglm.py
tests/kit/model_zoo/transformers/chatglm.py
+9
-2
tests/test_shardformer/test_model/test_shard_chatglm.py
tests/test_shardformer/test_model/test_shard_chatglm.py
+3
-1
No files found.
colossalai/shardformer/policies/chatglm.py
View file @
8120eca0
...
@@ -90,7 +90,31 @@ class ChatGLMModelPolicy(Policy):
...
@@ -90,7 +90,31 @@ class ChatGLMModelPolicy(Policy):
policy
=
policy
,
policy
=
policy
,
target_key
=
ChatGLMModel
)
target_key
=
ChatGLMModel
)
else
:
self
.
append_or_create_submodule_replacement
(
description
=
[
SubModuleReplacementDescription
(
suffix
=
"input_layernorm"
,
target_module
=
col_nn
.
FusedRMSNorm
),
SubModuleReplacementDescription
(
suffix
=
"post_attention_layernorm"
,
target_module
=
col_nn
.
FusedRMSNorm
)
],
policy
=
policy
,
target_key
=
GLMBlock
)
if
self
.
model
.
config
.
post_layer_norm
:
self
.
append_or_create_submodule_replacement
(
description
=
[
SubModuleReplacementDescription
(
suffix
=
"encoder.final_layernorm"
,
target_module
=
col_nn
.
FusedRMSNorm
)
],
policy
=
policy
,
target_key
=
ChatGLMModel
)
return
policy
return
policy
def
postprocess
(
self
):
def
postprocess
(
self
):
return
self
.
model
return
self
.
model
class
ChatGLMForConditionalGenerationPolicy
(
ChatGLMModelPolicy
):
def
module_policy
(
self
):
policy
=
super
().
module_policy
()
return
policy
colossalai/shardformer/policies/vit.py
View file @
8120eca0
...
@@ -23,7 +23,7 @@ class ViTPolicy(Policy):
...
@@ -23,7 +23,7 @@ class ViTPolicy(Policy):
return
self
.
model
return
self
.
model
def
module_policy
(
self
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
def
module_policy
(
self
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
from
transformers.models.vit.modeling_vit
import
ViTEmbeddings
,
ViTLayer
from
transformers.models.vit.modeling_vit
import
ViTEmbeddings
,
ViTLayer
,
ViTModel
policy
=
{}
policy
=
{}
...
...
tests/kit/model_zoo/transformers/chatglm.py
View file @
8120eca0
...
@@ -3,7 +3,7 @@ import transformers
...
@@ -3,7 +3,7 @@ import transformers
from
..registry
import
ModelAttribute
,
model_zoo
from
..registry
import
ModelAttribute
,
model_zoo
from
.chatglm2_6b.configuration_chatglm
import
ChatGLMConfig
from
.chatglm2_6b.configuration_chatglm
import
ChatGLMConfig
from
.chatglm2_6b.modeling_chatglm
import
ChatGLMModel
from
.chatglm2_6b.modeling_chatglm
import
ChatGLMForConditionalGeneration
,
ChatGLMModel
# ================================
# ================================
# Register single-sentence ChatGLM
# Register single-sentence ChatGLM
...
@@ -21,7 +21,7 @@ output_transform_fn = lambda x: x
...
@@ -21,7 +21,7 @@ output_transform_fn = lambda x: x
# define loss function
# define loss function
loss_fn_for_chatglm_model
=
lambda
x
:
x
.
last_hidden_state
.
mean
()
loss_fn_for_chatglm_model
=
lambda
x
:
x
.
last_hidden_state
.
mean
()
loss_fn
=
lambda
x
:
x
.
lo
ss
loss_fn
=
lambda
x
:
x
.
lo
gits
.
mean
()
config
=
ChatGLMConfig
(
num_layers
=
1
,
config
=
ChatGLMConfig
(
num_layers
=
1
,
padded_vocab_size
=
65024
,
padded_vocab_size
=
65024
,
hidden_size
=
64
,
hidden_size
=
64
,
...
@@ -36,3 +36,10 @@ model_zoo.register(name='transformers_chatglm',
...
@@ -36,3 +36,10 @@ model_zoo.register(name='transformers_chatglm',
output_transform_fn
=
output_transform_fn
,
output_transform_fn
=
output_transform_fn
,
loss_fn
=
loss_fn_for_chatglm_model
,
loss_fn
=
loss_fn_for_chatglm_model
,
model_attribute
=
ModelAttribute
(
has_control_flow
=
True
))
model_attribute
=
ModelAttribute
(
has_control_flow
=
True
))
model_zoo
.
register
(
name
=
"transformers_chatglm_for_conditional_generation"
,
model_fn
=
lambda
:
ChatGLMForConditionalGeneration
(
config
,
empty_init
=
False
),
data_gen_fn
=
data_gen
,
output_transform_fn
=
output_transform_fn
,
loss_fn
=
loss_fn
,
model_attribute
=
ModelAttribute
(
has_control_flow
=
True
))
tests/test_shardformer/test_model/test_shard_chatglm.py
View file @
8120eca0
...
@@ -7,7 +7,7 @@ import torch
...
@@ -7,7 +7,7 @@ import torch
import
colossalai
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer
import
ShardConfig
,
ShardFormer
from
colossalai.shardformer
import
ShardConfig
,
ShardFormer
from
colossalai.shardformer.policies.chatglm
import
ChatGLMModelPolicy
from
colossalai.shardformer.policies.chatglm
import
ChatGLMForConditionalGenerationPolicy
,
ChatGLMModelPolicy
from
colossalai.tensor.d_tensor.api
import
is_customized_distributed_tensor
,
is_distributed_tensor
from
colossalai.tensor.d_tensor.api
import
is_customized_distributed_tensor
,
is_distributed_tensor
from
colossalai.testing
import
(
from
colossalai.testing
import
(
assert_hf_output_close
,
assert_hf_output_close
,
...
@@ -85,6 +85,8 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism):
...
@@ -85,6 +85,8 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism):
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
if
name
==
"transformers_chatglm"
:
if
name
==
"transformers_chatglm"
:
sharded_model
=
shard_former
.
optimize
(
model_copy
,
ChatGLMModelPolicy
()).
cuda
()
sharded_model
=
shard_former
.
optimize
(
model_copy
,
ChatGLMModelPolicy
()).
cuda
()
else
:
sharded_model
=
shard_former
.
optimize
(
model_copy
,
ChatGLMForConditionalGenerationPolicy
()).
cuda
()
check_forward_backward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
)
check_forward_backward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment