Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cb34cd38
Commit
cb34cd38
authored
Mar 10, 2022
by
Jiarui Fang
Committed by
Frank Lee
Mar 11, 2022
Browse files
[test] polish zero related unitest (#351)
parent
534e0bb1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
75 additions
and
123 deletions
+75
-123
colossalai/zero/sharded_model/utils.py
colossalai/zero/sharded_model/utils.py
+19
-0
tests/test_zero_data_parallel/common.py
tests/test_zero_data_parallel/common.py
+17
-0
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+38
-49
tests/test_zero_data_parallel/test_sharded_model_with_ctx.py
tests/test_zero_data_parallel/test_sharded_model_with_ctx.py
+0
-73
tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
...zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
+1
-1
No files found.
colossalai/zero/sharded_model/utils.py
0 → 100644
View file @
cb34cd38
import
torch
from
colossalai.zero.sharded_model
import
ShardedModelV2
import
copy
def
col_model_deepcopy
(
sharded_model
:
ShardedModelV2
,
other_model
:
torch
.
nn
.
Module
):
"""
copy param of the ShardedModelV2 to other_model.
Note the other_model has to be the same as self.
"""
for
zero_param
,
param
in
zip
(
sharded_model
.
parameters
(),
other_model
.
parameters
()):
assert
hasattr
(
zero_param
,
'col_attr'
)
shard_flag
=
zero_param
.
col_attr
.
data
.
is_sharded
if
shard_flag
:
sharded_model
.
shard_strategy
.
gather
([
zero_param
.
col_attr
.
data
])
param
.
data
=
copy
.
deepcopy
(
zero_param
.
col_attr
.
data
.
payload
)
if
shard_flag
:
sharded_model
.
shard_strategy
.
shard
([
zero_param
.
col_attr
.
data
])
tests/test_zero_data_parallel/common.py
View file @
cb34cd38
...
...
@@ -3,8 +3,10 @@ from functools import partial
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
checkpoint
from
colossalai.zero.sharded_model
import
ShardedModelV2
LOGGER
=
get_dist_logger
()
...
...
@@ -20,6 +22,21 @@ CONFIG = dict(fp16=dict(mode=None,),
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
1
,
mode
=
None
)))
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
if
criterion
:
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
else
:
loss
=
model
(
data
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
ShardedModelV2
):
model
.
backward
(
loss
)
else
:
loss
.
backward
()
def
checkpoint_wrapper
(
module
,
enable
=
True
):
if
enable
:
module
.
forward
=
partial
(
checkpoint
,
module
.
forward
)
...
...
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
cb34cd38
...
...
@@ -3,81 +3,70 @@
import
copy
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
colossalai
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
\
TensorShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._zero3_utils
import
cast_tensor_to_fp16
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
common
import
CONFIG
,
check_grads_padding
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
ShardedModelV2
):
model
.
backward
(
loss
)
else
:
loss
.
backward
()
# with no criterion
def
run_fwd_bwd_no_criterion
(
model
,
data
,
label
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
loss
=
model
(
data
,
label
)
if
isinstance
(
model
,
ShardedModelV2
):
model
.
backward
(
loss
)
else
:
loss
.
backward
()
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
CONFIG
,
check_grads_padding
,
run_fwd_bwd
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
,
use_zero_init_ctx
,
enable_autocast
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
shard_strategy
=
TensorShardStrategy
()
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
=
model
(
checkpoint
=
True
).
half
().
cuda
()
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
shard_strategy
)
if
dist
.
get_world_size
()
>
1
:
model
=
DDP
(
model
)
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
if
use_zero_init_ctx
:
with
ZeroInitContext
(
convert_fp16
=
True
,
convert_cuda
=
True
,
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
zero_model
=
model_builder
(
checkpoint
=
True
)
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
)
model
=
model_builder
(
checkpoint
=
True
).
half
()
col_model_deepcopy
(
zero_model
,
model
)
model
=
model
.
cuda
()
else
:
model
=
model_builder
(
checkpoint
=
True
).
half
().
cuda
()
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
shard_strategy
)
model
=
DDP
(
model
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
if
i
>
3
:
break
if
criterion
is
None
:
data
,
label
=
data
.
cuda
(),
label
.
cuda
()
run_fwd_bwd_no_criterion
(
model
,
data
,
label
,
False
)
run_fwd_bwd_no_criterion
(
zero_model
,
data
,
label
,
False
)
else
:
data
,
label
=
cast_tensor_to_fp16
(
data
).
cuda
(),
label
.
cuda
()
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
False
)
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
False
)
data
,
label
=
cast_tensor_to_fp16
(
data
).
cuda
(),
label
.
cuda
()
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
)
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
enable_autocast
)
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
,
4
])
def
test_shard_model_v2
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"enable_autocast"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"use_zero_init_ctx"
,
[
True
])
def
test_shard_model_v2
(
world_size
,
use_zero_init_ctx
,
enable_autocast
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
use_zero_init_ctx
=
use_zero_init_ctx
,
enable_autocast
=
enable_autocast
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_shard_model_v2
(
world_size
=
2
)
test_shard_model_v2
(
world_size
=
2
,
use_zero_init_ctx
=
True
,
enable_autocast
=
True
)
tests/test_zero_data_parallel/test_sharded_model_with_ctx.py
deleted
100644 → 0
View file @
534e0bb1
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
copy
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
\
TensorShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
common
import
CONFIG
,
check_grads
,
check_grads_padding
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
=
False
):
model
.
train
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
loss
=
loss
.
float
()
if
isinstance
(
model
,
ShardedModelV2
):
model
.
backward
(
loss
)
else
:
loss
.
backward
()
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
]
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
shard_strategy
=
TensorShardStrategy
()
with
ZeroInitContext
(
convert_fp16
=
True
,
convert_cuda
=
True
,
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
zero_model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
zero_model
=
zero_model
()
model
=
copy
.
deepcopy
(
zero_model
)
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
)
model_state_dict
=
zero_model
.
state_dict
()
for
n
,
p
in
model
.
named_parameters
():
p
.
data
=
model_state_dict
[
n
]
model
=
model
.
half
().
cuda
()
if
dist
.
get_world_size
()
>
1
:
model
=
DDP
(
model
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
break
data
,
label
=
data
.
half
().
cuda
(),
label
.
cuda
()
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
False
)
run_fwd_bwd
(
zero_model
,
data
,
label
,
criterion
,
False
)
if
dist
.
get_world_size
()
>
1
:
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
else
:
check_grads
(
model
,
zero_model
,
loose
=
True
)
@
pytest
.
mark
.
dist
def
test_shard_model_v2
():
world_size
=
2
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_shard_model_v2
()
tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
View file @
cb34cd38
...
...
@@ -78,7 +78,7 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
def
test_sharded_optim_v2
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment