Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
640a6cd3
Unverified
Commit
640a6cd3
authored
Mar 16, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 16, 2022
Browse files
[refactory] refactory the initialize method for new zero design (#431)
parent
4f85b687
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
184 additions
and
24 deletions
+184
-24
colossalai/initialize.py
colossalai/initialize.py
+34
-21
colossalai/zero/__init__.py
colossalai/zero/__init__.py
+51
-0
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+3
-0
tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py
...est_zero_data_parallel/test_sharded_optim_with_sync_bn.py
+4
-3
tests/test_zero_data_parallel/test_zero_init_v2.py
tests/test_zero_data_parallel/test_zero_init_v2.py
+92
-0
No files found.
colossalai/initialize.py
View file @
640a6cd3
...
...
@@ -5,7 +5,7 @@ import argparse
import
os
import
pprint
from
pathlib
import
Path
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
Type
import
torch
import
torch.nn
as
nn
...
...
@@ -26,8 +26,9 @@ from colossalai.logging import get_dist_logger
from
colossalai.nn.optimizer.colossalai_optimizer
import
ColossalaiOptimizer
from
colossalai.utils
import
(
accumulate_gradient
,
get_current_device
,
is_using_ddp
,
is_using_pp
,
is_using_sequence
,
sync_model_param
)
from
colossalai.zero
import
convert_to_zero
,
ShardedOptimizer
from
colossalai.zero
import
convert_to_zero
_v2
from
colossalai.engine.ophooks
import
BaseOpHook
from
colossalai.zero.sharded_optim.sharded_optim_v2
import
ShardedOptimizerV2
def
get_default_parser
():
...
...
@@ -216,8 +217,8 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
verbose
=
verbose
)
def
initialize
(
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
def
initialize
(
model
:
Union
[
Callable
,
nn
.
Module
]
,
optimizer
:
Union
[
Type
[
Optimizer
],
Optimizer
]
,
criterion
:
Optional
[
_Loss
]
=
None
,
train_dataloader
:
Optional
[
Iterable
]
=
None
,
test_dataloader
:
Optional
[
Iterable
]
=
None
,
...
...
@@ -227,10 +228,10 @@ def initialize(model: nn.Module,
"""Core function to wrap the essential training components with our functionality based on the config which is
loaded into gpc.config.
:param model: Your model instance
:type model: :class:`torch.nn.Module`
:param model: Your model instance
or a function to build the model
:type model: :class:`torch.nn.Module`
or Callbale
:param optimizer: Your optimizer instance
:type optimizer: :class:`torch.optim.optimizer.Optimizer`
:type optimizer: :class:`torch.optim.optimizer.Optimizer`
or :class:`Type[torch.optim.optimizer]`
:param criterion: Your criterion instance
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional
:param train_dataloader: Dataloader for training
...
...
@@ -267,10 +268,28 @@ def initialize(model: nn.Module,
if
verbose
:
logger
.
info
(
f
"cuDNN benchmark =
{
cudnn_benchmark
}
, deterministic =
{
cudnn_deterministic
}
"
,
ranks
=
[
0
])
use_zero
=
hasattr
(
gpc
.
config
,
'zero'
)
if
use_zero
:
zero_cfg
=
gpc
.
config
.
get
(
'zero'
,
None
)
if
zero_cfg
is
not
None
:
cfg_
=
zero_cfg
.
copy
()
else
:
cfg_
=
{}
optimizer_config
=
zero_cfg
.
get
(
'optimzer'
,
None
)
model
,
optimizer
=
convert_to_zero_v2
(
model_builder
=
model
,
optimizer_config
=
optimizer_config
)
logger
.
info
(
"Initializing ZeRO model and optimzer finished!"
,
ranks
=
[
0
])
#FIXME() throw a warning if using zero with MP
if
gpc
.
get_world_size
(
ParallelMode
.
MODEL
)
>
1
:
logger
.
warning
(
"ZeRO currently has not been tested with model parallelism."
,
ranks
=
[
0
])
else
:
if
isinstance
(
model
,
nn
.
Module
):
# first sync model across dp ranks
model
.
to
(
get_current_device
())
use_zero3
=
hasattr
(
gpc
.
config
,
'zero'
)
and
gpc
.
config
.
zero
.
level
==
3
if
not
moe_env
.
is_initialized
()
and
not
use_zero3
:
elif
isinstance
(
model
,
Callable
):
model
=
model
().
to
(
get_current_device
())
if
not
moe_env
.
is_initialized
()
and
not
use_zero
:
if
is_using_sequence
():
sync_model_param
(
model
,
ParallelMode
.
SEQUENCE_DP
)
elif
is_using_ddp
():
...
...
@@ -283,16 +302,15 @@ def initialize(model: nn.Module,
# check amp and zero
fp16_cfg
=
gpc
.
config
.
get
(
'fp16'
,
None
)
zero_cfg
=
gpc
.
config
.
get
(
'zero'
,
None
)
if
fp16_cfg
is
not
None
and
fp16_cfg
.
mode
is
not
None
and
zero
_cfg
is
not
None
:
if
fp16_cfg
is
not
None
and
fp16_cfg
.
mode
is
not
None
and
use_
zero
:
raise
ConfigException
(
"It is not allowed to set fp16 and zero configuration in your config file at the same time"
)
# clip grad norm
clip_grad_norm
=
gpc
.
config
.
get
(
'clip_grad_norm'
,
0.0
)
if
clip_grad_norm
>
0
:
if
zero_cfg
is
not
None
:
if
use_zero
and
zero_cfg
is
not
None
:
raise
ConfigException
(
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration"
)
...
...
@@ -311,11 +329,6 @@ def initialize(model: nn.Module,
mode
=
amp_mode
,
amp_config
=
cfg_
)
if
zero_cfg
is
not
None
:
cfg_
=
zero_cfg
.
copy
()
level
=
cfg_
.
pop
(
'level'
)
model
,
optimizer
=
convert_to_zero
(
model
=
model
,
optimizer
=
optimizer
,
level
=
level
,
zero_config
=
cfg_
)
# gradient handler
gradient_handler_cfg
=
gpc
.
config
.
get
(
'gradient_handler'
,
None
)
if
gradient_handler_cfg
is
None
:
...
...
@@ -324,7 +337,7 @@ def initialize(model: nn.Module,
# 1. if optimizer is ZERO, then use zero grad handler
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
if
isinstance
(
optimizer
,
ShardedOptimizer
):
if
isinstance
(
optimizer
,
ShardedOptimizer
V2
):
gradient_handler_cfg
=
[
dict
(
type
=
'ZeROGradientHandler'
)]
if
verbose
:
logger
.
info
(
...
...
@@ -392,7 +405,7 @@ def initialize(model: nn.Module,
gradient_handlers
=
[
build_gradient_handler
(
cfg
,
model
,
optimizer
)
for
cfg
in
gradient_handler_cfg
]
# check if optimizer is ColossalaiOptimizer
if
not
isinstance
(
optimizer
,
(
ColossalaiOptimizer
,
ShardedOptimizer
)):
if
not
isinstance
(
optimizer
,
(
ColossalaiOptimizer
,
ShardedOptimizer
V2
)):
optimizer
=
ColossalaiOptimizer
(
optim
=
optimizer
)
# gradient accumulation
...
...
colossalai/zero/__init__.py
View file @
640a6cd3
from
asyncio.log
import
logger
from
distutils.command.config
import
config
from
colossalai.zero.sharded_model.sharded_model_v2
import
ShardedModelV2
from
colossalai.zero.sharded_optim.sharded_optim_v2
import
ShardedOptimizerV2
from
colossalai.zero.shard_utils
import
TensorShardStrategy
import
torch
import
torch.nn
as
nn
from
colossalai.amp.naive_amp
import
NaiveAMPModel
...
...
@@ -7,6 +11,53 @@ from colossalai.core import global_context as gpc
from
torch.optim
import
Optimizer
from
.sharded_model
import
ShardedModel
from
.sharded_optim
import
ShardedOptimizer
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
typing
import
Callable
,
Type
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
def
convert_to_zero_v2
(
model_builder
:
Callable
,
optimizer_config
)
->
(
ShardedModelV2
,
ShardedOptimizerV2
):
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
:param model: Your model object
:type model: :class:`torch.nn.Module`
:param optimizer_config: Your optimizer object
:type optimizer_config: :class:`dict`
:return: (model, optimizer)
:rtype: Tuple
"""
logger
=
get_dist_logger
(
'convert_to_zero_v2'
)
# FIXME() pass shard strategy from config
shard_strategy
=
TensorShardStrategy
()
if
isinstance
(
model_builder
,
nn
.
Module
):
model
=
model_builder
elif
isinstance
(
model_builder
,
Callable
):
with
ZeroInitContext
(
convert_fp16
=
'fp16'
in
gpc
.
config
,
target_device
=
torch
.
cuda
.
current_device
(),
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
model
=
model_builder
()
else
:
raise
TypeError
(
f
"convert_to_zero_v2 dose not support model_builder of type
{
type
(
convert_to_zero_v2
)
}
"
)
zero_model
=
ShardedModelV2
(
model
,
shard_strategy
=
shard_strategy
)
optimizer_class
=
optimizer_config
.
get
(
'optimizer_type'
,
None
)
if
optimizer_class
is
None
:
raise
RuntimeError
(
"Set optimizer_class in zero_config"
)
logger
.
info
(
f
'optimizer class is
{
optimizer_class
}
'
)
cfg
=
optimizer_config
.
get
(
'optimizer_config'
,
None
)
logger
.
info
(
f
'optimizer_config is
{
cfg
}
'
)
zero_optimizer
=
ShardedOptimizerV2
(
zero_model
,
optimizer_class
,
**
optimizer_config
.
get
(
'optimizer_config'
,
None
))
return
zero_model
,
zero_optimizer
def
convert_to_zero
(
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
level
:
int
,
zero_config
:
dict
):
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
640a6cd3
...
...
@@ -223,3 +223,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Because we will judge whether local grad accumulation
# is enabled by wheter grad is None
self
.
optim
.
zero_grad
(
set_to_none
=
True
)
def
sync_grad
(
self
):
pass
tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py
View file @
640a6cd3
...
...
@@ -19,7 +19,8 @@ def run_dist(rank, world_size, port):
# as this model has sync batch normalization
# need to configure cudnn deterministic so that
# randomness of convolution layers will be disabled
colossalai
.
launch
(
config
=
dict
(
zero
=
dict
(
level
=
2
,
partition_grad
=
True
),
colossalai
.
launch
(
config
=
dict
(
zero
=
dict
(
optimzer
=
dict
(
optimizer_type
=
torch
.
optim
.
Adam
,
optimizer_config
=
dict
(
lr
=
1e-3
))),
cudnn_determinstic
=
True
,
cudnn_benchmark
=
False
),
rank
=
rank
,
...
...
tests/test_zero_data_parallel/test_zero_init_v2.py
0 → 100644
View file @
640a6cd3
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
copy
from
functools
import
partial
import
pytest
import
colossalai
from
colossalai.utils
import
free_port
import
torch
import
torch.multiprocessing
as
mp
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
check_sharded_params_padding
def
run_dist
(
rank
,
world_size
,
port
):
_config
=
dict
(
fp16
=
dict
(
mode
=
None
,),
zero
=
dict
(
optimzer
=
dict
(
optimizer_type
=
torch
.
optim
.
Adam
,
optimizer_config
=
dict
(
lr
=
1e-3
)),
offload_optimizer_config
=
dict
(
device
=
'cpu'
,
pin_memory
=
True
,
buffer_count
=
5
,
fast_init
=
False
),
offload_param_config
=
dict
(
device
=
'cpu'
,
pin_memory
=
True
,
buffer_count
=
5
,
buffer_size
=
1e8
,
max_in_cpu
=
1e9
)),
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
1
,
mode
=
None
)))
colossalai
.
launch
(
config
=
_config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
# FIXME revert back
# test_models = ['repeated_computed_layers', 'resnet18', 'bert']
test_models
=
[
'bert'
]
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
optimizer_class
,
criterion
=
get_components_func
()
# adapt to a Callbale with empty parameters
# def module_builder_new():
# return model_builder(checkpoint=True)
zero_model
=
model_builder
(
checkpoint
=
True
)
torch_model
=
copy
.
deepcopy
(
zero_model
).
cuda
()
engine
,
train_dataloader
,
_
,
_
=
colossalai
.
initialize
(
zero_model
,
optimizer
=
optimizer_class
,
criterion
=
criterion
,
train_dataloader
=
train_dataloader
)
engine
.
train
()
torch_optimizer
=
optimizer_class
(
torch_model
.
parameters
())
i
=
0
for
data
,
label
in
train_dataloader
:
if
i
>
3
:
break
data
,
label
=
data
.
cuda
(),
label
.
cuda
()
engine
.
zero_grad
()
torch_optimizer
.
zero_grad
()
if
criterion
:
output
=
engine
(
data
)
loss
=
engine
.
criterion
(
output
,
label
)
torch_model
(
data
,
label
)
torch_loss
=
engine
.
criterion
(
output
,
label
)
else
:
loss
=
engine
(
data
,
label
)
torch_loss
=
torch_model
(
data
,
label
)
engine
.
backward
(
loss
)
engine
.
step
()
torch_loss
.
backward
()
torch_optimizer
.
step
()
i
+=
1
check_sharded_params_padding
(
torch_model
,
zero_model
,
loose
=
True
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
def
test_zero_init
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_zero_init
(
world_size
=
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment