Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ae035d30
Commit
ae035d30
authored
Jun 30, 2023
by
Frank Lee
Browse files
[shardformer] added embedding gradient check (#4124)
parent
44a190e6
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
253 additions
and
72 deletions
+253
-72
colossalai/shardformer/_utils.py
colossalai/shardformer/_utils.py
+2
-2
colossalai/shardformer/policies/bert.py
colossalai/shardformer/policies/bert.py
+1
-1
colossalai/shardformer/policies/bloom.py
colossalai/shardformer/policies/bloom.py
+16
-3
colossalai/shardformer/policies/opt.py
colossalai/shardformer/policies/opt.py
+15
-2
colossalai/shardformer/policies/t5.py
colossalai/shardformer/policies/t5.py
+87
-18
colossalai/shardformer/shard/sharder.py
colossalai/shardformer/shard/sharder.py
+0
-11
tests/kit/model_zoo/registry.py
tests/kit/model_zoo/registry.py
+2
-0
tests/test_shardformer/test_model/test_shard_bert.py
tests/test_shardformer/test_model/test_shard_bert.py
+22
-7
tests/test_shardformer/test_model/test_shard_bloom.py
tests/test_shardformer/test_model/test_shard_bloom.py
+23
-7
tests/test_shardformer/test_model/test_shard_gpt2.py
tests/test_shardformer/test_model/test_shard_gpt2.py
+23
-7
tests/test_shardformer/test_model/test_shard_llama.py
tests/test_shardformer/test_model/test_shard_llama.py
+13
-3
tests/test_shardformer/test_model/test_shard_opt.py
tests/test_shardformer/test_model/test_shard_opt.py
+19
-5
tests/test_shardformer/test_model/test_shard_t5.py
tests/test_shardformer/test_model/test_shard_t5.py
+29
-6
tests/test_shardformer/test_model/test_shard_vit.py
tests/test_shardformer/test_model/test_shard_vit.py
+1
-0
No files found.
colossalai/shardformer/_utils.py
View file @
ae035d30
...
@@ -55,7 +55,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False):
...
@@ -55,7 +55,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False):
except
AttributeError
:
except
AttributeError
:
if
ignore
:
if
ignore
:
return
return
raise
AttributeError
(
f
"Object
{
obj
}
has no attribute
{
attr
}
"
)
raise
AttributeError
(
f
"Object
{
obj
.
__class__
.
__name__
}
has no attribute
{
attr
}
"
)
setattr
(
obj
,
attrs
[
-
1
],
value
)
setattr
(
obj
,
attrs
[
-
1
],
value
)
...
@@ -76,5 +76,5 @@ def getattr_(obj, attr: str, ignore: bool = False):
...
@@ -76,5 +76,5 @@ def getattr_(obj, attr: str, ignore: bool = False):
except
AttributeError
:
except
AttributeError
:
if
ignore
:
if
ignore
:
return
None
return
None
raise
AttributeError
(
f
"Object
{
obj
}
has no attribute
{
attr
}
"
)
raise
AttributeError
(
f
"Object
{
obj
.
__class__
.
__name__
}
has no attribute
{
attr
}
"
)
return
obj
return
obj
colossalai/shardformer/policies/bert.py
View file @
ae035d30
...
@@ -97,7 +97,7 @@ class BertPolicy(Policy):
...
@@ -97,7 +97,7 @@ class BertPolicy(Policy):
),
),
SubModuleReplacementDescription
(
SubModuleReplacementDescription
(
suffix
=
"dropout"
,
suffix
=
"dropout"
,
target_module
=
col_nn
.
DropoutFor
Parallel
Input
,
target_module
=
col_nn
.
DropoutFor
Replicated
Input
,
)
)
])
])
}
}
...
...
colossalai/shardformer/policies/bloom.py
View file @
ae035d30
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
colossalai.shardformer.layer
as
col_nn
import
colossalai.shardformer.layer
as
col_nn
from
.._utils
import
getattr_
,
setattr_
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
...
@@ -73,7 +75,6 @@ class BloomPolicy(Policy):
...
@@ -73,7 +75,6 @@ class BloomPolicy(Policy):
r
"""
r
"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
"""
# TODO:
vocab_size
=
self
.
model
.
config
.
vocab_size
vocab_size
=
self
.
model
.
config
.
vocab_size
world_size
=
self
.
shard_config
.
tensor_parallel_size
world_size
=
self
.
shard_config
.
tensor_parallel_size
if
vocab_size
%
world_size
!=
0
:
if
vocab_size
%
world_size
!=
0
:
...
@@ -161,13 +162,12 @@ class BloomPolicy(Policy):
...
@@ -161,13 +162,12 @@ class BloomPolicy(Policy):
def
new_model_class
(
self
):
def
new_model_class
(
self
):
# do nothing
# do nothing
return
self
.
model
return
None
def
postprocess
(
self
):
def
postprocess
(
self
):
return
self
.
model
return
self
.
model
# BertModel
class
BloomModelPolicy
(
BloomPolicy
):
class
BloomModelPolicy
(
BloomPolicy
):
pass
pass
...
@@ -191,6 +191,19 @@ class BloomForCausalLMPolicy(BloomPolicy):
...
@@ -191,6 +191,19 @@ class BloomForCausalLMPolicy(BloomPolicy):
policy
.
update
(
new_item
)
policy
.
update
(
new_item
)
return
policy
return
policy
def
postprocess
(
self
):
binding_map
=
{
"transformer.word_embeddings.weight"
:
"lm_head.weight"
}
for
k
,
v
in
binding_map
.
items
():
param
=
getattr_
(
self
.
model
,
k
)
if
not
isinstance
(
param
,
nn
.
Parameter
):
param
=
nn
.
Parameter
(
param
)
# tie weights
setattr_
(
self
.
model
,
k
,
param
)
setattr_
(
self
.
model
,
v
,
param
)
return
self
.
model
class
BloomForSequenceClassificationPolicy
(
BloomPolicy
):
class
BloomForSequenceClassificationPolicy
(
BloomPolicy
):
...
...
colossalai/shardformer/policies/opt.py
View file @
ae035d30
from
colossalai.shardformer.layer
import
Embedding1D
,
FusedLayerNorm
,
Linear1D_Col
,
Linear1D_Row
from
colossalai.shardformer.layer
import
FusedLayerNorm
,
Linear1D_Col
,
Linear1D_Row
,
VocabParallelEmbedding1D
from
.._utils
import
getattr_
,
setattr_
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
__all__
=
[
__all__
=
[
...
@@ -35,7 +36,7 @@ class OPTPolicy(Policy):
...
@@ -35,7 +36,7 @@ class OPTPolicy(Policy):
sub_module_replacement
=
[
sub_module_replacement
=
[
SubModuleReplacementDescription
(
SubModuleReplacementDescription
(
suffix
=
"embed_tokens"
,
suffix
=
"embed_tokens"
,
target_module
=
Embedding1D
,
target_module
=
VocabParallel
Embedding1D
,
)
)
]),
]),
OPTDecoderLayer
:
OPTDecoderLayer
:
...
@@ -127,6 +128,18 @@ class OPTForCausalLMPolicy(OPTPolicy):
...
@@ -127,6 +128,18 @@ class OPTForCausalLMPolicy(OPTPolicy):
policy
.
update
(
new_item
)
policy
.
update
(
new_item
)
return
policy
return
policy
def
postprocess
(
self
):
binding_map
=
{
'model.decoder.embed_tokens'
:
'lm_head'
,
}
for
k
,
v
in
binding_map
.
items
():
src_mod
=
getattr_
(
self
.
model
,
k
)
dst_mod
=
getattr_
(
self
.
model
,
v
)
dst_mod
.
weight
=
src_mod
.
weight
return
self
.
model
class
OPTForSequenceClassificationPolicy
(
OPTPolicy
):
class
OPTForSequenceClassificationPolicy
(
OPTPolicy
):
...
...
colossalai/shardformer/policies/t5.py
View file @
ae035d30
from
colossalai.shardformer.layer
import
DropoutForParallelInput
,
Embedding1D
,
Linear1D_Col
,
Linear1D_Row
from
colossalai.shardformer.layer
import
(
DropoutForParallelInput
,
Embedding1D
,
FusedRMSNorm
,
Linear1D_Col
,
Linear1D_Row
,
VocabParallelEmbedding1D
,
)
from
colossalai.shardformer.policies.basepolicy
import
ModulePolicyDescription
from
.._utils
import
getattr_
,
setattr_
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
from
.basepolicy
import
ModulePolicyDescription
,
Policy
,
SubModuleReplacementDescription
__all__
=
[
"T5ModelPolicy"
,
"T5ForConditionalGenerationPolicy"
,
"T5EncoderPolicy"
]
__all__
=
[
"T5ModelPolicy"
,
"T5ForConditionalGenerationPolicy"
,
"T5EncoderPolicy"
]
class
T5
Model
Policy
(
Policy
):
class
T5
Base
Policy
(
Policy
):
def
config_sanity_check
(
self
):
def
config_sanity_check
(
self
):
pass
pass
...
@@ -33,7 +42,7 @@ class T5ModelPolicy(Policy):
...
@@ -33,7 +42,7 @@ class T5ModelPolicy(Policy):
T5Stack
,
T5Stack
,
)
)
return
{
base_policy
=
{
T5Stack
:
T5Stack
:
ModulePolicyDescription
(
attribute_replacement
=
{},
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
param_replacement
=
[],
...
@@ -41,6 +50,10 @@ class T5ModelPolicy(Policy):
...
@@ -41,6 +50,10 @@ class T5ModelPolicy(Policy):
SubModuleReplacementDescription
(
SubModuleReplacementDescription
(
suffix
=
"dropout"
,
suffix
=
"dropout"
,
target_module
=
DropoutForParallelInput
,
target_module
=
DropoutForParallelInput
,
),
SubModuleReplacementDescription
(
suffix
=
"embed_tokens"
,
target_module
=
Embedding1D
,
)
)
]),
]),
T5LayerSelfAttention
:
T5LayerSelfAttention
:
...
@@ -158,30 +171,86 @@ class T5ModelPolicy(Policy):
...
@@ -158,30 +171,86 @@ class T5ModelPolicy(Policy):
return
None
return
None
def
postprocess
(
self
):
def
postprocess
(
self
):
binding_map
=
[[
"shared"
,
"encoder.embed_tokens"
],
[
"shared"
,
"decoder.embed_tokens"
]]
for
k
,
v
in
binding_map
:
mod
=
getattr_
(
self
.
model
,
k
)
setattr_
(
self
.
model
,
v
,
mod
)
return
self
.
model
return
self
.
model
class
T5ForConditionalGenerationPolicy
(
T5ModelPolicy
):
class
T5ModelPolicy
(
T5BasePolicy
):
def
module_policy
(
self
):
from
transformers
import
T5Model
base_policy
=
super
().
module_policy
()
base_policy
[
T5Model
]
=
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"shared"
,
target_module
=
VocabParallelEmbedding1D
,
)
])
return
base_policy
class
T5ForConditionalGenerationPolicy
(
T5BasePolicy
):
def
module_policy
(
self
):
def
module_policy
(
self
):
from
transformers
import
T5ForConditionalGeneration
from
transformers
import
T5ForConditionalGeneration
policy
=
super
().
module_policy
()
policy
=
super
().
module_policy
()
policy
[
T5ForConditionalGeneration
]
=
ModulePolicyDescription
(
attribute_replacement
=
{},
new_item
=
{
T5ForConditionalGeneration
:
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
param_replacement
=
[],
sub_module_replacement
=
[
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"lm_head"
,
SubModuleReplacementDescription
(
suffix
=
"shared"
,
target_module
=
VocabParallelEmbedding1D
,
),
SubModuleReplacementDescription
(
suffix
=
"lm_head"
,
target_module
=
Linear1D_Col
,
target_module
=
Linear1D_Col
,
kwargs
=
dict
(
gather_output
=
True
))
kwargs
=
dict
(
gather_output
=
True
))
])
])
}
policy
.
update
(
new_item
)
return
policy
return
policy
def
postprocess
(
self
):
super
().
postprocess
()
class
T5EncoderPolicy
(
T5ModelPolicy
):
binding_map
=
{
"shared"
:
"lm_head"
}
pass
for
k
,
v
in
binding_map
.
items
():
src_mod
=
getattr_
(
self
.
model
,
k
)
dst_mod
=
getattr_
(
self
.
model
,
v
)
dst_mod
.
weight
=
src_mod
.
weight
return
self
.
model
class
T5EncoderPolicy
(
T5BasePolicy
):
def
module_policy
(
self
):
from
transformers
import
T5EncoderModel
base_policy
=
super
().
module_policy
()
base_policy
[
T5EncoderModel
]
=
ModulePolicyDescription
(
attribute_replacement
=
{},
param_replacement
=
[],
sub_module_replacement
=
[
SubModuleReplacementDescription
(
suffix
=
"shared"
,
target_module
=
VocabParallelEmbedding1D
,
)
])
return
base_policy
def
postprocess
(
self
):
binding_map
=
[
[
"shared"
,
"encoder.embed_tokens"
],
]
for
k
,
v
in
binding_map
:
mod
=
getattr_
(
self
.
model
,
k
)
setattr_
(
self
.
model
,
v
,
mod
)
return
self
.
model
colossalai/shardformer/shard/sharder.py
View file @
ae035d30
...
@@ -38,17 +38,6 @@ class ModelSharder(object):
...
@@ -38,17 +38,6 @@ class ModelSharder(object):
self
.
_replace_module
()
self
.
_replace_module
()
self
.
_postprocess
()
self
.
_postprocess
()
def
reshape_embedding
(
self
)
->
None
:
r
"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
vocab_size
=
self
.
model_config
.
vocab_size
world_size
=
self
.
shard_config
.
world_size
if
vocab_size
%
world_size
!=
0
:
new_vocab_size
=
vocab_size
+
world_size
-
vocab_size
%
world_size
self
.
model
.
resize_token_embeddings
(
new_vocab_size
)
self
.
model_config
=
self
.
model
.
config
def
_preprocess
(
self
)
->
None
:
def
_preprocess
(
self
)
->
None
:
self
.
model
=
self
.
policy
.
preprocess
()
self
.
model
=
self
.
policy
.
preprocess
()
...
...
tests/kit/model_zoo/registry.py
View file @
ae035d30
...
@@ -70,6 +70,8 @@ class ModelZooRegistry(dict):
...
@@ -70,6 +70,8 @@ class ModelZooRegistry(dict):
for
k
,
v
in
self
.
items
():
for
k
,
v
in
self
.
items
():
if
keyword
in
k
:
if
keyword
in
k
:
new_dict
[
k
]
=
v
new_dict
[
k
]
=
v
assert
len
(
new_dict
)
>
0
,
f
'No model found with keyword
{
keyword
}
'
return
new_dict
return
new_dict
...
...
tests/test_shardformer/test_model/test_shard_bert.py
View file @
ae035d30
...
@@ -18,20 +18,35 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
...
@@ -18,20 +18,35 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
org_loss
.
backward
()
org_loss
.
backward
()
shard_loss
.
backward
()
shard_loss
.
backward
()
# check grad equality
assert
torch
.
allclose
(
org_loss
,
shard_loss
,
atol
=
1e-5
),
f
"shard model loss is not equal to orgin model loss
\n
{
org_loss
}
\n
{
shard_loss
}
"
# check grad
if
org_model
.
__class__
.
__name__
==
'BertModel'
:
if
org_model
.
__class__
.
__name__
==
'BertModel'
:
org_grad
=
org_model
.
encoder
.
layer
[
0
].
attention
.
self
.
query
.
weight
.
grad
bert
=
org_model
shard
_grad
=
sharded_model
.
encoder
.
layer
[
0
].
attention
.
self
.
query
.
weight
.
grad
shard
ed_bert
=
sharded_model
else
:
else
:
org_grad
=
org_model
.
bert
.
encoder
.
layer
[
0
].
attention
.
self
.
query
.
weight
.
grad
bert
=
org_model
.
bert
shard_grad
=
sharded_model
.
bert
.
encoder
.
layer
[
0
].
attention
.
self
.
query
.
weight
.
grad
sharded_bert
=
sharded_model
.
bert
# compare self attention grad
org_grad
=
bert
.
encoder
.
layer
[
0
].
attention
.
self
.
query
.
weight
.
grad
shard_grad
=
sharded_bert
.
encoder
.
layer
[
0
].
attention
.
self
.
query
.
weight
.
grad
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
2
)]
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
2
)]
shard_grad
=
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
shard_grad
=
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
all_shard_grad
}
"
assert
torch
.
allclose
(
org_loss
,
shard_loss
,
# compare embedding grad
atol
=
1e-5
),
f
"shard model loss is not equal to orgin model loss
\n
{
org_loss
}
\n
{
shard_loss
}
"
org_grad
=
bert
.
embeddings
.
word_embeddings
.
weight
.
grad
shard_grad
=
sharded_bert
.
embeddings
.
word_embeddings
.
weight
.
grad
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
2
)]
shard_grad
=
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
all_shard_grad
}
"
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
all_shard_grad
}
"
...
...
tests/test_shardformer/test_model/test_shard_bloom.py
View file @
ae035d30
...
@@ -18,20 +18,36 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
...
@@ -18,20 +18,36 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
org_loss
.
backward
()
org_loss
.
backward
()
shard_loss
.
backward
()
shard_loss
.
backward
()
# check grad equality
assert
torch
.
allclose
(
org_loss
,
shard_loss
,
atol
=
1e-5
),
f
"shard model loss is not equal to orgin model loss
\n
{
org_loss
}
\n
{
shard_loss
}
"
# unwrap model
if
org_model
.
__class__
.
__name__
==
'BloomModel'
:
if
org_model
.
__class__
.
__name__
==
'BloomModel'
:
org_grad
=
org_model
.
h
[
0
].
self_attention
.
query_key_value
.
weight
.
grad
bloom
=
org_model
shard
_grad
=
sharded_model
.
h
[
0
].
self_attention
.
query_key_value
.
weight
.
grad
shard
ed_bloom
=
sharded_model
else
:
else
:
org_grad
=
org_model
.
transformer
.
h
[
0
].
self_attention
.
query_key_value
.
weight
.
grad
bloom
=
org_model
.
transformer
shard_grad
=
sharded_model
.
transformer
.
h
[
0
].
self_attention
.
query_key_value
.
weight
.
grad
sharded_bloom
=
sharded_model
.
transformer
# check attention grad
org_grad
=
bloom
.
h
[
0
].
self_attention
.
query_key_value
.
weight
.
grad
shard_grad
=
sharded_bloom
.
h
[
0
].
self_attention
.
query_key_value
.
weight
.
grad
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
2
)]
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
all_shard_grad
}
"
# check embedding weights
org_grad
=
bloom
.
word_embeddings
.
weight
.
grad
shard_grad
=
sharded_bloom
.
word_embeddings
.
weight
.
grad
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
2
)]
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
2
)]
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
assert
torch
.
allclose
(
org_loss
,
shard_loss
,
atol
=
1e-5
),
f
"shard model loss is not equal to orgin model loss
\n
{
org_loss
}
\n
{
shard_loss
}
"
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
all_shard_grad
}
"
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
all_shard_grad
}
"
...
...
tests/test_shardformer/test_model/test_shard_gpt2.py
View file @
ae035d30
...
@@ -18,20 +18,36 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
...
@@ -18,20 +18,36 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
org_loss
.
backward
()
org_loss
.
backward
()
shard_loss
.
backward
()
shard_loss
.
backward
()
# check grad equality
assert
torch
.
allclose
(
org_loss
,
shard_loss
,
atol
=
1e-5
),
f
"shard model loss is not equal to origin model loss
\n
{
org_loss
}
\n
{
shard_loss
}
"
# unwrap model
if
org_model
.
__class__
.
__name__
==
'GPT2Model'
:
if
org_model
.
__class__
.
__name__
==
'GPT2Model'
:
org_model
=
org_model
sharded_model
=
sharded_model
else
:
org_model
=
org_model
.
transformer
sharded_model
=
sharded_model
.
transformer
# check mlp grad
org_grad
=
org_model
.
h
[
0
].
mlp
.
c_fc
.
weight
.
grad
org_grad
=
org_model
.
h
[
0
].
mlp
.
c_fc
.
weight
.
grad
shard_grad
=
sharded_model
.
h
[
0
].
mlp
.
c_fc
.
weight
.
grad
shard_grad
=
sharded_model
.
h
[
0
].
mlp
.
c_fc
.
weight
.
grad
else
:
org_grad
=
org_model
.
transformer
.
h
[
0
].
mlp
.
c_fc
.
weight
.
grad
shard_grad
=
sharded_model
.
transformer
.
h
[
0
].
mlp
.
c_fc
.
weight
.
grad
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
2
)]
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
2
)]
shard_grad
=
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
shard_grad
=
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
1
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
1
)
assert
torch
.
allclose
(
org_loss
,
shard_loss
,
assert
torch
.
allclose
(
atol
=
1e-5
),
f
"shard model loss is not equal to origin model loss
\n
{
org_loss
}
\n
{
shard_loss
}
"
org_grad
,
all_shard_grad
,
atol
=
1e-5
),
f
"shard model grad is not equal to origin model grad
\n
{
org_grad
}
\n
{
all_shard_grad
}
"
# check embedding weights
org_grad
=
org_model
.
wte
.
weight
.
grad
shard_grad
=
sharded_model
.
wte
.
weight
.
grad
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
2
)]
shard_grad
=
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
assert
torch
.
allclose
(
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
org_grad
,
all_shard_grad
,
atol
=
1e-5
),
f
"shard model grad is not equal to origin model grad
\n
{
org_grad
}
\n
{
all_shard_grad
}
"
atol
=
1e-5
),
f
"shard model grad is not equal to origin model grad
\n
{
org_grad
}
\n
{
all_shard_grad
}
"
...
...
tests/test_shardformer/test_model/test_shard_llama.py
View file @
ae035d30
...
@@ -23,7 +23,10 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
...
@@ -23,7 +23,10 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
org_loss
.
backward
()
org_loss
.
backward
()
shard_loss
.
backward
()
shard_loss
.
backward
()
# check grad
assert
torch
.
allclose
(
org_loss
,
shard_loss
,
atol
=
1e-5
),
f
"shard model loss is not equal to orgin model loss
\n
{
org_loss
}
\n
{
shard_loss
}
"
# unwrap model
if
hasattr
(
org_model
,
'model'
):
if
hasattr
(
org_model
,
'model'
):
llama_model
=
org_model
.
model
llama_model
=
org_model
.
model
shard_llama_model
=
sharded_model
.
model
shard_llama_model
=
sharded_model
.
model
...
@@ -31,14 +34,21 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
...
@@ -31,14 +34,21 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
llama_model
=
org_model
llama_model
=
org_model
shard_llama_model
=
sharded_model
shard_llama_model
=
sharded_model
# check attention grad
org_grad
=
llama_model
.
layers
[
0
].
self_attn
.
q_proj
.
weight
.
grad
org_grad
=
llama_model
.
layers
[
0
].
self_attn
.
q_proj
.
weight
.
grad
shard_grad
=
shard_llama_model
.
layers
[
0
].
self_attn
.
q_proj
.
weight
.
grad
shard_grad
=
shard_llama_model
.
layers
[
0
].
self_attn
.
q_proj
.
weight
.
grad
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
4
)]
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
4
)]
shard_grad
=
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
shard_grad
=
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
shard_grad
}
"
assert
torch
.
allclose
(
org_loss
,
shard_loss
,
# check embedding grad
atol
=
1e-5
),
f
"shard model loss is not equal to orgin model loss
\n
{
org_loss
}
\n
{
shard_loss
}
"
org_grad
=
llama_model
.
embed_tokens
.
weight
.
grad
shard_grad
=
shard_llama_model
.
embed_tokens
.
weight
.
grad
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
4
)]
shard_grad
=
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
shard_grad
}
"
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
shard_grad
}
"
...
...
tests/test_shardformer/test_model/test_shard_opt.py
View file @
ae035d30
...
@@ -28,7 +28,10 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
...
@@ -28,7 +28,10 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
org_loss
.
backward
()
org_loss
.
backward
()
shard_loss
.
backward
()
shard_loss
.
backward
()
# check grad
assert
torch
.
allclose
(
org_loss
,
shard_loss
,
atol
=
1e-5
),
f
"shard model loss is not equal to orgin model loss
\n
{
org_loss
}
\n
{
shard_loss
}
"
# unwrap model
if
hasattr
(
org_model
,
'model'
):
if
hasattr
(
org_model
,
'model'
):
opt_model
=
org_model
.
model
opt_model
=
org_model
.
model
shard_opt_model
=
sharded_model
.
model
shard_opt_model
=
sharded_model
.
model
...
@@ -36,16 +39,23 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
...
@@ -36,16 +39,23 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
opt_model
=
org_model
opt_model
=
org_model
shard_opt_model
=
sharded_model
shard_opt_model
=
sharded_model
# check attention grad
org_grad
=
opt_model
.
decoder
.
layers
[
0
].
self_attn
.
q_proj
.
weight
.
grad
org_grad
=
opt_model
.
decoder
.
layers
[
0
].
self_attn
.
q_proj
.
weight
.
grad
shard_grad
=
shard_opt_model
.
decoder
.
layers
[
0
].
self_attn
.
q_proj
.
weight
.
grad
shard_grad
=
shard_opt_model
.
decoder
.
layers
[
0
].
self_attn
.
q_proj
.
weight
.
grad
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
4
)]
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
all_shard_grad
}
"
# check embedding grad
org_grad
=
opt_model
.
decoder
.
embed_tokens
.
weight
.
grad
shard_grad
=
shard_opt_model
.
decoder
.
embed_tokens
.
weight
.
grad
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
4
)]
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
4
)]
shard_grad
=
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
assert
torch
.
allclose
(
org_loss
,
shard_loss
,
atol
=
1e-5
),
f
"shard model loss is not equal to orgin model loss
\n
{
org_loss
}
\n
{
shard_loss
}
"
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
shard_grad
}
"
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
all_
shard_grad
}
"
def
check_OPTModel
(
rank
,
world_size
,
port
):
def
check_OPTModel
(
rank
,
world_size
,
port
):
...
@@ -65,3 +75,7 @@ def check_OPTModel(rank, world_size, port):
...
@@ -65,3 +75,7 @@ def check_OPTModel(rank, world_size, port):
@
clear_cache_before_run
()
@
clear_cache_before_run
()
def
test_OPTModel
():
def
test_OPTModel
():
spawn
(
check_OPTModel
,
4
)
spawn
(
check_OPTModel
,
4
)
if
__name__
==
'__main__'
:
test_OPTModel
()
tests/test_shardformer/test_model/test_shard_t5.py
View file @
ae035d30
...
@@ -21,19 +21,43 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
...
@@ -21,19 +21,43 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
org_loss
.
backward
()
org_loss
.
backward
()
shard_loss
.
backward
()
shard_loss
.
backward
()
# check grad equality
assert
torch
.
allclose
(
org_loss
,
shard_loss
,
atol
=
1e-5
),
f
"shard model loss is not equal to orgin model loss
\n
{
org_loss
}
\n
{
shard_loss
}
"
# check attention grad
org_grad
=
org_model
.
encoder
.
block
[
0
].
layer
[
0
].
SelfAttention
.
q
.
weight
.
grad
org_grad
=
org_model
.
encoder
.
block
[
0
].
layer
[
0
].
SelfAttention
.
q
.
weight
.
grad
shard_grad
=
sharded_model
.
encoder
.
block
[
0
].
layer
[
0
].
SelfAttention
.
q
.
weight
.
grad
shard_grad
=
sharded_model
.
encoder
.
block
[
0
].
layer
[
0
].
SelfAttention
.
q
.
weight
.
grad
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
2
)]
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
2
)]
shard_grad
=
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
shard_grad
=
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
assert
torch
.
allclose
(
org_loss
,
shard_loss
,
atol
=
1e-5
),
f
"shard model loss is not equal to orgin model loss
\n
{
org_loss
}
\n
{
shard_loss
}
"
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
shard_grad
}
"
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
shard_grad
}
"
# check self attention embed
org_grad
=
org_model
.
encoder
.
block
[
0
].
layer
[
0
].
SelfAttention
.
relative_attention_bias
.
weight
.
grad
shard_grad
=
sharded_model
.
encoder
.
block
[
0
].
layer
[
0
].
SelfAttention
.
relative_attention_bias
.
weight
.
grad
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
2
)]
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
1
)
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
all_shard_grad
}
"
# check token embedding grad
org_grad
=
org_model
.
shared
.
weight
.
grad
# check weights are tied
if
hasattr
(
org_model
,
'lm_head'
):
assert
org_model
.
shared
.
weight
.
data
.
data_ptr
()
==
org_model
.
lm_head
.
weight
.
data
.
data_ptr
()
assert
sharded_model
.
shared
.
weight
.
data
.
data_ptr
()
==
sharded_model
.
lm_head
.
weight
.
data
.
data_ptr
()
shard_grad
=
sharded_model
.
shared
.
weight
.
grad
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
2
)]
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
all_shard_grad
}
"
def
check_t5
(
rank
,
world_size
,
port
):
def
check_t5
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
disable_existing_loggers
()
...
@@ -44,7 +68,6 @@ def check_t5(rank, world_size, port):
...
@@ -44,7 +68,6 @@ def check_t5(rank, world_size, port):
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
org_model
,
sharded_model
=
build_model
(
model_fn
)
org_model
,
sharded_model
=
build_model
(
model_fn
)
check_forward_backward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
)
check_forward_backward
(
org_model
,
sharded_model
,
data_gen_fn
,
output_transform_fn
,
loss_fn
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
tests/test_shardformer/test_model/test_shard_vit.py
View file @
ae035d30
...
@@ -45,6 +45,7 @@ def check_vit(rank, world_size, port):
...
@@ -45,6 +45,7 @@ def check_vit(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
skip
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
@
clear_cache_before_run
()
def
test_vit
():
def
test_vit
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment