Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8620009d
Commit
8620009d
authored
Jul 10, 2023
by
klhhhhh
Committed by
Hongxin Liu
Aug 15, 2023
Browse files
[sharformer] add first version of policy of chatglm
parent
6ee4c9ee
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
1 deletion
+44
-1
colossalai/shardformer/policies/chatglm.py
colossalai/shardformer/policies/chatglm.py
+44
-0
tests/test_shardformer/test_model/test_shard_chatglm.py
tests/test_shardformer/test_model/test_shard_chatglm.py
+0
-1
No files found.
colossalai/shardformer/policies/chatglm.py
View file @
8620009d
from
typing
import
Dict
,
Union
import
torch.nn
as
nn
from
....tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm
import
ChatGLMModel
,
GLMBlock
import
colossalai.shardformer.layer
as
col_nn
...
...
@@ -8,6 +9,49 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDes
__all__
=
[
'ChatGLMModelPolicy'
,
'ChatGLMForConditionalGenerationPolicy'
]
class
ChatGLMModelPolicy
(
Policy
):
def
config_sanity_check
(
self
):
pass
def
preprocess
(
self
):
# Resize embedding
vocab_size
=
self
.
model
.
config
.
vocab_size
world_size
=
self
.
shard_config
.
tensor_parallel_size
if
vocab_size
%
world_size
!=
0
:
new_vocab_size
=
vocab_size
+
world_size
-
vocab_size
%
world_size
self
.
model
.
resize_token_embeddings
(
new_vocab_size
)
return
self
.
model
def
module_policy
(
self
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
from
....tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm
import
ChatGLMModel
,
GLMBlock
policy
=
{}
if
self
.
shard_config
.
enable_tensor_parallelism
:
policy
[
GLMBlock
]
=
ModulePolicyDescription
(
attribute_replacement
=
{},
sub_module_replacement
=
[
# SubModuleReplacementDescription(
# suffix = "self_attention.query_key_value",
# target_module = col_nn.Linear1D_Col,
# ),
# SubModuleReplacementDescription(
# suffix = "self_attention.dense",
# target_module = col_nn.Linear1D_Row,
# )
# SubModuleReplacementDescription(
# suffix = "self_attention.core_attention.attention_dropout",
# target_module = col_nn.DropoutForParallelInput,
# )
],)
def
postprocess
(
self
):
return
self
.
model
class
ChatGLMModelPolicy
(
Policy
):
...
...
tests/test_shardformer/test_model/test_shard_chatglm.py
View file @
8620009d
...
...
@@ -19,7 +19,6 @@ from colossalai.testing import (
from
tests.kit.model_zoo
import
model_zoo
from
tests.test_shardformer.test_model._utils
import
build_model
,
run_forward
def
check_forward_backward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
):
# check forward
org_output
,
org_loss
,
shard_output
,
shard_loss
=
run_forward
(
org_model
,
sharded_model
,
data_gen_fn
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment