Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c1c672d0
Commit
c1c672d0
authored
Jun 15, 2023
by
wukong1992
Committed by
Frank Lee
Jul 04, 2023
Browse files
[shardformer] shardformer support t5 model (#3994)
test t5
parent
6b30dfb7
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
320 additions
and
10 deletions
+320
-10
applications/Chat/coati/trainer/.sft.py.swp
applications/Chat/coati/trainer/.sft.py.swp
+0
-0
colossalai/shardformer/layer/layers.py
colossalai/shardformer/layer/layers.py
+6
-2
colossalai/shardformer/policies/autopolicy.py
colossalai/shardformer/policies/autopolicy.py
+9
-0
colossalai/shardformer/policies/basepolicy.py
colossalai/shardformer/policies/basepolicy.py
+12
-0
colossalai/shardformer/policies/t5.py
colossalai/shardformer/policies/t5.py
+159
-0
colossalai/shardformer/shard/sharder.py
colossalai/shardformer/shard/sharder.py
+8
-3
colossalai/shardformer/shard/slicer.py
colossalai/shardformer/shard/slicer.py
+4
-2
colossalai/shardformer/utils/utils.py
colossalai/shardformer/utils/utils.py
+22
-3
requirements/requirements-test.txt
requirements/requirements-test.txt
+1
-0
tests/test_shardformer/test_model/test_shard_t5.py
tests/test_shardformer/test_model/test_shard_t5.py
+99
-0
No files found.
applications/Chat/coati/trainer/.sft.py.swp
0 → 100644
View file @
c1c672d0
File added
colossalai/shardformer/layer/layers.py
View file @
c1c672d0
...
...
@@ -770,6 +770,7 @@ class Embedding1D(ParallelLayer):
embedding_dim
:
int
,
padding_idx
:
int
=
None
,
dtype
:
torch
.
dtype
=
None
,
gather_output
:
bool
=
True
,
weight_initializer
:
Callable
=
init
.
normal_
(),
*
args
,
**
kwargs
):
...
...
@@ -782,6 +783,7 @@ class Embedding1D(ParallelLayer):
self
.
padding_idx
=
padding_idx
self
.
embed_args
=
args
self
.
embed_kwargs
=
kwargs
self
.
gather_output
=
gather_output
self
.
weight
=
Parameter
(
torch
.
empty
((
num_embeddings
,
embed_dim_per_partition
),
device
=
get_current_device
(),
dtype
=
dtype
))
...
...
@@ -832,8 +834,10 @@ class Embedding1D(ParallelLayer):
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
output_parallel
=
F
.
embedding
(
input_
,
self
.
weight
,
self
.
padding_idx
,
*
self
.
embed_args
,
**
self
.
embed_kwargs
)
output
=
gather_forward_split_backward
(
output_parallel
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
if
self
.
gather_output
:
output
=
gather_forward_split_backward
(
output_parallel
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
else
:
output
=
output_parallel
return
output
...
...
colossalai/shardformer/policies/autopolicy.py
View file @
c1c672d0
...
...
@@ -43,6 +43,15 @@ def build_policies():
from
.gpt2
import
GPT2LMHeadModelPolicy
auto_policy_dict
[
GPT2LMHeadModel
]
=
GPT2LMHeadModelPolicy
from
.t5
import
T5ForConditionalGenerationPolicy
,
T5EncoderModelPolicy
,
T5ModelPolicy
from
transformers
import
T5ForConditionalGeneration
,
T5EncoderModel
,
T5Model
t5
=
{
T5ForConditionalGeneration
:
T5ForConditionalGenerationPolicy
,
T5EncoderModel
:
T5EncoderModelPolicy
,
T5Model
:
T5ModelPolicy
,
}
auto_policy_dict
.
update
(
t5
)
return
auto_policy_dict
...
...
colossalai/shardformer/policies/basepolicy.py
View file @
c1c672d0
...
...
@@ -80,6 +80,18 @@ class Dropout_Layer(Layer):
p
:
str
=
None
@
dataclass
class
Embedding_Layer
(
Layer
):
r
"""
Class for col shard layer in tensor parrallel
Args:
weight (str): The weight suffix of the layer
"""
weight
:
str
=
None
gather_output
:
bool
=
True
class
Policy
():
r
"""
The base class for all the policies
...
...
colossalai/shardformer/policies/t5.py
0 → 100644
View file @
c1c672d0
from
typing
import
Dict
import
torch.nn
as
nn
from
torch.nn
import
Embedding
from
transformers.models.t5.modeling_t5
import
(
T5Attention
,
T5Block
,
T5DenseActDense
,
T5DenseGatedActDense
,
T5LayerCrossAttention
,
T5LayerFF
,
T5LayerSelfAttention
,
T5Model
,
T5Stack
,
)
import
colossalai.shardformer.layer.layers
as
col_nn
from
.basepolicy
import
Argument
,
Col_Layer
,
Dropout_Layer
,
Embedding_Layer
,
Policy
,
Row_Layer
class
T5ModelPolicy
(
Policy
):
@
staticmethod
def
argument_policy
(
config
,
world_size
:
int
)
->
Dict
[
nn
.
Module
,
Argument
]:
print
(
'config heads'
,
config
.
num_heads
)
return
{
T5Stack
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
T5ModelPolicy
.
dropout
,
T5ModelPolicy
.
embedding
]),
T5Block
:
Argument
(
attr_dict
=
{},
param_funcs
=
[]),
T5LayerSelfAttention
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
T5ModelPolicy
.
dropout
]),
T5LayerCrossAttention
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
T5ModelPolicy
.
dropout
]),
T5Attention
:
Argument
(
attr_dict
=
{
"d_model"
:
config
.
d_model
//
world_size
,
"n_heads"
:
config
.
num_heads
//
world_size
,
"inner_dim"
:
config
.
num_heads
*
config
.
d_kv
//
world_size
,
},
param_funcs
=
[
T5ModelPolicy
.
attn_layer
]),
T5LayerFF
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
T5ModelPolicy
.
dropout
]),
T5DenseGatedActDense
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
T5ModelPolicy
.
dropout
,
T5ModelPolicy
.
dense_gated_layer
]),
T5DenseActDense
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
T5ModelPolicy
.
dropout
,
T5ModelPolicy
.
dense_act_layer
]),
}
@
staticmethod
def
dense_gated_layer
():
return
[
Col_Layer
(
suffix
=
"wi_0"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
),
Row_Layer
(
suffix
=
"wi_1"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Linear1D_Row
,
),
Col_Layer
(
suffix
=
"wo"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
gather_output
=
True
)
]
@
staticmethod
def
dense_act_layer
():
return
[
Col_Layer
(
suffix
=
"wi"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
),
Row_Layer
(
suffix
=
"wo"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Linear1D_Row
,
)
]
@
staticmethod
def
attn_layer
():
return
[
Col_Layer
(
suffix
=
"q"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
),
Col_Layer
(
suffix
=
"k"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
),
Col_Layer
(
suffix
=
"v"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
),
Row_Layer
(
suffix
=
"o"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Row
,
),
]
@
staticmethod
def
dropout
():
return
[
Dropout_Layer
(
suffix
=
"dropout"
,
p
=
"p"
,
replace_layer
=
col_nn
.
Dropout1D
,
)]
@
staticmethod
def
embedding
():
return
[
Embedding_Layer
(
suffix
=
"block[0].layer[0].SelfAttention.relative_attention_bias"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Embedding1D
,
gather_output
=
False
,
)
]
from
transformers
import
T5ForConditionalGeneration
class
T5ForConditionalGenerationPolicy
(
T5ModelPolicy
):
@
staticmethod
def
argument_policy
(
config
,
world_size
):
base_argument
=
T5ModelPolicy
.
argument_policy
(
config
,
world_size
)
argument
=
{
T5ForConditionalGeneration
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
T5ForConditionalGenerationPolicy
.
lm_head
])
}
argument
.
update
(
base_argument
)
return
argument
@
staticmethod
def
lm_head
():
return
[
Col_Layer
(
suffix
=
"lm_head"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
gather_output
=
True
,
)]
from
transformers
import
T5EncoderModel
class
T5EncoderModelPolicy
(
T5ModelPolicy
):
pass
colossalai/shardformer/shard/sharder.py
View file @
c1c672d0
...
...
@@ -5,7 +5,7 @@ import torch.nn as nn
from
transformers.pytorch_utils
import
Conv1D
from
..policies.autopolicy
import
get_autopolicy
from
..policies.basepolicy
import
Col_Layer
,
Dropout_Layer
,
Policy
,
Row_Layer
from
..policies.basepolicy
import
Col_Layer
,
Dropout_Layer
,
Policy
,
Row_Layer
,
Embedding_Layer
from
..utils.utils
import
getattr_
,
hasattr_
,
setattr_
from
.shard_config
import
ShardConfig
from
.slicer
import
Slicer
...
...
@@ -155,11 +155,11 @@ class ModelSharder(object):
assert
suffix_layer
is
not
None
or
ignore
,
f
"Layer
{
org_layer
.
__class__
.
__qualname__
}
has no attribute
{
suffix
}
"
if
suffix_layer
is
None
and
ignore
:
continue
if
isinstance
(
policy_layer
,
(
Col_Layer
,
Row_Layer
)):
if
isinstance
(
policy_layer
,
(
Col_Layer
,
Row_Layer
,
Embedding_Layer
)):
weight
=
None
bias
=
None
weight_attr
=
suffix
+
'.'
+
policy_layer
.
weight
if
policy_layer
.
weight
is
not
None
else
None
bias_attr
=
suffix
+
'.'
+
policy_layer
.
bias
if
policy_layer
.
bias
is
not
None
else
None
bias_attr
=
suffix
+
'.'
+
policy_layer
.
bias
if
hasattr
(
policy_layer
,
'bias'
)
and
policy_layer
.
bias
is
not
None
else
None
if
weight_attr
is
not
None
:
if
hasattr_
(
org_layer
,
weight_attr
):
...
...
@@ -189,6 +189,11 @@ class ModelSharder(object):
weight
.
shape
[
1
],
bias
=
False
if
bias
is
None
else
True
,
gather_output
=
gather_output
)
elif
replace_layer_cls
.
__name__
==
"Embedding1D"
:
gather_output
=
policy_layer
.
gather_output
replace_layer
=
replace_layer_cls
(
weight
.
shape
[
0
],
weight
.
shape
[
1
],
gather_output
=
gather_output
)
elif
replace_layer_cls
.
__name__
==
"VocabParallelEmbedding1D"
:
replace_layer
=
replace_layer_cls
(
weight
.
shape
[
0
],
weight
.
shape
[
1
],
getattr_
(
org_layer
,
f
"
{
suffix
}
.padding_idx"
,
ignore
=
True
))
...
...
colossalai/shardformer/shard/slicer.py
View file @
c1c672d0
import
torch
from
..policies.basepolicy
import
Col_Layer
,
Dropout_Layer
,
Layer
,
Row_Layer
from
..policies.basepolicy
import
Col_Layer
,
Dropout_Layer
,
Layer
,
Row_Layer
,
Embedding_Layer
from
.shard_config
import
ShardConfig
dim_mapping
=
{
Col_Layer
:
0
,
Row_Layer
:
1
}
dim_mapping
=
{
Col_Layer
:
0
,
Row_Layer
:
1
,
Embedding_Layer
:
1
}
class
Slicer
():
...
...
@@ -43,6 +43,8 @@ class Slicer():
bias
=
self
.
slice_tensor
(
bias
,
0
,
True
,
n_cast
)
elif
policy_layer_cls
==
Row_Layer
:
weight
=
self
.
slice_tensor
(
weight
,
dim
,
False
,
n_cast
)
elif
policy_layer_cls
==
Embedding_Layer
:
weight
=
self
.
slice_tensor
(
weight
,
dim
,
False
,
n_cast
)
else
:
raise
NotImplementedError
(
f
"The policy layer class
{
policy_layer_cls
}
is not supported"
)
if
reversed
:
...
...
colossalai/shardformer/utils/utils.py
View file @
c1c672d0
import
re
def
get_obj_list_element
(
obj
,
a
):
re_pattern
=
r
'\[\d+\]'
prog
=
re
.
compile
(
re_pattern
)
result
=
prog
.
search
(
a
)
if
result
:
matched_brackets
=
result
.
group
()
matched_index
=
matched_brackets
.
replace
(
'['
,
''
)
matched_index
=
matched_index
.
replace
(
']'
,
''
)
a_
=
a
.
replace
(
matched_brackets
,
''
)
container_obj
=
getattr
(
obj
,
a_
)
obj
=
container_obj
[
int
(
matched_index
)]
else
:
obj
=
getattr
(
obj
,
a
)
return
obj
def
hasattr_
(
obj
,
attr
:
str
):
r
"""
Check whether the object has the multi sublevel attr
...
...
@@ -9,7 +28,7 @@ def hasattr_(obj, attr: str):
attrs
=
attr
.
split
(
'.'
)
for
a
in
attrs
:
try
:
obj
=
get
attr
(
obj
,
a
)
obj
=
get
_obj_list_element
(
obj
,
a
)
except
AttributeError
:
return
False
return
True
...
...
@@ -29,7 +48,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False):
attrs
=
attr
.
split
(
'.'
)
for
a
in
attrs
[:
-
1
]:
try
:
obj
=
get
attr
(
obj
,
a
)
obj
=
get
_obj_list_element
(
obj
,
a
)
except
AttributeError
:
if
ignore
:
return
...
...
@@ -50,7 +69,7 @@ def getattr_(obj, attr: str, ignore: bool = False):
attrs
=
attr
.
split
(
'.'
)
for
a
in
attrs
:
try
:
obj
=
get
attr
(
obj
,
a
)
obj
=
get
_obj_list_element
(
obj
,
a
)
except
AttributeError
:
if
ignore
:
return
None
...
...
requirements/requirements-test.txt
View file @
c1c672d0
...
...
@@ -15,3 +15,4 @@ einops
triton==2.0.0.dev20221202
git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
SentencePiece
tests/test_shardformer/test_model/test_shard_t5.py
0 → 100644
View file @
c1c672d0
import
copy
import
os
import
random
import
pytest
import
torch
from
transformers
import
AutoTokenizer
,
BertConfig
,
BertForMaskedLM
,
T5Config
,
T5ForConditionalGeneration
,
T5Tokenizer
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer.shard
import
ShardConfig
,
shard_model
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
CONFIG
=
dict
(
parallel
=
dict
(
data
=
1
,
pipeline
=
1
,
tensor
=
dict
(
size
=
2
,
mode
=
'1d'
)),)
tokenizer
=
T5Tokenizer
.
from_pretrained
(
"t5-small"
)
def
build_model
(
rank
,
world_size
):
config
=
T5Config
.
from_pretrained
(
"t5-small"
)
config
.
dropout_rate
=
0
org_model
=
T5ForConditionalGeneration
.
from_pretrained
(
"t5-small"
,
config
=
config
).
to
(
'cuda'
)
shardconfig
=
ShardConfig
(
rank
=
rank
,
world_size
=
world_size
,
gather_output
=
True
,
)
org_model_for_shard
=
copy
.
deepcopy
(
org_model
)
sharded_model
=
shard_model
(
org_model_for_shard
,
shardconfig
).
to
(
'cuda'
)
return
org_model
,
sharded_model
def
check_forward
(
org_model
,
sharded_model
):
input_ids
=
tokenizer
(
"translate English to German: The house is wonderful."
,
return_tensors
=
"pt"
).
input_ids
.
to
(
'cuda'
)
#orgin model
org_model
.
eval
()
org_output
=
org_model
.
generate
(
input_ids
)
#shard model
sharded_model
.
eval
()
shard_output
=
sharded_model
.
generate
(
input_ids
)
assert
torch
.
allclose
(
org_output
[
0
],
shard_output
[
0
],
atol
=
1e-5
),
f
"shard model output is not equal to orgin model output
\n
{
org_out
[
0
]
}
\n
{
shard_out
[
0
]
}
"
def
check_backward
(
org_model
,
sharded_model
):
# prepare input
input_ids
=
tokenizer
(
"translate English to German: The house is wonderful."
,
return_tensors
=
"pt"
).
input_ids
.
to
(
'cuda'
)
labels
=
tokenizer
(
"Das Haus ist wunderbar."
,
return_tensors
=
"pt"
).
input_ids
.
to
(
'cuda'
)
#orgin model
org_model
.
train
()
org_loss
=
org_model
(
input_ids
=
input_ids
,
labels
=
labels
).
loss
org_loss
.
backward
()
org_grad
=
org_model
.
encoder
.
block
[
0
].
layer
[
0
].
SelfAttention
.
q
.
weight
.
grad
#shard model
sharded_model
.
train
()
shard_loss
=
sharded_model
(
input_ids
=
input_ids
,
labels
=
labels
).
loss
shard_loss
.
backward
()
shard_grad
=
sharded_model
.
encoder
.
block
[
0
].
layer
[
0
].
SelfAttention
.
q
.
weight
.
grad
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
2
)]
shard_grad
=
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
assert
torch
.
allclose
(
org_loss
,
shard_loss
,
atol
=
1e-5
),
f
"shard model loss is not equal to orgin model loss
\n
{
org_loss
}
\n
{
shard_loss
}
"
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
shard_grad
}
"
def
check_t5
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
org_model
,
sharded_model
=
build_model
(
rank
,
world_size
)
check_forward
(
org_model
,
sharded_model
)
check_backward
(
org_model
,
sharded_model
)
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_t5
():
spawn
(
check_t5
,
2
)
if
__name__
==
"__main__"
:
test_t5
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment