Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
57567ee7
Commit
57567ee7
authored
Mar 18, 2022
by
ver217
Browse files
update sharded optim and fix zero init ctx
parent
f27d801a
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
145 additions
and
140 deletions
+145
-140
colossalai/initialize.py
colossalai/initialize.py
+6
-8
colossalai/zero/__init__.py
colossalai/zero/__init__.py
+6
-25
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+8
-0
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+14
-15
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+11
-20
tests/test_zero_data_parallel/common.py
tests/test_zero_data_parallel/common.py
+7
-7
tests/test_zero_data_parallel/test_shard_model_v2.py
tests/test_zero_data_parallel/test_shard_model_v2.py
+16
-28
tests/test_zero_data_parallel/test_sharded_optim_v2.py
tests/test_zero_data_parallel/test_sharded_optim_v2.py
+28
-19
tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py
...est_zero_data_parallel/test_sharded_optim_with_sync_bn.py
+8
-3
tests/test_zero_data_parallel/test_state_dict.py
tests/test_zero_data_parallel/test_state_dict.py
+15
-3
tests/test_zero_data_parallel/test_zero_engine.py
tests/test_zero_data_parallel/test_zero_engine.py
+26
-12
No files found.
colossalai/initialize.py
View file @
57567ee7
...
...
@@ -5,7 +5,7 @@ import argparse
import
os
import
pprint
from
pathlib
import
Path
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
Type
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
...
...
@@ -21,13 +21,13 @@ from colossalai.builder.builder import build_gradient_handler
from
colossalai.context
import
Config
,
ConfigException
,
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.engine
import
Engine
from
colossalai.engine.ophooks
import
BaseOpHook
from
colossalai.global_variables
import
moe_env
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer.colossalai_optimizer
import
ColossalaiOptimizer
from
colossalai.utils
import
(
accumulate_gradient
,
get_current_device
,
is_using_ddp
,
is_using_pp
,
is_using_sequence
,
sync_model_param
)
from
colossalai.zero
import
convert_to_zero_v2
from
colossalai.engine.ophooks
import
BaseOpHook
from
colossalai.zero.sharded_optim.sharded_optim_v2
import
ShardedOptimizerV2
...
...
@@ -217,8 +217,8 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
verbose
=
verbose
)
def
initialize
(
model
:
Union
[
Callable
,
nn
.
Module
]
,
optimizer
:
Union
[
Type
[
Optimizer
],
Optimizer
]
,
def
initialize
(
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
criterion
:
Optional
[
_Loss
]
=
None
,
train_dataloader
:
Optional
[
Iterable
]
=
None
,
test_dataloader
:
Optional
[
Iterable
]
=
None
,
...
...
@@ -278,12 +278,10 @@ def initialize(model: Union[Callable, nn.Module],
cfg_
=
{}
optimizer_config
=
zero_cfg
.
get
(
'optimizer_config'
,
None
)
model_config
=
zero_cfg
.
get
(
'model_config'
,
None
)
model
,
optimizer
=
convert_to_zero_v2
(
model_builder
=
model
,
model_config
=
model_config
,
optimizer_config
=
optimizer_config
)
model
,
optimizer
=
convert_to_zero_v2
(
model
,
model_config
=
model_config
,
optimizer_config
=
optimizer_config
)
logger
.
info
(
"Initializing ZeRO model and optimizer finished!"
,
ranks
=
[
0
])
#FIXME() throw a warning if using zero with MP
#
FIXME() throw a warning if using zero with MP
if
gpc
.
get_world_size
(
ParallelMode
.
MODEL
)
>
1
:
logger
.
warning
(
"ZeRO currently has not been tested with model parallelism."
,
ranks
=
[
0
])
else
:
...
...
colossalai/zero/__init__.py
View file @
57567ee7
from
typing
import
Callab
le
from
typing
import
Tup
le
import
torch
import
torch.nn
as
nn
from
torch.optim
import
Optimizer
from
colossalai.zero.sharded_model.sharded_model_v2
import
ShardedModelV2
from
colossalai.zero.sharded_optim.sharded_optim_v2
import
ShardedOptimizerV2
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.amp.naive_amp
import
NaiveAMPModel
from
colossalai.core
import
global_context
as
gpc
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.logging
import
get_dist_logger
from
colossalai.zero.sharded_model.sharded_model_v2
import
ShardedModelV2
from
colossalai.zero.sharded_optim.sharded_optim_v2
import
ShardedOptimizerV2
from
torch.optim
import
Optimizer
from
.sharded_model
import
ShardedModel
from
.sharded_optim
import
ShardedOptimizer
def
convert_to_zero_v2
(
model
_builder
:
Callab
le
,
model_config
,
optimizer_config
)
->
(
ShardedModelV2
,
ShardedOptimizerV2
)
:
def
convert_to_zero_v2
(
model
:
nn
.
Modu
le
,
model_config
,
optimizer_config
)
->
Tuple
[
ShardedModelV2
,
ShardedOptimizerV2
]
:
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
...
...
@@ -31,9 +26,6 @@ def convert_to_zero_v2(model_builder: Callable, model_config, optimizer_config)
logger
=
get_dist_logger
(
'convert_to_zero_v2'
)
# FIXME() pass shard strategy from config
shard_strategy
=
TensorShardStrategy
()
logger
.
info
(
f
'optimizer_config is
{
optimizer_config
}
'
)
if
optimizer_config
is
None
:
optimizer_config
=
dict
()
...
...
@@ -41,18 +33,7 @@ def convert_to_zero_v2(model_builder: Callable, model_config, optimizer_config)
if
model_config
is
None
:
model_config
=
dict
()
if
isinstance
(
model_builder
,
nn
.
Module
):
model
=
model_builder
elif
isinstance
(
model_builder
,
Callable
):
with
ZeroInitContext
(
convert_fp16
=
'fp16'
in
gpc
.
config
,
target_device
=
torch
.
cuda
.
current_device
(),
shard_strategy
=
shard_strategy
,
shard_param
=
model_config
.
get
(
'shard_param'
,
True
)):
model
=
model_builder
()
else
:
raise
TypeError
(
f
"convert_to_zero_v2 dose not support model_builder of type
{
type
(
convert_to_zero_v2
)
}
"
)
zero_model
=
ShardedModelV2
(
model
,
shard_strategy
=
shard_strategy
,
**
model_config
)
zero_model
=
ShardedModelV2
(
model
,
**
model_config
)
zero_optimizer
=
ShardedOptimizerV2
(
zero_model
,
**
optimizer_config
)
return
zero_model
,
zero_optimizer
...
...
colossalai/zero/init_ctx/init_context.py
View file @
57567ee7
...
...
@@ -4,6 +4,7 @@ import torch
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model._zero3_utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_param
import
ShardedParamV2
# Inserts _post_init_method at the end of init method
...
...
@@ -158,3 +159,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
# if param.col_attr.grad and self.shard_grad:
# self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
# We must cast them
for
buffer
in
module
.
buffers
():
buffer
.
data
=
buffer
.
data
.
to
(
device
=
torch
.
cuda
.
current_device
())
if
self
.
convert_fp16
:
buffer
.
data
=
cast_tensor_to_fp16
(
buffer
.
data
)
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
57567ee7
import
functools
from
collections
import
OrderedDict
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
,
Type
import
torch
import
torch.distributed
as
dist
...
...
@@ -28,14 +28,13 @@ class ShardedModelV2(nn.Module):
def
__init__
(
self
,
module
:
nn
.
Module
,
shard_strategy
:
BaseShardStrategy
,
shard_strategy
:
Type
[
BaseShardStrategy
]
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
reduce_scatter_process_group
:
Optional
[
ProcessGroup
]
=
None
,
reduce_scatter_bucket_size_mb
:
int
=
25
,
fp32_reduce_scatter
:
bool
=
False
,
offload_config
:
Optional
[
dict
]
=
None
,
gradient_predivide_factor
:
Optional
[
float
]
=
1.0
,
shard_param
:
bool
=
True
,
use_memory_tracer
:
bool
=
False
):
r
"""
A demo to reconfigure zero1 shared_model.
...
...
@@ -44,23 +43,23 @@ class ShardedModelV2(nn.Module):
super
().
__init__
()
self
.
logger
=
get_dist_logger
()
# We force users to use ZeroInitContext
sharded
=
[]
unsharded
=
[]
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
),
'You must use ZeroInitContext to init your module first.'
sharded
.
append
(
param
.
col_attr
.
param_is_sharded
)
unsharded
.
append
(
not
param
.
col_attr
.
param_is_sharded
)
assert
all
(
sharded
)
or
all
(
unsharded
),
'Parameters must be all sharded or all unsharded! Parameters are partially sharded nwo.'
self
.
shard_param
=
all
(
sharded
)
self
.
module
=
module
self
.
process_group
=
process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
reduce_scatter_process_group
=
reduce_scatter_process_group
or
self
.
process_group
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
# Cast module to fp16 and cuda, in case user didn't use ZeroInitContext
self
.
module
=
module
.
half
().
cuda
()
self
.
shard_strategy
=
shard_strategy
self
.
shard_param
=
shard_param
# In case user didn't use ZeroInitContext
for
param
in
self
.
module
.
parameters
():
if
not
hasattr
(
param
,
'col_attr'
):
param
.
col_attr
=
ShardedParamV2
(
param
,
process_group
,
rm_torch_payload
=
True
)
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
data
])
# Init Memory Statistics Collector
self
.
_use_memory_tracer
=
use_memory_tracer
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
57567ee7
from
enum
import
Enum
from
typing
import
Dict
,
Optional
,
Type
,
Any
from
typing
import
Dict
,
Optional
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.optim
import
Optimizer
from
colossalai.amp.naive_amp.grad_scaler
import
DynamicGradScaler
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._zero3_utils
import
cast_tensor_to_fp32
from
colossalai.logging
import
get_dist_logger
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.optim
import
Optimizer
from
._utils
import
has_inf_or_nan
...
...
@@ -30,7 +27,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def
__init__
(
self
,
sharded_model
:
ShardedModelV2
,
optimizer
_class
:
Type
[
Optimizer
]
,
optimizer
:
Optimizer
,
cpu_offload
:
bool
=
False
,
initial_scale
:
float
=
2
**
32
,
min_scale
:
float
=
1
,
...
...
@@ -40,8 +37,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
hysteresis
:
float
=
2
,
max_scale
:
int
=
2
**
32
,
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
,
mp_process_group
:
Optional
[
ProcessGroup
]
=
None
,
**
defaults
:
Any
)
->
None
:
mp_process_group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
"""
:param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the
shard strategy provided by sharded model to shard param fp32 tensors.
...
...
@@ -84,13 +80,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
:type defaults: dict()
"""
assert
isinstance
(
sharded_model
,
ShardedModelV2
),
'model must be wrapped with ShardedModel'
self
.
_logger
=
get_dist_logger
(
'ShardedOptimV2 logger'
)
self
.
_optim_defaults
=
defaults
# initialize the M, V as zeros tensors and initialize param fp32 from sharded_model.parameters()
self
.
optimizer
=
optimizer_class
(
sharded_model
.
parameters
(),
**
self
.
_optim_defaults
)
super
().
__init__
(
self
.
optimizer
)
super
().
__init__
(
optimizer
)
self
.
shard_strategy
=
sharded_model
.
shard_strategy
self
.
model
:
ShardedModelV2
=
sharded_model
if
cpu_offload
and
not
sharded_model
.
cpu_offload
:
...
...
@@ -114,7 +105,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Store fp32 param shards
self
.
master_params
:
Dict
[
Parameter
,
Tensor
]
=
{}
for
group
in
self
.
optim
izer
.
param_groups
:
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
assert
hasattr
(
p
,
'col_attr'
),
'The parameter must be wrapped with ShardedParam'
is_param_sharded
=
p
.
col_attr
.
data
.
is_sharded
...
...
tests/test_zero_data_parallel/common.py
View file @
57567ee7
import
imp
from
functools
import
partial
import
torch
import
torch.distributed
as
dist
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer
import
CPUAdam
from
colossalai.utils
import
checkpoint
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.nn.optimizer
import
CPUAdam
LOGGER
=
get_dist_logger
(
'zero_test'
)
...
...
@@ -16,11 +17,10 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
fp32_reduce_scatter
=
False
,
offload_config
=
None
,
gradient_predivide_factor
=
1.0
,
shard_param
=
Tru
e
,
use_memory_tracer
=
False
)
use_memory_tracer
=
Fals
e
,
shard_strategy
=
TensorShardStrategy
)
_ZERO_OPTIMIZER_CONFIG
=
dict
(
optimizer_class
=
torch
.
optim
.
Adam
,
#CPUAdam
cpu_offload
=
False
,
initial_scale
=
2
**
5
,
min_scale
=
1
,
...
...
@@ -35,7 +35,7 @@ ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,),
zero
=
dict
(
model_config
=
_ZERO_MODEL_CONFIG
,
optimizer_config
=
_ZERO_OPTIMIZER_CONFIG
,
),
),
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
1
,
mode
=
None
)))
CONFIG
=
dict
(
fp16
=
dict
(
mode
=
None
,),
...
...
tests/test_zero_data_parallel/test_shard_model_v2.py
View file @
57567ee7
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
copy
from
asyncio.log
import
logger
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.
logg
ing
import
get_dist_logger
from
colossalai.
test
ing
import
parameterize
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._zero3_utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
...
...
@@ -20,13 +19,11 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
common
import
CONFIG
,
check_grads_padding
,
run_fwd_bwd
from
colossalai.testing
import
parameterize
@
parameterize
(
"enable_autocast"
,
[
True
])
@
parameterize
(
"use_zero_init_ctx"
,
[
True
])
@
parameterize
(
"shard_strategy"
,
[
TensorShardStrategy
,
BucketTensorShardStrategy
])
def
run_model_test
(
enable_autocast
,
use_zero_init_ctx
,
shard_strategy
,
logger
):
def
run_model_test
(
enable_autocast
,
shard_strategy
):
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'bert'
]
shard_strategy
=
shard_strategy
()
for
model_name
in
test_models
:
...
...
@@ -35,9 +32,8 @@ def run_model_test(enable_autocast, use_zero_init_ctx, shard_strategy, logger):
rm_torch_payload_on_the_fly
=
False
if
use_zero_init_ctx
:
with
ZeroInitContext
(
convert_fp16
=
True
,
target_device
=
torch
.
device
(
f
'cpu:0'
),
target_device
=
torch
.
cuda
.
current_device
(
),
shard_strategy
=
shard_strategy
,
shard_param
=
True
,
rm_torch_payload_on_the_fly
=
rm_torch_payload_on_the_fly
):
...
...
@@ -47,9 +43,6 @@ def run_model_test(enable_autocast, use_zero_init_ctx, shard_strategy, logger):
model
=
model_builder
(
checkpoint
=
True
).
half
()
col_model_deepcopy
(
zero_model
,
model
)
model
=
model
.
cuda
()
else
:
model
=
model_builder
(
checkpoint
=
True
).
half
().
cuda
()
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
shard_strategy
)
model
=
DDP
(
model
)
...
...
@@ -63,15 +56,10 @@ def run_model_test(enable_autocast, use_zero_init_ctx, shard_strategy, logger):
check_grads_padding
(
model
,
zero_model
,
loose
=
True
)
# logger.debug('overall cuda ', zero_model._memstats_collector._overall_cuda)
# logger.debug('model cuda ', zero_model._memstats_collector._model_data_cuda)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
logger
=
get_dist_logger
()
logger
.
set_level
(
'DEBUG'
)
run_model_test
(
logger
=
logger
)
run_model_test
()
@
pytest
.
mark
.
dist
...
...
tests/test_zero_data_parallel/test_sharded_optim_v2.py
View file @
57567ee7
import
copy
from
functools
import
partial
import
colossalai
...
...
@@ -6,15 +5,19 @@ import pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
colossalai.nn.optimizer
import
CPUAdam
from
colossalai.testing
import
parameterize
from
colossalai.utils
import
free_port
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
colossalai.nn.optimizer
import
CPUAdam
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
colossalai.testing
import
parameterize
from
common
import
CONFIG
,
check_sharded_params_padding
...
...
@@ -48,26 +51,32 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy, use_cpuadam):
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model
,
train_dataloader
,
_
,
optimizer_class
,
criterion
=
get_components_func
()
model
=
model
(
checkpoint
=
True
).
cuda
()
zero_model
=
ShardedModelV2
(
copy
.
deepcopy
(
model
),
model_builder
,
train_dataloader
,
_
,
optimizer_class
,
criterion
=
get_components_func
()
with
ZeroInitContext
(
convert_fp16
=
True
,
target_device
=
torch
.
device
(
f
'cpu:0'
),
shard_strategy
=
shard_strategy
,
shard_param
=
True
,
rm_torch_payload_on_the_fly
=
False
):
zero_model
=
model_builder
(
checkpoint
=
True
)
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
,
offload_config
=
dict
(
device
=
'cpu'
)
if
cpu_offload
else
None
)
model
=
model_builder
(
checkpoint
=
True
).
half
()
col_model_deepcopy
(
zero_model
,
model
)
model
=
model
.
cuda
().
float
()
if
dist
.
get_world_size
()
>
1
:
model
=
DDP
(
model
)
lr
=
1e-3
if
use_cpuadam
:
optim
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
lr
)
sharded_optim
=
ShardedOptimizerV2
(
zero_model
,
CPUAdam
,
cpu_offload
=
cpu_offload
,
initial_scale
=
2
**
5
,
lr
=
lr
)
else
:
optim
=
optimizer_class
(
model
.
parameters
(),
lr
=
lr
)
sharded_optim
=
ShardedOptimizerV2
(
zero_model
,
optimizer_class
,
cpu_offload
=
cpu_offload
,
initial_scale
=
2
**
5
,
lr
=
lr
)
optimizer_class
=
CPUAdam
optim
=
optimizer_class
(
model
.
parameters
(),
lr
=
1e-3
)
sharded_optim
=
optimizer_class
(
zero_model
.
parameters
(),
lr
=
1e-3
)
sharded_optim
=
ShardedOptimizerV2
(
zero_model
,
sharded_optim
,
cpu_offload
=
cpu_offload
,
initial_scale
=
2
**
5
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
#FIXME() if i > 5, the unittest will fail
#
FIXME() if i > 5, the unittest will fail
if
i
>
3
:
break
data
,
label
=
data
.
cuda
(),
label
.
cuda
()
...
...
tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py
View file @
57567ee7
...
...
@@ -4,14 +4,15 @@
from
functools
import
partial
import
colossalai
import
pyte
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
free_port
from
torchvision.models
import
resnet50
import
torch.distributed
as
dist
def
run_dist
(
rank
,
world_size
,
port
):
...
...
@@ -64,6 +65,10 @@ def run_dist(rank, world_size, port):
'expected the output from different ranks to be the same, but got different values'
# FIXME: enable this test in next PR
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
dist
def
test_sharded_optim_with_sync_bn
():
"""
...
...
tests/test_zero_data_parallel/test_state_dict.py
View file @
57567ee7
...
...
@@ -9,8 +9,10 @@ import pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.shard_utils
import
(
BucketTensorShardStrategy
,
TensorShardStrategy
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.testing
import
parameterize
from
common
import
CONFIG
...
...
@@ -23,9 +25,19 @@ def run_zero_state_dict(shard_strategy):
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer
,
criterion
=
get_components_func
()
model
=
model_builder
()
model
=
model
.
half
().
cuda
()
zero_model
=
ShardedModelV2
(
deepcopy
(
model
),
shard_strategy
)
with
ZeroInitContext
(
convert_fp16
=
True
,
target_device
=
torch
.
cuda
.
current_device
(),
shard_strategy
=
shard_strategy
,
shard_param
=
True
,
rm_torch_payload_on_the_fly
=
False
):
zero_model
=
model_builder
(
checkpoint
=
True
)
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
)
model
=
model_builder
(
checkpoint
=
True
).
half
()
col_model_deepcopy
(
zero_model
,
model
)
model
=
model
.
cuda
()
zero_state_dict
=
zero_model
.
state_dict
()
for
key
,
val
in
model
.
state_dict
().
items
():
assert
torch
.
equal
(
val
,
zero_state_dict
[
key
])
...
...
tests/test_zero_data_parallel/test_zero_engine.py
View file @
57567ee7
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
copy
from
functools
import
partial
from
colossalai.zero.sharded_model.sharded_model_v2
import
ShardedModelV2
import
pytest
import
colossalai
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
free_port
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
colossalai.zero.sharded_model.utils
import
col_model_deepcopy
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
tests.
com
p
on
ents_to_test.registry
import
non_distributed_component_funcs
from
common
import
check_sharded_params_padding
,
ZERO_PARALLEL_CONFIG
,
MP_PARALLEL_CONFIG
,
check_params
from
com
m
on
import
(
MP_PARALLEL_CONFIG
,
ZERO_PARALLEL_CONFIG
,
check_params
,
check_sharded_params_padding
)
def
run_dist
(
rank
,
world_size
,
port
,
parallel_config
):
...
...
@@ -30,10 +33,16 @@ def run_dist(rank, world_size, port, parallel_config):
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
optimizer_class
,
criterion
=
get_components_func
()
with
ZeroInitContext
(
convert_fp16
=
hasattr
(
gpc
.
config
,
'fp16'
),
target_device
=
torch
.
cuda
.
current_device
(),
shard_strategy
=
gpc
.
config
.
zero
.
model_config
.
shared_strategy
(
gpc
.
get_group
(
ParallelMode
.
DATA
)),
shard_param
=
True
):
colo_model
=
model_builder
(
checkpoint
=
True
)
torch_model
=
copy
.
deepcopy
(
colo_model
).
cuda
()
torch_model
.
train
()
torch_model
=
model_builder
(
checkpoint
=
True
).
half
()
col_model_deepcopy
(
colo_model
,
torch_model
)
torch_model
=
torch_model
.
cuda
().
float
()
engine
,
train_dataloader
,
_
,
_
=
colossalai
.
initialize
(
colo_model
,
optimizer
=
optimizer_class
,
criterion
=
criterion
,
...
...
@@ -82,6 +91,10 @@ def run_dist(rank, world_size, port, parallel_config):
check_sharded_params_padding
(
torch_model
,
colo_model
,
loose
=
True
)
# FIXME: enable this test in next PR
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
])
def
test_mp_engine
(
world_size
):
...
...
@@ -89,6 +102,7 @@ def test_mp_engine(world_size):
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
def
test_zero_engine
(
world_size
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment