Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
54fd37f0
Commit
54fd37f0
authored
Mar 14, 2022
by
ver217
Browse files
polish unit test
parent
88804aee
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
62 additions
and
51 deletions
+62
-51
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+14
-10
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+14
-14
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+8
-7
tests/test_zero_data_parallel/test_sharded_optim_v2.py
tests/test_zero_data_parallel/test_sharded_optim_v2.py
+11
-6
tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
...zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
+7
-6
tests/test_zero_data_parallel/test_state_dict.py
tests/test_zero_data_parallel/test_state_dict.py
+8
-8
No files found.
tests/test_zero_data_parallel/test_init_context.py
View file @
54fd37f0
...
@@ -4,21 +4,20 @@
...
@@ -4,21 +4,20 @@
from
functools
import
partial
from
functools
import
partial
import
colossalai
import
colossalai
from
colossalai.utils.cuda
import
get_current_device
import
pytest
import
pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.memory_tracer.allocator
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
\
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
TensorShardStrategy
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
CONFIG
from
common
import
CONFIG
from
colossalai.utils.memory_tracer.allocator
import
GLOBAL_MODEL_DATA_TRACER
def
run_dist
(
rank
,
world_size
,
port
,
init_device
):
def
run_dist
(
rank
,
world_size
,
port
,
init_device
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
for
get_components_func
in
non_distributed_component_funcs
:
for
get_components_func
in
non_distributed_component_funcs
:
...
@@ -26,7 +25,7 @@ def run_dist(rank, world_size, port, init_device):
...
@@ -26,7 +25,7 @@ def run_dist(rank, world_size, port, init_device):
model_numel_tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
)
model_numel_tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
)
with
ZeroInitContext
(
convert_fp16
=
True
,
with
ZeroInitContext
(
convert_fp16
=
True
,
target_device
=
init_device
,
target_device
=
init_device
,
shard_strategy
=
TensorS
hard
S
trategy
(),
shard_strategy
=
s
hard
_s
trategy
(),
shard_param
=
True
,
shard_param
=
True
,
model_numel_tensor
=
model_numel_tensor
):
model_numel_tensor
=
model_numel_tensor
):
model
=
model_builder
(
checkpoint
=
True
)
model
=
model_builder
(
checkpoint
=
True
)
...
@@ -50,11 +49,16 @@ def run_dist(rank, world_size, port, init_device):
...
@@ -50,11 +49,16 @@ def run_dist(rank, world_size, port, init_device):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"init_device"
,
[
torch
.
device
(
'cpu'
),
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)])
@
pytest
.
mark
.
parametrize
(
"init_device"
,
[
torch
.
device
(
'cpu'
),
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)])
def
test_zero_init_context
(
world_size
,
init_device
):
@
pytest
.
mark
.
parametrize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
init_device
=
init_device
)
def
test_zero_init_context
(
world_size
,
init_device
,
shard_strategy
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
init_device
=
init_device
,
shard_strategy
=
shard_strategy
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_zero_init_context
(
2
,
torch
.
device
(
'cpu'
))
test_zero_init_context
(
2
,
torch
.
device
(
'cpu'
)
,
TensorShardStrategy
)
test_zero_init_context
(
2
,
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
))
test_zero_init_context
(
2
,
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)
,
TensorShardStrategy
)
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
54fd37f0
...
@@ -3,30 +3,28 @@
...
@@ -3,30 +3,28 @@
import
copy
import
copy
from
functools
import
partial
from
functools
import
partial
import
pytest
import
colossalai
import
pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
colossalai
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.zero.
shard_utils.tensor_shard_strategy
import
\
from
colossalai.zero.
init_ctx
import
ZeroInitContext
TensorShardStrategy
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._zero3_utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model._zero3_utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
common
import
CONFIG
,
check_grads_padding
,
run_fwd_bwd
from
common
import
CONFIG
,
check_grads_padding
,
run_fwd_bwd
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
def
run_dist
(
rank
,
world_size
,
port
,
use_zero_init_ctx
,
enable_autocast
):
def
run_dist
(
rank
,
world_size
,
port
,
use_zero_init_ctx
,
enable_autocast
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
shard_strategy
=
TensorS
hard
S
trategy
()
shard_strategy
=
s
hard
_s
trategy
()
for
model_name
in
test_models
:
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
...
@@ -66,14 +64,16 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
...
@@ -66,14 +64,16 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"enable_autocast"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"enable_autocast"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"use_zero_init_ctx"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"use_zero_init_ctx"
,
[
True
])
def
test_shard_model_v2
(
world_size
,
use_zero_init_ctx
,
enable_autocast
):
@
pytest
.
mark
.
parametrize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
test_shard_model_v2
(
world_size
,
use_zero_init_ctx
,
enable_autocast
,
shard_strategy
):
run_func
=
partial
(
run_dist
,
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
world_size
=
world_size
,
port
=
free_port
(),
port
=
free_port
(),
use_zero_init_ctx
=
use_zero_init_ctx
,
use_zero_init_ctx
=
use_zero_init_ctx
,
enable_autocast
=
enable_autocast
)
enable_autocast
=
enable_autocast
,
shard_strategy
=
shard_strategy
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_shard_model_v2
(
world_size
=
2
,
use_zero_init_ctx
=
True
,
enable_autocast
=
True
)
test_shard_model_v2
(
world_size
=
2
,
use_zero_init_ctx
=
True
,
enable_autocast
=
True
,
shard_strategy
=
TensorShardStrategy
)
tests/test_zero_data_parallel/test_shard_param.py
View file @
54fd37f0
...
@@ -10,20 +10,20 @@ import torch
...
@@ -10,20 +10,20 @@ import torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.shard_utils
import
(
Bucket
TensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_param
import
ShardedParam
,
ShardedTensor
from
colossalai.zero.sharded_param
import
ShardedParam
,
ShardedTensor
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
from
tests.test_zero_data_parallel.common
import
CONFIG
,
allclose
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_zero_data_parallel.common
import
CONFIG
,
allclose
def
_run_shard_tensor
(
rank
,
world_size
,
port
):
def
_run_shard_tensor
(
rank
,
world_size
,
port
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
t
=
ShardedTensor
(
tensor
=
torch
.
randn
(
world_size
*
2
,
3
))
t
=
ShardedTensor
(
tensor
=
torch
.
randn
(
world_size
*
2
,
3
))
assert
list
(
t
.
origin_shape
)
==
[
world_size
*
2
,
3
]
assert
list
(
t
.
origin_shape
)
==
[
world_size
*
2
,
3
]
assert
list
(
t
.
shape
)
==
[
world_size
*
2
,
3
]
assert
list
(
t
.
shape
)
==
[
world_size
*
2
,
3
]
shard_strategy
=
TensorS
hard
S
trategy
(
process_group
=
None
)
shard_strategy
=
s
hard
_s
trategy
(
process_group
=
None
)
# test shard strategy
# test shard strategy
shard_strategy
.
shard
([
t
])
shard_strategy
.
shard
([
t
])
...
@@ -34,8 +34,9 @@ def _run_shard_tensor(rank, world_size, port):
...
@@ -34,8 +34,9 @@ def _run_shard_tensor(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
def
test_shard_tensor
(
world_size
):
@
pytest
.
mark
.
parametrize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
run_func
=
partial
(
_run_shard_tensor
,
world_size
=
world_size
,
port
=
free_port
())
def
test_shard_tensor
(
world_size
,
shard_strategy
):
run_func
=
partial
(
_run_shard_tensor
,
world_size
=
world_size
,
port
=
free_port
(),
shard_strategy
=
shard_strategy
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
@@ -121,7 +122,7 @@ def test_init_shard_param(world_size):
...
@@ -121,7 +122,7 @@ def test_init_shard_param(world_size):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_shard_tensor
(
2
)
test_shard_tensor
(
2
,
TensorShardStrategy
)
test_shard_param
(
2
)
test_shard_param
(
2
)
test_shard_param_v2
(
2
)
test_shard_param_v2
(
2
)
test_init_shard_param
(
4
)
test_init_shard_param
(
4
)
tests/test_zero_data_parallel/test_sharded_optim_v2.py
View file @
54fd37f0
...
@@ -10,7 +10,7 @@ import torch
...
@@ -10,7 +10,7 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.shard_utils
import
(
Bucket
TensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
...
@@ -38,12 +38,12 @@ def run_step(model, optimizer, data, label, criterion, enable_autocast=False):
...
@@ -38,12 +38,12 @@ def run_step(model, optimizer, data, label, criterion, enable_autocast=False):
optimizer
.
step
()
optimizer
.
step
()
def
run_dist
(
rank
,
world_size
,
port
,
cpu_offload
):
def
run_dist
(
rank
,
world_size
,
port
,
cpu_offload
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
shard_strategy
=
shard_strategy
()
for
model_name
in
test_models
:
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
shard_strategy
=
TensorShardStrategy
()
model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
=
model
(
checkpoint
=
True
).
cuda
()
model
=
model
(
checkpoint
=
True
).
cuda
()
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
...
@@ -69,10 +69,15 @@ def run_dist(rank, world_size, port, cpu_offload):
...
@@ -69,10 +69,15 @@ def run_dist(rank, world_size, port, cpu_offload):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"cpu_offload"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"cpu_offload"
,
[
True
,
False
])
def
test_sharded_optim_v2
(
world_size
,
cpu_offload
):
@
pytest
.
mark
.
parametrize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
cpu_offload
=
cpu_offload
)
def
test_sharded_optim_v2
(
world_size
,
cpu_offload
,
shard_strategy
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
cpu_offload
=
cpu_offload
,
shard_strategy
=
shard_strategy
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_sharded_optim_v2
(
world_size
=
2
,
cpu_offload
=
True
)
test_sharded_optim_v2
(
world_size
=
2
,
cpu_offload
=
True
,
shard_strategy
=
TensorShardStrategy
)
tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
View file @
54fd37f0
...
@@ -11,7 +11,7 @@ import torch.distributed as dist
...
@@ -11,7 +11,7 @@ import torch.distributed as dist
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.nn.optimizer
import
CPUAdam
from
colossalai.nn.optimizer
import
CPUAdam
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.shard_utils
import
(
Bucket
TensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
...
@@ -47,12 +47,12 @@ def run_step_no_criterion(model, optimizer, data, label, enable_autocast=False):
...
@@ -47,12 +47,12 @@ def run_step_no_criterion(model, optimizer, data, label, enable_autocast=False):
optimizer
.
step
()
optimizer
.
step
()
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
shard_strategy
=
shard_strategy
()
for
model_name
in
test_models
:
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
shard_strategy
=
TensorShardStrategy
()
model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
=
model
(
checkpoint
=
True
).
cuda
()
model
=
model
(
checkpoint
=
True
).
cuda
()
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
shard_strategy
,
offload_config
=
{
'device'
:
'cpu'
})
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
shard_strategy
,
offload_config
=
{
'device'
:
'cpu'
})
...
@@ -79,10 +79,11 @@ def run_dist(rank, world_size, port):
...
@@ -79,10 +79,11 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
def
test_sharded_optim_v2
(
world_size
):
@
pytest
.
mark
.
parametrize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
def
test_sharded_optim_v2
(
world_size
,
shard_strategy
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
shard_strategy
=
shard_strategy
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_sharded_optim_v2
(
world_size
=
2
)
test_sharded_optim_v2
(
world_size
=
2
,
shard_strategy
=
TensorShardStrategy
)
tests/test_zero_data_parallel/test_state_dict.py
View file @
54fd37f0
...
@@ -9,22 +9,21 @@ import pytest
...
@@ -9,22 +9,21 @@ import pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
\
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
TensorShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
CONFIG
from
common
import
CONFIG
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
]
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
]
shard_strategy
=
shard_strategy
()
for
model_name
in
test_models
:
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
=
model_builder
()
model
=
model_builder
()
shard_strategy
=
TensorShardStrategy
()
model
=
model
.
half
().
cuda
()
model
=
model
.
half
().
cuda
()
zero_model
=
ShardedModelV2
(
deepcopy
(
model
),
shard_strategy
)
zero_model
=
ShardedModelV2
(
deepcopy
(
model
),
shard_strategy
)
zero_state_dict
=
zero_model
.
state_dict
()
zero_state_dict
=
zero_model
.
state_dict
()
...
@@ -33,11 +32,12 @@ def run_dist(rank, world_size, port):
...
@@ -33,11 +32,12 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
def
test_zero_state_dict
():
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
world_size
=
2
@
pytest
.
mark
.
parametrize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
def
test_zero_state_dict
(
world_size
,
shard_strategy
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
shard_strategy
=
shard_strategy
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_zero_state_dict
()
test_zero_state_dict
(
2
,
TensorShardStrategy
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment