Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e17e54e3
Commit
e17e54e3
authored
Mar 02, 2022
by
Frank Lee
Browse files
added buffer sync to naive amp model wrapper (#291)
parent
8d653af4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
191 additions
and
46 deletions
+191
-46
colossalai/amp/naive_amp/naive_amp.py
colossalai/amp/naive_amp/naive_amp.py
+67
-5
colossalai/initialize.py
colossalai/initialize.py
+32
-35
colossalai/zero/__init__.py
colossalai/zero/__init__.py
+8
-6
tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py
...est_zero_data_parallel/test_sharded_optim_with_sync_bn.py
+84
-0
No files found.
colossalai/amp/naive_amp/naive_amp.py
View file @
e17e54e3
...
...
@@ -3,12 +3,15 @@
import
torch
import
torch.nn
as
nn
import
torch.distributed
as
dist
from
torch
import
Tensor
from
typing
import
Union
,
List
,
Any
,
Dict
from
typing
import
Any
from
torch.optim
import
Optimizer
import
torch.cuda.amp
as
torch_amp
from
torch.distributed
import
ReduceOp
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
._fp16_optimizer
import
FP16Optimizer
...
...
@@ -43,16 +46,36 @@ class NaiveAMPOptimizer(ColossalaiOptimizer):
class
NaiveAMPModel
(
nn
.
Module
):
"""A wrapper class for model to cast the model into fp16 and
"""A wrapper class for model to cast the model into fp16 and
automatically cast the input and output
"""
def
__init__
(
self
,
model
:
nn
.
Module
,
output_to_fp32
:
bool
=
True
):
output_to_fp32
:
bool
=
True
,
parallel_mode
:
ParallelMode
=
ParallelMode
.
DATA
,
sync_buffer
:
bool
=
True
):
super
().
__init__
()
self
.
model
=
model
.
half
()
self
.
_output_to_fp32
=
output_to_fp32
self
.
_sync_buf
=
sync_buffer
if
gpc
.
is_initialized
(
parallel_mode
)
and
gpc
.
get_world_size
(
parallel_mode
)
>
1
:
self
.
_process_group
=
gpc
.
get_group
(
parallel_mode
)
self
.
_world_size
=
gpc
.
get_world_size
(
parallel_mode
)
else
:
self
.
_process_group
=
None
self
.
_world_size
=
1
self
.
_sync_buf
=
False
self
.
_first_eval_run
=
False
@
property
def
sync_buffer
(
self
):
return
self
.
_sync_buf
@
sync_buffer
.
setter
def
sync_buffer
(
self
,
state
:
bool
):
self
.
_sync_buf
=
state
def
_convert_to_fp16
(
self
,
input_
:
Any
):
if
isinstance
(
input_
,
Tensor
)
and
input_
.
dtype
==
torch
.
float32
:
...
...
@@ -64,7 +87,46 @@ class NaiveAMPModel(nn.Module):
input_
=
input_
.
float
()
return
input_
def
_reduce_module_buffer
(
self
):
"""
All-reduce the buffers (e.g. running stats of batch normalization) across
data parallel ranks so that all the ranks will produce consistent results
when given the same input
"""
buf_list
=
[]
# find valid buffers
for
buf
in
self
.
model
.
buffers
():
if
buf
is
not
None
:
buf_list
.
append
(
buf
)
# reduce buffers across data parallel ranks
if
buf_list
:
coalesced_buf
=
_flatten_dense_tensors
(
buf_list
)
coalesced_buf
.
div_
(
self
.
_world_size
)
dist
.
all_reduce
(
coalesced_buf
,
op
=
ReduceOp
.
SUM
,
group
=
self
.
_process_group
)
unflattened_buf_list
=
_unflatten_dense_tensors
(
coalesced_buf
,
buf_list
)
for
old
,
new
in
zip
(
buf_list
,
unflattened_buf_list
):
old
.
copy_
(
new
)
def
eval
(
self
):
self
.
model
.
eval
()
# we only sync buffer in the first eval iteration
# so that future eval iterations can be done without communication
self
.
_first_eval_run
=
True
def
forward
(
self
,
*
args
,
**
kwargs
):
# reduce buffers after forward will lead to error
# as we cannot change the variables needed for gradient computation after forward
# so we sync buffer before forward
if
(
self
.
training
or
self
.
_first_eval_run
)
and
self
.
_sync_buf
:
with
torch
.
no_grad
():
self
.
_reduce_module_buffer
()
if
self
.
_first_eval_run
:
self
.
_first_eval_run
=
False
if
args
:
args
=
[
self
.
_convert_to_fp16
(
arg
)
for
arg
in
args
]
if
kwargs
:
...
...
colossalai/initialize.py
View file @
e17e54e3
...
...
@@ -16,6 +16,7 @@ from torch.optim.optimizer import Optimizer
from
torch.utils.data
import
DataLoader
from
colossalai.amp
import
AMP_TYPE
,
convert_to_amp
from
colossalai.amp.naive_amp
import
NaiveAMPModel
from
colossalai.builder.builder
import
build_gradient_handler
from
colossalai.context
import
Config
,
ConfigException
,
ParallelMode
from
colossalai.core
import
global_context
as
gpc
...
...
@@ -23,8 +24,7 @@ from colossalai.engine import Engine
from
colossalai.global_variables
import
moe_env
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer.colossalai_optimizer
import
ColossalaiOptimizer
from
colossalai.utils
import
(
accumulate_gradient
,
get_current_device
,
is_using_ddp
,
is_using_pp
,
is_using_sequence
,
from
colossalai.utils
import
(
accumulate_gradient
,
get_current_device
,
is_using_ddp
,
is_using_pp
,
is_using_sequence
,
sync_model_param
)
from
colossalai.zero
import
convert_to_zero
,
ShardedOptimizer
from
colossalai.engine.ophooks
import
register_ophooks_recursively
,
BaseOpHook
...
...
@@ -39,21 +39,12 @@ def get_default_parser():
"""
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--config'
,
type
=
str
,
help
=
'path to the config file'
)
parser
.
add_argument
(
'--host'
,
type
=
str
,
help
=
'the master address for distributed training'
)
parser
.
add_argument
(
'--port'
,
type
=
int
,
help
=
'the master port for distributed training'
)
parser
.
add_argument
(
'--host'
,
type
=
str
,
help
=
'the master address for distributed training'
)
parser
.
add_argument
(
'--port'
,
type
=
int
,
help
=
'the master port for distributed training'
)
parser
.
add_argument
(
'--world_size'
,
type
=
int
,
help
=
'world size for distributed training'
)
parser
.
add_argument
(
'--rank'
,
type
=
int
,
help
=
'rank for the default process group'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
help
=
'local rank on the node'
)
parser
.
add_argument
(
'--backend'
,
type
=
str
,
default
=
'nccl'
,
help
=
'backend for distributed communication'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
help
=
'local rank on the node'
)
parser
.
add_argument
(
'--backend'
,
type
=
str
,
default
=
'nccl'
,
help
=
'backend for distributed communication'
)
return
parser
...
...
@@ -116,9 +107,11 @@ def launch(config: Union[str, Path, Config, Dict],
if
verbose
:
logger
=
get_dist_logger
()
logger
.
info
(
f
'Distributed environment is initialized, '
f
'data parallel size:
{
gpc
.
data_parallel_size
}
, pipeline parallel size:
{
gpc
.
pipeline_parallel_size
}
, '
f
'tensor parallel size:
{
gpc
.
tensor_parallel_size
}
'
,
ranks
=
[
0
])
logger
.
info
(
f
'Distributed environment is initialized, '
f
'data parallel size:
{
gpc
.
data_parallel_size
}
, pipeline parallel size:
{
gpc
.
pipeline_parallel_size
}
, '
f
'tensor parallel size:
{
gpc
.
tensor_parallel_size
}
'
,
ranks
=
[
0
])
def
launch_from_slurm
(
config
:
Union
[
str
,
Path
,
Config
,
Dict
],
...
...
@@ -261,9 +254,11 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
# print config
if
verbose
:
logger
.
info
(
f
"
\n
========== Your Config ========
\n
"
f
"
{
pprint
.
pformat
(
gpc
.
config
)
}
\n
"
f
"================================
\n
"
,
ranks
=
[
0
])
logger
.
info
(
f
"
\n
========== Your Config ========
\n
"
f
"
{
pprint
.
pformat
(
gpc
.
config
)
}
\n
"
f
"================================
\n
"
,
ranks
=
[
0
])
# cudnn
cudnn_benchmark
=
config
.
get
(
'cudnn_benchmark'
,
True
)
...
...
@@ -271,8 +266,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
torch
.
backends
.
cudnn
.
benchmark
=
cudnn_benchmark
torch
.
backends
.
cudnn
.
deterministic
=
cudnn_deterministic
if
verbose
:
logger
.
info
(
f
"cuDNN benchmark =
{
cudnn_benchmark
}
, deterministic =
{
cudnn_deterministic
}
"
,
ranks
=
[
0
])
logger
.
info
(
f
"cuDNN benchmark =
{
cudnn_benchmark
}
, deterministic =
{
cudnn_deterministic
}
"
,
ranks
=
[
0
])
# first sync model across dp ranks
model
.
to
(
get_current_device
())
...
...
@@ -321,11 +315,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
if
zero_cfg
is
not
None
:
cfg_
=
zero_cfg
.
copy
()
level
=
cfg_
.
pop
(
'level'
)
model
,
optimizer
=
convert_to_zero
(
model
=
model
,
optimizer
=
optimizer
,
level
=
level
,
zero_config
=
cfg_
)
model
,
optimizer
=
convert_to_zero
(
model
=
model
,
optimizer
=
optimizer
,
level
=
level
,
zero_config
=
cfg_
)
# gradient handler
gradient_handler_cfg
=
gpc
.
config
.
get
(
'gradient_handler'
,
None
)
...
...
@@ -350,21 +340,22 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
"added even though not specified in the configuration"
,
ranks
=
[
0
])
elif
is_using_sequence
():
model
=
DDP
(
model
,
process_group
=
gpc
.
get_group
(
ParallelMode
.
SEQUENCE_DP
),
model
=
DDP
(
model
,
process_group
=
gpc
.
get_group
(
ParallelMode
.
SEQUENCE_DP
),
device_ids
=
[
torch
.
cuda
.
current_device
()])
if
verbose
:
logger
.
info
(
'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism'
,
ranks
=
[
0
])
logger
.
info
(
'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism'
,
ranks
=
[
0
])
elif
is_using_ddp
()
and
not
is_using_pp
()
and
amp_mode
!=
AMP_TYPE
.
NAIVE
:
model
=
DDP
(
model
,
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
),
device_ids
=
[
torch
.
cuda
.
current_device
()])
if
verbose
:
logger
.
info
(
'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism'
,
ranks
=
[
0
])
logger
.
info
(
'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism'
,
ranks
=
[
0
])
elif
is_using_ddp
():
gradient_handler_cfg
=
[
dict
(
type
=
'DataParallelGradientHandler'
)]
if
verbose
:
logger
.
info
(
"Data parallel training is detected when using pipeline parallel, DataParallelGradientHandler is automatically "
"Data parallel training is detected when using pipeline parallel, "
"DataParallelGradientHandler is automatically "
"added even though not specified in the configuration"
,
ranks
=
[
0
])
# add pipeline parallel gradient handler, if pipeline shared module is detected
...
...
@@ -383,7 +374,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
else
:
if
not
isinstance
(
gradient_handler_cfg
,
list
):
raise
ConfigException
(
f
"expected gradient_handler in the configuration file to be a list but got
{
type
(
gradient_handler_cfg
)
}
"
)
f
"expected gradient_handler in the configuration file to be a list but got
{
type
(
gradient_handler_cfg
)
}
"
)
# turn off sync buffer for NaiveAMPModel if using torch DDP and NaiveAMPModel at the same time
# to avoid duplicated buffer synchronization
if
isinstance
(
model
,
DDP
)
and
isinstance
(
model
.
module
,
NaiveAMPModel
):
model
.
module
.
sync_buffer
=
False
if
gradient_handler_cfg
is
None
:
gradient_handlers
=
None
...
...
colossalai/zero/__init__.py
View file @
e17e54e3
...
...
@@ -9,10 +9,7 @@ from .sharded_model import ShardedModel
from
.sharded_optim
import
ShardedOptimizer
def
convert_to_zero
(
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
level
:
int
,
zero_config
:
dict
):
def
convert_to_zero
(
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
level
:
int
,
zero_config
:
dict
):
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
...
...
@@ -31,11 +28,16 @@ def convert_to_zero(model: nn.Module,
assert
1
<=
level
<=
3
,
'Only ZERO Optimizer Level 1-3 are provided'
if
level
in
[
1
,
2
]:
if
level
==
2
:
assert
config
[
'partition_grad'
],
'ZeRO Optimizer requires partition_grad to be True'
if
'partition_grad'
in
zero_config
:
assert
zero_config
[
'partition_grad'
],
\
'Sharded Optimizer requires partition_grad to be True'
else
:
zero_config
[
'partiton_grad'
]
=
True
model
=
NaiveAMPModel
(
model
,
output_to_fp32
=
True
)
optimizer
=
ShardedOptimizer
(
model
.
parameters
()
,
*
zero_config
)
optimizer
=
ShardedOptimizer
(
optimizer
,
*
*
zero_config
)
else
:
model
=
ShardedModel
(
module
=
model
,
**
zero_config
)
return
model
,
optimizer
__all__
=
[
'convert_to_zero'
,
'ShardedModel'
,
'ShardedOptimizer'
]
tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py
0 → 100644
View file @
e17e54e3
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.parallel_mode
import
ParallelMode
from
torchvision.models
import
resnet50
import
torch.distributed
as
dist
def
run_dist
(
rank
,
world_size
,
port
):
# need to configure cudnn deterministic so that
# randomness of convolution layers will be disabled
colossalai
.
launch
(
config
=
dict
(
zero
=
dict
(
level
=
2
,
partition_grad
=
True
),
cudnn_determinstic
=
True
,
cudnn_benchmark
=
False
),
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
resnet50
()
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
engine
,
*
args
=
colossalai
.
initialize
(
model
,
optimizer
,
criterion
)
# train for dummy iterations
engine
.
train
()
for
_
in
range
(
2
):
data
=
torch
.
rand
(
4
,
3
,
128
,
128
).
cuda
().
half
()
label
=
torch
.
randint
(
0
,
10
,
size
=
(
4
,)).
cuda
()
engine
.
zero_grad
()
out
=
engine
(
data
)
loss
=
engine
.
criterion
(
out
,
label
)
engine
.
backward
(
loss
)
engine
.
step
()
# test
# need to make sure the batch norm stats are synchronized
# so that given the same input, the model will produce the same
# output on different ranks
engine
.
eval
()
data
=
torch
.
rand
(
4
,
3
,
128
,
128
).
cuda
().
half
()
dist
.
broadcast
(
data
,
src
=
0
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
# predict
out
=
engine
(
data
)
# test if results are equal
tensor_list
=
[
torch
.
empty_like
(
out
)
for
_
in
range
(
world_size
-
1
)]
tensor_list
.
insert
(
rank
,
out
)
dist
.
all_gather
(
tensor_list
=
tensor_list
,
tensor
=
out
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
assert
torch
.
all
(
tensor_list
[
0
]
==
tensor_list
[
1
]),
\
'expected the output from different ranks to be the same, but got different values'
@
pytest
.
mark
.
dist
def
test_sharded_optim_with_sync_bn
():
"""
This test is to make sure that buffers are synchronized between ranks
when using ZeRO. An example of module buffer is the running stats of
BatchNormalization layer, i.e. mean and var.
If the buffers are not synchronized, the model will produce different
output even though the input and parameters are the same. This is not
wanted if we are doing predictions.
"""
world_size
=
2
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_sharded_optim_with_sync_bn
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment