Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
640a6cd3
Unverified
Commit
640a6cd3
authored
Mar 16, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 16, 2022
Browse files
[refactory] refactory the initialize method for new zero design (#431)
parent
4f85b687
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
184 additions
and
24 deletions
+184
-24
colossalai/initialize.py
colossalai/initialize.py
+34
-21
colossalai/zero/__init__.py
colossalai/zero/__init__.py
+51
-0
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+3
-0
tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py
...est_zero_data_parallel/test_sharded_optim_with_sync_bn.py
+4
-3
tests/test_zero_data_parallel/test_zero_init_v2.py
tests/test_zero_data_parallel/test_zero_init_v2.py
+92
-0
No files found.
colossalai/initialize.py
View file @
640a6cd3
...
@@ -5,7 +5,7 @@ import argparse
...
@@ -5,7 +5,7 @@ import argparse
import
os
import
os
import
pprint
import
pprint
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
Type
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -26,8 +26,9 @@ from colossalai.logging import get_dist_logger
...
@@ -26,8 +26,9 @@ from colossalai.logging import get_dist_logger
from
colossalai.nn.optimizer.colossalai_optimizer
import
ColossalaiOptimizer
from
colossalai.nn.optimizer.colossalai_optimizer
import
ColossalaiOptimizer
from
colossalai.utils
import
(
accumulate_gradient
,
get_current_device
,
is_using_ddp
,
is_using_pp
,
is_using_sequence
,
from
colossalai.utils
import
(
accumulate_gradient
,
get_current_device
,
is_using_ddp
,
is_using_pp
,
is_using_sequence
,
sync_model_param
)
sync_model_param
)
from
colossalai.zero
import
convert_to_zero
,
ShardedOptimizer
from
colossalai.zero
import
convert_to_zero
_v2
from
colossalai.engine.ophooks
import
BaseOpHook
from
colossalai.engine.ophooks
import
BaseOpHook
from
colossalai.zero.sharded_optim.sharded_optim_v2
import
ShardedOptimizerV2
def
get_default_parser
():
def
get_default_parser
():
...
@@ -216,8 +217,8 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
...
@@ -216,8 +217,8 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
verbose
=
verbose
)
verbose
=
verbose
)
def
initialize
(
model
:
nn
.
Module
,
def
initialize
(
model
:
Union
[
Callable
,
nn
.
Module
]
,
optimizer
:
Optimizer
,
optimizer
:
Union
[
Type
[
Optimizer
],
Optimizer
]
,
criterion
:
Optional
[
_Loss
]
=
None
,
criterion
:
Optional
[
_Loss
]
=
None
,
train_dataloader
:
Optional
[
Iterable
]
=
None
,
train_dataloader
:
Optional
[
Iterable
]
=
None
,
test_dataloader
:
Optional
[
Iterable
]
=
None
,
test_dataloader
:
Optional
[
Iterable
]
=
None
,
...
@@ -227,10 +228,10 @@ def initialize(model: nn.Module,
...
@@ -227,10 +228,10 @@ def initialize(model: nn.Module,
"""Core function to wrap the essential training components with our functionality based on the config which is
"""Core function to wrap the essential training components with our functionality based on the config which is
loaded into gpc.config.
loaded into gpc.config.
:param model: Your model instance
:param model: Your model instance
or a function to build the model
:type model: :class:`torch.nn.Module`
:type model: :class:`torch.nn.Module`
or Callbale
:param optimizer: Your optimizer instance
:param optimizer: Your optimizer instance
:type optimizer: :class:`torch.optim.optimizer.Optimizer`
:type optimizer: :class:`torch.optim.optimizer.Optimizer`
or :class:`Type[torch.optim.optimizer]`
:param criterion: Your criterion instance
:param criterion: Your criterion instance
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional
:param train_dataloader: Dataloader for training
:param train_dataloader: Dataloader for training
...
@@ -267,10 +268,28 @@ def initialize(model: nn.Module,
...
@@ -267,10 +268,28 @@ def initialize(model: nn.Module,
if
verbose
:
if
verbose
:
logger
.
info
(
f
"cuDNN benchmark =
{
cudnn_benchmark
}
, deterministic =
{
cudnn_deterministic
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
"cuDNN benchmark =
{
cudnn_benchmark
}
, deterministic =
{
cudnn_deterministic
}
"
,
ranks
=
[
0
])
# first sync model across dp ranks
use_zero
=
hasattr
(
gpc
.
config
,
'zero'
)
model
.
to
(
get_current_device
())
if
use_zero
:
use_zero3
=
hasattr
(
gpc
.
config
,
'zero'
)
and
gpc
.
config
.
zero
.
level
==
3
zero_cfg
=
gpc
.
config
.
get
(
'zero'
,
None
)
if
not
moe_env
.
is_initialized
()
and
not
use_zero3
:
if
zero_cfg
is
not
None
:
cfg_
=
zero_cfg
.
copy
()
else
:
cfg_
=
{}
optimizer_config
=
zero_cfg
.
get
(
'optimzer'
,
None
)
model
,
optimizer
=
convert_to_zero_v2
(
model_builder
=
model
,
optimizer_config
=
optimizer_config
)
logger
.
info
(
"Initializing ZeRO model and optimzer finished!"
,
ranks
=
[
0
])
#FIXME() throw a warning if using zero with MP
if
gpc
.
get_world_size
(
ParallelMode
.
MODEL
)
>
1
:
logger
.
warning
(
"ZeRO currently has not been tested with model parallelism."
,
ranks
=
[
0
])
else
:
if
isinstance
(
model
,
nn
.
Module
):
# first sync model across dp ranks
model
.
to
(
get_current_device
())
elif
isinstance
(
model
,
Callable
):
model
=
model
().
to
(
get_current_device
())
if
not
moe_env
.
is_initialized
()
and
not
use_zero
:
if
is_using_sequence
():
if
is_using_sequence
():
sync_model_param
(
model
,
ParallelMode
.
SEQUENCE_DP
)
sync_model_param
(
model
,
ParallelMode
.
SEQUENCE_DP
)
elif
is_using_ddp
():
elif
is_using_ddp
():
...
@@ -283,16 +302,15 @@ def initialize(model: nn.Module,
...
@@ -283,16 +302,15 @@ def initialize(model: nn.Module,
# check amp and zero
# check amp and zero
fp16_cfg
=
gpc
.
config
.
get
(
'fp16'
,
None
)
fp16_cfg
=
gpc
.
config
.
get
(
'fp16'
,
None
)
zero_cfg
=
gpc
.
config
.
get
(
'zero'
,
None
)
if
fp16_cfg
is
not
None
and
fp16_cfg
.
mode
is
not
None
and
zero
_cfg
is
not
None
:
if
fp16_cfg
is
not
None
and
fp16_cfg
.
mode
is
not
None
and
use_
zero
:
raise
ConfigException
(
raise
ConfigException
(
"It is not allowed to set fp16 and zero configuration in your config file at the same time"
)
"It is not allowed to set fp16 and zero configuration in your config file at the same time"
)
# clip grad norm
# clip grad norm
clip_grad_norm
=
gpc
.
config
.
get
(
'clip_grad_norm'
,
0.0
)
clip_grad_norm
=
gpc
.
config
.
get
(
'clip_grad_norm'
,
0.0
)
if
clip_grad_norm
>
0
:
if
clip_grad_norm
>
0
:
if
zero_cfg
is
not
None
:
if
use_zero
and
zero_cfg
is
not
None
:
raise
ConfigException
(
raise
ConfigException
(
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration"
)
"clip_grad_norm should be specified with zero, you should specify clip_grad in zero configuration"
)
...
@@ -311,11 +329,6 @@ def initialize(model: nn.Module,
...
@@ -311,11 +329,6 @@ def initialize(model: nn.Module,
mode
=
amp_mode
,
mode
=
amp_mode
,
amp_config
=
cfg_
)
amp_config
=
cfg_
)
if
zero_cfg
is
not
None
:
cfg_
=
zero_cfg
.
copy
()
level
=
cfg_
.
pop
(
'level'
)
model
,
optimizer
=
convert_to_zero
(
model
=
model
,
optimizer
=
optimizer
,
level
=
level
,
zero_config
=
cfg_
)
# gradient handler
# gradient handler
gradient_handler_cfg
=
gpc
.
config
.
get
(
'gradient_handler'
,
None
)
gradient_handler_cfg
=
gpc
.
config
.
get
(
'gradient_handler'
,
None
)
if
gradient_handler_cfg
is
None
:
if
gradient_handler_cfg
is
None
:
...
@@ -324,7 +337,7 @@ def initialize(model: nn.Module,
...
@@ -324,7 +337,7 @@ def initialize(model: nn.Module,
# 1. if optimizer is ZERO, then use zero grad handler
# 1. if optimizer is ZERO, then use zero grad handler
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
if
isinstance
(
optimizer
,
ShardedOptimizer
):
if
isinstance
(
optimizer
,
ShardedOptimizer
V2
):
gradient_handler_cfg
=
[
dict
(
type
=
'ZeROGradientHandler'
)]
gradient_handler_cfg
=
[
dict
(
type
=
'ZeROGradientHandler'
)]
if
verbose
:
if
verbose
:
logger
.
info
(
logger
.
info
(
...
@@ -392,7 +405,7 @@ def initialize(model: nn.Module,
...
@@ -392,7 +405,7 @@ def initialize(model: nn.Module,
gradient_handlers
=
[
build_gradient_handler
(
cfg
,
model
,
optimizer
)
for
cfg
in
gradient_handler_cfg
]
gradient_handlers
=
[
build_gradient_handler
(
cfg
,
model
,
optimizer
)
for
cfg
in
gradient_handler_cfg
]
# check if optimizer is ColossalaiOptimizer
# check if optimizer is ColossalaiOptimizer
if
not
isinstance
(
optimizer
,
(
ColossalaiOptimizer
,
ShardedOptimizer
)):
if
not
isinstance
(
optimizer
,
(
ColossalaiOptimizer
,
ShardedOptimizer
V2
)):
optimizer
=
ColossalaiOptimizer
(
optim
=
optimizer
)
optimizer
=
ColossalaiOptimizer
(
optim
=
optimizer
)
# gradient accumulation
# gradient accumulation
...
...
colossalai/zero/__init__.py
View file @
640a6cd3
from
asyncio.log
import
logger
from
distutils.command.config
import
config
from
distutils.command.config
import
config
from
colossalai.zero.sharded_model.sharded_model_v2
import
ShardedModelV2
from
colossalai.zero.sharded_optim.sharded_optim_v2
import
ShardedOptimizerV2
from
colossalai.zero.shard_utils
import
TensorShardStrategy
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.amp.naive_amp
import
NaiveAMPModel
from
colossalai.amp.naive_amp
import
NaiveAMPModel
...
@@ -7,6 +11,53 @@ from colossalai.core import global_context as gpc
...
@@ -7,6 +11,53 @@ from colossalai.core import global_context as gpc
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
.sharded_model
import
ShardedModel
from
.sharded_model
import
ShardedModel
from
.sharded_optim
import
ShardedOptimizer
from
.sharded_optim
import
ShardedOptimizer
from
colossalai.zero.init_ctx
import
ZeroInitContext
from
typing
import
Callable
,
Type
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
def
convert_to_zero_v2
(
model_builder
:
Callable
,
optimizer_config
)
->
(
ShardedModelV2
,
ShardedOptimizerV2
):
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
:param model: Your model object
:type model: :class:`torch.nn.Module`
:param optimizer_config: Your optimizer object
:type optimizer_config: :class:`dict`
:return: (model, optimizer)
:rtype: Tuple
"""
logger
=
get_dist_logger
(
'convert_to_zero_v2'
)
# FIXME() pass shard strategy from config
shard_strategy
=
TensorShardStrategy
()
if
isinstance
(
model_builder
,
nn
.
Module
):
model
=
model_builder
elif
isinstance
(
model_builder
,
Callable
):
with
ZeroInitContext
(
convert_fp16
=
'fp16'
in
gpc
.
config
,
target_device
=
torch
.
cuda
.
current_device
(),
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
model
=
model_builder
()
else
:
raise
TypeError
(
f
"convert_to_zero_v2 dose not support model_builder of type
{
type
(
convert_to_zero_v2
)
}
"
)
zero_model
=
ShardedModelV2
(
model
,
shard_strategy
=
shard_strategy
)
optimizer_class
=
optimizer_config
.
get
(
'optimizer_type'
,
None
)
if
optimizer_class
is
None
:
raise
RuntimeError
(
"Set optimizer_class in zero_config"
)
logger
.
info
(
f
'optimizer class is
{
optimizer_class
}
'
)
cfg
=
optimizer_config
.
get
(
'optimizer_config'
,
None
)
logger
.
info
(
f
'optimizer_config is
{
cfg
}
'
)
zero_optimizer
=
ShardedOptimizerV2
(
zero_model
,
optimizer_class
,
**
optimizer_config
.
get
(
'optimizer_config'
,
None
))
return
zero_model
,
zero_optimizer
def
convert_to_zero
(
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
level
:
int
,
zero_config
:
dict
):
def
convert_to_zero
(
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
level
:
int
,
zero_config
:
dict
):
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
640a6cd3
...
@@ -223,3 +223,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -223,3 +223,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Because we will judge whether local grad accumulation
# Because we will judge whether local grad accumulation
# is enabled by wheter grad is None
# is enabled by wheter grad is None
self
.
optim
.
zero_grad
(
set_to_none
=
True
)
self
.
optim
.
zero_grad
(
set_to_none
=
True
)
def
sync_grad
(
self
):
pass
tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py
View file @
640a6cd3
...
@@ -19,9 +19,10 @@ def run_dist(rank, world_size, port):
...
@@ -19,9 +19,10 @@ def run_dist(rank, world_size, port):
# as this model has sync batch normalization
# as this model has sync batch normalization
# need to configure cudnn deterministic so that
# need to configure cudnn deterministic so that
# randomness of convolution layers will be disabled
# randomness of convolution layers will be disabled
colossalai
.
launch
(
config
=
dict
(
zero
=
dict
(
level
=
2
,
partition_grad
=
True
),
colossalai
.
launch
(
config
=
dict
(
cudnn_determinstic
=
True
,
zero
=
dict
(
optimzer
=
dict
(
optimizer_type
=
torch
.
optim
.
Adam
,
optimizer_config
=
dict
(
lr
=
1e-3
))),
cudnn_benchmark
=
False
),
cudnn_determinstic
=
True
,
cudnn_benchmark
=
False
),
rank
=
rank
,
rank
=
rank
,
world_size
=
world_size
,
world_size
=
world_size
,
host
=
'localhost'
,
host
=
'localhost'
,
...
...
tests/test_zero_data_parallel/test_zero_init_v2.py
0 → 100644
View file @
640a6cd3
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
copy
from
functools
import
partial
import
pytest
import
colossalai
from
colossalai.utils
import
free_port
import
torch
import
torch.multiprocessing
as
mp
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
common
import
check_sharded_params_padding
def
run_dist
(
rank
,
world_size
,
port
):
_config
=
dict
(
fp16
=
dict
(
mode
=
None
,),
zero
=
dict
(
optimzer
=
dict
(
optimizer_type
=
torch
.
optim
.
Adam
,
optimizer_config
=
dict
(
lr
=
1e-3
)),
offload_optimizer_config
=
dict
(
device
=
'cpu'
,
pin_memory
=
True
,
buffer_count
=
5
,
fast_init
=
False
),
offload_param_config
=
dict
(
device
=
'cpu'
,
pin_memory
=
True
,
buffer_count
=
5
,
buffer_size
=
1e8
,
max_in_cpu
=
1e9
)),
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
1
,
mode
=
None
)))
colossalai
.
launch
(
config
=
_config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
# FIXME revert back
# test_models = ['repeated_computed_layers', 'resnet18', 'bert']
test_models
=
[
'bert'
]
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
optimizer_class
,
criterion
=
get_components_func
()
# adapt to a Callbale with empty parameters
# def module_builder_new():
# return model_builder(checkpoint=True)
zero_model
=
model_builder
(
checkpoint
=
True
)
torch_model
=
copy
.
deepcopy
(
zero_model
).
cuda
()
engine
,
train_dataloader
,
_
,
_
=
colossalai
.
initialize
(
zero_model
,
optimizer
=
optimizer_class
,
criterion
=
criterion
,
train_dataloader
=
train_dataloader
)
engine
.
train
()
torch_optimizer
=
optimizer_class
(
torch_model
.
parameters
())
i
=
0
for
data
,
label
in
train_dataloader
:
if
i
>
3
:
break
data
,
label
=
data
.
cuda
(),
label
.
cuda
()
engine
.
zero_grad
()
torch_optimizer
.
zero_grad
()
if
criterion
:
output
=
engine
(
data
)
loss
=
engine
.
criterion
(
output
,
label
)
torch_model
(
data
,
label
)
torch_loss
=
engine
.
criterion
(
output
,
label
)
else
:
loss
=
engine
(
data
,
label
)
torch_loss
=
torch_model
(
data
,
label
)
engine
.
backward
(
loss
)
engine
.
step
()
torch_loss
.
backward
()
torch_optimizer
.
step
()
i
+=
1
check_sharded_params_padding
(
torch_model
,
zero_model
,
loose
=
True
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
,
2
])
def
test_zero_init
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_zero_init
(
world_size
=
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment