Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c1d5453e
Commit
c1d5453e
authored
Jun 19, 2023
by
Frank Lee
Browse files
[shardformer] adapted llama to the new API (#4036)
parent
74d176c8
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
245 additions
and
208 deletions
+245
-208
colossalai/shardformer/policies/autopolicy.py
colossalai/shardformer/policies/autopolicy.py
+72
-62
colossalai/shardformer/policies/basepolicy.py
colossalai/shardformer/policies/basepolicy.py
+5
-0
colossalai/shardformer/policies/llama.py
colossalai/shardformer/policies/llama.py
+97
-98
colossalai/shardformer/shard/shard_config.py
colossalai/shardformer/shard/shard_config.py
+15
-4
colossalai/shardformer/shard/sharder.py
colossalai/shardformer/shard/sharder.py
+8
-10
colossalai/shardformer/shard/shardformer.py
colossalai/shardformer/shard/shardformer.py
+4
-2
tests/test_shardformer/test_model/test_shard_bert.py
tests/test_shardformer/test_model/test_shard_bert.py
+16
-12
tests/test_shardformer/test_model/test_shard_llama.py
tests/test_shardformer/test_model/test_shard_llama.py
+26
-19
tests/test_shardformer/test_model/test_shard_t5.py
tests/test_shardformer/test_model/test_shard_t5.py
+2
-1
No files found.
colossalai/shardformer/policies/autopolicy.py
View file @
c1d5453e
import
importlib
from
dataclasses
import
dataclass
import
torch.nn
as
nn
import
torch.nn
as
nn
from
.basepolicy
import
Policy
from
.basepolicy
import
Policy
def
build_policies
():
@
dataclass
r
"""
class
PolicyLocation
:
Build the policies for the model
Return:
The dict for the policies
"""
"""
auto_policy_dict
=
{}
PolicyLocation describes the location of a policy class.
from
transformers
import
BertModel
from
.bert
import
BertModelPolicy
auto_policy_dict
[
BertModel
]
=
BertModelPolicy
from
transformers
import
BertForPreTraining
from
.bert
import
BertForPretrainingPolicy
auto_policy_dict
[
BertForPreTraining
]
=
BertForPretrainingPolicy
from
transformers
import
BertLMHeadModel
from
.bert
import
BertLMHeadModelPolicy
auto_policy_dict
[
BertLMHeadModel
]
=
BertLMHeadModelPolicy
from
transformers
import
BertForMaskedLM
from
.bert
import
BertForMaskedLMPolicy
auto_policy_dict
[
BertForMaskedLM
]
=
BertForMaskedLMPolicy
from
transformers
import
BertForNextSentencePrediction
from
.bert
import
BertForNextSentencePredictionPolicy
Args:
auto_policy_dict
[
BertForNextSentencePrediction
]
=
BertForNextSentencePredictionPolicy
file_name (str): The file name of the policy under colossalai.shardformer.policies
class_name (str): The class name of the policy class
from
transformers
import
BertForSequenceClassification
"""
file_name
:
str
from
.bert
import
BertForSequenceClassificationPolicy
class_name
:
str
auto_policy_dict
[
BertForSequenceClassification
]
=
BertForSequenceClassificationPolicy
from
transformers.models.llama.modeling_llama
import
LlamaModel
# we don't want to import all policies here
# as each policy file imports its own model zoo library
# we will allow the user to only import the policy file needed
_POLICY_LIST
=
{
# BERT
"transformers.models.bert.modeling_bert.BertModel"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertPolicy"
),
"transformers.models.bert.modeling_bert.BertForPreTraining"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertForPretrainingPolicy"
),
"transformers.models.bert.modeling_bert.BertForMaskedLM"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertForMaskedLMPolicy"
),
"transformers.models.bert.modeling_bert.BertLMHeadModel"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertLMHeadModelPolicy"
),
"transformers.models.bert.modeling_bert.BertForNextSentencePrediction"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertForNextSentencePredictionPolicy"
),
"transformers.models.bert.modeling_bert.BertForSequenceClassification"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertForSequenceClassificationPolicy"
),
"transformers.models.bert.modeling_bert.BertForMultipleChoice"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertForMultipleChoicePolicy"
),
# LLaMA
"transformers.models.llama.modeling_llama.LlamaModel"
:
PolicyLocation
(
file_name
=
"llama"
,
class_name
=
"LlamaPolicy"
),
"transformers.models.llama.modeling_llama.LlamaForCausalLM"
:
PolicyLocation
(
file_name
=
"llama"
,
class_name
=
"LlamaForCausalLMPolicy"
),
"transformers.models.llama.modeling_llama.LlamaForSequenceClassification"
:
PolicyLocation
(
file_name
=
"llama"
,
class_name
=
"LlamaForSequenceClassificationPolicy"
),
# T5
# GPT2
}
def
import_policy
(
policy_location
:
PolicyLocation
)
->
Policy
:
"""
Dynamically import a Policy class based on the policy location.
"""
module_name
=
f
"colossalai.shardformer.policies.
{
policy_location
.
file_name
}
"
module
=
importlib
.
import_module
(
module_name
)
return
getattr
(
module
,
policy_location
.
class_name
)
# from .llama import LlamaPolicy
# auto_policy_dict[LlamaModel] = LlamaPolicy
# from transformers import LlamaForSequenceClassification
# from .llama import LlamaForSequenceClassificationPolicy
# auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy
# from transformers import LlamaForCausalLM
# from .llama import LlamaForCausalLMPolicy
# auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy
# from transformers import GPT2Model
# from .gpt2 import GPT2Policy
# auto_policy_dict[GPT2Model] = GPT2Policy
# from transformers import GPT2LMHeadModel
# from .gpt2 import GPT2LMHeadModelPolicy
# auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
return
auto_policy_dict
def
_fullname
(
obj
):
"""
Return the full name of an object, including the module name.
"""
klass
=
obj
.
__class__
module
=
klass
.
__module__
if
module
==
'builtins'
:
return
klass
.
__qualname__
# avoid outputs like 'builtins.str'
return
module
+
'.'
+
klass
.
__qualname__
def
get_autopolicy
(
model
:
nn
.
Module
)
->
Policy
:
def
get_autopolicy
(
model
:
nn
.
Module
)
->
Policy
:
...
@@ -71,16 +83,14 @@ def get_autopolicy(model: nn.Module) -> Policy:
...
@@ -71,16 +83,14 @@ def get_autopolicy(model: nn.Module) -> Policy:
Return:
Return:
:class:`Policy`: The auto policy for the model
:class:`Policy`: The auto policy for the model
"""
"""
auto_policy_dict
=
build_policies
()
full_name
=
_fullname
(
model
)
policy
=
auto_policy_dict
.
get
(
model
.
__class__
,
None
)
policy_location
=
_POLICY_LIST
.
get
(
full_name
,
None
)
if
policy
is
None
:
if
policy_location
is
None
:
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"Auto policy for
{
model
.
__class__
.
__qualname__
}
is not implemented
\n
Supported models are
{
[
i
.
__qualname__
for
i
in
auto_policy_dict
.
keys
()
]
}
"
f
"Auto policy for
{
model
.
__class__
.
__qualname__
}
is not implemented
\n
.
Supported models are
{
list
(
_POLICY_LIST
.
keys
()
)
}
"
)
)
else
:
policy
=
import_policy
(
policy_location
)
return
policy
()
return
policy
()
return
policy
()
# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
# model = BertForPreTraining
# policy = get_autopolicy(model)
# print(policy)
colossalai/shardformer/policies/basepolicy.py
View file @
c1d5453e
...
@@ -75,6 +75,7 @@ class Policy(ABC):
...
@@ -75,6 +75,7 @@ class Policy(ABC):
"""
"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
self
.
shard_config
=
None
self
.
model
=
None
self
.
model
=
None
self
.
shard_config
=
None
self
.
shard_config
=
None
...
@@ -101,6 +102,7 @@ class Policy(ABC):
...
@@ -101,6 +102,7 @@ class Policy(ABC):
r
"""
r
"""
Perform some preprocessing of the model, like reshaping the embedding layer
Perform some preprocessing of the model, like reshaping the embedding layer
"""
"""
pass
@
abstractmethod
@
abstractmethod
def
module_policy
(
self
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
def
module_policy
(
self
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
...
@@ -135,6 +137,7 @@ class Policy(ABC):
...
@@ -135,6 +137,7 @@ class Policy(ABC):
...
...
}
}
"""
"""
pass
@
abstractmethod
@
abstractmethod
def
new_model_class
(
self
)
->
Union
[
Type
[
nn
.
Module
],
None
]:
def
new_model_class
(
self
)
->
Union
[
Type
[
nn
.
Module
],
None
]:
...
@@ -149,6 +152,7 @@ class Policy(ABC):
...
@@ -149,6 +152,7 @@ class Policy(ABC):
return BertModel_
return BertModel_
```
```
"""
"""
pass
@
abstractmethod
@
abstractmethod
def
postprocess
(
self
)
->
nn
.
Module
:
def
postprocess
(
self
)
->
nn
.
Module
:
...
@@ -156,3 +160,4 @@ class Policy(ABC):
...
@@ -156,3 +160,4 @@ class Policy(ABC):
Perform some postprocessing of the model, like binding the weight of embedding layer with
Perform some postprocessing of the model, like binding the weight of embedding layer with
the classifier layer
the classifier layer
"""
"""
pass
colossalai/shardformer/policies/llama.py
View file @
c1d5453e
from
dataclasses
import
dataclass
from
typing
import
Dict
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
,
Type
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers
import
LlamaForCausalLM
,
LlamaForSequenceClassification
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaModel
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaModel
import
colossalai.shardformer.layer.layers
as
col_nn
from
colossalai.shardformer.layer.layers
import
Linear1D_Col
,
Linear1D_Row
,
VocabParallelEmbedding1D
from
.basepolicy
import
Argument
,
Col_Layer
,
Policy
,
Row_Layer
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
class
LlamaPolicy
(
Policy
):
class
LlamaPolicy
(
Policy
):
@
staticmethod
def
preprocess
(
self
):
def
argument_policy
(
config
,
world_size
:
int
)
->
Dict
[
nn
.
Module
,
Argument
]:
# Resize embedding
vocab_size
=
self
.
model
.
config
.
vocab_size
world_size
=
self
.
shard_config
.
tensor_parallel_size
if
vocab_size
%
world_size
!=
0
:
new_vocab_size
=
vocab_size
+
world_size
-
vocab_size
%
world_size
self
.
model
.
resize_token_embeddings
(
new_vocab_size
)
return
self
.
model
def
module_policy
(
self
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
return
{
return
{
LlamaDecoderLayer
:
LlamaDecoderLayer
:
Argument
(
attr_dict
=
{
ModulePolicyDescription
(
"self_attn.hidden_size"
:
config
.
hidden_size
//
world_size
,
attribute_replacement
=
{
"self_attn.num_heads"
:
config
.
num_attention_heads
//
world_size
,
"self_attn.hidden_size"
:
},
self
.
model
.
config
.
hidden_size
//
self
.
shard_config
.
tensor_parallel_size
,
param_funcs
=
[
LlamaPolicy
.
attn_layer
,
LlamaPolicy
.
mlp_layer
]),
"self_attn.num_heads"
:
self
.
model
.
config
.
num_attention_heads
//
self
.
shard_config
.
tensor_parallel_size
,
},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"self_attn.q_proj"
,
target_module
=
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"self_attn.k_proj"
,
target_module
=
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"self_attn.v_proj"
,
target_module
=
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"self_attn.o_proj"
,
target_module
=
Linear1D_Row
,
),
SubModuleReplacementDescription
(
suffix
=
"mlp.gate_proj"
,
target_module
=
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"mlp.up_proj"
,
target_module
=
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"mlp.down_proj"
,
target_module
=
Linear1D_Row
,
)
],
),
LlamaModel
:
LlamaModel
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
LlamaPolicy
.
embeddings
])
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"embed_tokens"
,
target_module
=
VocabParallelEmbedding1D
,
)
])
}
}
@
staticmethod
def
new_model_class
(
self
):
def
attn_layer
()
->
List
:
return
None
return
[
Col_Layer
(
suffix
=
"self_attn.q_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
),
Col_Layer
(
suffix
=
"self_attn.k_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
),
Col_Layer
(
suffix
=
"self_attn.v_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
),
Row_Layer
(
suffix
=
"self_attn.o_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Row
,
)
]
@
staticmethod
def
mlp_layer
()
->
List
:
return
[
Col_Layer
(
suffix
=
"mlp.gate_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
gather_output
=
True
,
),
Col_Layer
(
suffix
=
"mlp.up_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Row
,
gather_output
=
True
,
),
Col_Layer
(
suffix
=
"mlp.down_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
gather_output
=
True
,
),
]
@
staticmethod
def
embeddings
()
->
List
:
return
[
Col_Layer
(
suffix
=
"embed_tokens"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
VocabParallelEmbedding1D
,
)]
from
transformers
import
LlamaForCausalLM
class
LlamaForCausalLMPolicy
(
LlamaPolicy
):
@
staticmethod
def
postprocess
(
self
):
def
argument
(
config
,
world_size
):
return
self
.
model
llamapolicy
=
LlamaPolicy
.
argument_policy
(
config
,
world_size
)
argument
=
{
LlamaForCausalLM
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
LlamaForCausalLMPolicy
.
lm_head
])}
argument
.
update
(
llamapolicy
)
@
staticmethod
def
lm_head
()
->
List
:
return
[
Col_Layer
(
suffix
=
"lm_head"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
gather_output
=
True
)]
class
LlamaForCausalLMPolicy
(
LlamaPolicy
):
from
transformers
import
LlamaForSequenceClassification
def
module_policy
(
self
):
policy
=
super
().
module_policy
()
# add a new item for casual lm
new_item
=
{
LlamaForCausalLM
:
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"lm_head"
,
target_module
=
Linear1D_Col
,
kwargs
=
dict
(
gather_output
=
True
))
])
}
policy
.
update
(
new_item
)
return
policy
class
LlamaForSequenceClassificationPolicy
(
LlamaPolicy
):
class
LlamaForSequenceClassificationPolicy
(
LlamaPolicy
):
@
staticmethod
def
module_policy
(
self
):
def
argument
(
config
,
world_size
):
policy
=
super
().
module_policy
()
llamapolicy
=
LlamaPolicy
.
argument_policy
(
config
,
world_size
)
argument
=
{
# add a new item for sequence classification
new_item
=
{
LlamaForSequenceClassification
:
LlamaForSequenceClassification
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
LlamaForSequenceClassificationPolicy
.
score
])
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"score"
,
target_module
=
Linear1D_Col
,
kwargs
=
dict
(
gather_output
=
True
))
])
}
}
argument
.
update
(
llamapolicy
)
policy
.
update
(
new_item
)
return
policy
@
staticmethod
def
score
()
->
List
:
return
[
Col_Layer
(
suffix
=
"score"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
gather_output
=
True
)]
colossalai/shardformer/shard/shard_config.py
View file @
c1d5453e
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Literal
from
colossalai.cluster.dist_coordinator
import
DistCoordinator
__all__
=
[
'ShardConfig'
]
__all__
=
[
'ShardConfig'
]
...
@@ -19,9 +20,19 @@ class ShardConfig:
...
@@ -19,9 +20,19 @@ class ShardConfig:
gather_output (bool): Whether to gather the output of the model of the last layer
gather_output (bool): Whether to gather the output of the model of the last layer
"""
"""
tensor_parallel_size
:
int
tensor_parallel_size
:
int
# TODO: add support for tensor parallel
# TODO: add support for tensor parallel
# pipeline_parallel_size: int
# pipeline_parallel_size: int
# data_parallel_size: int
# data_parallel_size: int
tensor_parallel_mode
:
Literal
[
'1d'
,
'2d'
,
'2.5d'
,
'3d'
]
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
inference_only
:
bool
=
True
# inference_only: bool = True
gather_output
:
bool
=
True
# gather_output: bool = True
def
__post_init__
(
self
):
coordinator
=
DistCoordinator
()
# ensure the parallel size can match the world size
world_size
=
coordinator
.
world_size
self
.
data_parallel_size
=
world_size
//
self
.
tensor_parallel_size
assert
world_size
==
self
.
data_parallel_size
*
self
.
tensor_parallel_size
,
\
f
"The world size (
{
world_size
}
) should be divisible by the data parallel size
{
self
.
data_parallel_size
}
and tensor parallel size
{
self
.
tensor_parallel_size
}
"
colossalai/shardformer/shard/sharder.py
View file @
c1d5453e
from
typing
import
Any
,
Callable
,
Dict
,
List
from
typing
import
Any
,
Callable
,
Dict
,
List
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers.pytorch_utils
import
Conv1D
from
colossalai.cluster.process_group_manager
import
ProcessGroupManager
from
colossalai.cluster.process_group_manager
import
ProcessGroupManager
...
@@ -41,10 +39,10 @@ class ModelSharder(object):
...
@@ -41,10 +39,10 @@ class ModelSharder(object):
"""
"""
self
.
policy
.
set_model
(
self
.
model
)
self
.
policy
.
set_model
(
self
.
model
)
self
.
policy
.
set_shard_config
(
self
.
shard_config
)
self
.
policy
.
set_shard_config
(
self
.
shard_config
)
self
.
preprocess
()
self
.
_
preprocess
()
self
.
replace_model_class
()
self
.
_
replace_model_class
()
self
.
replace_module
()
self
.
_
replace_module
()
self
.
postprocess
()
self
.
_
postprocess
()
def
reshape_embedding
(
self
)
->
None
:
def
reshape_embedding
(
self
)
->
None
:
r
"""
r
"""
...
@@ -57,13 +55,13 @@ class ModelSharder(object):
...
@@ -57,13 +55,13 @@ class ModelSharder(object):
self
.
model
.
resize_token_embeddings
(
new_vocab_size
)
self
.
model
.
resize_token_embeddings
(
new_vocab_size
)
self
.
model_config
=
self
.
model
.
config
self
.
model_config
=
self
.
model
.
config
def
preprocess
(
self
)
->
None
:
def
_
preprocess
(
self
)
->
None
:
self
.
model
=
self
.
policy
.
preprocess
()
self
.
model
=
self
.
policy
.
preprocess
()
def
postprocess
(
self
)
->
None
:
def
_
postprocess
(
self
)
->
None
:
self
.
model
=
self
.
policy
.
postprocess
()
self
.
model
=
self
.
policy
.
postprocess
()
def
replace_model_class
(
self
)
->
None
:
def
_
replace_model_class
(
self
,
)
->
None
:
r
"""
r
"""
Replace the model to policy defined model
Replace the model to policy defined model
Mainly modify the forward and backward to fit distributed model
Mainly modify the forward and backward to fit distributed model
...
@@ -84,7 +82,7 @@ class ModelSharder(object):
...
@@ -84,7 +82,7 @@ class ModelSharder(object):
getattr
(
new_model_class
,
key
),
getattr
(
new_model_class
,
key
),
)
)
def
replace_module
(
self
)
->
None
:
def
_
replace_module
(
self
,
)
->
None
:
r
"""
r
"""
Replace the module according to the policy, and replace the module one by one
Replace the module according to the policy, and replace the module one by one
...
...
colossalai/shardformer/shard/shardformer.py
View file @
c1d5453e
...
@@ -47,10 +47,12 @@ class ShardFormer:
...
@@ -47,10 +47,12 @@ class ShardFormer:
"""
"""
Initialize the distributed process group according to the
Initialize the distributed process group according to the
"""
"""
# create process group manager and 1d process group
# TODO: may need to support other parallel mode when the config has such as field
pg_manager
=
ProcessGroupManager
()
pg_manager
=
ProcessGroupManager
()
if
(
self
.
shard_config
.
tensor_parallel_mode
==
'1d'
):
pg_manager
.
create_process_group
(
name
=
'tp1d'
,
ranks
=
range
(
self
.
coordinator
.
world_size
))
pg_manager
.
create_process_group
(
name
=
'tp1d'
,
ranks
=
range
(
self
.
coordinator
.
world_size
))
self
.
pg_manager
=
pg_manager
self
.
pg_manager
=
pg_manager
return
pg_manager
return
pg_manager
def
shard_model
(
self
,
model
:
nn
.
Module
,
policy
:
Policy
=
None
):
def
shard_model
(
self
,
model
:
nn
.
Module
,
policy
:
Policy
=
None
):
...
...
tests/test_shardformer/test_model/test_shard_bert.py
View file @
c1d5453e
...
@@ -24,21 +24,18 @@ CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),
...
@@ -24,21 +24,18 @@ CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"bert-base-uncased"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"bert-base-uncased"
)
def
build_model
(
rank
,
world_size
,
model
):
def
build_model
(
world_size
,
model
_fn
):
config
=
BertConfig
.
from_pretrained
(
'bert-base-uncased'
)
config
=
BertConfig
(
)
config
.
hidden_dropout_prob
=
0
config
.
hidden_dropout_prob
=
0
config
.
attention_probs_dropout_prob
=
0
config
.
attention_probs_dropout_prob
=
0
org_model
=
BertForMaskedLM
.
from_pretrained
(
'bert-base-uncased'
,
config
=
config
)
org_model
=
model_fn
(
config
=
config
)
org_model_forshard
=
copy
.
deepcopy
(
org_model
)
org_model_forshard
=
copy
.
deepcopy
(
org_model
)
org_model
.
to
(
'cuda'
)
org_model
.
to
(
'cuda'
)
# TODO: no need to transfer to cuda
# TODO: no need to transfer to cuda
org_model_forshard
.
to
(
'cuda'
)
org_model_forshard
.
to
(
'cuda'
)
shard_config
=
ShardConfig
(
shard_config
=
ShardConfig
(
tensor_parallel_size
=
world_size
,)
tensor_parallel_size
=
2
,
tensor_parallel_mode
=
'1d'
,
)
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
shard_former
.
init_distributed
()
shard_former
.
init_distributed
()
sharded_model
=
shard_former
.
shard_model
(
org_model_forshard
).
to
(
'cuda'
)
sharded_model
=
shard_former
.
shard_model
(
org_model_forshard
).
to
(
'cuda'
)
...
@@ -99,15 +96,22 @@ def check_bert(rank, world_size, port):
...
@@ -99,15 +96,22 @@ def check_bert(rank, world_size, port):
disable_existing_loggers
()
disable_existing_loggers
()
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
forward_list
=
[
forward_list
=
[
BertModel
,
BertForPreTraining
,
BertForMaskedLM
,
BertLMHeadModel
,
BertForNextSentencePrediction
,
BertForMaskedLM
,
BertForSequenceClassification
BertForPreTraining
,
BertLMHeadModel
,
# TODO: do not work yet
# BertModel,
# BertForSequenceClassification
# BertForNextSentencePrediction,
]
]
backward_lsit
=
[
BertForMaskedLM
,
BertLMHeadModel
]
backward_lsit
=
[
BertForMaskedLM
,
BertLMHeadModel
]
for
model
in
forward_list
:
for
model
_fn
in
forward_list
:
org_model
,
sharded_model
=
build_model
(
rank
,
world_size
,
model
)
org_model
,
sharded_model
=
build_model
(
model
_fn
)
check_forward
(
org_model
,
sharded_model
)
check_forward
(
org_model
,
sharded_model
)
if
model
in
backward_lsit
:
if
model_fn
in
backward_lsit
:
check_backward
(
org_model
,
sharded_model
)
check_backward
(
org_model
,
sharded_model
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
tests/test_shardformer/test_model/test_shard_llama.py
View file @
c1d5453e
...
@@ -4,31 +4,28 @@ import random
...
@@ -4,31 +4,28 @@ import random
import
pytest
import
pytest
import
torch
import
torch
from
transformers
import
AutoTokenizer
,
LlamaConfig
,
LlamaForCausalLM
,
LlamaModel
,
LlamaTokenizerFast
from
transformers
import
LlamaConfig
,
LlamaForCausalLM
,
LlamaForSequenceClassification
,
LlamaModel
,
LlamaTokenizerFast
import
colossalai
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer
.shard
import
ShardConfig
,
s
hard
_model
from
colossalai.shardformer
import
ShardConfig
,
S
hard
Former
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
CONFIG
=
dict
(
parallel
=
dict
(
data
=
1
,
pipeline
=
1
,
tensor
=
dict
(
size
=
4
,
mode
=
'1d'
)),)
tokenizer
=
LlamaTokenizerFast
.
from_pretrained
(
"hf-internal-testing/llama-tokenizer"
)
tokenizer
=
LlamaTokenizerFast
.
from_pretrained
(
"hf-internal-testing/llama-tokenizer"
)
def
build_model
(
rank
,
world_size
):
def
build_model
(
world_size
,
model_fn
):
cfg
=
LlamaConfig
(
num_hidden_layers
=
16
)
# create new model
org_model
=
LlamaForCausalLM
(
cfg
)
config
=
LlamaConfig
(
num_hidden_layers
=
8
)
org_model
=
model_fn
(
config
).
cuda
()
shardconfig
=
ShardConfig
(
# shard model
rank
=
rank
,
shard_config
=
ShardConfig
(
tensor_parallel_size
=
world_size
)
world_size
=
world_size
,
model_copy
=
copy
.
deepcopy
(
org_model
)
gather_output
=
True
,
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
)
shard_former
.
init_distributed
()
org_model
=
org_model
.
to
(
'cuda'
)
sharded_model
=
shard_former
.
shard_model
(
model_copy
)
org_model_forshard
=
copy
.
deepcopy
(
org_model
)
sharded_model
=
shard_model
(
org_model_forshard
,
shardconfig
).
to
(
'cuda'
)
return
org_model
,
sharded_model
return
org_model
,
sharded_model
...
@@ -38,6 +35,7 @@ def check_forward(org_model, sharded_model):
...
@@ -38,6 +35,7 @@ def check_forward(org_model, sharded_model):
inputs
=
tokenizer
(
input
,
return_tensors
=
'pt'
).
to
(
'cuda'
)
inputs
=
tokenizer
(
input
,
return_tensors
=
'pt'
).
to
(
'cuda'
)
del
inputs
[
"token_type_ids"
]
del
inputs
[
"token_type_ids"
]
del
inputs
[
"attention_mask"
]
del
inputs
[
"attention_mask"
]
#orgin model
#orgin model
org_model
.
eval
()
org_model
.
eval
()
org_out
=
org_model
(
**
inputs
)
org_out
=
org_model
(
**
inputs
)
...
@@ -87,11 +85,20 @@ def check_backward(org_model, sharded_model):
...
@@ -87,11 +85,20 @@ def check_backward(org_model, sharded_model):
def
check_llama
(
rank
,
world_size
,
port
):
def
check_llama
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
disable_existing_loggers
()
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model_list
=
[
LlamaForCausalLM
,
# TODO: do not work yet
# LlamaModel,
# LlamaForSequenceClassification
]
org_model
,
sharded_model
=
build_model
(
rank
,
world_size
)
for
model_fn
in
model_list
:
check_forward
(
org_model
,
sharded_model
)
org_model
,
sharded_model
=
build_model
(
world_size
,
model_fn
)
check_backward
(
org_model
,
sharded_model
)
check_forward
(
org_model
,
sharded_model
)
check_backward
(
org_model
,
sharded_model
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
tests/test_shardformer/test_model/test_shard_t5.py
View file @
c1d5453e
...
@@ -8,7 +8,7 @@ from transformers import AutoTokenizer, BertConfig, BertForMaskedLM, T5Config, T
...
@@ -8,7 +8,7 @@ from transformers import AutoTokenizer, BertConfig, BertForMaskedLM, T5Config, T
import
colossalai
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer.shard
import
ShardConfig
,
s
hard
_model
from
colossalai.shardformer.shard
import
ShardConfig
,
S
hard
Former
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
...
@@ -90,6 +90,7 @@ def check_t5(rank, world_size, port):
...
@@ -90,6 +90,7 @@ def check_t5(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
skip
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_t5
():
def
test_t5
():
spawn
(
check_t5
,
2
)
spawn
(
check_t5
,
2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment