Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c1d5453e
Commit
c1d5453e
authored
Jun 19, 2023
by
Frank Lee
Browse files
[shardformer] adapted llama to the new API (#4036)
parent
74d176c8
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
245 additions
and
208 deletions
+245
-208
colossalai/shardformer/policies/autopolicy.py
colossalai/shardformer/policies/autopolicy.py
+72
-62
colossalai/shardformer/policies/basepolicy.py
colossalai/shardformer/policies/basepolicy.py
+5
-0
colossalai/shardformer/policies/llama.py
colossalai/shardformer/policies/llama.py
+97
-98
colossalai/shardformer/shard/shard_config.py
colossalai/shardformer/shard/shard_config.py
+15
-4
colossalai/shardformer/shard/sharder.py
colossalai/shardformer/shard/sharder.py
+8
-10
colossalai/shardformer/shard/shardformer.py
colossalai/shardformer/shard/shardformer.py
+4
-2
tests/test_shardformer/test_model/test_shard_bert.py
tests/test_shardformer/test_model/test_shard_bert.py
+16
-12
tests/test_shardformer/test_model/test_shard_llama.py
tests/test_shardformer/test_model/test_shard_llama.py
+26
-19
tests/test_shardformer/test_model/test_shard_t5.py
tests/test_shardformer/test_model/test_shard_t5.py
+2
-1
No files found.
colossalai/shardformer/policies/autopolicy.py
View file @
c1d5453e
import
importlib
from
dataclasses
import
dataclass
import
torch.nn
as
nn
from
.basepolicy
import
Policy
def
build_policies
():
r
"""
Build the policies for the model
Return:
The dict for the policies
@
dataclass
class
PolicyLocation
:
"""
auto_policy_dict
=
{}
from
transformers
import
BertModel
from
.bert
import
BertModelPolicy
auto_policy_dict
[
BertModel
]
=
BertModelPolicy
from
transformers
import
BertForPreTraining
from
.bert
import
BertForPretrainingPolicy
auto_policy_dict
[
BertForPreTraining
]
=
BertForPretrainingPolicy
from
transformers
import
BertLMHeadModel
from
.bert
import
BertLMHeadModelPolicy
auto_policy_dict
[
BertLMHeadModel
]
=
BertLMHeadModelPolicy
from
transformers
import
BertForMaskedLM
from
.bert
import
BertForMaskedLMPolicy
auto_policy_dict
[
BertForMaskedLM
]
=
BertForMaskedLMPolicy
from
transformers
import
BertForNextSentencePrediction
PolicyLocation describes the location of a policy class.
from
.bert
import
BertForNextSentencePredictionPolicy
auto_policy_dict
[
BertForNextSentencePrediction
]
=
BertForNextSentencePredictionPolicy
from
transformers
import
BertForSequenceClassification
from
.bert
import
BertForSequenceClassificationPolicy
auto_policy_dict
[
BertForSequenceClassification
]
=
BertForSequenceClassificationPolicy
from
transformers.models.llama.modeling_llama
import
LlamaModel
Args:
file_name (str): The file name of the policy under colossalai.shardformer.policies
class_name (str): The class name of the policy class
"""
file_name
:
str
class_name
:
str
# we don't want to import all policies here
# as each policy file imports its own model zoo library
# we will allow the user to only import the policy file needed
_POLICY_LIST
=
{
# BERT
"transformers.models.bert.modeling_bert.BertModel"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertPolicy"
),
"transformers.models.bert.modeling_bert.BertForPreTraining"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertForPretrainingPolicy"
),
"transformers.models.bert.modeling_bert.BertForMaskedLM"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertForMaskedLMPolicy"
),
"transformers.models.bert.modeling_bert.BertLMHeadModel"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertLMHeadModelPolicy"
),
"transformers.models.bert.modeling_bert.BertForNextSentencePrediction"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertForNextSentencePredictionPolicy"
),
"transformers.models.bert.modeling_bert.BertForSequenceClassification"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertForSequenceClassificationPolicy"
),
"transformers.models.bert.modeling_bert.BertForMultipleChoice"
:
PolicyLocation
(
file_name
=
"bert"
,
class_name
=
"BertForMultipleChoicePolicy"
),
# LLaMA
"transformers.models.llama.modeling_llama.LlamaModel"
:
PolicyLocation
(
file_name
=
"llama"
,
class_name
=
"LlamaPolicy"
),
"transformers.models.llama.modeling_llama.LlamaForCausalLM"
:
PolicyLocation
(
file_name
=
"llama"
,
class_name
=
"LlamaForCausalLMPolicy"
),
"transformers.models.llama.modeling_llama.LlamaForSequenceClassification"
:
PolicyLocation
(
file_name
=
"llama"
,
class_name
=
"LlamaForSequenceClassificationPolicy"
),
# T5
# GPT2
}
def
import_policy
(
policy_location
:
PolicyLocation
)
->
Policy
:
"""
Dynamically import a Policy class based on the policy location.
"""
module_name
=
f
"colossalai.shardformer.policies.
{
policy_location
.
file_name
}
"
module
=
importlib
.
import_module
(
module_name
)
return
getattr
(
module
,
policy_location
.
class_name
)
# from .llama import LlamaPolicy
# auto_policy_dict[LlamaModel] = LlamaPolicy
# from transformers import LlamaForSequenceClassification
# from .llama import LlamaForSequenceClassificationPolicy
# auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy
# from transformers import LlamaForCausalLM
# from .llama import LlamaForCausalLMPolicy
# auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy
# from transformers import GPT2Model
# from .gpt2 import GPT2Policy
# auto_policy_dict[GPT2Model] = GPT2Policy
# from transformers import GPT2LMHeadModel
# from .gpt2 import GPT2LMHeadModelPolicy
# auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
return
auto_policy_dict
def
_fullname
(
obj
):
"""
Return the full name of an object, including the module name.
"""
klass
=
obj
.
__class__
module
=
klass
.
__module__
if
module
==
'builtins'
:
return
klass
.
__qualname__
# avoid outputs like 'builtins.str'
return
module
+
'.'
+
klass
.
__qualname__
def
get_autopolicy
(
model
:
nn
.
Module
)
->
Policy
:
...
...
@@ -71,16 +83,14 @@ def get_autopolicy(model: nn.Module) -> Policy:
Return:
:class:`Policy`: The auto policy for the model
"""
auto_policy_dict
=
build_policies
()
policy
=
auto_policy_dict
.
get
(
model
.
__class__
,
None
)
if
policy
is
None
:
full_name
=
_fullname
(
model
)
policy_location
=
_POLICY_LIST
.
get
(
full_name
,
None
)
if
policy_location
is
None
:
raise
NotImplementedError
(
f
"Auto policy for
{
model
.
__class__
.
__qualname__
}
is not implemented
\n
Supported models are
{
[
i
.
__qualname__
for
i
in
auto_policy_dict
.
keys
()
]
}
"
f
"Auto policy for
{
model
.
__class__
.
__qualname__
}
is not implemented
\n
.
Supported models are
{
list
(
_POLICY_LIST
.
keys
()
)
}
"
)
else
:
policy
=
import_policy
(
policy_location
)
return
policy
()
return
policy
()
# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
# model = BertForPreTraining
# policy = get_autopolicy(model)
# print(policy)
colossalai/shardformer/policies/basepolicy.py
View file @
c1d5453e
...
...
@@ -75,6 +75,7 @@ class Policy(ABC):
"""
def
__init__
(
self
)
->
None
:
self
.
shard_config
=
None
self
.
model
=
None
self
.
shard_config
=
None
...
...
@@ -101,6 +102,7 @@ class Policy(ABC):
r
"""
Perform some preprocessing of the model, like reshaping the embedding layer
"""
pass
@
abstractmethod
def
module_policy
(
self
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
...
...
@@ -135,6 +137,7 @@ class Policy(ABC):
...
}
"""
pass
@
abstractmethod
def
new_model_class
(
self
)
->
Union
[
Type
[
nn
.
Module
],
None
]:
...
...
@@ -149,6 +152,7 @@ class Policy(ABC):
return BertModel_
```
"""
pass
@
abstractmethod
def
postprocess
(
self
)
->
nn
.
Module
:
...
...
@@ -156,3 +160,4 @@ class Policy(ABC):
Perform some postprocessing of the model, like binding the weight of embedding layer with
the classifier layer
"""
pass
colossalai/shardformer/policies/llama.py
View file @
c1d5453e
from
dataclasses
import
dataclass
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
,
Type
from
typing
import
Dict
,
Union
import
torch.nn
as
nn
from
transformers
import
LlamaForCausalLM
,
LlamaForSequenceClassification
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaModel
import
colossalai.shardformer.layer.layers
as
col_nn
from
colossalai.shardformer.layer.layers
import
Linear1D_Col
,
Linear1D_Row
,
VocabParallelEmbedding1D
from
.basepolicy
import
Argument
,
Col_Layer
,
Policy
,
Row_Layer
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
class
LlamaPolicy
(
Policy
):
@
staticmethod
def
argument_policy
(
config
,
world_size
:
int
)
->
Dict
[
nn
.
Module
,
Argument
]:
def
preprocess
(
self
):
# Resize embedding
vocab_size
=
self
.
model
.
config
.
vocab_size
world_size
=
self
.
shard_config
.
tensor_parallel_size
if
vocab_size
%
world_size
!=
0
:
new_vocab_size
=
vocab_size
+
world_size
-
vocab_size
%
world_size
self
.
model
.
resize_token_embeddings
(
new_vocab_size
)
return
self
.
model
def
module_policy
(
self
)
->
Dict
[
Union
[
str
,
nn
.
Module
],
ModulePolicyDescription
]:
return
{
LlamaDecoderLayer
:
Argument
(
attr_dict
=
{
"self_attn.hidden_size"
:
config
.
hidden_size
//
world_size
,
"self_attn.num_heads"
:
config
.
num_attention_heads
//
world_size
,
},
param_funcs
=
[
LlamaPolicy
.
attn_layer
,
LlamaPolicy
.
mlp_layer
]),
ModulePolicyDescription
(
attribute_replacement
=
{
"self_attn.hidden_size"
:
self
.
model
.
config
.
hidden_size
//
self
.
shard_config
.
tensor_parallel_size
,
"self_attn.num_heads"
:
self
.
model
.
config
.
num_attention_heads
//
self
.
shard_config
.
tensor_parallel_size
,
},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"self_attn.q_proj"
,
target_module
=
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"self_attn.k_proj"
,
target_module
=
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"self_attn.v_proj"
,
target_module
=
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"self_attn.o_proj"
,
target_module
=
Linear1D_Row
,
),
SubModuleReplacementDescription
(
suffix
=
"mlp.gate_proj"
,
target_module
=
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"mlp.up_proj"
,
target_module
=
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"mlp.down_proj"
,
target_module
=
Linear1D_Row
,
)
],
),
LlamaModel
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
LlamaPolicy
.
embeddings
])
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"embed_tokens"
,
target_module
=
VocabParallelEmbedding1D
,
)
])
}
@
staticmethod
def
attn_layer
()
->
List
:
return
[
Col_Layer
(
suffix
=
"self_attn.q_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
),
Col_Layer
(
suffix
=
"self_attn.k_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
),
Col_Layer
(
suffix
=
"self_attn.v_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
),
Row_Layer
(
suffix
=
"self_attn.o_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Row
,
)
]
@
staticmethod
def
mlp_layer
()
->
List
:
return
[
Col_Layer
(
suffix
=
"mlp.gate_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
gather_output
=
True
,
),
Col_Layer
(
suffix
=
"mlp.up_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Row
,
gather_output
=
True
,
),
Col_Layer
(
suffix
=
"mlp.down_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
gather_output
=
True
,
),
]
@
staticmethod
def
embeddings
()
->
List
:
return
[
Col_Layer
(
suffix
=
"embed_tokens"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
VocabParallelEmbedding1D
,
)]
from
transformers
import
LlamaForCausalLM
class
LlamaForCausalLMPolicy
(
LlamaPolicy
):
def
new_model_class
(
self
):
return
None
@
staticmethod
def
argument
(
config
,
world_size
):
llamapolicy
=
LlamaPolicy
.
argument_policy
(
config
,
world_size
)
argument
=
{
LlamaForCausalLM
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
LlamaForCausalLMPolicy
.
lm_head
])}
argument
.
update
(
llamapolicy
)
def
postprocess
(
self
):
return
self
.
model
@
staticmethod
def
lm_head
()
->
List
:
return
[
Col_Layer
(
suffix
=
"lm_head"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
gather_output
=
True
)]
class
LlamaForCausalLMPolicy
(
LlamaPolicy
):
from
transformers
import
LlamaForSequenceClassification
def
module_policy
(
self
):
policy
=
super
().
module_policy
()
# add a new item for casual lm
new_item
=
{
LlamaForCausalLM
:
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"lm_head"
,
target_module
=
Linear1D_Col
,
kwargs
=
dict
(
gather_output
=
True
))
])
}
policy
.
update
(
new_item
)
return
policy
class
LlamaForSequenceClassificationPolicy
(
LlamaPolicy
):
@
staticmethod
def
argument
(
config
,
world_size
):
llamapolicy
=
LlamaPolicy
.
argument_policy
(
config
,
world_size
)
argument
=
{
def
module_policy
(
self
):
policy
=
super
().
module_policy
()
# add a new item for sequence classification
new_item
=
{
LlamaForSequenceClassification
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
LlamaForSequenceClassificationPolicy
.
score
])
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"score"
,
target_module
=
Linear1D_Col
,
kwargs
=
dict
(
gather_output
=
True
))
])
}
argument
.
update
(
llamapolicy
)
@
staticmethod
def
score
()
->
List
:
return
[
Col_Layer
(
suffix
=
"score"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
gather_output
=
True
)]
policy
.
update
(
new_item
)
return
policy
colossalai/shardformer/shard/shard_config.py
View file @
c1d5453e
from
dataclasses
import
dataclass
from
typing
import
List
,
Literal
from
colossalai.cluster.dist_coordinator
import
DistCoordinator
__all__
=
[
'ShardConfig'
]
...
...
@@ -19,9 +20,19 @@ class ShardConfig:
gather_output (bool): Whether to gather the output of the model of the last layer
"""
tensor_parallel_size
:
int
# TODO: add support for tensor parallel
# pipeline_parallel_size: int
# data_parallel_size: int
tensor_parallel_mode
:
Literal
[
'1d'
,
'2d'
,
'2.5d'
,
'3d'
]
inference_only
:
bool
=
True
gather_output
:
bool
=
True
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
# inference_only: bool = True
# gather_output: bool = True
def
__post_init__
(
self
):
coordinator
=
DistCoordinator
()
# ensure the parallel size can match the world size
world_size
=
coordinator
.
world_size
self
.
data_parallel_size
=
world_size
//
self
.
tensor_parallel_size
assert
world_size
==
self
.
data_parallel_size
*
self
.
tensor_parallel_size
,
\
f
"The world size (
{
world_size
}
) should be divisible by the data parallel size
{
self
.
data_parallel_size
}
and tensor parallel size
{
self
.
tensor_parallel_size
}
"
colossalai/shardformer/shard/sharder.py
View file @
c1d5453e
from
typing
import
Any
,
Callable
,
Dict
,
List
import
torch
import
torch.nn
as
nn
from
transformers.pytorch_utils
import
Conv1D
from
colossalai.cluster.process_group_manager
import
ProcessGroupManager
...
...
@@ -41,10 +39,10 @@ class ModelSharder(object):
"""
self
.
policy
.
set_model
(
self
.
model
)
self
.
policy
.
set_shard_config
(
self
.
shard_config
)
self
.
preprocess
()
self
.
replace_model_class
()
self
.
replace_module
()
self
.
postprocess
()
self
.
_
preprocess
()
self
.
_
replace_model_class
()
self
.
_
replace_module
()
self
.
_
postprocess
()
def
reshape_embedding
(
self
)
->
None
:
r
"""
...
...
@@ -57,13 +55,13 @@ class ModelSharder(object):
self
.
model
.
resize_token_embeddings
(
new_vocab_size
)
self
.
model_config
=
self
.
model
.
config
def
preprocess
(
self
)
->
None
:
def
_
preprocess
(
self
)
->
None
:
self
.
model
=
self
.
policy
.
preprocess
()
def
postprocess
(
self
)
->
None
:
def
_
postprocess
(
self
)
->
None
:
self
.
model
=
self
.
policy
.
postprocess
()
def
replace_model_class
(
self
)
->
None
:
def
_
replace_model_class
(
self
,
)
->
None
:
r
"""
Replace the model to policy defined model
Mainly modify the forward and backward to fit distributed model
...
...
@@ -84,7 +82,7 @@ class ModelSharder(object):
getattr
(
new_model_class
,
key
),
)
def
replace_module
(
self
)
->
None
:
def
_
replace_module
(
self
,
)
->
None
:
r
"""
Replace the module according to the policy, and replace the module one by one
...
...
colossalai/shardformer/shard/shardformer.py
View file @
c1d5453e
...
...
@@ -47,10 +47,12 @@ class ShardFormer:
"""
Initialize the distributed process group according to the
"""
# create process group manager and 1d process group
# TODO: may need to support other parallel mode when the config has such as field
pg_manager
=
ProcessGroupManager
()
if
(
self
.
shard_config
.
tensor_parallel_mode
==
'1d'
):
pg_manager
.
create_process_group
(
name
=
'tp1d'
,
ranks
=
range
(
self
.
coordinator
.
world_size
))
pg_manager
.
create_process_group
(
name
=
'tp1d'
,
ranks
=
range
(
self
.
coordinator
.
world_size
))
self
.
pg_manager
=
pg_manager
return
pg_manager
def
shard_model
(
self
,
model
:
nn
.
Module
,
policy
:
Policy
=
None
):
...
...
tests/test_shardformer/test_model/test_shard_bert.py
View file @
c1d5453e
...
...
@@ -24,21 +24,18 @@ CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"bert-base-uncased"
)
def
build_model
(
rank
,
world_size
,
model
):
config
=
BertConfig
.
from_pretrained
(
'bert-base-uncased'
)
def
build_model
(
world_size
,
model
_fn
):
config
=
BertConfig
(
)
config
.
hidden_dropout_prob
=
0
config
.
attention_probs_dropout_prob
=
0
org_model
=
BertForMaskedLM
.
from_pretrained
(
'bert-base-uncased'
,
config
=
config
)
org_model
=
model_fn
(
config
=
config
)
org_model_forshard
=
copy
.
deepcopy
(
org_model
)
org_model
.
to
(
'cuda'
)
# TODO: no need to transfer to cuda
org_model_forshard
.
to
(
'cuda'
)
shard_config
=
ShardConfig
(
tensor_parallel_size
=
2
,
tensor_parallel_mode
=
'1d'
,
)
shard_config
=
ShardConfig
(
tensor_parallel_size
=
world_size
,)
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
shard_former
.
init_distributed
()
sharded_model
=
shard_former
.
shard_model
(
org_model_forshard
).
to
(
'cuda'
)
...
...
@@ -99,15 +96,22 @@ def check_bert(rank, world_size, port):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
forward_list
=
[
BertModel
,
BertForPreTraining
,
BertForMaskedLM
,
BertLMHeadModel
,
BertForNextSentencePrediction
,
BertForSequenceClassification
BertForMaskedLM
,
BertForPreTraining
,
BertLMHeadModel
,
# TODO: do not work yet
# BertModel,
# BertForSequenceClassification
# BertForNextSentencePrediction,
]
backward_lsit
=
[
BertForMaskedLM
,
BertLMHeadModel
]
for
model
in
forward_list
:
org_model
,
sharded_model
=
build_model
(
rank
,
world_size
,
model
)
for
model
_fn
in
forward_list
:
org_model
,
sharded_model
=
build_model
(
model
_fn
)
check_forward
(
org_model
,
sharded_model
)
if
model
in
backward_lsit
:
if
model_fn
in
backward_lsit
:
check_backward
(
org_model
,
sharded_model
)
torch
.
cuda
.
empty_cache
()
...
...
tests/test_shardformer/test_model/test_shard_llama.py
View file @
c1d5453e
...
...
@@ -4,31 +4,28 @@ import random
import
pytest
import
torch
from
transformers
import
AutoTokenizer
,
LlamaConfig
,
LlamaForCausalLM
,
LlamaModel
,
LlamaTokenizerFast
from
transformers
import
LlamaConfig
,
LlamaForCausalLM
,
LlamaForSequenceClassification
,
LlamaModel
,
LlamaTokenizerFast
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer
.shard
import
ShardConfig
,
s
hard
_model
from
colossalai.shardformer
import
ShardConfig
,
S
hard
Former
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
CONFIG
=
dict
(
parallel
=
dict
(
data
=
1
,
pipeline
=
1
,
tensor
=
dict
(
size
=
4
,
mode
=
'1d'
)),)
tokenizer
=
LlamaTokenizerFast
.
from_pretrained
(
"hf-internal-testing/llama-tokenizer"
)
def
build_model
(
rank
,
world_size
):
cfg
=
LlamaConfig
(
num_hidden_layers
=
16
)
org_model
=
LlamaForCausalLM
(
cfg
)
def
build_model
(
world_size
,
model_fn
):
# create new model
config
=
LlamaConfig
(
num_hidden_layers
=
8
)
org_model
=
model_fn
(
config
).
cuda
()
shardconfig
=
ShardConfig
(
rank
=
rank
,
world_size
=
world_size
,
gather_output
=
True
,
)
org_model
=
org_model
.
to
(
'cuda'
)
org_model_forshard
=
copy
.
deepcopy
(
org_model
)
sharded_model
=
shard_model
(
org_model_forshard
,
shardconfig
).
to
(
'cuda'
)
# shard model
shard_config
=
ShardConfig
(
tensor_parallel_size
=
world_size
)
model_copy
=
copy
.
deepcopy
(
org_model
)
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
shard_former
.
init_distributed
()
sharded_model
=
shard_former
.
shard_model
(
model_copy
)
return
org_model
,
sharded_model
...
...
@@ -38,6 +35,7 @@ def check_forward(org_model, sharded_model):
inputs
=
tokenizer
(
input
,
return_tensors
=
'pt'
).
to
(
'cuda'
)
del
inputs
[
"token_type_ids"
]
del
inputs
[
"attention_mask"
]
#orgin model
org_model
.
eval
()
org_out
=
org_model
(
**
inputs
)
...
...
@@ -87,11 +85,20 @@ def check_backward(org_model, sharded_model):
def
check_llama
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model_list
=
[
LlamaForCausalLM
,
# TODO: do not work yet
# LlamaModel,
# LlamaForSequenceClassification
]
org_model
,
sharded_model
=
build_model
(
rank
,
world_size
)
check_forward
(
org_model
,
sharded_model
)
check_backward
(
org_model
,
sharded_model
)
for
model_fn
in
model_list
:
org_model
,
sharded_model
=
build_model
(
world_size
,
model_fn
)
check_forward
(
org_model
,
sharded_model
)
check_backward
(
org_model
,
sharded_model
)
torch
.
cuda
.
empty_cache
()
...
...
tests/test_shardformer/test_model/test_shard_t5.py
View file @
c1d5453e
...
...
@@ -8,7 +8,7 @@ from transformers import AutoTokenizer, BertConfig, BertForMaskedLM, T5Config, T
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer.shard
import
ShardConfig
,
s
hard
_model
from
colossalai.shardformer.shard
import
ShardConfig
,
S
hard
Former
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
...
...
@@ -90,6 +90,7 @@ def check_t5(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
skip
@
rerun_if_address_is_in_use
()
def
test_t5
():
spawn
(
check_t5
,
2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment