Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
74d176c8
Commit
74d176c8
authored
Jun 19, 2023
by
FoolPlayer
Committed by
Frank Lee
Jul 04, 2023
Browse files
[shardformer] fix bert and gpt downstream with new api (#4024)
* fix bert downstream with new api * remove comment line
parent
e253a070
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
97 additions
and
39 deletions
+97
-39
colossalai/shardformer/policies/basepolicy.py
colossalai/shardformer/policies/basepolicy.py
+12
-2
colossalai/shardformer/policies/bert.py
colossalai/shardformer/policies/bert.py
+73
-19
colossalai/shardformer/shard/shard_config.py
colossalai/shardformer/shard/shard_config.py
+3
-3
colossalai/shardformer/shard/sharder.py
colossalai/shardformer/shard/sharder.py
+5
-4
colossalai/shardformer/shard/shardformer.py
colossalai/shardformer/shard/shardformer.py
+0
-4
tests/test_shardformer/test_model/test_shard_bert.py
tests/test_shardformer/test_model/test_shard_bert.py
+4
-7
No files found.
colossalai/shardformer/policies/basepolicy.py
View file @
74d176c8
...
...
@@ -76,6 +76,7 @@ class Policy(ABC):
def
__init__
(
self
)
->
None
:
self
.
model
=
None
self
.
shard_config
=
None
def
set_model
(
self
,
model
:
nn
.
Module
)
->
None
:
r
"""
...
...
@@ -86,14 +87,23 @@ class Policy(ABC):
"""
self
.
model
=
model
def
set_shard_config
(
self
,
shard_config
:
ShardConfig
)
->
None
:
r
"""
Set shard config as an attribute of the Policy object.
Args:
shard_config (:class:`ShardConfig`): The shard config to be perform
"""
self
.
shard_config
=
shard_config
@
abstractmethod
def
preprocess
(
self
,
shard_config
:
ShardConfig
=
None
)
->
nn
.
Module
:
def
preprocess
(
self
)
->
nn
.
Module
:
r
"""
Perform some preprocessing of the model, like reshaping the embedding layer
"""
@
abstractmethod
def
module_policy
(
self
,
shard_config
:
ShardConfig
=
None
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
def
module_policy
(
self
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
r
"""
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer
...
...
colossalai/shardformer/policies/bert.py
View file @
74d176c8
...
...
@@ -4,41 +4,40 @@ from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, Be
import
colossalai.shardformer.layer.layers
as
col_nn
from
colossalai.shardformer.layer.dropout
import
Dropout1D
from
..shard.shard_config
import
ShardConfig
from
..utils
import
getattr_
,
setattr_
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
class
BertPolicy
(
Policy
):
def
preprocess
(
self
,
shard_config
:
ShardConfig
=
None
):
def
preprocess
(
self
):
# reshape the embedding layer
r
"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
vocab_size
=
self
.
model
.
config
.
vocab_size
world_size
=
shard_config
.
tensor_parallel_size
world_size
=
self
.
shard_config
.
tensor_parallel_size
if
vocab_size
%
world_size
!=
0
:
new_vocab_size
=
vocab_size
+
world_size
-
vocab_size
%
world_size
self
.
model
.
resize_token_embeddings
(
new_vocab_size
)
return
self
.
model
def
module_policy
(
self
,
shard_config
:
ShardConfig
=
None
):
def
module_policy
(
self
):
return
{
BertLayer
:
ModulePolicyDescription
(
attribute_replacement
=
{
# 1. shard hidden size
"attention.self.all_head_size"
:
self
.
model
.
config
.
hidden_size
//
shard_config
.
tensor_parallel_size
,
self
.
model
.
config
.
hidden_size
//
self
.
shard_config
.
tensor_parallel_size
,
"crossattention.self.all_head_size"
:
self
.
model
.
config
.
hidden_size
//
shard_config
.
tensor_parallel_size
,
self
.
model
.
config
.
hidden_size
//
self
.
shard_config
.
tensor_parallel_size
,
# 2. shard number of heads
"attention.self.num_attention_heads"
:
self
.
model
.
config
.
num_attention_heads
//
shard_config
.
tensor_parallel_size
,
self
.
model
.
config
.
num_attention_heads
//
self
.
shard_config
.
tensor_parallel_size
,
"crossattention.self.num_attention_heads"
:
self
.
model
.
config
.
num_attention_heads
//
shard_config
.
tensor_parallel_size
,
self
.
model
.
config
.
num_attention_heads
//
self
.
shard_config
.
tensor_parallel_size
,
},
param_replacement
=
[],
sub_module_replacement
=
[
...
...
@@ -100,13 +99,43 @@ class BertPolicy(Policy):
return
self
.
model
# BertModel
class
BertModelPolicy
(
BertPolicy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# BertForPreTraining
class
BertForPretrainingPolicy
(
BertPolicy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
def
module_policy
(
self
):
module_policy
=
super
().
module_policy
()
addon_module
=
{
BertLMPredictionHead
:
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"decoder"
,
target_module
=
col_nn
.
Linear1D_Col
,
kwargs
=
{
"gather_output"
:
True
})
])
}
module_policy
.
update
(
addon_module
)
return
module_policy
# BertForMaskedLM
class
BertForMaskedLMPolicy
(
BertPolicy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
def
module_policy
(
self
,
shard_config
:
ShardConfig
=
None
):
module_policy
=
super
().
module_policy
(
shard_config
)
def
module_policy
(
self
):
module_policy
=
super
().
module_policy
()
addon_module
=
{
BertLMPredictionHead
:
ModulePolicyDescription
(
attribute_replacement
=
{},
...
...
@@ -124,16 +153,41 @@ class BertForMaskedLMPolicy(BertPolicy):
# BertLMHeadModel
class
BertLMHeadModelPolicy
(
BertPolicy
):
@
staticmethod
def
argument_policy
(
config
,
world_size
):
base_argument
=
BertPolicy
.
argument_policy
(
config
,
world_size
)
argument
=
{
BertLMPredictionHead
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
BertPolicy
.
unembedding
,
]),
def
__init__
(
self
)
->
None
:
super
().
__init__
()
def
module_policy
(
self
):
module_policy
=
super
().
module_policy
()
addon_module
=
{
BertLMPredictionHead
:
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"decoder"
,
target_module
=
col_nn
.
Linear1D_Col
,
kwargs
=
{
"gather_output"
:
True
})
])
}
argument
.
update
(
base_argument
)
return
argument
module_policy
.
update
(
addon_module
)
return
module_policy
# BertForNextSentencePrediction
class
BertForNextSentencePredictionPolicy
(
BertPolicy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# BertForSequenceClassification
class
BertForSequenceClassificationPolicy
(
BertPolicy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# BertForMultipleChoice
class
BertForMultipleChoicePolicy
(
BertPolicy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
colossalai/shardformer/shard/shard_config.py
View file @
74d176c8
...
...
@@ -18,10 +18,10 @@ class ShardConfig:
will not calculate the loss and just return the output.
gather_output (bool): Whether to gather the output of the model of the last layer
"""
data_parallel_size
:
int
tensor_parallel_size
:
int
pipeline_parallel_size
:
int
# TODO: add support for tensor parallel
# pipeline_parallel_size: int
# data_parallel_size: int
tensor_parallel_mode
:
Literal
[
'1d'
,
'2d'
,
'2.5d'
,
'3d'
]
inference_only
:
bool
=
True
gather_output
:
bool
=
True
colossalai/shardformer/shard/sharder.py
View file @
74d176c8
...
...
@@ -40,6 +40,7 @@ class ModelSharder(object):
Shard the model according to the policy
"""
self
.
policy
.
set_model
(
self
.
model
)
self
.
policy
.
set_shard_config
(
self
.
shard_config
)
self
.
preprocess
()
self
.
replace_model_class
()
self
.
replace_module
()
...
...
@@ -57,12 +58,12 @@ class ModelSharder(object):
self
.
model_config
=
self
.
model
.
config
def
preprocess
(
self
)
->
None
:
self
.
model
=
self
.
policy
.
preprocess
(
self
.
shard_config
)
self
.
model
=
self
.
policy
.
preprocess
()
def
postprocess
(
self
)
->
None
:
self
.
model
=
self
.
policy
.
postprocess
()
def
replace_model_class
(
self
,
)
->
None
:
def
replace_model_class
(
self
)
->
None
:
r
"""
Replace the model to policy defined model
Mainly modify the forward and backward to fit distributed model
...
...
@@ -83,14 +84,14 @@ class ModelSharder(object):
getattr
(
new_model_class
,
key
),
)
def
replace_module
(
self
,
)
->
None
:
def
replace_module
(
self
)
->
None
:
r
"""
Replace the module according to the policy, and replace the module one by one
Args:
model (:class:`torch.nn.Module`): The model to shard
"""
module_descriptions
=
self
.
policy
.
module_policy
(
self
.
shard_config
)
module_descriptions
=
self
.
policy
.
module_policy
()
for
module_description
in
module_descriptions
.
items
():
origin_layer_cls
=
module_description
[
0
]
attr_replacement
=
module_description
[
1
].
attribute_replacement
...
...
colossalai/shardformer/shard/shardformer.py
View file @
74d176c8
...
...
@@ -25,11 +25,7 @@ class ShardFormer:
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
shard_config = ShardConfig(
tensor_parallel_size=2,
data_parallel_size=1,
pipeline_parallel_size=1,
tensor_parallel_mode='1d',
inference_only=True,
gather_output=True
)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
...
...
tests/test_shardformer/test_model/test_shard_bert.py
View file @
74d176c8
...
...
@@ -7,7 +7,6 @@ from transformers import (
AutoTokenizer
,
BertConfig
,
BertForMaskedLM
,
BertForMultipleChoice
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForSequenceClassification
,
...
...
@@ -36,12 +35,10 @@ def build_model(rank, world_size, model):
org_model
.
to
(
'cuda'
)
# TODO: no need to transfer to cuda
org_model_forshard
.
to
(
'cuda'
)
shard_config
=
ShardConfig
(
tensor_parallel_size
=
2
,
data_parallel_size
=
1
,
pipeline_parallel_size
=
1
,
shard_config
=
ShardConfig
(
tensor_parallel_size
=
2
,
tensor_parallel_mode
=
'1d'
,
inference_only
=
True
,
gather_output
=
True
)
)
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
shard_former
.
init_distributed
()
sharded_model
=
shard_former
.
shard_model
(
org_model_forshard
).
to
(
'cuda'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment