Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d857f3db
Commit
d857f3db
authored
Jun 19, 2023
by
Frank Lee
Browse files
[shardformer] supported T5 and its variants (#4045)
parent
c1d5453e
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
316 additions
and
221 deletions
+316
-221
colossalai/shardformer/README.md
colossalai/shardformer/README.md
+2
-3
colossalai/shardformer/layer/layers.py
colossalai/shardformer/layer/layers.py
+17
-9
colossalai/shardformer/policies/autopolicy.py
colossalai/shardformer/policies/autopolicy.py
+6
-0
colossalai/shardformer/policies/basepolicy.py
colossalai/shardformer/policies/basepolicy.py
+1
-0
colossalai/shardformer/policies/t5.py
colossalai/shardformer/policies/t5.py
+136
-122
colossalai/shardformer/shard/sharder.py
colossalai/shardformer/shard/sharder.py
+10
-1
colossalai/testing/__init__.py
colossalai/testing/__init__.py
+2
-1
colossalai/testing/comparison.py
colossalai/testing/comparison.py
+50
-1
tests/test_shardformer/test_model/test_shard_llama.py
tests/test_shardformer/test_model/test_shard_llama.py
+42
-40
tests/test_shardformer/test_model/test_shard_t5.py
tests/test_shardformer/test_model/test_shard_t5.py
+50
-44
No files found.
colossalai/shardformer/README.md
View file @
d857f3db
...
...
@@ -81,8 +81,8 @@ We will follow this roadmap to develop Shardformer:
-
[ ] Hugging Face
-
[ ] NLP
-
[x] BERT
-
[
] T5
-
[
] LlaMa
-
[
x
] T5
-
[
x
] LlaMa
-
[ ] GPT2
-
[ ] BLOOM
-
[ ] RoBERTa
...
...
@@ -90,7 +90,6 @@ We will follow this roadmap to develop Shardformer:
-
[ ] ERNIE
-
[ ] GPT Neo
-
[ ] GPT-J
-
[ ] CV
-
[ ] CV
-
[ ] ViT
-
[ ] BEiT
...
...
colossalai/shardformer/layer/layers.py
View file @
d857f3db
...
...
@@ -469,13 +469,14 @@ class Embedding1D(ParallelModule):
dtype
:
torch
.
dtype
=
None
,
device
:
torch
.
device
=
None
,
process_group
:
ProcessGroup
=
None
,
gather_output
:
bool
=
True
,
weight_initializer
:
Callable
=
init
.
normal_
(),
*
args
,
**
kwargs
):
super
().
__init__
()
self
.
num_embeddings
=
num_embeddings
self
.
embed_dim
=
embedding_dim
self
.
embed
ding
_dim
=
embedding_dim
self
.
process_group
=
process_group
self
.
num_partitions
=
dist
.
get_world_size
(
process_group
)
self
.
embed_dim_per_partition
=
divide
(
embedding_dim
,
self
.
num_partitions
)
...
...
@@ -499,7 +500,9 @@ class Embedding1D(ParallelModule):
@
staticmethod
def
from_native_module
(
module
:
nn
.
Embedding
,
process_group
:
Union
[
ProcessGroup
,
List
[
ProcessGroup
]]
=
None
)
->
"Embedding1D"
:
process_group
:
Union
[
ProcessGroup
,
List
[
ProcessGroup
]]
=
None
,
*
args
,
**
kwargs
)
->
"Embedding1D"
:
r
"""
Build a 1D parallelized Embedding from a native nn.Embedding module.
"""
...
...
@@ -527,7 +530,9 @@ class Embedding1D(ParallelModule):
max_norm
=
max_norm
,
norm_type
=
norm_type
,
scale_grad_by_freq
=
scale_grad_by_freq
,
sparse
=
sparse
)
sparse
=
sparse
,
*
args
,
**
kwargs
)
# copy the weight
with
torch
.
no_grad
():
...
...
@@ -537,7 +542,7 @@ class Embedding1D(ParallelModule):
return
embedding
def
reset_parameters
(
self
,
weight_initializer
)
->
None
:
fan_in
,
fan_out
=
self
.
num_embeddings
,
self
.
embed_dim
fan_in
,
fan_out
=
self
.
num_embeddings
,
self
.
embed
ding
_dim
weight_initializer
(
self
.
weight
,
fan_in
=
fan_in
,
fan_out
=
fan_out
)
self
.
_fill_padding_idx_with_zero
()
...
...
@@ -548,9 +553,12 @@ class Embedding1D(ParallelModule):
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
output_parallel
=
F
.
embedding
(
input_
,
self
.
weight
,
self
.
padding_idx
,
*
self
.
embed_args
,
**
self
.
embed_kwargs
)
output
=
gather_forward_split_backward
(
output_parallel
,
dim
=-
1
,
process_group
=
self
.
process_group
)
return
output
if
self
.
gather_output
:
output
=
gather_forward_split_backward
(
output_parallel
,
dim
=-
1
,
process_group
=
self
.
process_group
)
return
output
else
:
return
output_parallel
class
VocabParallelEmbedding1D
(
ParallelLayer
):
...
...
@@ -595,7 +603,7 @@ class VocabParallelEmbedding1D(ParallelLayer):
**
kwargs
):
super
().
__init__
()
self
.
num_embeddings
=
num_embeddings
self
.
embed_dim
=
embedding_dim
self
.
embed
ding
_dim
=
embedding_dim
self
.
padding_idx
=
padding_idx
self
.
embed_args
=
args
self
.
embed_kwargs
=
kwargs
...
...
@@ -610,7 +618,7 @@ class VocabParallelEmbedding1D(ParallelLayer):
self
.
vocab_end_index
=
self
.
vocab_start_index
+
self
.
num_embeddings_per_partition
self
.
weight
=
Parameter
(
torch
.
empty
((
self
.
num_embeddings_per_partition
,
self
.
embed_dim
),
device
=
device
,
dtype
=
dtype
))
torch
.
empty
((
self
.
num_embeddings_per_partition
,
self
.
embed
ding
_dim
),
device
=
device
,
dtype
=
dtype
))
# offset the seed with randomizer index and rank
seed
=
torch
.
random
.
initial_seed
()
...
...
@@ -662,7 +670,7 @@ class VocabParallelEmbedding1D(ParallelLayer):
def
reset_parameters
(
self
,
weight_initializer
)
->
None
:
with
seed
(
ParallelMode
.
TENSOR
):
fan_in
,
fan_out
=
self
.
num_embeddings
,
self
.
embed_dim
fan_in
,
fan_out
=
self
.
num_embeddings
,
self
.
embed
ding
_dim
weight_initializer
(
self
.
weight
,
fan_in
=
fan_in
,
fan_out
=
fan_out
)
self
.
_fill_padding_idx_with_zero
()
...
...
colossalai/shardformer/policies/autopolicy.py
View file @
d857f3db
...
...
@@ -48,6 +48,12 @@ _POLICY_LIST = {
PolicyLocation
(
file_name
=
"llama"
,
class_name
=
"LlamaForSequenceClassificationPolicy"
),
# T5
"transformers.models.t5.modeling_t5.T5Model"
:
PolicyLocation
(
file_name
=
"t5"
,
class_name
=
"T5ModelPolicy"
),
"transformers.models.t5.modeling_t5.T5ForConditionalGeneration"
:
PolicyLocation
(
file_name
=
"t5"
,
class_name
=
"T5ForConditionalGenerationPolicy"
),
"transformers.models.t5.modeling_t5.T5EncoderModel"
:
PolicyLocation
(
file_name
=
"t5"
,
class_name
=
"T5EncoderPolicy"
),
# GPT2
}
...
...
colossalai/shardformer/policies/basepolicy.py
View file @
d857f3db
...
...
@@ -27,6 +27,7 @@ class SubModuleReplacementDescription:
suffix
:
str
target_module
:
ParallelModule
kwargs
:
Dict
[
str
,
Any
]
=
None
ignore_if_not_exist
:
bool
=
False
@
dataclass
...
...
colossalai/shardformer/policies/t5.py
View file @
d857f3db
from
typing
import
Dict
import
torch
import
torch.nn
as
nn
from
t
orch.nn
import
Embedding
from
t
ransformers
import
T5ForConditionalGeneration
from
transformers.models.t5.modeling_t5
import
(
T5Attention
,
T5Block
,
T5DenseActDense
,
T5DenseGatedActDense
,
T5LayerCrossAttention
,
T5LayerFF
,
T5LayerSelfAttention
,
T5Model
,
T5Stack
,
)
import
colossalai.shardformer.layer.layers
as
col_nn
from
colossalai.shardformer.layer.dropout
import
Dropout1D
from
colossalai.shardformer.layer.layers
import
Embedding1D
,
Linear1D_Col
,
Linear1D_Row
from
.basepolicy
import
Argument
,
Col_Layer
,
Dropout_Layer
,
Embedding_Layer
,
Policy
,
Row_Layer
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
__all__
=
[
"T5ModelPolicy"
,
"T5ForConditionalGenerationPolicy"
,
"T5EncoderPolicy"
]
class
T5ModelPolicy
(
Policy
):
@
staticmethod
def
argument_policy
(
config
,
world_size
:
int
)
->
Dict
[
nn
.
Module
,
Argument
]:
print
(
'config heads'
,
config
.
num_heads
)
def
preprocess
(
self
):
# reshape the embedding layer
r
"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
vocab_size
=
self
.
model
.
config
.
vocab_size
world_size
=
self
.
shard_config
.
tensor_parallel_size
if
vocab_size
%
world_size
!=
0
:
new_vocab_size
=
vocab_size
+
world_size
-
vocab_size
%
world_size
self
.
model
.
resize_token_embeddings
(
new_vocab_size
)
return
self
.
model
def
module_policy
(
self
):
return
{
T5Stack
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
T5ModelPolicy
.
dropout
,
T5ModelPolicy
.
embedding
]),
T5Block
:
Argument
(
attr_dict
=
{},
param_funcs
=
[]),
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"dropout"
,
target_module
=
Dropout1D
,
)
]),
T5LayerSelfAttention
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
T5ModelPolicy
.
dropout
]),
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"dropout"
,
target_module
=
Dropout1D
,
),
]),
T5LayerCrossAttention
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
T5ModelPolicy
.
dropout
]),
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"dropout"
,
target_module
=
Dropout1D
,
)
]),
T5Attention
:
Argument
(
attr_dict
=
{
"d_model"
:
config
.
d_model
//
world_size
,
"n_heads"
:
config
.
num_heads
//
world_size
,
"inner_dim"
:
config
.
num_heads
*
config
.
d_kv
//
world_size
,
ModulePolicyDescription
(
attribute_replacement
=
{
"d_model"
:
self
.
model
.
config
.
d_model
//
self
.
shard_config
.
tensor_parallel_size
,
"n_heads"
:
self
.
model
.
config
.
num_heads
//
self
.
shard_config
.
tensor_parallel_size
,
"inner_dim"
:
self
.
model
.
config
.
num_heads
*
self
.
model
.
config
.
d_kv
//
self
.
shard_config
.
tensor_parallel_size
},
param_funcs
=
[
T5ModelPolicy
.
attn_layer
]),
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"q"
,
target_module
=
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"k"
,
target_module
=
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"v"
,
target_module
=
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"o"
,
target_module
=
Linear1D_Row
,
),
SubModuleReplacementDescription
(
suffix
=
"relative_attention_bias"
,
target_module
=
Embedding1D
,
kwargs
=
dict
(
gather_output
=
False
),
ignore_if_not_exist
=
True
)
]),
T5LayerFF
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
T5ModelPolicy
.
dropout
]),
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"dropout"
,
target_module
=
Dropout1D
,
),
]),
T5DenseGatedActDense
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
T5ModelPolicy
.
dropout
,
T5ModelPolicy
.
dense_gated_layer
]),
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"wi_0"
,
target_module
=
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"wi_1"
,
target_module
=
Linear1D_Row
,
),
SubModuleReplacementDescription
(
suffix
=
"wo"
,
target_module
=
Linear1D_Col
,
kwargs
=
dict
(
gather_output
=
True
)),
SubModuleReplacementDescription
(
suffix
=
"dropout"
,
target_module
=
Dropout1D
,
)
]),
T5DenseActDense
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
T5ModelPolicy
.
dropout
,
T5ModelPolicy
.
dense_act_layer
]),
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"wi"
,
target_module
=
Linear1D_Col
,
),
SubModuleReplacementDescription
(
suffix
=
"wo"
,
target_module
=
Linear1D_Row
,
),
SubModuleReplacementDescription
(
suffix
=
"dropout"
,
target_module
=
Dropout1D
,
)
])
}
@
staticmethod
def
dense_gated_layer
():
return
[
Col_Layer
(
suffix
=
"wi_0"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
),
Row_Layer
(
suffix
=
"wi_1"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Linear1D_Row
,
),
Col_Layer
(
suffix
=
"wo"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
gather_output
=
True
)
]
@
staticmethod
def
dense_act_layer
():
return
[
Col_Layer
(
suffix
=
"wi"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
),
Row_Layer
(
suffix
=
"wo"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Linear1D_Row
,
)
]
@
staticmethod
def
attn_layer
():
return
[
Col_Layer
(
suffix
=
"q"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
),
Col_Layer
(
suffix
=
"k"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
),
Col_Layer
(
suffix
=
"v"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
),
Row_Layer
(
suffix
=
"o"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Row
,
),
]
@
staticmethod
def
dropout
():
return
[
Dropout_Layer
(
suffix
=
"dropout"
,
p
=
"p"
,
replace_layer
=
col_nn
.
Dropout1D
,
)]
@
staticmethod
def
embedding
():
return
[
Embedding_Layer
(
suffix
=
"block[0].layer[0].SelfAttention.relative_attention_bias"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Embedding1D
,
gather_output
=
False
,
)
]
def
new_model_class
(
self
):
return
None
from
transformers
import
T5ForConditionalGeneration
def
postprocess
(
self
):
return
self
.
model
class
T5ForConditionalGenerationPolicy
(
T5ModelPolicy
):
@
staticmethod
def
argument_policy
(
config
,
world_size
):
base_argument
=
T5ModelPolicy
.
argument_policy
(
config
,
world_size
)
argument
=
{
T5ForConditionalGeneration
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
T5ForConditionalGenerationPolicy
.
lm_head
])
def
module_policy
(
self
):
policy
=
super
().
module_policy
()
new_item
=
{
T5ForConditionalGeneration
:
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"lm_head"
,
target_module
=
Linear1D_Col
,
kwargs
=
dict
(
gather_output
=
True
))
])
}
argument
.
update
(
base_argument
)
return
argument
@
staticmethod
def
lm_head
():
return
[
Col_Layer
(
suffix
=
"lm_head"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
gather_output
=
True
,
)]
from
transformers
import
T5EncoderModel
policy
.
update
(
new_item
)
return
policy
class
T5Encoder
Model
Policy
(
T5ModelPolicy
):
class
T5EncoderPolicy
(
T5ModelPolicy
):
pass
colossalai/shardformer/shard/sharder.py
View file @
d857f3db
...
...
@@ -175,7 +175,16 @@ class ModelSharder(object):
assert
target_module
is
not
None
,
'target_module should not be None'
# TODO: support different parallel mode
native_sub_module
=
getattr_
(
org_layer
,
suffix
)
native_sub_module
=
getattr_
(
org_layer
,
suffix
,
ignore
=
True
)
assert
not
isinstance
(
native_sub_module
,
target_module
),
\
f
"The module with suffix
{
suffix
}
has been replaced, please check the policy"
# if it is None and we are allowed to ignore this module
# just skip
if
description
.
ignore_if_not_exist
and
native_sub_module
is
None
:
continue
replace_layer
=
target_module
.
from_native_module
(
native_sub_module
,
self
.
pg_manager
.
pg_store
[
'tp1d'
],
**
kwargs
)
...
...
colossalai/testing/__init__.py
View file @
d857f3db
...
...
@@ -3,6 +3,7 @@ from .comparison import (
assert_close_loose
,
assert_equal
,
assert_equal_in_group
,
assert_hf_output_close
,
assert_not_equal
,
check_state_dict_equal
,
)
...
...
@@ -20,5 +21,5 @@ from .utils import (
__all__
=
[
'assert_equal'
,
'assert_not_equal'
,
'assert_close'
,
'assert_close_loose'
,
'assert_equal_in_group'
,
'parameterize'
,
'rerun_on_exception'
,
'rerun_if_address_is_in_use'
,
'skip_if_not_enough_gpus'
,
'free_port'
,
'spawn'
,
'clear_cache_before_run'
,
'run_on_environment_flag'
,
'check_state_dict_equal'
'clear_cache_before_run'
,
'run_on_environment_flag'
,
'check_state_dict_equal'
,
'assert_hf_output_close'
]
colossalai/testing/comparison.py
View file @
d857f3db
from
typing
import
OrderedDict
from
typing
import
Any
,
List
,
OrderedDict
import
torch
import
torch.distributed
as
dist
...
...
@@ -52,3 +52,52 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool
assert
torch
.
equal
(
v
,
d2
[
k
])
else
:
assert
v
==
d2
[
k
]
def
assert_hf_output_close
(
out1
:
Any
,
out2
:
Any
,
ignore_keys
:
List
[
str
]
=
None
,
track_name
:
str
=
""
,
atol
=
1e-5
,
rtol
=
1e-5
):
"""
Check if two outputs from huggingface are equal.
Args:
out1 (Any): the first output
out2 (Any): the second output
ignore_keys (List[str]): the keys to ignore when comparing two dicts
track_name (str): the name of the value compared, used to track the path
"""
if
isinstance
(
out1
,
dict
)
and
isinstance
(
out2
,
dict
):
# if two values are dict
# we recursively check the keys
assert
set
(
out1
.
keys
())
==
set
(
out2
.
keys
())
for
k
in
out1
.
keys
():
if
ignore_keys
is
not
None
and
k
in
ignore_keys
:
continue
assert_hf_output_close
(
out1
[
k
],
out2
[
k
],
track_name
=
f
"
{
track_name
}
.
{
k
}
"
,
ignore_keys
=
ignore_keys
,
atol
=
atol
,
rtol
=
rtol
)
elif
isinstance
(
out1
,
(
list
,
tuple
))
and
isinstance
(
out2
,
(
list
,
tuple
)):
# if two values are list
# we recursively check the elements
assert
len
(
out1
)
==
len
(
out2
)
for
i
in
range
(
len
(
out1
)):
assert_hf_output_close
(
out1
[
i
],
out2
[
i
],
track_name
=
f
"
{
track_name
}
.
{
i
}
"
,
ignore_keys
=
ignore_keys
,
atol
=
atol
,
rtol
=
rtol
)
elif
isinstance
(
out1
,
Tensor
)
and
isinstance
(
out2
,
Tensor
):
if
out1
.
shape
!=
out2
.
shape
:
raise
AssertionError
(
f
"
{
track_name
}
: shape mismatch:
{
out1
.
shape
}
vs
{
out2
.
shape
}
"
)
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
atol
,
rtol
=
rtol
),
f
"
{
track_name
}
: tensor value mismatch
\n
value 1:
{
out1
}
\n
value 2:
{
out2
}
, mean error:
{
torch
.
abs
(
out1
-
out2
).
mean
()
}
"
else
:
assert
out1
==
out2
,
f
"
{
track_name
}
: value mismatch.
\n
out1:
{
out1
}
\n
out2:
{
out2
}
"
tests/test_shardformer/test_model/test_shard_llama.py
View file @
d857f3db
...
...
@@ -9,7 +9,7 @@ from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassifi
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer
import
ShardConfig
,
ShardFormer
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
assert_hf_output_close
,
clear_cache_before_run
,
rerun_if_address_is_in_use
,
spawn
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
tokenizer
=
LlamaTokenizerFast
.
from_pretrained
(
"hf-internal-testing/llama-tokenizer"
)
...
...
@@ -17,7 +17,11 @@ tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokeni
def
build_model
(
world_size
,
model_fn
):
# create new model
config
=
LlamaConfig
(
num_hidden_layers
=
8
)
config
=
LlamaConfig
(
num_hidden_layers
=
4
,
hidden_size
=
128
,
intermediate_size
=
256
,
num_attention_heads
=
4
,
max_position_embeddings
=
128
)
org_model
=
model_fn
(
config
).
cuda
()
# shard model
...
...
@@ -30,49 +34,47 @@ def build_model(world_size, model_fn):
return
org_model
,
sharded_model
def
check_forward
(
org_model
,
sharded_model
):
input
=
'Hello, my dog is cute'
inputs
=
tokenizer
(
input
,
return_tensors
=
'pt'
).
to
(
'cuda'
)
del
inputs
[
"token_type_ids"
]
del
inputs
[
"attention_mask"
]
#orgin model
org_model
.
eval
()
org_out
=
org_model
(
**
inputs
)
#shard model
sharded_model
.
eval
()
shard_out
=
sharded_model
(
**
inputs
)
assert
torch
.
allclose
(
org_out
[
0
],
shard_out
[
0
],
atol
=
1e-4
),
f
"shard model output is not equal to orgin model output
\n
{
org_out
[
0
]
}
\n
{
shard_out
[
0
]
}
"
def
check_backward
(
org_model
,
sharded_model
):
def
check_forward_backward
(
org_model
,
sharded_model
):
# prepare input
input
=
'Hello, my dog is cute'
tokenized_input
=
tokenizer
(
input
,
return_tensors
=
'pt'
).
to
(
'cuda'
)
del
tokenized_input
[
"token_type_ids"
]
del
tokenized_input
[
"attention_mask"
]
labels
=
tokenized_input
[
'input_ids'
].
clone
()
labels
[
labels
==
tokenizer
.
pad_token_id
]
=
-
100
tokenized_input
[
'labels'
]
=
labels
#
org
in mode
l
#
switch to tra
in mode
org_model
.
train
()
org_out
=
org_model
(
**
tokenized_input
)
org_loss
=
org_out
.
loss
org_loss
.
backward
()
org_grad
=
org_model
.
model
.
layers
[
0
].
self_attn
.
q_proj
.
weight
.
grad
torch
.
cuda
.
empty_cache
()
#shard model
sharded_model
.
train
()
shard_out
=
sharded_model
(
**
tokenized_input
)
shard_loss
=
shard_out
.
loss
if
isinstance
(
org_model
,
(
LlamaModel
,
LlamaForSequenceClassification
)):
org_output
=
org_model
(
**
tokenized_input
)
org_loss
=
org_output
.
last_hidden_state
.
mean
()
shard_output
=
sharded_model
(
**
tokenized_input
)
shard_loss
=
shard_output
.
last_hidden_state
.
mean
()
elif
isinstance
(
org_model
,
LlamaForCausalLM
):
labels
=
tokenized_input
[
'input_ids'
].
clone
()
labels
[
labels
==
tokenizer
.
pad_token_id
]
=
-
100
tokenized_input
[
'labels'
]
=
labels
org_output
=
org_model
(
**
tokenized_input
)
org_loss
=
org_output
.
loss
shard_output
=
sharded_model
(
**
tokenized_input
)
shard_loss
=
shard_output
.
loss
assert_hf_output_close
(
org_output
,
shard_output
,
ignore_keys
=
[
'past_key_values'
],
rtol
=
1e-4
)
# run backward
org_loss
.
backward
()
shard_loss
.
backward
()
shard_grad
=
sharded_model
.
model
.
layers
[
0
].
self_attn
.
q_proj
.
weight
.
grad
# check grad
if
isinstance
(
org_model
,
LlamaModel
):
llama_model
=
org_model
shard_llama_model
=
sharded_model
else
:
llama_model
=
org_model
.
model
shard_llama_model
=
sharded_model
.
model
org_grad
=
llama_model
.
layers
[
0
].
self_attn
.
q_proj
.
weight
.
grad
shard_grad
=
shard_llama_model
.
layers
[
0
].
self_attn
.
q_proj
.
weight
.
grad
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
4
)]
shard_grad
=
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
...
...
@@ -88,23 +90,23 @@ def check_llama(rank, world_size, port):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model_list
=
[
LlamaForCausalLM
,
LlamaModel
,
# LlamaForCausalLM,
# TODO: do not work yet
# LlamaModel,
# LlamaForSequenceClassification
]
for
model_fn
in
model_list
:
org_model
,
sharded_model
=
build_model
(
world_size
,
model_fn
)
check_forward
(
org_model
,
sharded_model
)
check_backward
(
org_model
,
sharded_model
)
check_forward_backward
(
org_model
,
sharded_model
)
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_llama
():
spawn
(
check_llama
,
4
)
...
...
tests/test_shardformer/test_model/test_shard_t5.py
View file @
d857f3db
import
copy
import
os
import
random
import
pytest
import
torch
from
transformers
import
AutoTokenizer
,
BertConfig
,
BertForMaskedLM
,
T5Config
,
T5ForConditionalGeneration
,
T5Tokenizer
from
transformers
import
T5Config
,
T5EncoderModel
,
T5ForConditionalGeneration
,
T5Model
,
T5Tokenizer
,
T5TokenizerFast
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer.shard
import
ShardConfig
,
ShardFormer
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
assert_hf_output_close
,
clear_cache_before_run
,
rerun_if_address_is_in_use
,
spawn
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
CONFIG
=
dict
(
parallel
=
dict
(
data
=
1
,
pipeline
=
1
,
tensor
=
dict
(
size
=
2
,
mode
=
'1d'
)),)
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
def
build_model
(
rank
,
world_size
):
config
=
T5Config
.
from_pretrained
(
"t5-small"
)
def
build_model
(
world_size
,
model_fn
):
config
=
T5Config
(
decoder_start_token_id
=
0
)
config
.
dropout_rate
=
0
org_model
=
T5ForConditionalGeneration
.
from_pretrained
(
"t5-small"
,
config
=
config
).
to
(
'cuda'
)
org_model
=
model_fn
(
config
=
config
).
to
(
'cuda'
)
shard_config
=
ShardConfig
(
tensor_parallel_size
=
world_size
)
shardconfig
=
ShardConfig
(
rank
=
rank
,
world_size
=
world_size
,
gather_output
=
True
,
)
org_model_for_shard
=
copy
.
deepcopy
(
org_model
)
sharded_model
=
shard_model
(
org_model_for_shard
,
shardconfig
).
to
(
'cuda'
)
# shard model
shard_config
=
ShardConfig
(
tensor_parallel_size
=
world_size
)
model_copy
=
copy
.
deepcopy
(
org_model
)
shard_former
=
ShardFormer
(
shard_config
=
shard_config
)
shard_former
.
init_distributed
()
sharded_model
=
shard_former
.
shard_model
(
model_copy
)
return
org_model
,
sharded_model
def
check_forward
(
org_model
,
sharded_model
):
input_ids
=
tokenizer
(
"translate English to German: The house is wonderful."
,
return_tensors
=
"pt"
).
input_ids
.
to
(
'cuda'
)
#orgin model
org_model
.
eval
()
org_output
=
org_model
.
generate
(
input_ids
)
#shard model
sharded_model
.
eval
()
shard_output
=
sharded_model
.
generate
(
input_ids
)
assert
torch
.
allclose
(
org_output
[
0
],
shard_output
[
0
],
atol
=
1e-5
),
f
"shard model output is not equal to orgin model output
\n
{
org_out
[
0
]
}
\n
{
shard_out
[
0
]
}
"
def
check_backward
(
org_model
,
sharded_model
):
def
check_forward_backward
(
org_model
,
sharded_model
):
# prepare input
input_ids
=
tokenizer
(
"translate English to German: The house is wonderful."
,
return_tensors
=
"pt"
).
input_ids
.
to
(
'cuda'
)
labels
=
tokenizer
(
"Das Haus ist wunderbar."
,
return_tensors
=
"pt"
).
input_ids
.
to
(
'cuda'
)
#
org
in mode
l
#
switch to tra
in mode
org_model
.
train
()
org_loss
=
org_model
(
input_ids
=
input_ids
,
labels
=
labels
).
loss
org_loss
.
backward
()
org_grad
=
org_model
.
encoder
.
block
[
0
].
layer
[
0
].
SelfAttention
.
q
.
weight
.
grad
#shard model
sharded_model
.
train
()
shard_loss
=
sharded_model
(
input_ids
=
input_ids
,
labels
=
labels
).
loss
if
isinstance
(
org_model
,
T5ForConditionalGeneration
):
org_output
=
org_model
(
input_ids
=
input_ids
,
labels
=
labels
)
org_loss
=
org_output
.
loss
shard_output
=
sharded_model
(
input_ids
=
input_ids
,
labels
=
labels
)
shard_loss
=
shard_output
.
loss
elif
isinstance
(
org_model
,
T5Model
):
decoder_input_ids
=
org_model
.
_shift_right
(
input_ids
)
org_output
=
org_model
(
input_ids
=
input_ids
,
decoder_input_ids
=
decoder_input_ids
)
org_loss
=
org_output
.
last_hidden_state
.
mean
()
shard_output
=
sharded_model
(
input_ids
=
input_ids
,
decoder_input_ids
=
decoder_input_ids
)
shard_loss
=
shard_output
.
last_hidden_state
.
mean
()
elif
isinstance
(
org_model
,
T5EncoderModel
):
org_output
=
org_model
(
input_ids
=
input_ids
)
org_loss
=
org_output
.
last_hidden_state
.
mean
()
shard_output
=
sharded_model
(
input_ids
=
input_ids
)
shard_loss
=
shard_output
.
last_hidden_state
.
mean
()
# key is sharded, so we ignore
assert_hf_output_close
(
org_output
,
shard_output
,
ignore_keys
=
[
'past_key_values'
])
# do backward
org_loss
.
backward
()
shard_loss
.
backward
()
# check grad equality
org_grad
=
org_model
.
encoder
.
block
[
0
].
layer
[
0
].
SelfAttention
.
q
.
weight
.
grad
shard_grad
=
sharded_model
.
encoder
.
block
[
0
].
layer
[
0
].
SelfAttention
.
q
.
weight
.
grad
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
2
)]
...
...
@@ -82,16 +83,21 @@ def check_t5(rank, world_size, port):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
org_model
,
sharded_model
=
build_model
(
rank
,
world_size
)
check_forward
(
org_model
,
sharded_model
)
check_backward
(
org_model
,
sharded_model
)
model_fn_list
=
[
T5Model
,
T5ForConditionalGeneration
,
T5EncoderModel
,
]
torch
.
cuda
.
empty_cache
()
for
model_fn
in
model_fn_list
:
org_model
,
sharded_model
=
build_model
(
world_size
,
model_fn
)
check_forward_backward
(
org_model
,
sharded_model
)
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
skip
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_t5
():
spawn
(
check_t5
,
2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment