Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
dfca9678
Commit
dfca9678
authored
Jun 16, 2023
by
FoolPlayer
Committed by
Frank Lee
Jul 04, 2023
Browse files
integrate with dist layer (#4011)
parent
015af592
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
24 deletions
+42
-24
colossalai/shardformer/policies/bert.py
colossalai/shardformer/policies/bert.py
+21
-7
colossalai/shardformer/shard/sharder.py
colossalai/shardformer/shard/sharder.py
+7
-8
tests/test_shardformer/test_model/test_shard_bert.py
tests/test_shardformer/test_model/test_shard_bert.py
+14
-9
No files found.
colossalai/shardformer/policies/bert.py
View file @
dfca9678
...
@@ -8,12 +8,6 @@ from ..utils import getattr_, setattr_
...
@@ -8,12 +8,6 @@ from ..utils import getattr_, setattr_
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
class
ParallelModule
():
def
__init__
(
self
):
pass
class
BertPolicy
(
Policy
):
class
BertPolicy
(
Policy
):
def
preprocess
(
self
,
shard_config
:
ShardConfig
=
None
):
def
preprocess
(
self
,
shard_config
:
ShardConfig
=
None
):
...
@@ -49,7 +43,27 @@ class BertPolicy(Policy):
...
@@ -49,7 +43,27 @@ class BertPolicy(Policy):
sub_module_replacement
=
[
sub_module_replacement
=
[
SubModuleReplacementDescription
(
SubModuleReplacementDescription
(
suffix
=
"attention.self.query"
,
suffix
=
"attention.self.query"
,
target_module
=
ParallelModule
,
target_module
=
col_nn
.
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"attention.self.key"
,
target_module
=
col_nn
.
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"attention.self.value"
,
target_module
=
col_nn
.
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"attention.output.dense"
,
target_module
=
col_nn
.
Linear1D_Row
,
),
SubModuleReplacementDescription
(
suffix
=
"intermediate.dense"
,
target_module
=
col_nn
.
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"output.dense"
,
target_module
=
col_nn
.
Linear1D_Row
,
),
),
])
])
}
}
...
...
colossalai/shardformer/shard/sharder.py
View file @
dfca9678
...
@@ -7,8 +7,8 @@ from transformers.pytorch_utils import Conv1D
...
@@ -7,8 +7,8 @@ from transformers.pytorch_utils import Conv1D
from
colossalai.cluster.process_group_manager
import
ProcessGroupManager
from
colossalai.cluster.process_group_manager
import
ProcessGroupManager
from
..policies.autopolicy
import
get_autopolicy
from
..policies.autopolicy
import
get_autopolicy
from
..policies.basepolicy
import
Policy
from
..policies.basepolicy
import
Policy
,
SubModuleReplacementDescription
from
..utils.utils
import
setattr_
from
..utils.utils
import
getattr_
,
setattr_
from
.shard_config
import
ShardConfig
from
.shard_config
import
ShardConfig
__all__
=
[
'ModelSharder'
,
'shard_model'
]
__all__
=
[
'ModelSharder'
,
'shard_model'
]
...
@@ -90,9 +90,7 @@ class ModelSharder(object):
...
@@ -90,9 +90,7 @@ class ModelSharder(object):
Args:
Args:
model (:class:`torch.nn.Module`): The model to shard
model (:class:`torch.nn.Module`): The model to shard
"""
"""
print
(
self
.
policy
)
module_descriptions
=
self
.
policy
.
module_policy
(
self
.
shard_config
)
module_descriptions
=
self
.
policy
.
module_policy
(
self
.
shard_config
)
print
(
f
"*******
{
module_descriptions
}
"
)
for
module_description
in
module_descriptions
.
items
():
for
module_description
in
module_descriptions
.
items
():
origin_layer_cls
=
module_description
[
0
]
origin_layer_cls
=
module_description
[
0
]
attr_replacement
=
module_description
[
1
].
attribute_replacement
attr_replacement
=
module_description
[
1
].
attribute_replacement
...
@@ -160,7 +158,7 @@ class ModelSharder(object):
...
@@ -160,7 +158,7 @@ class ModelSharder(object):
def
_replace_sub_module
(
def
_replace_sub_module
(
self
,
self
,
org_layer
:
nn
.
Module
,
org_layer
:
nn
.
Module
,
sub_module_replacement
:
List
[
Callable
],
sub_module_replacement
:
List
[
SubModuleReplacementDescription
],
)
->
None
:
)
->
None
:
r
"""
r
"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
...
@@ -177,7 +175,8 @@ class ModelSharder(object):
...
@@ -177,7 +175,8 @@ class ModelSharder(object):
assert
target_module
is
not
None
,
'target_module should not be None'
assert
target_module
is
not
None
,
'target_module should not be None'
# TODO: integrate with new layer
# TODO: support different parallel mode
# replace_layer = target_module.from_native_layer(org_layer, self.pg_manager)
native_sub_module
=
getattr_
(
org_layer
,
suffix
)
replace_layer
=
None
replace_layer
=
target_module
.
from_native_module
(
native_sub_module
,
self
.
pg_manager
.
pg_store
[
'tp1d'
])
setattr_
(
org_layer
,
suffix
,
replace_layer
)
setattr_
(
org_layer
,
suffix
,
replace_layer
)
tests/test_shardformer/test_model/test_shard_bert.py
View file @
dfca9678
...
@@ -17,7 +17,7 @@ from transformers import (
...
@@ -17,7 +17,7 @@ from transformers import (
import
colossalai
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer
.shard
import
ShardConfig
,
s
hard
_model
from
colossalai.shardformer
import
ShardConfig
,
S
hard
Former
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
...
@@ -30,16 +30,21 @@ def build_model(rank, world_size, model):
...
@@ -30,16 +30,21 @@ def build_model(rank, world_size, model):
config
.
hidden_dropout_prob
=
0
config
.
hidden_dropout_prob
=
0
config
.
attention_probs_dropout_prob
=
0
config
.
attention_probs_dropout_prob
=
0
org_model
=
model
(
config
=
config
)
org_model
=
BertForMaskedLM
.
from_pretrained
(
'bert-base-uncased'
,
config
=
config
)
org_model_forshard
=
copy
.
deepcopy
(
org_model
)
org_model_forshard
=
copy
.
deepcopy
(
org_model
)
org_model
=
org_model
.
to
(
'cuda'
)
org_model
.
to
(
'cuda'
)
shardconfig
=
ShardConfig
(
# TODO: no need to transfer to cuda
rank
=
rank
,
org_model_forshard
.
to
(
'cuda'
)
world_size
=
world_size
,
shard_config
=
ShardConfig
(
tensor_parallel_size
=
2
,
gather_output
=
True
,
data_parallel_size
=
1
,
)
pipeline_parallel_size
=
1
,
sharded_model
=
shard_model
(
org_model_forshard
,
shardconfig
).
to
(
'cuda'
)
tensor_parallel_mode
=
'1d'
,
inference_only
=
True
,
gather_output
=
True
)
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
shard_former
.
init_distributed
()
sharded_model
=
shard_former
.
shard_model
(
org_model_forshard
).
to
(
'cuda'
)
return
org_model
,
sharded_model
return
org_model
,
sharded_model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment