Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1a29e8fc
Commit
1a29e8fc
authored
Jul 12, 2023
by
klhhhhh
Committed by
Hongxin Liu
Aug 15, 2023
Browse files
[shardformer] polish chatglm code
parent
8620009d
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
4 additions
and
44 deletions
+4
-44
colossalai/shardformer/policies/auto_policy.py
colossalai/shardformer/policies/auto_policy.py
+3
-0
colossalai/shardformer/policies/chatglm.py
colossalai/shardformer/policies/chatglm.py
+0
-44
tests/test_shardformer/test_model/test_shard_chatglm.py
tests/test_shardformer/test_model/test_shard_chatglm.py
+1
-0
No files found.
colossalai/shardformer/policies/auto_policy.py
View file @
1a29e8fc
...
...
@@ -116,6 +116,9 @@ _POLICY_LIST = {
# Sam
"transformers.models.sam.modeling_sam.SamModel"
:
PolicyLocation
(
file_name
=
"sam"
,
class_name
=
"SamModelPolicy"
),
# ChatGLM
"tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm.ChatGLMModel"
:
PolicyLocation
(
file_name
=
"chatglm"
,
class_name
=
"ChatGLMModelPolicy"
),
}
...
...
colossalai/shardformer/policies/chatglm.py
View file @
1a29e8fc
from
typing
import
Dict
,
Union
import
torch.nn
as
nn
from
....tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm
import
ChatGLMModel
,
GLMBlock
import
colossalai.shardformer.layer
as
col_nn
...
...
@@ -9,49 +8,6 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes
__all__
=
[
'ChatGLMModelPolicy'
,
'ChatGLMForConditionalGenerationPolicy'
]
class
ChatGLMModelPolicy
(
Policy
):
def
config_sanity_check
(
self
):
pass
def
preprocess
(
self
):
# Resize embedding
vocab_size
=
self
.
model
.
config
.
vocab_size
world_size
=
self
.
shard_config
.
tensor_parallel_size
if
vocab_size
%
world_size
!=
0
:
new_vocab_size
=
vocab_size
+
world_size
-
vocab_size
%
world_size
self
.
model
.
resize_token_embeddings
(
new_vocab_size
)
return
self
.
model
def
module_policy
(
self
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
from
....tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm
import
ChatGLMModel
,
GLMBlock
policy
=
{}
if
self
.
shard_config
.
enable_tensor_parallelism
:
policy
[
GLMBlock
]
=
ModulePolicyDescription
(
attribute_replacement
=
{},
sub_module_replacement
=
[
# SubModuleReplacementDescription(
# suffix = "self_attention.query_key_value",
# target_module = col_nn.Linear1D_Col,
# ),
# SubModuleReplacementDescription(
# suffix = "self_attention.dense",
# target_module = col_nn.Linear1D_Row,
# )
# SubModuleReplacementDescription(
# suffix = "self_attention.core_attention.attention_dropout",
# target_module = col_nn.DropoutForParallelInput,
# )
],)
def
postprocess
(
self
):
return
self
.
model
class
ChatGLMModelPolicy
(
Policy
):
...
...
tests/test_shardformer/test_model/test_shard_chatglm.py
View file @
1a29e8fc
...
...
@@ -19,6 +19,7 @@ from colossalai.testing import (
from
tests.kit.model_zoo
import
model_zoo
from
tests.test_shardformer.test_model._utils
import
build_model
,
run_forward
def
check_forward_backward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
):
# check forward
org_output
,
org_loss
,
shard_output
,
shard_loss
=
run_forward
(
org_model
,
sharded_model
,
data_gen_fn
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment