Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f8a0e7fb
Unverified
Commit
f8a0e7fb
authored
Mar 14, 2022
by
Frank Lee
Committed by
GitHub
Mar 14, 2022
Browse files
Merge pull request #412 from hpcaitech/develop
merge develop to main
parents
fc5101f2
21dc54e0
Changes
37
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
189 additions
and
101 deletions
+189
-101
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
+41
-0
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+25
-4
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+32
-7
examples
examples
+1
-1
tests/components_to_test/bert.py
tests/components_to_test/bert.py
+1
-4
tests/components_to_test/nested_model.py
tests/components_to_test/nested_model.py
+1
-4
tests/components_to_test/repeated_computed_layer.py
tests/components_to_test/repeated_computed_layer.py
+1
-4
tests/components_to_test/resnet.py
tests/components_to_test/resnet.py
+1
-4
tests/test_engine/test_engine.py
tests/test_engine/test_engine.py
+3
-3
tests/test_trainer/test_trainer_with_non_pipe_schedule.py
tests/test_trainer/test_trainer_with_non_pipe_schedule.py
+2
-2
tests/test_utils/test_activation_checkpointing.py
tests/test_utils/test_activation_checkpointing.py
+1
-0
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+16
-15
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+19
-15
tests/test_zero_data_parallel/test_shard_param.py
tests/test_zero_data_parallel/test_shard_param.py
+8
-7
tests/test_zero_data_parallel/test_sharded_optim_v2.py
tests/test_zero_data_parallel/test_sharded_optim_v2.py
+18
-14
tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
...zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
+11
-9
tests/test_zero_data_parallel/test_state_dict.py
tests/test_zero_data_parallel/test_state_dict.py
+8
-8
No files found.
colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py
0 → 100644
View file @
f8a0e7fb
from
typing
import
List
import
torch
import
torch.distributed
as
dist
from
colossalai.utils
import
get_current_device
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
torch._utils
import
_flatten_dense_tensors
as
flatten
from
.tensor_shard_strategy
import
TensorShardStrategy
class
BucketTensorShardStrategy
(
TensorShardStrategy
):
def
gather
(
self
,
tensor_list
:
List
[
ShardedTensor
]):
tensor_list
:
List
[
ShardedTensor
]
=
[
t
for
t
in
tensor_list
if
t
.
is_sharded
]
if
len
(
tensor_list
)
==
0
:
return
target_device
=
tensor_list
[
0
].
device
dtype
=
tensor_list
[
0
].
dtype
buffer_list
:
List
[
torch
.
Tensor
]
=
[]
tensor_numels
=
[
t
.
payload
.
numel
()
for
t
in
tensor_list
]
buffer_size
=
sum
(
tensor_numels
)
for
i
in
range
(
self
.
world_size
):
if
i
==
self
.
local_rank
:
buffer_list
.
append
(
flatten
([
t
.
payload
for
t
in
tensor_list
]).
cuda
(
get_current_device
()))
# Release payload here, to decrease peak memory usage
for
t
in
tensor_list
:
t
.
reset_payload
(
None
)
else
:
buffer_list
.
append
(
torch
.
zeros
(
buffer_size
,
dtype
=
dtype
,
device
=
get_current_device
()))
dist
.
all_gather
(
buffer_list
,
buffer_list
[
self
.
local_rank
],
group
=
self
.
process_group
)
# Move to target device before splitting buffer
# Ensure we utilize maximum PCIE bandwidth
buffer_list
=
[
buffer
.
to
(
target_device
)
for
buffer
in
buffer_list
]
offset
=
0
for
i
,
t
in
enumerate
(
tensor_list
):
gathered_payload
=
[
buffer
[
offset
:
offset
+
tensor_numels
[
i
]]
for
buffer
in
buffer_list
]
gathered_payload
=
torch
.
cat
(
gathered_payload
)[:
t
.
origin_numel
].
view
(
t
.
origin_shape
)
t
.
reset_payload
(
gathered_payload
)
t
.
is_sharded
=
False
offset
+=
tensor_numels
[
i
]
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
f8a0e7fb
...
...
@@ -17,7 +17,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.allocator
import
col_move_to_cpu
from
._zero3_utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
get_gradient_predivide_factor
)
...
...
@@ -33,7 +34,8 @@ class ShardedModelV2(nn.Module):
fp32_reduce_scatter
:
bool
=
False
,
offload_config
:
Optional
[
dict
]
=
None
,
gradient_predivide_factor
:
Optional
[
float
]
=
1.0
,
shard_param
:
bool
=
True
):
shard_param
:
bool
=
True
,
use_memory_tracer
:
bool
=
False
):
r
"""
A demo to reconfigure zero1 shared_model.
Currently do not consider the Optimizer States.
...
...
@@ -59,8 +61,16 @@ class ShardedModelV2(nn.Module):
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
data
])
# Init Memory Statistics Collector
self
.
_use_memory_tracer
=
use_memory_tracer
if
self
.
_use_memory_tracer
:
self
.
_memstats_collector
=
MemStatsCollector
()
else
:
self
.
_memstats_collector
=
None
self
.
_iter_cnter
=
0
# Register hooks
register_ophooks_recursively
(
self
.
module
,
[
ZeroHook
(
self
.
shard_strategy
)])
register_ophooks_recursively
(
self
.
module
,
[
ZeroHook
(
self
.
shard_strategy
,
self
.
_memstats_collector
)])
self
.
param_hook_mgr
=
BaseParamHookMgr
(
list
(
self
.
module
.
parameters
()))
self
.
param_hook_mgr
.
register_backward_hooks
(
self
.
_grad_post_backward_hook
)
...
...
@@ -84,6 +94,9 @@ class ShardedModelV2(nn.Module):
return
self
.
_cpu_offload
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
if
self
.
_iter_cnter
==
0
and
self
.
_memstats_collector
:
# the opeartion will affect the flag in ZeroHook
self
.
_memstats_collector
.
start_collection
()
args
,
kwargs
=
cast_float_arguments
(
cast_tensor_to_fp16
,
*
args
,
**
kwargs
)
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
return
outputs
...
...
@@ -98,6 +111,12 @@ class ShardedModelV2(nn.Module):
@
torch
.
no_grad
()
def
_final_backward_hook
(
self
)
->
None
:
if
self
.
_iter_cnter
==
0
and
self
.
_memstats_collector
:
self
.
_memstats_collector
.
finish_collection
()
if
self
.
_memstats_collector
:
self
.
_memstats_collector
.
reset_sampling_cnter
()
self
.
_iter_cnter
+=
1
if
self
.
_require_backward_grad_sync
:
# Flush any unreduced buckets in the post_backward stream.
with
torch
.
cuda
.
stream
(
self
.
comm_stream
):
...
...
@@ -185,8 +204,10 @@ class ShardedModelV2(nn.Module):
reduced_grad
.
data
=
cast_tensor_to_fp32
(
reduced_grad
.
data
)
# Maybe offload
# TODO() optimize GPU->CPU bandwidth utilization
if
self
.
_cpu_offload
:
reduced_grad
.
data
=
reduced_grad
.
data
.
cpu
()
col_move_to_cpu
(
reduced_grad
)
# reduced_grad.data = reduced_grad.data.cpu()
if
param
.
col_attr
.
grad
is
None
:
param
.
col_attr
.
grad
=
reduced_grad
.
data
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
f8a0e7fb
from
enum
import
Enum
from
typing
import
Dict
,
Optional
from
typing
import
Callable
,
Dict
,
Optional
,
Union
import
torch
import
torch.distributed
as
dist
...
...
@@ -15,7 +15,7 @@ from torch import Tensor
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.optim
import
Optimizer
from
typing
import
Type
,
Any
from
._utils
import
has_inf_or_nan
...
...
@@ -27,8 +27,8 @@ class OptimState(Enum):
class
ShardedOptimizerV2
(
ColossalaiOptimizer
):
def
__init__
(
self
,
optimizer
:
Optimizer
,
sharded_model
:
ShardedModelV2
,
optimizer_class
:
Type
[
Optimizer
],
shard_strategy
:
BaseShardStrategy
,
cpu_offload
:
bool
=
False
,
initial_scale
:
float
=
2
**
32
,
...
...
@@ -39,9 +39,34 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
hysteresis
:
float
=
2
,
max_scale
:
int
=
2
**
32
,
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
,
mp_process_group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
mp_process_group
:
Optional
[
ProcessGroup
]
=
None
,
**
defaults
:
Any
)
->
None
:
"""
:param sharded_model: A sharded model initialized by class ShardedModelV2
:type sharded_model: sharded_model
:param optimizer_class: A type of Optimizer
:type optimizer_class: Type[Optimizer]
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
:type shard_strategy: BaseShardStrategy
:param cpu_offload: is offloading the optimizer states to CPU.
:type cpu_offload: bool
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
:type shard_strategy: BaseShardStrategy
:**defaults: any trailing arguments, which are forwarded to the local optimizer.
:type defaults: dict()
"""
assert
isinstance
(
sharded_model
,
ShardedModelV2
),
'model must be wrapped with ShardedModel'
super
().
__init__
(
optimizer
)
self
.
_optim_defaults
=
defaults
# initialize the M, V as zeros tensors and initialize param fp32 from sharded_model.parameters()
self
.
optimizer
=
optimizer_class
(
sharded_model
.
parameters
(),
**
self
.
_optim_defaults
)
super
().
__init__
(
self
.
optimizer
)
self
.
shard_strategy
=
shard_strategy
self
.
model
:
ShardedModelV2
=
sharded_model
if
cpu_offload
and
not
sharded_model
.
cpu_offload
:
...
...
@@ -65,7 +90,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Store fp32 param shards
self
.
master_params
:
Dict
[
Parameter
,
Tensor
]
=
{}
for
group
in
optimizer
.
param_groups
:
for
group
in
self
.
optimizer
.
param_groups
:
for
p
in
group
[
'params'
]:
assert
hasattr
(
p
,
'col_attr'
),
'The parameter must be wrapped with ShardedParam'
is_param_sharded
=
p
.
col_attr
.
data
.
is_sharded
...
...
@@ -118,7 +143,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.col_attr.data is fp16
# TODO() optimize this line
# TODO() optimize this line
CPU (fp32) -> GPU (fp16)
p
.
col_attr
.
data
.
copy_payload
(
p
.
data
)
if
not
is_param_sharded
:
...
...
examples
@
5345187a
Compare
d50ef2db
...
5345187a
Subproject commit
d50ef2db51e7d02ed3f7e9de13f9af86b04eaae9
Subproject commit
5345187ad55e8c80c111e0c5f7ad9b9241e8f913
tests/components_to_test/bert.py
View file @
f8a0e7fb
...
...
@@ -74,8 +74,5 @@ def get_training_components():
sequence_length
=
sequence_length
,
is_distrbuted
=
True
)
def
get_optim
(
model
):
return
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
None
return
bert_model_builder
,
trainloader
,
testloader
,
get_opti
m
,
criterion
return
bert_model_builder
,
trainloader
,
testloader
,
torch
.
optim
.
Ada
m
,
criterion
tests/components_to_test/nested_model.py
View file @
f8a0e7fb
...
...
@@ -49,8 +49,5 @@ def get_training_components():
trainloader
=
DummyDataLoader
()
testloader
=
DummyDataLoader
()
def
optim_builder
(
model
):
return
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
return
model_builder
,
trainloader
,
testloader
,
optim_builder
,
criterion
return
model_builder
,
trainloader
,
testloader
,
torch
.
optim
.
Adam
,
criterion
tests/components_to_test/repeated_computed_layer.py
View file @
f8a0e7fb
...
...
@@ -43,8 +43,5 @@ def get_training_components():
trainloader
=
DummyDataLoader
()
testloader
=
DummyDataLoader
()
def
optim_builder
(
model
):
return
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
return
model_builder
,
trainloader
,
testloader
,
optim_builder
,
criterion
return
model_builder
,
trainloader
,
testloader
,
torch
.
optim
.
Adam
,
criterion
tests/components_to_test/resnet.py
View file @
f8a0e7fb
...
...
@@ -29,8 +29,5 @@ def get_resnet_training_components():
trainloader
=
get_cifar10_dataloader
(
train
=
True
)
testloader
=
get_cifar10_dataloader
(
train
=
False
)
def
optim_builder
(
model
):
return
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
return
model_builder
,
trainloader
,
testloader
,
optim_builder
,
criterion
return
model_builder
,
trainloader
,
testloader
,
torch
.
optim
.
Adam
,
criterion
tests/test_engine/test_engine.py
View file @
f8a0e7fb
...
...
@@ -19,11 +19,11 @@ def run_train():
# FIXME: test bert
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
optimizer_
builder
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
_
,
optimizer_
class
,
criterion
=
get_components_func
()
model
=
model_builder
(
checkpoint
=
False
)
engine
,
train_dataloader
,
*
args
=
colossalai
.
initialize
(
model
=
model
,
optimizer
=
optimizer_
builder
(
model
),
optimizer
=
optimizer_
class
(
model
.
parameters
(),
lr
=
1e-3
),
criterion
=
criterion
,
train_dataloader
=
train_dataloader
)
...
...
@@ -84,7 +84,7 @@ def run_engine(rank, world_size, port):
@
pytest
.
mark
.
dist
def
test_engine
():
world_size
=
4
world_size
=
2
run_func
=
partial
(
run_engine
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
tests/test_trainer/test_trainer_with_non_pipe_schedule.py
View file @
f8a0e7fb
...
...
@@ -25,9 +25,9 @@ def run_trainer_no_pipeline(rank, world_size, port):
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'nested_model'
]
for
name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_
builder
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_
class
,
criterion
=
get_components_func
()
model
=
model_builder
()
optimizer
=
optimizer_
builder
(
model
)
optimizer
=
optimizer_
class
(
model
.
parameters
(),
lr
=
1e-3
)
engine
,
train_dataloader
,
*
_
=
colossalai
.
initialize
(
model
=
model
,
optimizer
=
optimizer
,
criterion
=
criterion
,
...
...
tests/test_utils/test_activation_checkpointing.py
View file @
f8a0e7fb
...
...
@@ -56,6 +56,7 @@ def test_activation_checkpointing(cpu_offload):
assert
torch
.
all
(
data
.
grad
==
data_
.
grad
),
'Gradient of the input does not match'
torch
.
cuda
.
empty_cache
()
# as seed manager is singleton
# if we don't reset seeds here,
# other tests will fail if running together with this test
...
...
tests/test_zero_data_parallel/test_init_context.py
View file @
f8a0e7fb
...
...
@@ -4,21 +4,20 @@
from
functools
import
partial
import
colossalai
from
colossalai.utils.cuda
import
get_current_device
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
\
TensorShardStrategy
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
CONFIG
from
colossalai.utils.memory_tracer.
allocator
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_tracer.
model_data_memtracer
import
ModelDataTracer
def
run_dist
(
rank
,
world_size
,
port
,
init_device
):
def
run_dist
(
rank
,
world_size
,
port
,
init_device
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
for
get_components_func
in
non_distributed_component_funcs
:
...
...
@@ -26,7 +25,7 @@ def run_dist(rank, world_size, port, init_device):
model_numel_tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
)
with
ZeroInitContext
(
convert_fp16
=
True
,
target_device
=
init_device
,
shard_strategy
=
TensorS
hard
S
trategy
(),
shard_strategy
=
s
hard
_s
trategy
(),
shard_param
=
True
,
model_numel_tensor
=
model_numel_tensor
):
model
=
model_builder
(
checkpoint
=
True
)
...
...
@@ -38,23 +37,25 @@ def run_dist(rank, world_size, port, init_device):
assert
param
.
col_attr
.
data
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
col_attr
.
data
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
print
(
f
'cpu usgae
{
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
}
'
)
print
(
f
'cuda usgae
{
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
}
'
)
print
(
f
'cuda usgae
{
ModelDataTracer
().
cuda_usage
}
'
)
print
(
f
'numel
{
model_numel_tensor
}
'
)
if
init_device
.
type
==
'cuda'
:
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
>
0
)
elif
init_device
.
type
==
'cpu'
:
assert
(
GLOBAL_MODEL_DATA_TRACER
.
cpu_usage
>
0
)
assert
(
ModelDataTracer
().
cuda_usage
>
0
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"init_device"
,
[
torch
.
device
(
'cpu'
),
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)])
def
test_zero_init_context
(
world_size
,
init_device
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
init_device
=
init_device
)
@
pytest
.
mark
.
parametrize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
test_zero_init_context
(
world_size
,
init_device
,
shard_strategy
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
init_device
=
init_device
,
shard_strategy
=
shard_strategy
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_zero_init_context
(
2
,
torch
.
device
(
'cpu'
))
test_zero_init_context
(
2
,
torch
.
device
(
f
'c
uda:
{
get_current_device
()
}
'
)
)
#
test_zero_init_context(2, torch.device('cpu')
, TensorShardStrategy
)
test_zero_init_context
(
4
,
torch
.
device
(
'c
pu'
),
BucketTensorShardStrategy
)
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
f8a0e7fb
...
...
@@ -3,30 +3,29 @@
import
copy
from
functools
import
partial
import
pytest
import
colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
colossalai
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.utils
import
free_port
from
colossalai.zero.
shard_utils.tensor_shard_strategy
import
\
TensorShardStrategy
from
colossalai.zero.
init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._zero3_utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
common
import
CONFIG
,
check_grads_padding
,
run_fwd_bwd
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
def
run_dist
(
rank
,
world_size
,
port
,
use_zero_init_ctx
,
enable_autocast
):
def
run_dist
(
rank
,
world_size
,
port
,
use_zero_init_ctx
,
enable_autocast
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
shard_strategy
=
TensorS
hard
S
trategy
()
shard_strategy
=
s
hard
_s
trategy
()
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
...
...
@@ -35,12 +34,12 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
if
use_zero_init_ctx
:
with
ZeroInitContext
(
convert_fp16
=
True
,
target_device
=
torch
.
device
(
'cpu'
),
target_device
=
torch
.
device
(
f
'cpu
:0
'
),
shard_strategy
=
shard_strategy
,
shard_param
=
True
,
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
):
zero_model
=
model_builder
(
checkpoint
=
True
)
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
)
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
,
use_memory_tracer
=
True
)
model
=
model_builder
(
checkpoint
=
True
).
half
()
col_model_deepcopy
(
zero_model
,
model
)
...
...
@@ -61,19 +60,24 @@ def run_dist(rank, world_size, port, use_zero_init_ctx, enable_autocast):
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
print
(
'overall cuda '
,
zero_model
.
_memstats_collector
.
_overall_cuda
)
print
(
'model cuda '
,
zero_model
.
_memstats_collector
.
_model_data_cuda
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"enable_autocast"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"use_zero_init_ctx"
,
[
True
])
def
test_shard_model_v2
(
world_size
,
use_zero_init_ctx
,
enable_autocast
):
@
pytest
.
mark
.
parametrize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
test_shard_model_v2
(
world_size
,
use_zero_init_ctx
,
enable_autocast
,
shard_strategy
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
use_zero_init_ctx
=
use_zero_init_ctx
,
enable_autocast
=
enable_autocast
)
enable_autocast
=
enable_autocast
,
shard_strategy
=
shard_strategy
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_shard_model_v2
(
world_size
=
2
,
use_zero_init_ctx
=
True
,
enable_autocast
=
True
)
test_shard_model_v2
(
world_size
=
2
,
use_zero_init_ctx
=
True
,
enable_autocast
=
True
,
shard_strategy
=
TensorShardStrategy
)
tests/test_zero_data_parallel/test_shard_param.py
View file @
f8a0e7fb
...
...
@@ -10,20 +10,20 @@ import torch
import
torch.multiprocessing
as
mp
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.shard_utils
import
(
Bucket
TensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_param
import
ShardedParam
,
ShardedTensor
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
from
tests.test_zero_data_parallel.common
import
CONFIG
,
allclose
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_zero_data_parallel.common
import
CONFIG
,
allclose
def
_run_shard_tensor
(
rank
,
world_size
,
port
):
def
_run_shard_tensor
(
rank
,
world_size
,
port
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
t
=
ShardedTensor
(
tensor
=
torch
.
randn
(
world_size
*
2
,
3
))
assert
list
(
t
.
origin_shape
)
==
[
world_size
*
2
,
3
]
assert
list
(
t
.
shape
)
==
[
world_size
*
2
,
3
]
shard_strategy
=
TensorS
hard
S
trategy
(
process_group
=
None
)
shard_strategy
=
s
hard
_s
trategy
(
process_group
=
None
)
# test shard strategy
shard_strategy
.
shard
([
t
])
...
...
@@ -34,8 +34,9 @@ def _run_shard_tensor(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
def
test_shard_tensor
(
world_size
):
run_func
=
partial
(
_run_shard_tensor
,
world_size
=
world_size
,
port
=
free_port
())
@
pytest
.
mark
.
parametrize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
test_shard_tensor
(
world_size
,
shard_strategy
):
run_func
=
partial
(
_run_shard_tensor
,
world_size
=
world_size
,
port
=
free_port
(),
shard_strategy
=
shard_strategy
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
@@ -121,7 +122,7 @@ def test_init_shard_param(world_size):
if
__name__
==
'__main__'
:
test_shard_tensor
(
2
)
test_shard_tensor
(
2
,
TensorShardStrategy
)
test_shard_param
(
2
)
test_shard_param_v2
(
2
)
test_init_shard_param
(
4
)
tests/test_zero_data_parallel/test_sharded_optim_v2.py
View file @
f8a0e7fb
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
copy
from
functools
import
partial
...
...
@@ -10,7 +7,7 @@ import torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.shard_utils
import
(
Bucket
TensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
tests.components_to_test.registry
import
non_distributed_component_funcs
...
...
@@ -38,25 +35,27 @@ def run_step(model, optimizer, data, label, criterion, enable_autocast=False):
optimizer
.
step
()
def
run_dist
(
rank
,
world_size
,
port
,
cpu_offload
):
def
run_dist
(
rank
,
world_size
,
port
,
cpu_offload
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
shard_strategy
=
shard_strategy
()
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
shard_strategy
=
TensorShardStrategy
()
model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model
=
model
(
checkpoint
=
True
).
cuda
()
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
shard_strategy
,
offload_config
=
dict
(
device
=
'cpu'
)
if
cpu_offload
else
None
)
if
dist
.
get_world_size
()
>
1
:
model
=
DDP
(
model
)
optim
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
sharded_optim
=
ShardedOptimizerV2
(
Adam
(
zero_model
.
parameters
(),
lr
=
1e-3
),
zero_model
,
lr
=
1e-3
optim
=
optimizer_class
(
model
.
parameters
(),
lr
=
lr
)
sharded_optim
=
ShardedOptimizerV2
(
zero_model
,
optimizer_class
,
shard_strategy
,
cpu_offload
=
cpu_offload
,
initial_scale
=
2
**
5
)
initial_scale
=
2
**
5
,
lr
=
lr
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
break
...
...
@@ -69,10 +68,15 @@ def run_dist(rank, world_size, port, cpu_offload):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"cpu_offload"
,
[
True
,
False
])
def
test_sharded_optim_v2
(
world_size
,
cpu_offload
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
cpu_offload
=
cpu_offload
)
@
pytest
.
mark
.
parametrize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
test_sharded_optim_v2
(
world_size
,
cpu_offload
,
shard_strategy
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
cpu_offload
=
cpu_offload
,
shard_strategy
=
shard_strategy
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_sharded_optim_v2
(
world_size
=
2
,
cpu_offload
=
True
)
test_sharded_optim_v2
(
world_size
=
2
,
cpu_offload
=
True
,
shard_strategy
=
TensorShardStrategy
)
\ No newline at end of file
tests/test_zero_data_parallel/test_sharded_optim_v2_with_cpu_adam.py
View file @
f8a0e7fb
...
...
@@ -11,7 +11,7 @@ import torch.distributed as dist
import
torch.multiprocessing
as
mp
from
colossalai.nn.optimizer
import
CPUAdam
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.shard_utils
import
(
Bucket
TensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
tests.components_to_test.registry
import
non_distributed_component_funcs
...
...
@@ -47,23 +47,24 @@ def run_step_no_criterion(model, optimizer, data, label, enable_autocast=False):
optimizer
.
step
()
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
shard_strategy
=
shard_strategy
()
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
shard_strategy
=
TensorShardStrategy
()
model
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
=
model
(
checkpoint
=
True
).
cuda
()
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
shard_strategy
,
offload_config
=
{
'device'
:
'cpu'
})
if
dist
.
get_world_size
()
>
1
:
model
=
DDP
(
model
)
optim
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
sharded_optim
=
ShardedOptimizerV2
(
CPUAdam
(
zero_model
.
parameters
(),
lr
=
1e-3
)
,
zero_model
,
sharded_optim
=
ShardedOptimizerV2
(
zero_model
,
CPUAdam
,
shard_strategy
,
initial_scale
=
2
**
5
,
cpu_offload
=
True
)
cpu_offload
=
True
,
lr
=
1e-3
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
break
...
...
@@ -79,10 +80,11 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
def
test_sharded_optim_v2
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
@
pytest
.
mark
.
parametrize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
test_sharded_optim_v2
(
world_size
,
shard_strategy
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
shard_strategy
=
shard_strategy
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_sharded_optim_v2
(
world_size
=
2
)
test_sharded_optim_v2
(
world_size
=
2
,
shard_strategy
=
TensorShardStrategy
)
tests/test_zero_data_parallel/test_state_dict.py
View file @
f8a0e7fb
...
...
@@ -9,22 +9,21 @@ import pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils.tensor_shard_strategy
import
\
TensorShardStrategy
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
CONFIG
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
,
shard_strategy
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
]
shard_strategy
=
shard_strategy
()
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
=
model_builder
()
shard_strategy
=
TensorShardStrategy
()
model
=
model
.
half
().
cuda
()
zero_model
=
ShardedModelV2
(
deepcopy
(
model
),
shard_strategy
)
zero_state_dict
=
zero_model
.
state_dict
()
...
...
@@ -33,11 +32,12 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
def
test_zero_state_dict
():
world_size
=
2
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
test_zero_state_dict
(
world_size
,
shard_strategy
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
shard_strategy
=
shard_strategy
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_zero_state_dict
()
test_zero_state_dict
(
2
,
TensorShardStrategy
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment