Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e04436a8
Unverified
Commit
e04436a8
authored
Aug 23, 2023
by
Jianghai
Committed by
GitHub
Aug 23, 2023
Browse files
[shardformer] tests for 3d parallel (#4493)
parent
59e252ec
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
324 additions
and
5 deletions
+324
-5
tests/test_shardformer/test_model/_utils.py
tests/test_shardformer/test_model/_utils.py
+0
-1
tests/test_shardformer/test_model/test_shard_bert.py
tests/test_shardformer/test_model/test_shard_bert.py
+36
-0
tests/test_shardformer/test_model/test_shard_bloom.py
tests/test_shardformer/test_model/test_shard_bloom.py
+38
-0
tests/test_shardformer/test_model/test_shard_chatglm2.py
tests/test_shardformer/test_model/test_shard_chatglm2.py
+35
-0
tests/test_shardformer/test_model/test_shard_gpt2.py
tests/test_shardformer/test_model/test_shard_gpt2.py
+36
-0
tests/test_shardformer/test_model/test_shard_llama.py
tests/test_shardformer/test_model/test_shard_llama.py
+36
-1
tests/test_shardformer/test_model/test_shard_opt.py
tests/test_shardformer/test_model/test_shard_opt.py
+35
-0
tests/test_shardformer/test_model/test_shard_t5.py
tests/test_shardformer/test_model/test_shard_t5.py
+35
-0
tests/test_shardformer/test_model/test_shard_vit.py
tests/test_shardformer/test_model/test_shard_vit.py
+35
-0
tests/test_shardformer/test_model/test_shard_whisper.py
tests/test_shardformer/test_model/test_shard_whisper.py
+38
-3
No files found.
tests/test_shardformer/test_model/_utils.py
View file @
e04436a8
...
@@ -245,7 +245,6 @@ def check_grad(org_model: Module,
...
@@ -245,7 +245,6 @@ def check_grad(org_model: Module,
org_grad
=
getattr_
(
org_model
,
suffix
).
weight
.
grad
org_grad
=
getattr_
(
org_model
,
suffix
).
weight
.
grad
shard_grad
=
getattr_
(
sharded_model
,
suffix
).
weight
.
grad
shard_grad
=
getattr_
(
sharded_model
,
suffix
).
weight
.
grad
shard_weight
=
getattr_
(
sharded_model
,
suffix
).
weight
shard_weight
=
getattr_
(
sharded_model
,
suffix
).
weight
if
is_distributed_tensor
(
shard_weight
)
or
is_customized_distributed_tensor
(
shard_weight
):
if
is_distributed_tensor
(
shard_weight
)
or
is_customized_distributed_tensor
(
shard_weight
):
shard_grad_list
=
[
torch
.
zeros_like
(
shard_grad
).
to
(
'cuda'
)
for
_
in
range
(
dist
.
get_world_size
(
tp_group
))]
shard_grad_list
=
[
torch
.
zeros_like
(
shard_grad
).
to
(
'cuda'
)
for
_
in
range
(
dist
.
get_world_size
(
tp_group
))]
dist
.
all_gather
(
shard_grad_list
,
shard_grad
,
tp_group
)
dist
.
all_gather
(
shard_grad_list
,
shard_grad
,
tp_group
)
...
...
tests/test_shardformer/test_model/test_shard_bert.py
View file @
e04436a8
...
@@ -120,12 +120,40 @@ def run_bert_test(test_config):
...
@@ -120,12 +120,40 @@ def run_bert_test(test_config):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
@
parameterize
(
'test_config'
,
[
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
])
def
run_bert_3d_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_bert'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
Randomizer
.
reset_index
()
torch
.
cuda
.
empty_cache
()
def
check_bert
(
rank
,
world_size
,
port
):
def
check_bert
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_bert_test
()
run_bert_test
()
def
check_bert_3d
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_bert_3d_test
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
@
clear_cache_before_run
()
...
@@ -133,5 +161,13 @@ def test_bert():
...
@@ -133,5 +161,13 @@ def test_bert():
spawn
(
check_bert
,
4
)
spawn
(
check_bert
,
4
)
@
pytest
.
mark
.
largedist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_bert_3d
():
spawn
(
check_bert_3d
,
8
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_bert
()
test_bert
()
test_bert_3d
()
tests/test_shardformer/test_model/test_shard_bloom.py
View file @
e04436a8
...
@@ -3,6 +3,7 @@ import torch
...
@@ -3,6 +3,7 @@ import torch
import
colossalai
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer.layer.utils
import
Randomizer
from
colossalai.tensor.d_tensor.api
import
clear_layout_converter
from
colossalai.tensor.d_tensor.api
import
clear_layout_converter
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
from
tests.kit.model_zoo
import
model_zoo
...
@@ -118,6 +119,29 @@ def run_bloom_test(test_config):
...
@@ -118,6 +119,29 @@ def run_bloom_test(test_config):
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
clear_layout_converter
()
Randomizer
.
reset_index
()
torch
.
cuda
.
empty_cache
()
@
parameterize
(
'test_config'
,
[
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
])
def
run_bloom_3d_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_bloom'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
Randomizer
.
reset_index
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -127,6 +151,12 @@ def check_bloom(rank, world_size, port):
...
@@ -127,6 +151,12 @@ def check_bloom(rank, world_size, port):
run_bloom_test
()
run_bloom_test
()
def
check_bloom_3d
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_bloom_3d_test
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
@
clear_cache_before_run
()
...
@@ -134,5 +164,13 @@ def test_bloom():
...
@@ -134,5 +164,13 @@ def test_bloom():
spawn
(
check_bloom
,
4
)
spawn
(
check_bloom
,
4
)
@
pytest
.
mark
.
largedist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_bloom_3d
():
spawn
(
check_bloom_3d
,
8
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_bloom
()
test_bloom
()
test_bloom_3d
()
tests/test_shardformer/test_model/test_shard_chatglm2.py
View file @
e04436a8
...
@@ -145,12 +145,39 @@ def run_chatglm_test(test_config):
...
@@ -145,12 +145,39 @@ def run_chatglm_test(test_config):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
@
parameterize
(
'test_config'
,
[
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
])
def
run_chatglm_3d_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_chatglm'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
torch
.
cuda
.
empty_cache
()
def
check_chatglm
(
rank
,
world_size
,
port
):
def
check_chatglm
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_chatglm_test
()
run_chatglm_test
()
def
check_chatglm_3d
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_chatglm_3d_test
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
@
clear_cache_before_run
()
...
@@ -158,5 +185,13 @@ def test_chatglm():
...
@@ -158,5 +185,13 @@ def test_chatglm():
spawn
(
check_chatglm
,
4
)
spawn
(
check_chatglm
,
4
)
@
pytest
.
mark
.
largedist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_chatglm_3d
():
spawn
(
check_chatglm_3d
,
8
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_chatglm
()
test_chatglm
()
test_chatglm_3d
()
tests/test_shardformer/test_model/test_shard_gpt2.py
View file @
e04436a8
...
@@ -141,12 +141,40 @@ def run_gpt2_test(test_config):
...
@@ -141,12 +141,40 @@ def run_gpt2_test(test_config):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
@
parameterize
(
'test_config'
,
[
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
])
@
clear_cache_before_run
()
def
run_gpt2_3d_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_gpt'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
torch
.
cuda
.
empty_cache
()
def
check_gpt2
(
rank
,
world_size
,
port
):
def
check_gpt2
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_gpt2_test
()
run_gpt2_test
()
def
check_gpt2_3d
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_gpt2_3d_test
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
@
clear_cache_before_run
()
...
@@ -154,5 +182,13 @@ def test_gpt2():
...
@@ -154,5 +182,13 @@ def test_gpt2():
spawn
(
check_gpt2
,
4
)
spawn
(
check_gpt2
,
4
)
@
pytest
.
mark
.
largedist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_gpt2_3d
():
spawn
(
check_gpt2_3d
,
8
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_gpt2
()
test_gpt2
()
test_gpt2_3d
()
tests/test_shardformer/test_model/test_shard_llama.py
View file @
e04436a8
...
@@ -56,7 +56,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
...
@@ -56,7 +56,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# unwrap model
# unwrap model
llama_model
=
unwrap_model
(
org_model
,
'LlamaModel'
,
'model'
)
llama_model
=
unwrap_model
(
org_model
,
'LlamaModel'
,
'model'
)
shard_llama_model
=
unwrap_model
(
sharded_model
,
'LlamaModel'
,
'model'
)
shard_llama_model
=
unwrap_model
(
sharded_model
,
'LlamaModel'
,
'model'
)
# check grad
# check grad
row_layer_for_check
=
[
'layers[0].self_attn.q_proj'
,
'embed_tokens'
]
row_layer_for_check
=
[
'layers[0].self_attn.q_proj'
,
'embed_tokens'
]
col_layer_for_check
=
[
'layers[0].self_attn.o_proj'
]
col_layer_for_check
=
[
'layers[0].self_attn.o_proj'
]
...
@@ -156,12 +155,40 @@ def run_llama_test(test_config):
...
@@ -156,12 +155,40 @@ def run_llama_test(test_config):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
@
parameterize
(
'test_config'
,
[
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
])
def
run_llama_3d_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_llama'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
Randomizer
.
reset_index
()
torch
.
cuda
.
empty_cache
()
def
check_llama
(
rank
,
world_size
,
port
):
def
check_llama
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_llama_test
()
run_llama_test
()
def
check_llama_3d
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_llama_3d_test
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
@
clear_cache_before_run
()
...
@@ -169,5 +196,13 @@ def test_llama():
...
@@ -169,5 +196,13 @@ def test_llama():
spawn
(
check_llama
,
4
)
spawn
(
check_llama
,
4
)
@
pytest
.
mark
.
largedist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_llama_3d
():
spawn
(
check_llama_3d
,
8
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_llama
()
test_llama
()
test_llama_3d
()
tests/test_shardformer/test_model/test_shard_opt.py
View file @
e04436a8
...
@@ -146,12 +146,39 @@ def run_opt_test(test_config):
...
@@ -146,12 +146,39 @@ def run_opt_test(test_config):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
@
parameterize
(
'test_config'
,
[
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
])
def
run_opt_3d_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_opt'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
torch
.
cuda
.
empty_cache
()
def
check_OPTModel
(
rank
,
world_size
,
port
):
def
check_OPTModel
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_opt_test
()
run_opt_test
()
def
check_opt_3d
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_opt_3d_test
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
@
clear_cache_before_run
()
...
@@ -159,5 +186,13 @@ def test_OPTModel():
...
@@ -159,5 +186,13 @@ def test_OPTModel():
spawn
(
check_OPTModel
,
4
)
spawn
(
check_OPTModel
,
4
)
@
pytest
.
mark
.
largedist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_opt_3d
():
spawn
(
check_opt_3d
,
8
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_OPTModel
()
test_OPTModel
()
test_opt_3d
()
tests/test_shardformer/test_model/test_shard_t5.py
View file @
e04436a8
...
@@ -137,12 +137,39 @@ def run_t5_test(test_config):
...
@@ -137,12 +137,39 @@ def run_t5_test(test_config):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
@
parameterize
(
'test_config'
,
[
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
])
def
run_t5_3d_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_t5'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
torch
.
cuda
.
empty_cache
()
def
check_t5
(
rank
,
world_size
,
port
):
def
check_t5
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_t5_test
()
run_t5_test
()
def
check_t5_3d
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_t5_3d_test
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
@
clear_cache_before_run
()
...
@@ -150,5 +177,13 @@ def test_t5():
...
@@ -150,5 +177,13 @@ def test_t5():
spawn
(
check_t5
,
4
)
spawn
(
check_t5
,
4
)
@
pytest
.
mark
.
largedist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_t5_3d
():
spawn
(
check_t5_3d
,
8
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_t5
()
test_t5
()
test_t5_3d
()
tests/test_shardformer/test_model/test_shard_vit.py
View file @
e04436a8
...
@@ -146,12 +146,39 @@ def run_vit_test(test_config):
...
@@ -146,12 +146,39 @@ def run_vit_test(test_config):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
@
parameterize
(
'test_config'
,
[
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
])
def
run_vit_3d_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_vit'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
torch
.
cuda
.
empty_cache
()
def
check_vit
(
rank
,
world_size
,
port
):
def
check_vit
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_vit_test
()
run_vit_test
()
def
check_vit_3d
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_vit_3d_test
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
@
clear_cache_before_run
()
...
@@ -159,5 +186,13 @@ def test_vit():
...
@@ -159,5 +186,13 @@ def test_vit():
spawn
(
check_vit
,
4
)
spawn
(
check_vit
,
4
)
@
pytest
.
mark
.
largedist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_vit_3d
():
spawn
(
check_vit_3d
,
8
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_vit
()
test_vit
()
test_vit_3d
()
tests/test_shardformer/test_model/test_shard_whisper.py
View file @
e04436a8
...
@@ -82,8 +82,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
...
@@ -82,8 +82,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol
,
rtol
=
5e-3
,
5e-3
atol
,
rtol
=
5e-3
,
5e-3
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
if
stage_manager
is
None
or
stage_manager
.
is_first_stage
():
check_grad
(
whisper
,
sharded_whisper
,
row_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
0
)
check_grad
(
whisper
,
sharded_whisper
,
row_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
1
)
check_grad
(
whisper
,
sharded_whisper
,
col_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
1
)
check_grad
(
whisper
,
sharded_whisper
,
col_layer_for_check
,
tp_group
,
atol
=
atol
,
rtol
=
rtol
,
dim
=
0
)
# check weights after optimizer.step()
# check weights after optimizer.step()
org_optimizer
.
step
()
org_optimizer
.
step
()
...
@@ -99,7 +99,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
...
@@ -99,7 +99,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
tp_group
,
tp_group
,
atol
=
atol
,
atol
=
atol
,
rtol
=
rtol
,
rtol
=
rtol
,
dim
=
0
,
dim
=
1
,
verbose
=
False
)
verbose
=
False
)
check_weight
(
whisper
,
check_weight
(
whisper
,
sharded_whisper
,
sharded_whisper
,
...
@@ -155,12 +155,39 @@ def run_whisper_test(test_config):
...
@@ -155,12 +155,39 @@ def run_whisper_test(test_config):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
@
parameterize
(
'test_config'
,
[
{
'tp_size'
:
2
,
'pp_size'
:
2
,
'num_microbatches'
:
4
,
'enable_all_optimization'
:
False
,
'use_lazy_init'
:
False
,
'precision'
:
'fp32'
,
'initial_scale'
:
1
,
},
])
def
run_whisper_3d_test
(
test_config
):
sub_model_zoo
=
model_zoo
.
get_sub_registry
(
'transformers_whisper'
)
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
in
sub_model_zoo
.
items
():
check_forward_backward
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
test_config
)
clear_layout_converter
()
torch
.
cuda
.
empty_cache
()
def
check_whisper
(
rank
,
world_size
,
port
):
def
check_whisper
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_whisper_test
()
run_whisper_test
()
def
check_whisper_3d
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_whisper_3d_test
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
@
clear_cache_before_run
()
...
@@ -168,5 +195,13 @@ def test_whisper():
...
@@ -168,5 +195,13 @@ def test_whisper():
spawn
(
check_whisper
,
4
)
spawn
(
check_whisper
,
4
)
@
pytest
.
mark
.
largedist
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
def
test_whisper_3d
():
spawn
(
check_whisper_3d
,
8
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_whisper
()
test_whisper
()
test_whisper_3d
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment