Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
df018fc3
Commit
df018fc3
authored
Jun 16, 2023
by
FoolPlayer
Committed by
Frank Lee
Jul 04, 2023
Browse files
support bert with new api
parent
507c0ad3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
3 deletions
+37
-3
colossalai/shardformer/policies/bert.py
colossalai/shardformer/policies/bert.py
+34
-1
colossalai/shardformer/shard/sharder.py
colossalai/shardformer/shard/sharder.py
+3
-2
No files found.
colossalai/shardformer/policies/bert.py
View file @
df018fc3
...
@@ -2,6 +2,7 @@ import torch.nn as nn
...
@@ -2,6 +2,7 @@ import torch.nn as nn
from
transformers.models.bert.modeling_bert
import
BertEmbeddings
,
BertLayer
,
BertLMPredictionHead
from
transformers.models.bert.modeling_bert
import
BertEmbeddings
,
BertLayer
,
BertLMPredictionHead
import
colossalai.shardformer.layer.layers
as
col_nn
import
colossalai.shardformer.layer.layers
as
col_nn
from
colossalai.shardformer.layer.dropout
import
Dropout1D
from
..shard.shard_config
import
ShardConfig
from
..shard.shard_config
import
ShardConfig
from
..utils
import
getattr_
,
setattr_
from
..utils
import
getattr_
,
setattr_
...
@@ -65,7 +66,24 @@ class BertPolicy(Policy):
...
@@ -65,7 +66,24 @@ class BertPolicy(Policy):
suffix
=
"output.dense"
,
suffix
=
"output.dense"
,
target_module
=
col_nn
.
Linear1D_Row
,
target_module
=
col_nn
.
Linear1D_Row
,
),
),
])
SubModuleReplacementDescription
(
suffix
=
"attention.self.dropout"
,
target_module
=
Dropout1D
,
),
SubModuleReplacementDescription
(
suffix
=
"attention.output.dropout"
,
target_module
=
Dropout1D
,
)
]),
BertEmbeddings
:
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"word_embeddings"
,
target_module
=
col_nn
.
VocabParallelEmbedding1D
,
)
])
}
}
def
new_model_class
(
self
):
def
new_model_class
(
self
):
...
@@ -87,6 +105,21 @@ class BertForMaskedLMPolicy(BertPolicy):
...
@@ -87,6 +105,21 @@ class BertForMaskedLMPolicy(BertPolicy):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
()
super
().
__init__
()
def
module_policy
(
self
,
shard_config
:
ShardConfig
=
None
):
module_policy
=
super
().
module_policy
(
shard_config
)
addon_module
=
{
BertLMPredictionHead
:
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"decoder"
,
target_module
=
col_nn
.
Linear1D_Col
,
kwargs
=
{
"gather_output"
:
True
})
])
}
module_policy
.
update
(
addon_module
)
return
module_policy
# BertLMHeadModel
# BertLMHeadModel
class
BertLMHeadModelPolicy
(
BertPolicy
):
class
BertLMHeadModelPolicy
(
BertPolicy
):
...
...
colossalai/shardformer/shard/sharder.py
View file @
df018fc3
...
@@ -171,12 +171,13 @@ class ModelSharder(object):
...
@@ -171,12 +171,13 @@ class ModelSharder(object):
for
description
in
sub_module_replacement
:
for
description
in
sub_module_replacement
:
suffix
=
description
.
suffix
suffix
=
description
.
suffix
target_module
=
description
.
target_module
target_module
=
description
.
target_module
kwargs
=
description
.
kwargs
kwargs
=
{}
if
description
.
kwargs
is
None
else
description
.
kwargs
assert
target_module
is
not
None
,
'target_module should not be None'
assert
target_module
is
not
None
,
'target_module should not be None'
# TODO: support different parallel mode
# TODO: support different parallel mode
native_sub_module
=
getattr_
(
org_layer
,
suffix
)
native_sub_module
=
getattr_
(
org_layer
,
suffix
)
replace_layer
=
target_module
.
from_native_module
(
native_sub_module
,
self
.
pg_manager
.
pg_store
[
'tp1d'
])
replace_layer
=
target_module
.
from_native_module
(
native_sub_module
,
self
.
pg_manager
.
pg_store
[
'tp1d'
],
**
kwargs
)
setattr_
(
org_layer
,
suffix
,
replace_layer
)
setattr_
(
org_layer
,
suffix
,
replace_layer
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment