Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cb34cd38
Commit
cb34cd38
authored
Mar 10, 2022
by
Jiarui Fang
Committed by
Frank Lee
Mar 11, 2022
Browse files
[test] polish zero related unitest (#351)
parent
534e0bb1
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
75 additions
and
123 deletions
+75
-123
colossalai/zero/sharded_model/utils.py
colossalai/zero/sharded_model/utils.py
+19
-0
tests/test_zero_data_parallel/common.py
tests/test_zero_data_parallel/common.py
+17
-0
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+38
-49
tests/test_zero_data_parallel/test_sharded_model_with_ctx.py
tests/test_zero_data_parallel/test_sharded_model_with_ctx.py
+0
-73
tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
...zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
+1
-1
No files found.
colossalai/zero/sharded_model/utils.py
0 → 100644
View file @
cb34cd38
import
torch
from
colossalai.zero.sharded_model
import
ShardedModelV2
import
copy
def
col_model_deepcopy
(
sharded_model
:
ShardedModelV2
,
other_model
:
torch
.
nn
.
Module
):
"""
copy param of the ShardedModelV2 to other_model.
Note the other_model has to be the same as self.
"""
for
zero_param
,
param
in
zip
(
sharded_model
.
parameters
(),
other_model
.
parameters
()):
assert
hasattr
(
zero_param
,
'col_attr'
)
shard_flag
=
zero_param
.
col_attr
.
data
.
is_sharded
if
shard_flag
:
sharded_model
.
shard_strategy
.
gather
([
zero_param
.
col_attr
.
data
])
param
.
data
=
copy
.
deepcopy
(
zero_param
.
col_attr
.
data
.
payload
)
if
shard_flag
:
sharded_model
.
shard_strategy
.
shard
([
zero_param
.
col_attr
.
data
])
tests/test_zero_data_parallel/common.py
View file @
cb34cd38
...
@@ -3,8 +3,10 @@ from functools import partial
...
@@ -3,8 +3,10 @@ from functools import partial
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
checkpoint
from
colossalai.utils
import
checkpoint
from
colossalai.zero.sharded_model
import
ShardedModelV2
LOGGER
=
get_dist_logger
()
LOGGER
=
get_dist_logger
()
...
@@ -20,6 +22,21 @@ CONFIG = dict(fp16=dict(mode=None,),
...
@@ -20,6 +22,21 @@ CONFIG = dict(fp16=dict(mode=None,),
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
1
,
mode
=
None
)))
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
1
,
mode
=
None
)))
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
if
criterion
:
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
else
:
loss
=
model
(
data
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
ShardedModelV2
):
model
.
backward
(
loss
)
else
:
loss
.
backward
()
def
checkpoint_wrapper
(
module
,
enable
=
True
):
def
checkpoint_wrapper
(
module
,
enable
=
True
):
if
enable
:
if
enable
:
module
.
forward
=
partial
(
checkpoint
,
module
.
forward
)
module
.
forward
=
partial
(
checkpoint
,
module
.
forward
)
...
...
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
cb34cd38
...
@@ -3,81 +3,70 @@
...
@@ -3,81 +3,70 @@
import
copy
import
copy
from
functools
import
partial
from
functools
import
partial
import
colossalai
import
pytest
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
colossalai
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
\
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
\
TensorShardStrategy
TensorShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._zero3_utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model._zero3_utils
import
cast_tensor_to_fp16
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
common
import
CONFIG
,
check_grads_padding
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
=
False
):
from
tests.components_to_test.registry
import
non_distributed_component_funcs
model
.
train
()
from
common
import
CONFIG
,
check_grads_padding
,
run_fwd_bwd
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
ShardedModelV2
):
model
.
backward
(
loss
)
else
:
loss
.
backward
()
# with no criterion
def
run_fwd_bwd_no_criterion
(
model
,
data
,
label
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
loss
=
model
(
data
,
label
)
if
isinstance
(
model
,
ShardedModelV2
):
model
.
backward
(
loss
)
else
:
loss
.
backward
()
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
,
use_zero_init_ctx
,
enable_autocast
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
shard_strategy
=
TensorShardStrategy
()
shard_strategy
=
TensorShardStrategy
()
for
model_name
in
test_models
:
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
model
=
model
(
checkpoint
=
True
).
half
().
cuda
()
if
use_zero_init_ctx
:
with
ZeroInitContext
(
convert_fp16
=
True
,
convert_cuda
=
True
,
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
zero_model
=
model_builder
(
checkpoint
=
True
)
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
)
model
=
model_builder
(
checkpoint
=
True
).
half
()
col_model_deepcopy
(
zero_model
,
model
)
model
=
model
.
cuda
()
else
:
model
=
model_builder
(
checkpoint
=
True
).
half
().
cuda
()
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
shard_strategy
)
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
shard_strategy
)
if
dist
.
get_world_size
()
>
1
:
model
=
DDP
(
model
)
model
=
DDP
(
model
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
if
i
>
3
:
break
break
if
criterion
is
None
:
data
,
label
=
data
.
cuda
(),
label
.
cuda
()
run_fwd_bwd_no_criterion
(
model
,
data
,
label
,
False
)
run_fwd_bwd_no_criterion
(
zero_model
,
data
,
label
,
False
)
else
:
data
,
label
=
cast_tensor_to_fp16
(
data
).
cuda
(),
label
.
cuda
()
data
,
label
=
cast_tensor_to_fp16
(
data
).
cuda
(),
label
.
cuda
()
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
False
)
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
)
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
False
)
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
enable_autocast
)
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
def
test_shard_model_v2
(
world_size
):
@
pytest
.
mark
.
parametrize
(
"enable_autocast"
,
[
True
])
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
@
pytest
.
mark
.
parametrize
(
"use_zero_init_ctx"
,
[
True
])
def
test_shard_model_v2
(
world_size
,
use_zero_init_ctx
,
enable_autocast
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
use_zero_init_ctx
=
use_zero_init_ctx
,
enable_autocast
=
enable_autocast
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_shard_model_v2
(
world_size
=
2
)
test_shard_model_v2
(
world_size
=
2
,
use_zero_init_ctx
=
True
,
enable_autocast
=
True
)
tests/test_zero_data_parallel/test_sharded_model_with_ctx.py
deleted
100644 → 0
View file @
534e0bb1
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
copy
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
\
TensorShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
common
import
CONFIG
,
check_grads
,
check_grads_padding
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
ShardedModelV2
):
model
.
backward
(
loss
)
else
:
loss
.
backward
()
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
]
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
shard_strategy
=
TensorShardStrategy
()
with
ZeroInitContext
(
convert_fp16
=
True
,
convert_cuda
=
True
,
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
zero_model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
zero_model
=
zero_model
()
model
=
copy
.
deepcopy
(
zero_model
)
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
)
model_state_dict
=
zero_model
.
state_dict
()
for
n
,
p
in
model
.
named_parameters
():
p
.
data
=
model_state_dict
[
n
]
model
=
model
.
half
().
cuda
()
if
dist
.
get_world_size
()
>
1
:
model
=
DDP
(
model
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
break
data
,
label
=
data
.
half
().
cuda
(),
label
.
cuda
()
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
False
)
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
False
)
if
dist
.
get_world_size
()
>
1
:
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
else
:
check_grads
(
model
,
zero_model
,
loose
=
True
)
@
pytest
.
mark
.
dist
def
test_shard_model_v2
():
world_size
=
2
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_shard_model_v2
()
tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
View file @
cb34cd38
...
@@ -78,7 +78,7 @@ def run_dist(rank, world_size, port):
...
@@ -78,7 +78,7 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
def
test_sharded_optim_v2
(
world_size
):
def
test_sharded_optim_v2
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment