Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
74d176c8
Commit
74d176c8
authored
Jun 19, 2023
by
FoolPlayer
Committed by
Frank Lee
Jul 04, 2023
Browse files
[shardformer] fix bert and gpt downstream with new api (#4024)
* fix bert downstream with new api * remove comment line
parent
e253a070
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
97 additions
and
39 deletions
+97
-39
colossalai/shardformer/policies/basepolicy.py
colossalai/shardformer/policies/basepolicy.py
+12
-2
colossalai/shardformer/policies/bert.py
colossalai/shardformer/policies/bert.py
+73
-19
colossalai/shardformer/shard/shard_config.py
colossalai/shardformer/shard/shard_config.py
+3
-3
colossalai/shardformer/shard/sharder.py
colossalai/shardformer/shard/sharder.py
+5
-4
colossalai/shardformer/shard/shardformer.py
colossalai/shardformer/shard/shardformer.py
+0
-4
tests/test_shardformer/test_model/test_shard_bert.py
tests/test_shardformer/test_model/test_shard_bert.py
+4
-7
No files found.
colossalai/shardformer/policies/basepolicy.py
View file @
74d176c8
...
@@ -76,6 +76,7 @@ class Policy(ABC):
...
@@ -76,6 +76,7 @@ class Policy(ABC):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
self
.
model
=
None
self
.
model
=
None
self
.
shard_config
=
None
def
set_model
(
self
,
model
:
nn
.
Module
)
->
None
:
def
set_model
(
self
,
model
:
nn
.
Module
)
->
None
:
r
"""
r
"""
...
@@ -86,14 +87,23 @@ class Policy(ABC):
...
@@ -86,14 +87,23 @@ class Policy(ABC):
"""
"""
self
.
model
=
model
self
.
model
=
model
def
set_shard_config
(
self
,
shard_config
:
ShardConfig
)
->
None
:
r
"""
Set shard config as an attribute of the Policy object.
Args:
shard_config (:class:`ShardConfig`): The shard config to be perform
"""
self
.
shard_config
=
shard_config
@
abstractmethod
@
abstractmethod
def
preprocess
(
self
,
shard_config
:
ShardConfig
=
None
)
->
nn
.
Module
:
def
preprocess
(
self
)
->
nn
.
Module
:
r
"""
r
"""
Perform some preprocessing of the model, like reshaping the embedding layer
Perform some preprocessing of the model, like reshaping the embedding layer
"""
"""
@
abstractmethod
@
abstractmethod
def
module_policy
(
self
,
shard_config
:
ShardConfig
=
None
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
def
module_policy
(
self
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
r
"""
r
"""
Return the dict for the modify policy, the key is the original layer class and the value is the
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer
argument for the modify layer
...
...
colossalai/shardformer/policies/bert.py
View file @
74d176c8
...
@@ -4,41 +4,40 @@ from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, Be
...
@@ -4,41 +4,40 @@ from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, Be
import
colossalai.shardformer.layer.layers
as
col_nn
import
colossalai.shardformer.layer.layers
as
col_nn
from
colossalai.shardformer.layer.dropout
import
Dropout1D
from
colossalai.shardformer.layer.dropout
import
Dropout1D
from
..shard.shard_config
import
ShardConfig
from
..utils
import
getattr_
,
setattr_
from
..utils
import
getattr_
,
setattr_
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
class
BertPolicy
(
Policy
):
class
BertPolicy
(
Policy
):
def
preprocess
(
self
,
shard_config
:
ShardConfig
=
None
):
def
preprocess
(
self
):
# reshape the embedding layer
# reshape the embedding layer
r
"""
r
"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
"""
# TODO:
# TODO:
vocab_size
=
self
.
model
.
config
.
vocab_size
vocab_size
=
self
.
model
.
config
.
vocab_size
world_size
=
shard_config
.
tensor_parallel_size
world_size
=
self
.
shard_config
.
tensor_parallel_size
if
vocab_size
%
world_size
!=
0
:
if
vocab_size
%
world_size
!=
0
:
new_vocab_size
=
vocab_size
+
world_size
-
vocab_size
%
world_size
new_vocab_size
=
vocab_size
+
world_size
-
vocab_size
%
world_size
self
.
model
.
resize_token_embeddings
(
new_vocab_size
)
self
.
model
.
resize_token_embeddings
(
new_vocab_size
)
return
self
.
model
return
self
.
model
def
module_policy
(
self
,
shard_config
:
ShardConfig
=
None
):
def
module_policy
(
self
):
return
{
return
{
BertLayer
:
BertLayer
:
ModulePolicyDescription
(
ModulePolicyDescription
(
attribute_replacement
=
{
attribute_replacement
=
{
# 1. shard hidden size
# 1. shard hidden size
"attention.self.all_head_size"
:
"attention.self.all_head_size"
:
self
.
model
.
config
.
hidden_size
//
shard_config
.
tensor_parallel_size
,
self
.
model
.
config
.
hidden_size
//
self
.
shard_config
.
tensor_parallel_size
,
"crossattention.self.all_head_size"
:
"crossattention.self.all_head_size"
:
self
.
model
.
config
.
hidden_size
//
shard_config
.
tensor_parallel_size
,
self
.
model
.
config
.
hidden_size
//
self
.
shard_config
.
tensor_parallel_size
,
# 2. shard number of heads
# 2. shard number of heads
"attention.self.num_attention_heads"
:
"attention.self.num_attention_heads"
:
self
.
model
.
config
.
num_attention_heads
//
shard_config
.
tensor_parallel_size
,
self
.
model
.
config
.
num_attention_heads
//
self
.
shard_config
.
tensor_parallel_size
,
"crossattention.self.num_attention_heads"
:
"crossattention.self.num_attention_heads"
:
self
.
model
.
config
.
num_attention_heads
//
shard_config
.
tensor_parallel_size
,
self
.
model
.
config
.
num_attention_heads
//
self
.
shard_config
.
tensor_parallel_size
,
},
},
param_replacement
=
[],
param_replacement
=
[],
sub_module_replacement
=
[
sub_module_replacement
=
[
...
@@ -100,13 +99,43 @@ class BertPolicy(Policy):
...
@@ -100,13 +99,43 @@ class BertPolicy(Policy):
return
self
.
model
return
self
.
model
# BertModel
class
BertModelPolicy
(
BertPolicy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# BertForPreTraining
class
BertForPretrainingPolicy
(
BertPolicy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
def
module_policy
(
self
):
module_policy
=
super
().
module_policy
()
addon_module
=
{
BertLMPredictionHead
:
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"decoder"
,
target_module
=
col_nn
.
Linear1D_Col
,
kwargs
=
{
"gather_output"
:
True
})
])
}
module_policy
.
update
(
addon_module
)
return
module_policy
# BertForMaskedLM
class
BertForMaskedLMPolicy
(
BertPolicy
):
class
BertForMaskedLMPolicy
(
BertPolicy
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
()
super
().
__init__
()
def
module_policy
(
self
,
shard_config
:
ShardConfig
=
None
):
def
module_policy
(
self
):
module_policy
=
super
().
module_policy
(
shard_config
)
module_policy
=
super
().
module_policy
()
addon_module
=
{
addon_module
=
{
BertLMPredictionHead
:
BertLMPredictionHead
:
ModulePolicyDescription
(
attribute_replacement
=
{},
ModulePolicyDescription
(
attribute_replacement
=
{},
...
@@ -124,16 +153,41 @@ class BertForMaskedLMPolicy(BertPolicy):
...
@@ -124,16 +153,41 @@ class BertForMaskedLMPolicy(BertPolicy):
# BertLMHeadModel
# BertLMHeadModel
class
BertLMHeadModelPolicy
(
BertPolicy
):
class
BertLMHeadModelPolicy
(
BertPolicy
):
@
staticmethod
def
__init__
(
self
)
->
None
:
def
argument_policy
(
config
,
world_size
):
super
().
__init__
()
base_argument
=
BertPolicy
.
argument_policy
(
config
,
world_size
)
argument
=
{
def
module_policy
(
self
):
BertLMPredictionHead
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
module_policy
=
super
().
module_policy
()
BertPolicy
.
unembedding
,
addon_module
=
{
]),
BertLMPredictionHead
:
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"decoder"
,
target_module
=
col_nn
.
Linear1D_Col
,
kwargs
=
{
"gather_output"
:
True
})
])
}
}
argument
.
update
(
base_argument
)
module_policy
.
update
(
addon_module
)
return
argument
return
module_policy
# BertForNextSentencePrediction
class
BertForNextSentencePredictionPolicy
(
BertPolicy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# BertForSequenceClassification
class
BertForSequenceClassificationPolicy
(
BertPolicy
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# BertForMultipleChoice
class
BertForMultipleChoicePolicy
(
BertPolicy
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
()
super
().
__init__
()
colossalai/shardformer/shard/shard_config.py
View file @
74d176c8
...
@@ -18,10 +18,10 @@ class ShardConfig:
...
@@ -18,10 +18,10 @@ class ShardConfig:
will not calculate the loss and just return the output.
will not calculate the loss and just return the output.
gather_output (bool): Whether to gather the output of the model of the last layer
gather_output (bool): Whether to gather the output of the model of the last layer
"""
"""
data_parallel_size
:
int
tensor_parallel_size
:
int
tensor_parallel_size
:
int
# TODO: add support for tensor parallel
pipeline_parallel_size
:
int
# pipeline_parallel_size: int
# data_parallel_size: int
tensor_parallel_mode
:
Literal
[
'1d'
,
'2d'
,
'2.5d'
,
'3d'
]
tensor_parallel_mode
:
Literal
[
'1d'
,
'2d'
,
'2.5d'
,
'3d'
]
inference_only
:
bool
=
True
inference_only
:
bool
=
True
gather_output
:
bool
=
True
gather_output
:
bool
=
True
colossalai/shardformer/shard/sharder.py
View file @
74d176c8
...
@@ -40,6 +40,7 @@ class ModelSharder(object):
...
@@ -40,6 +40,7 @@ class ModelSharder(object):
Shard the model according to the policy
Shard the model according to the policy
"""
"""
self
.
policy
.
set_model
(
self
.
model
)
self
.
policy
.
set_model
(
self
.
model
)
self
.
policy
.
set_shard_config
(
self
.
shard_config
)
self
.
preprocess
()
self
.
preprocess
()
self
.
replace_model_class
()
self
.
replace_model_class
()
self
.
replace_module
()
self
.
replace_module
()
...
@@ -57,12 +58,12 @@ class ModelSharder(object):
...
@@ -57,12 +58,12 @@ class ModelSharder(object):
self
.
model_config
=
self
.
model
.
config
self
.
model_config
=
self
.
model
.
config
def
preprocess
(
self
)
->
None
:
def
preprocess
(
self
)
->
None
:
self
.
model
=
self
.
policy
.
preprocess
(
self
.
shard_config
)
self
.
model
=
self
.
policy
.
preprocess
()
def
postprocess
(
self
)
->
None
:
def
postprocess
(
self
)
->
None
:
self
.
model
=
self
.
policy
.
postprocess
()
self
.
model
=
self
.
policy
.
postprocess
()
def
replace_model_class
(
self
,
)
->
None
:
def
replace_model_class
(
self
)
->
None
:
r
"""
r
"""
Replace the model to policy defined model
Replace the model to policy defined model
Mainly modify the forward and backward to fit distributed model
Mainly modify the forward and backward to fit distributed model
...
@@ -83,14 +84,14 @@ class ModelSharder(object):
...
@@ -83,14 +84,14 @@ class ModelSharder(object):
getattr
(
new_model_class
,
key
),
getattr
(
new_model_class
,
key
),
)
)
def
replace_module
(
self
,
)
->
None
:
def
replace_module
(
self
)
->
None
:
r
"""
r
"""
Replace the module according to the policy, and replace the module one by one
Replace the module according to the policy, and replace the module one by one
Args:
Args:
model (:class:`torch.nn.Module`): The model to shard
model (:class:`torch.nn.Module`): The model to shard
"""
"""
module_descriptions
=
self
.
policy
.
module_policy
(
self
.
shard_config
)
module_descriptions
=
self
.
policy
.
module_policy
()
for
module_description
in
module_descriptions
.
items
():
for
module_description
in
module_descriptions
.
items
():
origin_layer_cls
=
module_description
[
0
]
origin_layer_cls
=
module_description
[
0
]
attr_replacement
=
module_description
[
1
].
attribute_replacement
attr_replacement
=
module_description
[
1
].
attribute_replacement
...
...
colossalai/shardformer/shard/shardformer.py
View file @
74d176c8
...
@@ -25,11 +25,7 @@ class ShardFormer:
...
@@ -25,11 +25,7 @@ class ShardFormer:
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
shard_config = ShardConfig(
shard_config = ShardConfig(
tensor_parallel_size=2,
tensor_parallel_size=2,
data_parallel_size=1,
pipeline_parallel_size=1,
tensor_parallel_mode='1d',
tensor_parallel_mode='1d',
inference_only=True,
gather_output=True
)
)
shard_former = ShardFormer(shard_config=shard_config)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
shard_former.init_distributed()
...
...
tests/test_shardformer/test_model/test_shard_bert.py
View file @
74d176c8
...
@@ -7,7 +7,6 @@ from transformers import (
...
@@ -7,7 +7,6 @@ from transformers import (
AutoTokenizer
,
AutoTokenizer
,
BertConfig
,
BertConfig
,
BertForMaskedLM
,
BertForMaskedLM
,
BertForMultipleChoice
,
BertForNextSentencePrediction
,
BertForNextSentencePrediction
,
BertForPreTraining
,
BertForPreTraining
,
BertForSequenceClassification
,
BertForSequenceClassification
,
...
@@ -36,12 +35,10 @@ def build_model(rank, world_size, model):
...
@@ -36,12 +35,10 @@ def build_model(rank, world_size, model):
org_model
.
to
(
'cuda'
)
org_model
.
to
(
'cuda'
)
# TODO: no need to transfer to cuda
# TODO: no need to transfer to cuda
org_model_forshard
.
to
(
'cuda'
)
org_model_forshard
.
to
(
'cuda'
)
shard_config
=
ShardConfig
(
tensor_parallel_size
=
2
,
shard_config
=
ShardConfig
(
data_parallel_size
=
1
,
tensor_parallel_size
=
2
,
pipeline_parallel_size
=
1
,
tensor_parallel_mode
=
'1d'
,
tensor_parallel_mode
=
'1d'
,
)
inference_only
=
True
,
gather_output
=
True
)
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
shard_former
.
init_distributed
()
shard_former
.
init_distributed
()
sharded_model
=
shard_former
.
shard_model
(
org_model_forshard
).
to
(
'cuda'
)
sharded_model
=
shard_former
.
shard_model
(
org_model_forshard
).
to
(
'cuda'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment