Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
dfca9678
"...source/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "335cb105e2bd9bf3931477cad4a5c6fc4712277e"
Commit
dfca9678
authored
Jun 16, 2023
by
FoolPlayer
Committed by
Frank Lee
Jul 04, 2023
Browse files
integrate with dist layer (#4011)
parent
015af592
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
24 deletions
+42
-24
colossalai/shardformer/policies/bert.py
colossalai/shardformer/policies/bert.py
+21
-7
colossalai/shardformer/shard/sharder.py
colossalai/shardformer/shard/sharder.py
+7
-8
tests/test_shardformer/test_model/test_shard_bert.py
tests/test_shardformer/test_model/test_shard_bert.py
+14
-9
No files found.
colossalai/shardformer/policies/bert.py
View file @
dfca9678
...
@@ -8,12 +8,6 @@ from ..utils import getattr_, setattr_
...
@@ -8,12 +8,6 @@ from ..utils import getattr_, setattr_
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
class
ParallelModule
():
def
__init__
(
self
):
pass
class
BertPolicy
(
Policy
):
class
BertPolicy
(
Policy
):
def
preprocess
(
self
,
shard_config
:
ShardConfig
=
None
):
def
preprocess
(
self
,
shard_config
:
ShardConfig
=
None
):
...
@@ -49,7 +43,27 @@ class BertPolicy(Policy):
...
@@ -49,7 +43,27 @@ class BertPolicy(Policy):
sub_module_replacement
=
[
sub_module_replacement
=
[
SubModuleReplacementDescription
(
SubModuleReplacementDescription
(
suffix
=
"attention.self.query"
,
suffix
=
"attention.self.query"
,
target_module
=
ParallelModule
,
target_module
=
col_nn
.
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"attention.self.key"
,
target_module
=
col_nn
.
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"attention.self.value"
,
target_module
=
col_nn
.
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"attention.output.dense"
,
target_module
=
col_nn
.
Linear1D_Row
,
),
SubModuleReplacementDescription
(
suffix
=
"intermediate.dense"
,
target_module
=
col_nn
.
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"output.dense"
,
target_module
=
col_nn
.
Linear1D_Row
,
),
),
])
])
}
}
...
...
colossalai/shardformer/shard/sharder.py
View file @
dfca9678
...
@@ -7,8 +7,8 @@ from transformers.pytorch_utils import Conv1D
...
@@ -7,8 +7,8 @@ from transformers.pytorch_utils import Conv1D
from
colossalai.cluster.process_group_manager
import
ProcessGroupManager
from
colossalai.cluster.process_group_manager
import
ProcessGroupManager
from
..policies.autopolicy
import
get_autopolicy
from
..policies.autopolicy
import
get_autopolicy
from
..policies.basepolicy
import
Policy
from
..policies.basepolicy
import
Policy
,
SubModuleReplacementDescription
from
..utils.utils
import
setattr_
from
..utils.utils
import
getattr_
,
setattr_
from
.shard_config
import
ShardConfig
from
.shard_config
import
ShardConfig
__all__
=
[
'ModelSharder'
,
'shard_model'
]
__all__
=
[
'ModelSharder'
,
'shard_model'
]
...
@@ -90,9 +90,7 @@ class ModelSharder(object):
...
@@ -90,9 +90,7 @@ class ModelSharder(object):
Args:
Args:
model (:class:`torch.nn.Module`): The model to shard
model (:class:`torch.nn.Module`): The model to shard
"""
"""
print
(
self
.
policy
)
module_descriptions
=
self
.
policy
.
module_policy
(
self
.
shard_config
)
module_descriptions
=
self
.
policy
.
module_policy
(
self
.
shard_config
)
print
(
f
"*******
{
module_descriptions
}
"
)
for
module_description
in
module_descriptions
.
items
():
for
module_description
in
module_descriptions
.
items
():
origin_layer_cls
=
module_description
[
0
]
origin_layer_cls
=
module_description
[
0
]
attr_replacement
=
module_description
[
1
].
attribute_replacement
attr_replacement
=
module_description
[
1
].
attribute_replacement
...
@@ -160,7 +158,7 @@ class ModelSharder(object):
...
@@ -160,7 +158,7 @@ class ModelSharder(object):
def
_replace_sub_module
(
def
_replace_sub_module
(
self
,
self
,
org_layer
:
nn
.
Module
,
org_layer
:
nn
.
Module
,
sub_module_replacement
:
List
[
Callable
],
sub_module_replacement
:
List
[
SubModuleReplacementDescription
],
)
->
None
:
)
->
None
:
r
"""
r
"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
...
@@ -177,7 +175,8 @@ class ModelSharder(object):
...
@@ -177,7 +175,8 @@ class ModelSharder(object):
assert
target_module
is
not
None
,
'target_module should not be None'
assert
target_module
is
not
None
,
'target_module should not be None'
# TODO: integrate with new layer
# TODO: support different parallel mode
# replace_layer = target_module.from_native_layer(org_layer, self.pg_manager)
native_sub_module
=
getattr_
(
org_layer
,
suffix
)
replace_layer
=
None
replace_layer
=
target_module
.
from_native_module
(
native_sub_module
,
self
.
pg_manager
.
pg_store
[
'tp1d'
])
setattr_
(
org_layer
,
suffix
,
replace_layer
)
setattr_
(
org_layer
,
suffix
,
replace_layer
)
tests/test_shardformer/test_model/test_shard_bert.py
View file @
dfca9678
...
@@ -17,7 +17,7 @@ from transformers import (
...
@@ -17,7 +17,7 @@ from transformers import (
import
colossalai
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer
.shard
import
ShardConfig
,
s
hard
_model
from
colossalai.shardformer
import
ShardConfig
,
S
hard
Former
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
...
@@ -30,16 +30,21 @@ def build_model(rank, world_size, model):
...
@@ -30,16 +30,21 @@ def build_model(rank, world_size, model):
config
.
hidden_dropout_prob
=
0
config
.
hidden_dropout_prob
=
0
config
.
attention_probs_dropout_prob
=
0
config
.
attention_probs_dropout_prob
=
0
org_model
=
model
(
config
=
config
)
org_model
=
BertForMaskedLM
.
from_pretrained
(
'bert-base-uncased'
,
config
=
config
)
org_model_forshard
=
copy
.
deepcopy
(
org_model
)
org_model_forshard
=
copy
.
deepcopy
(
org_model
)
org_model
=
org_model
.
to
(
'cuda'
)
org_model
.
to
(
'cuda'
)
shardconfig
=
ShardConfig
(
# TODO: no need to transfer to cuda
rank
=
rank
,
org_model_forshard
.
to
(
'cuda'
)
world_size
=
world_size
,
shard_config
=
ShardConfig
(
tensor_parallel_size
=
2
,
gather_output
=
True
,
data_parallel_size
=
1
,
)
pipeline_parallel_size
=
1
,
sharded_model
=
shard_model
(
org_model_forshard
,
shardconfig
).
to
(
'cuda'
)
tensor_parallel_mode
=
'1d'
,
inference_only
=
True
,
gather_output
=
True
)
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
shard_former
.
init_distributed
()
sharded_model
=
shard_former
.
shard_model
(
org_model_forshard
).
to
(
'cuda'
)
return
org_model
,
sharded_model
return
org_model
,
sharded_model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment