Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2fe68b35
Unverified
Commit
2fe68b35
authored
Mar 14, 2022
by
Frank Lee
Committed by
GitHub
Mar 14, 2022
Browse files
Merge pull request #403 from ver217/feature/shard-strategy
[zero] Add bucket tensor shard strategy
parents
cf92a779
63469c0f
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
127 additions
and
61 deletions
+127
-61
colossalai/engine/ophooks/zero_hook.py
colossalai/engine/ophooks/zero_hook.py
+20
-7
colossalai/zero/shard_utils/__init__.py
colossalai/zero/shard_utils/__init__.py
+4
-3
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
+41
-0
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+14
-10
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+14
-14
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+8
-7
tests/test_zero_data_parallel/test_sharded_optim_v2.py
tests/test_zero_data_parallel/test_sharded_optim_v2.py
+11
-6
tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
...zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
+7
-6
tests/test_zero_data_parallel/test_state_dict.py
tests/test_zero_data_parallel/test_state_dict.py
+8
-8
No files found.
colossalai/engine/ophooks/zero_hook.py
View file @
2fe68b35
import
torch
from
colossalai.registry
import
OPHOOKS
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.utils
import
get_current_device
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
._base_ophook
import
BaseOpHook
...
...
@@ -18,23 +19,32 @@ class ZeroHook(BaseOpHook):
self
.
computing_device
=
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
tensor_list
=
[]
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
gather
([
param
.
col_attr
.
data
])
tensor_list
.
append
(
param
.
col_attr
.
data
)
self
.
shard_strategy
.
gather
(
tensor_list
)
for
param
in
module
.
parameters
():
if
param
.
col_attr
.
data
.
device
!=
self
.
computing_device
:
param
.
col_attr
.
data
.
to
(
self
.
computing_device
)
param
.
data
=
param
.
col_attr
.
data
.
payload
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
tensor_list
=
[]
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
data
])
param
.
data
=
torch
.
empty
([],
dtype
=
param
.
col_attr
.
data
.
dtype
,
device
=
param
.
col_attr
.
data
.
payload
.
device
)
tensor_list
.
append
(
param
.
col_attr
.
data
)
self
.
shard_strategy
.
shard
(
tensor_list
)
for
param
in
module
.
parameters
():
param
.
col_attr
.
remove_torch_payload
()
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
tensor_list
=
[]
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
gather
([
param
.
col_attr
.
data
])
tensor_list
.
append
(
param
.
col_attr
.
data
)
self
.
shard_strategy
.
gather
(
tensor_list
)
for
param
in
module
.
parameters
():
if
param
.
col_attr
.
data
.
device
!=
self
.
computing_device
:
param
.
col_attr
.
data
.
to
(
self
.
computing_device
)
param
.
data
=
param
.
col_attr
.
data
.
payload
...
...
@@ -52,10 +62,13 @@ class ZeroHook(BaseOpHook):
param
.
col_attr
.
bwd_count
+=
1
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
tensor_list
=
[]
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
data
])
param
.
data
=
torch
.
empty
([],
dtype
=
param
.
col_attr
.
data
.
dtype
,
device
=
param
.
col_attr
.
data
.
payload
.
device
)
tensor_list
.
append
(
param
.
col_attr
.
data
)
self
.
shard_strategy
.
shard
(
tensor_list
)
for
param
in
module
.
parameters
():
param
.
col_attr
.
remove_torch_payload
()
def
pre_iter
(
self
):
pass
...
...
colossalai/zero/shard_utils/__init__.py
View file @
2fe68b35
from
colossalai.zero.shard_utils.base_shard_strategy
import
BaseShardStrategy
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
TensorShardStrategy
from
.base_shard_strategy
import
BaseShardStrategy
from
.bucket_tensor_shard_strategy
import
BucketTensorShardStrategy
from
.tensor_shard_strategy
import
TensorShardStrategy
__all__
=
[
'BaseShardStrategy'
,
'TensorShardStrategy'
]
__all__
=
[
'BaseShardStrategy'
,
'TensorShardStrategy'
,
'BucketTensorShardStrategy'
]
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
0 → 100644
View file @
2fe68b35
from
typing
import
List
import
torch
import
torch.distributed
as
dist
from
colossalai.utils
import
get_current_device
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
torch._utils
import
_flatten_dense_tensors
as
flatten
from
.tensor_shard_strategy
import
TensorShardStrategy
class
BucketTensorShardStrategy
(
TensorShardStrategy
):
def
gather
(
self
,
tensor_list
:
List
[
ShardedTensor
]):
tensor_list
:
List
[
ShardedTensor
]
=
[
t
for
t
in
tensor_list
if
t
.
is_sharded
]
if
len
(
tensor_list
)
==
0
:
return
target_device
=
tensor_list
[
0
].
device
dtype
=
tensor_list
[
0
].
dtype
buffer_list
:
List
[
torch
.
Tensor
]
=
[]
tensor_numels
=
[
t
.
payload
.
numel
()
for
t
in
tensor_list
]
buffer_size
=
sum
(
tensor_numels
)
for
i
in
range
(
self
.
world_size
):
if
i
==
self
.
local_rank
:
buffer_list
.
append
(
flatten
([
t
.
payload
for
t
in
tensor_list
]).
cuda
(
get_current_device
()))
# Release payload here, to decrease peak memory usage
for
t
in
tensor_list
:
t
.
reset_payload
(
None
)
else
:
buffer_list
.
append
(
torch
.
zeros
(
buffer_size
,
dtype
=
dtype
,
device
=
get_current_device
()))
dist
.
all_gather
(
buffer_list
,
buffer_list
[
self
.
local_rank
],
group
=
self
.
process_group
)
# Move to target device before splitting buffer
# Ensure we utilize maximum PCIE bandwidth
buffer_list
=
[
buffer
.
to
(
target_device
)
for
buffer
in
buffer_list
]
offset
=
0
for
i
,
t
in
enumerate
(
tensor_list
):
gathered_payload
=
[
buffer
[
offset
:
offset
+
tensor_numels
[
i
]]
for
buffer
in
buffer_list
]
gathered_payload
=
torch
.
cat
(
gathered_payload
)[:
t
.
origin_numel
].
view
(
t
.
origin_shape
)
t
.
reset_payload
(
gathered_payload
)
t
.
is_sharded
=
False
offset
+=
tensor_numels
[
i
]
tests/test_zero_data_parallel/test_init_context.py
View file @
2fe68b35
...
...
@@ -4,21 +4,20 @@
from
functools
import
partial
import
colossalai
from
colossalai.utils.cuda
import
get_current_device
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.memory_tracer.allocator
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
\
TensorShardStrategy
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
CONFIG
from
colossalai.utils.memory_tracer.allocator
import
GLOBAL_MODEL_DATA_TRACER
def
run_dist
(
rank
,
world_size
,
port
,
init_device
):
def
run_dist
(
rank
,
world_size
,
port
,
init_device
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
for
get_components_func
in
non_distributed_component_funcs
:
...
...
@@ -26,7 +25,7 @@ def run_dist(rank, world_size, port, init_device):
model_numel_tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
)
with
ZeroInitContext
(
convert_fp16
=
True
,
target_device
=
init_device
,
shard_strategy
=
TensorS
hard
S
trategy
(),
shard_strategy
=
s
hard
_s
trategy
(),
shard_param
=
True
,
model_numel_tensor
=
model_numel_tensor
):
model
=
model_builder
(
checkpoint
=
True
)
...
...
@@ -50,11 +49,16 @@ def run_dist(rank, world_size, port, init_device):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"init_device"
,
[
torch
.
device
(
'cpu'
),
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)])
def
test_zero_init_context
(
world_size
,
init_device
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
init_device
=
init_device
)
@
pytest
.
mark
.
parametrize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
test_zero_init_context
(
world_size
,
init_device
,
shard_strategy
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
init_device
=
init_device
,
shard_strategy
=
shard_strategy
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_zero_init_context
(
2
,
torch
.
device
(
'cpu'
))
test_zero_init_context
(
2
,
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
))
test_zero_init_context
(
2
,
torch
.
device
(
'cpu'
)
,
TensorShardStrategy
)
test_zero_init_context
(
2
,
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)
,
TensorShardStrategy
)
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
2fe68b35
...
...
@@ -3,30 +3,28 @@
import
copy
from
functools
import
partial
import
pytest
import
colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
colossalai
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.utils
import
free_port
from
colossalai.zero.
shard_utils.tensor_shard_strategy
import
\
TensorShardStrategy
from
colossalai.zero.
init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._zero3_utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
common
import
CONFIG
,
check_grads_padding
,
run_fwd_bwd
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
def
run_dist
(
rank
,
world_size
,
port
,
use_zero_init_ctx
,
enable_autocast
):
def
run_dist
(
rank
,
world_size
,
port
,
use_zero_init_ctx
,
enable_autocast
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
shard_strategy
=
TensorS
hard
S
trategy
()
shard_strategy
=
s
hard
_s
trategy
()
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
...
...
@@ -66,14 +64,16 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"enable_autocast"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"use_zero_init_ctx"
,
[
True
])
def
test_shard_model_v2
(
world_size
,
use_zero_init_ctx
,
enable_autocast
):
@
pytest
.
mark
.
parametrize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
test_shard_model_v2
(
world_size
,
use_zero_init_ctx
,
enable_autocast
,
shard_strategy
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
use_zero_init_ctx
=
use_zero_init_ctx
,
enable_autocast
=
enable_autocast
)
enable_autocast
=
enable_autocast
,
shard_strategy
=
shard_strategy
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_shard_model_v2
(
world_size
=
2
,
use_zero_init_ctx
=
True
,
enable_autocast
=
True
)
test_shard_model_v2
(
world_size
=
2
,
use_zero_init_ctx
=
True
,
enable_autocast
=
True
,
shard_strategy
=
TensorShardStrategy
)
tests/test_zero_data_parallel/test_shard_param.py
View file @
2fe68b35
...
...
@@ -10,20 +10,20 @@ import torch
import
torch.multiprocessing
as
mp
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.shard_utils
import
(
Bucket
TensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_param
import
ShardedParam
,
ShardedTensor
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
from
tests.test_zero_data_parallel.common
import
CONFIG
,
allclose
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_zero_data_parallel.common
import
CONFIG
,
allclose
def
_run_shard_tensor
(
rank
,
world_size
,
port
):
def
_run_shard_tensor
(
rank
,
world_size
,
port
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
t
=
ShardedTensor
(
tensor
=
torch
.
randn
(
world_size
*
2
,
3
))
assert
list
(
t
.
origin_shape
)
==
[
world_size
*
2
,
3
]
assert
list
(
t
.
shape
)
==
[
world_size
*
2
,
3
]
shard_strategy
=
TensorS
hard
S
trategy
(
process_group
=
None
)
shard_strategy
=
s
hard
_s
trategy
(
process_group
=
None
)
# test shard strategy
shard_strategy
.
shard
([
t
])
...
...
@@ -34,8 +34,9 @@ def _run_shard_tensor(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
def
test_shard_tensor
(
world_size
):
run_func
=
partial
(
_run_shard_tensor
,
world_size
=
world_size
,
port
=
free_port
())
@
pytest
.
mark
.
parametrize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
test_shard_tensor
(
world_size
,
shard_strategy
):
run_func
=
partial
(
_run_shard_tensor
,
world_size
=
world_size
,
port
=
free_port
(),
shard_strategy
=
shard_strategy
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
@@ -121,7 +122,7 @@ def test_init_shard_param(world_size):
if
__name__
==
'__main__'
:
test_shard_tensor
(
2
)
test_shard_tensor
(
2
,
TensorShardStrategy
)
test_shard_param
(
2
)
test_shard_param_v2
(
2
)
test_init_shard_param
(
4
)
tests/test_zero_data_parallel/test_sharded_optim_v2.py
View file @
2fe68b35
...
...
@@ -10,7 +10,7 @@ import torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.shard_utils
import
(
Bucket
TensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
tests.components_to_test.registry
import
non_distributed_component_funcs
...
...
@@ -38,12 +38,12 @@ def run_step(model, optimizer, data, label, criterion, enable_autocast=False):
optimizer
.
step
()
def
run_dist
(
rank
,
world_size
,
port
,
cpu_offload
):
def
run_dist
(
rank
,
world_size
,
port
,
cpu_offload
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
shard_strategy
=
shard_strategy
()
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
shard_strategy
=
TensorShardStrategy
()
model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
=
model
(
checkpoint
=
True
).
cuda
()
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
...
...
@@ -69,10 +69,15 @@ def run_dist(rank, world_size, port, cpu_offload):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"cpu_offload"
,
[
True
,
False
])
def
test_sharded_optim_v2
(
world_size
,
cpu_offload
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
cpu_offload
=
cpu_offload
)
@
pytest
.
mark
.
parametrize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
test_sharded_optim_v2
(
world_size
,
cpu_offload
,
shard_strategy
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
cpu_offload
=
cpu_offload
,
shard_strategy
=
shard_strategy
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_sharded_optim_v2
(
world_size
=
2
,
cpu_offload
=
True
)
test_sharded_optim_v2
(
world_size
=
2
,
cpu_offload
=
True
,
shard_strategy
=
TensorShardStrategy
)
tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
View file @
2fe68b35
...
...
@@ -11,7 +11,7 @@ import torch.distributed as dist
import
torch.multiprocessing
as
mp
from
colossalai.nn.optimizer
import
CPUAdam
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.shard_utils
import
(
Bucket
TensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
tests.components_to_test.registry
import
non_distributed_component_funcs
...
...
@@ -47,12 +47,12 @@ def run_step_no_criterion(model, optimizer, data, label, enable_autocast=False):
optimizer
.
step
()
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
shard_strategy
=
shard_strategy
()
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
shard_strategy
=
TensorShardStrategy
()
model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
=
model
(
checkpoint
=
True
).
cuda
()
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
shard_strategy
,
offload_config
=
{
'device'
:
'cpu'
})
...
...
@@ -79,10 +79,11 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
def
test_sharded_optim_v2
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
@
pytest
.
mark
.
parametrize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
test_sharded_optim_v2
(
world_size
,
shard_strategy
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
shard_strategy
=
shard_strategy
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_sharded_optim_v2
(
world_size
=
2
)
test_sharded_optim_v2
(
world_size
=
2
,
shard_strategy
=
TensorShardStrategy
)
tests/test_zero_data_parallel/test_state_dict.py
View file @
2fe68b35
...
...
@@ -9,22 +9,21 @@ import pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
\
TensorShardStrategy
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
CONFIG
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
]
shard_strategy
=
shard_strategy
()
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
=
model_builder
()
shard_strategy
=
TensorShardStrategy
()
model
=
model
.
half
().
cuda
()
zero_model
=
ShardedModelV2
(
deepcopy
(
model
),
shard_strategy
)
zero_state_dict
=
zero_model
.
state_dict
()
...
...
@@ -33,11 +32,12 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
def
test_zero_state_dict
():
world_size
=
2
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
test_zero_state_dict
(
world_size
,
shard_strategy
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
shard_strategy
=
shard_strategy
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_zero_state_dict
()
test_zero_state_dict
(
2
,
TensorShardStrategy
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment