Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
79cf1b5f
Commit
79cf1b5f
authored
Jul 04, 2023
by
LuGY
Committed by
Hongxin Liu
Jul 31, 2023
Browse files
[zero]support no_sync method for zero1 plugin (#4138)
* support no sync for zero1 plugin * polish * polish
parent
c6ab9698
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
45 additions
and
49 deletions
+45
-49
colossalai/booster/booster.py
colossalai/booster/booster.py
+7
-5
colossalai/booster/plugin/gemini_plugin.py
colossalai/booster/plugin/gemini_plugin.py
+1
-1
colossalai/booster/plugin/low_level_zero_plugin.py
colossalai/booster/plugin/low_level_zero_plugin.py
+7
-3
colossalai/booster/plugin/plugin_base.py
colossalai/booster/plugin/plugin_base.py
+1
-1
colossalai/booster/plugin/torch_ddp_plugin.py
colossalai/booster/plugin/torch_ddp_plugin.py
+1
-1
colossalai/booster/plugin/torch_fsdp_plugin.py
colossalai/booster/plugin/torch_fsdp_plugin.py
+1
-1
colossalai/zero/low_level/low_level_optim.py
colossalai/zero/low_level/low_level_optim.py
+12
-17
tests/test_zero/test_low_level/test_grad_acc.py
tests/test_zero/test_low_level/test_grad_acc.py
+15
-20
No files found.
colossalai/booster/booster.py
View file @
79cf1b5f
...
...
@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from
torch.utils.data
import
DataLoader
from
colossalai.checkpoint_io
import
GeneralCheckpointIO
from
colossalai.interface
import
ModelWrapper
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
.accelerator
import
Accelerator
from
.mixed_precision
import
MixedPrecision
,
mixed_precision_factory
...
...
@@ -153,18 +153,20 @@ class Booster:
# return loss or outputs if needed
pass
def
no_sync
(
self
,
model
:
nn
.
Module
)
->
contextmanager
:
def
no_sync
(
self
,
model
:
nn
.
Module
=
None
,
optimizer
:
OptimizerWrapper
=
None
)
->
contextmanager
:
"""Context manager to disable gradient synchronization across DP process groups.
Support torch DDP and Low Level ZeRO-1 for now.
Args:
model (nn.Module): The model to be disabled gradient synchronization.
model (nn.Module): The model to be disabled gradient synchronization, for DDP
optimizer (OptimizerWrapper): The optimizer to be disabled gradient synchronization, for ZeRO1-1
Returns:
contextmanager: Context to disable gradient synchronization.
"""
assert
self
.
plugin
is
not
None
,
f
'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
assert
self
.
plugin
.
support_no_sync
,
f
'The plugin
{
self
.
plugin
.
__class__
.
__name__
}
does not support no_sync.'
return
self
.
plugin
.
no_sync
(
model
)
assert
self
.
plugin
.
support_no_sync
()
,
f
'The plugin
{
self
.
plugin
.
__class__
.
__name__
}
does not support no_sync.'
return
self
.
plugin
.
no_sync
(
model
,
optimizer
)
def
load_model
(
self
,
model
:
Union
[
nn
.
Module
,
ModelWrapper
],
checkpoint
:
str
,
strict
:
bool
=
True
):
"""Load model from checkpoint.
...
...
colossalai/booster/plugin/gemini_plugin.py
View file @
79cf1b5f
...
...
@@ -408,5 +408,5 @@ class GeminiPlugin(DPPluginBase):
def
get_checkpoint_io
(
self
)
->
CheckpointIO
:
return
GeminiCheckpointIO
()
def
no_sync
(
self
,
model
:
nn
.
Module
)
->
Iterator
[
None
]:
def
no_sync
(
self
,
model
:
nn
.
Module
,
optimizer
:
OptimizerWrapper
)
->
Iterator
[
None
]:
raise
NotImplementedError
colossalai/booster/plugin/low_level_zero_plugin.py
View file @
79cf1b5f
...
...
@@ -179,8 +179,11 @@ class LowLevelZeroPlugin(DPPluginBase):
norm_type
=
norm_type
)
self
.
verbose
=
verbose
# set class name with stage, for better error message
setattr
(
self
.
__class__
,
"__name__"
,
f
"LowLevelZeroPlugin_ZeRO-
{
stage
}
"
)
def
support_no_sync
(
self
)
->
bool
:
return
False
return
self
.
stage
==
1
def
control_precision
(
self
)
->
bool
:
return
True
...
...
@@ -219,5 +222,6 @@ class LowLevelZeroPlugin(DPPluginBase):
def
get_checkpoint_io
(
self
)
->
CheckpointIO
:
return
LowLevelZeroCheckpointIO
()
def
no_sync
(
self
,
model
:
nn
.
Module
)
->
Iterator
[
None
]:
raise
NotImplementedError
def
no_sync
(
self
,
model
:
nn
.
Module
,
optimizer
:
OptimizerWrapper
)
->
Iterator
[
None
]:
assert
isinstance
(
optimizer
,
LowLevelZeroOptimizer
)
return
optimizer
.
optim
.
no_sync
()
colossalai/booster/plugin/plugin_base.py
View file @
79cf1b5f
...
...
@@ -61,7 +61,7 @@ class Plugin(ABC):
pass
@
abstractmethod
def
no_sync
(
self
,
model
:
nn
.
Module
)
->
Iterator
[
None
]:
def
no_sync
(
self
,
model
:
nn
.
Module
,
optimizer
:
OptimizerWrapper
)
->
Iterator
[
None
]:
"""
Context manager to disable gradient synchronization.
"""
...
...
colossalai/booster/plugin/torch_ddp_plugin.py
View file @
79cf1b5f
...
...
@@ -168,6 +168,6 @@ class TorchDDPPlugin(DPPluginBase):
def
get_checkpoint_io
(
self
)
->
CheckpointIO
:
return
TorchDDPCheckpointIO
()
def
no_sync
(
self
,
model
:
nn
.
Module
)
->
Iterator
[
None
]:
def
no_sync
(
self
,
model
:
nn
.
Module
,
optimizer
:
OptimizerWrapper
)
->
Iterator
[
None
]:
assert
isinstance
(
model
,
TorchDDPModel
),
'Model is not boosted by TorchDDPPlugin.'
return
model
.
module
.
no_sync
()
colossalai/booster/plugin/torch_fsdp_plugin.py
View file @
79cf1b5f
...
...
@@ -177,7 +177,7 @@ class TorchFSDPPlugin(DPPluginBase):
def
support_no_sync
(
self
)
->
bool
:
False
def
no_sync
(
self
,
model
:
nn
.
Module
)
->
Iterator
[
None
]:
def
no_sync
(
self
,
model
:
nn
.
Module
,
optimizer
:
OptimizerWrapper
)
->
Iterator
[
None
]:
raise
NotImplementedError
(
"Torch fsdp no_sync func not supported yet."
)
def
control_precision
(
self
)
->
bool
:
...
...
colossalai/zero/low_level/low_level_optim.py
View file @
79cf1b5f
...
...
@@ -14,10 +14,10 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
)
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.interface
import
OptimizerWrapper
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.tensor
import
ColoParameter
,
ProcessGroup
from
colossalai.utils
import
conditional_context
from
colossalai.utils.cuda
import
get_current_device
from
._utils
import
(
...
...
@@ -56,7 +56,7 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
return
False
class
LowLevelZeroOptimizer
(
Colossalai
Optimizer
):
class
LowLevelZeroOptimizer
(
Optimiz
erWrapp
er
):
"""Optimizer used for ZeRO-1 and ZeRO-2.
"""
...
...
@@ -77,11 +77,12 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
overlap_communication
:
bool
=
False
,
partition_grad
:
bool
=
False
,
# stage 2 flag
cpu_offload
:
bool
=
False
,
# cpu offload
grad_accumulate_interval
:
int
=
1
,
forced_dtype
:
Optional
[
torch
.
dtype
]
=
None
):
assert
not
(
partition_grad
and
grad_accumulate_interval
>
1
),
\
"gradient accumulation is not compatible with ZeRO-2"
# TODO:
# 1. process group api
# 2. checkpoint IO
super
(
LowLevelZeroOptimizer
,
self
).
__init__
(
optim
=
optimizer
)
self
.
_dtype
=
self
.
optim
.
param_groups
[
0
][
'params'
][
0
].
dtype
self
.
_logger
=
get_dist_logger
()
...
...
@@ -94,8 +95,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# grad accumulation
self
.
require_grad_sync
=
True
self
.
_accumulate_intervel
=
grad_accumulate_interval
self
.
_accumulate_step
=
0
colo_pg
=
self
.
_search_colo_process_group
()
if
isinstance
(
colo_pg
,
ProcessGroup
):
...
...
@@ -340,15 +339,15 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
################################
def
backward
(
self
,
loss
,
retain_graph
=
False
):
assert
not
(
self
.
_partition_grads
and
not
self
.
require_grad_sync
),
\
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
if
self
.
mixed_precision_mixin
is
not
None
:
loss
=
self
.
mixed_precision_mixin
.
pre_backward
(
loss
)
self
.
_accumulate_step
+=
1
no_sync
=
self
.
_accumulate_step
<
self
.
_accumulate_intervel
with
conditional_context
(
self
.
no_sync
(),
enable
=
no_sync
):
loss
.
backward
(
retain_graph
=
retain_graph
)
loss
.
backward
(
retain_graph
=
retain_graph
)
if
no_sync
:
if
no
t
self
.
require_grad
_sync
:
return
self
.
_reduce_grad
(
self
.
_partition_grads
)
...
...
@@ -385,7 +384,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
def
step
(
self
,
closure
=
None
):
assert
closure
is
None
,
'closure is not supported by step()'
if
not
self
.
_accumulate_step
==
self
.
_accumulate_intervel
:
if
not
self
.
require_grad_sync
:
return
if
self
.
mixed_precision_mixin
is
not
None
and
self
.
mixed_precision_mixin
.
should_skip_step
():
...
...
@@ -393,7 +392,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
if
self
.
_verbose
:
self
.
_logger
.
info
(
f
'Found overflow. Skip step'
)
self
.
zero_grad
()
self
.
_accumulate_step
-=
1
return
# record all grads for unscale and clip
...
...
@@ -463,9 +461,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self
.
optim
.
param_groups
[
group_id
][
'params'
]
=
self
.
_master_param_groups_of_current_rank
[
group_id
]
# reset accumulate step
self
.
_accumulate_step
=
0
#############################
# Mixed Precision Utilities #
#############################
...
...
tests/test_zero/test_low_level/test_grad_acc.py
View file @
79cf1b5f
...
...
@@ -9,6 +9,7 @@ from torch.testing import assert_close
import
colossalai
from
colossalai.testing
import
spawn
from
colossalai.testing.random
import
seed_all
from
colossalai.utils
import
conditional_context
from
colossalai.zero
import
LowLevelZeroOptimizer
...
...
@@ -39,14 +40,12 @@ def exam_zero_1_2_grad_acc():
overlap_communication
=
True
,
initial_scale
=
32
,
clip_grad_norm
=
1.0
,
grad_accumulate_interval
=
2
,
verbose
=
True
)
zero2_optimizer
=
LowLevelZeroOptimizer
(
zero2_optimizer
,
overlap_communication
=
True
,
partition_grad
=
True
,
initial_scale
=
32
,
clip_grad_norm
=
1.0
,
grad_accumulate_interval
=
2
)
clip_grad_norm
=
1.0
)
# create data
seed_all
(
2021
+
local_rank
)
input_data1
=
torch
.
randn
(
32
,
128
).
cuda
()
...
...
@@ -59,8 +58,11 @@ def exam_zero_1_2_grad_acc():
assert
torch
.
equal
(
zero1_output
,
zero2_output
)
# zero-dp backward
zero1_optimizer
.
backward
(
zero1_output
.
sum
().
float
())
zero2_optimizer
.
backward
(
zero2_output
.
sum
().
float
())
no_sync
=
number
==
0
with
conditional_context
(
zero1_optimizer
.
no_sync
(),
no_sync
):
zero1_optimizer
.
backward
(
zero1_output
.
sum
().
float
())
with
conditional_context
(
zero2_optimizer
.
no_sync
(),
no_sync
):
zero2_optimizer
.
backward
(
zero2_output
.
sum
().
float
())
if
check_flag
:
for
(
n
,
z1p
),
z2p
in
zip
(
zero1_model
.
named_parameters
(),
zero2_model
.
parameters
()):
...
...
@@ -101,8 +103,7 @@ def exam_zero_1_grad_acc():
zero_optimizer
=
LowLevelZeroOptimizer
(
zero_optimizer
,
overlap_communication
=
False
,
reduce_bucket_size
=
262144
,
clip_grad_norm
=
1.0
,
grad_accumulate_interval
=
2
)
clip_grad_norm
=
1.0
)
torch_optimizer
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1
)
...
...
@@ -112,20 +113,15 @@ def exam_zero_1_grad_acc():
input_data2
=
torch
.
randn
(
32
,
128
).
cuda
()
def
fwd_bwd_func
(
number
,
cur_data
,
check_flag
):
# zero-dp forward
zero_output
=
zero_model
(
cur_data
)
# torch-ddp forward
no_sync
=
number
==
0
# zero1 fwd and bwd
with
conditional_context
(
zero_optimizer
.
no_sync
(),
no_sync
):
zero_output
=
zero_model
(
cur_data
)
zero_optimizer
.
backward
(
zero_output
.
sum
().
float
())
# zero-dp backward
zero_optimizer
.
backward
(
zero_output
.
sum
().
float
())
# torch-ddp backward
if
number
<
1
:
with
torch_model
.
no_sync
():
torch_output
=
torch_model
(
cur_data
)
assert
torch
.
equal
(
zero_output
,
torch_output
)
torch_output
.
sum
().
backward
()
else
:
# torch-ddp fwd and bwd
with
conditional_context
(
torch_model
.
no_sync
(),
no_sync
):
torch_output
=
torch_model
(
cur_data
)
assert
torch
.
equal
(
zero_output
,
torch_output
)
torch_output
.
sum
().
backward
()
...
...
@@ -133,7 +129,6 @@ def exam_zero_1_grad_acc():
if
check_flag
:
# check grad
for
(
n
,
p
),
z1p
in
zip
(
torch_model
.
named_parameters
(),
zero_model
.
parameters
()):
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
assert
torch
.
equal
(
p
.
grad
,
z1p
.
grad
)
fwd_bwd_func
(
0
,
input_data1
,
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment