Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b528eea0
Unverified
Commit
b528eea0
authored
Jan 29, 2023
by
HELSON
Committed by
GitHub
Jan 29, 2023
Browse files
[zero] add zero wrappers (#2523)
* [zero] add zero wrappers * change names * add wrapper functions to init
parent
c198c7c0
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
128 additions
and
19 deletions
+128
-19
colossalai/nn/optimizer/zero_optimizer.py
colossalai/nn/optimizer/zero_optimizer.py
+2
-1
colossalai/nn/parallel/__init__.py
colossalai/nn/parallel/__init__.py
+2
-1
colossalai/nn/parallel/zero_wrapper.py
colossalai/nn/parallel/zero_wrapper.py
+106
-0
colossalai/zero/sharded_optim/low_level_optim.py
colossalai/zero/sharded_optim/low_level_optim.py
+6
-3
tests/test_zero/low_level_zero/test_grad_acc.py
tests/test_zero/low_level_zero/test_grad_acc.py
+6
-7
tests/test_zero/low_level_zero/test_zero1_2.py
tests/test_zero/low_level_zero/test_zero1_2.py
+6
-6
tests/test_zero/low_level_zero/test_zero_tp.py
tests/test_zero/low_level_zero/test_zero_tp.py
+0
-1
No files found.
colossalai/nn/optimizer/zero_optimizer.py
View file @
b528eea0
...
@@ -65,7 +65,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
...
@@ -65,7 +65,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
**
defaults
:
Any
):
**
defaults
:
Any
):
super
().
__init__
(
optim
)
super
().
__init__
(
optim
)
assert
isinstance
(
module
,
ZeroDDP
)
assert
isinstance
(
module
,
ZeroDDP
)
assert
type
(
optim
)
in
_AVAIL_OPTIM_LIST
,
"you should use the optimizer in the available list"
assert
type
(
optim
)
in
_AVAIL_OPTIM_LIST
,
"You should use an optimizer in the available list:
\n
"
\
f
"
{
_AVAIL_OPTIM_LIST
}
"
self
.
module
=
module
self
.
module
=
module
self
.
gemini_manager
=
module
.
gemini_manager
self
.
gemini_manager
=
module
.
gemini_manager
self
.
chunk_manager
:
ChunkManager
=
self
.
gemini_manager
.
chunk_manager
self
.
chunk_manager
:
ChunkManager
=
self
.
gemini_manager
.
chunk_manager
...
...
colossalai/nn/parallel/__init__.py
View file @
b528eea0
from
.data_parallel
import
ColoDDP
,
ZeroDDP
from
.data_parallel
import
ColoDDP
,
ZeroDDP
from
.gemini_parallel
import
GeminiDDP
from
.gemini_parallel
import
GeminiDDP
from
.zero_wrapper
import
zero_model_wrapper
,
zero_optim_wrapper
__all__
=
[
'ColoDDP'
,
'ZeroDDP'
,
'GeminiDDP'
]
__all__
=
[
'ColoDDP'
,
'ZeroDDP'
,
'GeminiDDP'
,
'zero_model_wrapper'
,
'zero_optim_wrapper'
]
colossalai/nn/parallel/zero_wrapper.py
0 → 100644
View file @
b528eea0
from
copy
import
copy
from
typing
import
Dict
,
Optional
import
torch
import
torch.nn
as
nn
from
.gemini_parallel
import
GeminiDDP
def
zero_model_wrapper
(
model
:
nn
.
Module
,
zero_stage
:
int
=
1
,
gemini_config
:
Optional
[
Dict
]
=
None
):
"""This wrapper function is used to wrap your training model for ZeRO DDP.
Example:
>>> with ColoInitContext():
>>> my_model = Bert()
>>> my_optim = SGD(my_model.parameters(), lr = 1e-3)
>>> zero_model = zero_model_wrapper(my_model, zero_stage=1)
>>> zero_optim = zero_optim_wrapper(zero_model, my_optim)
Args:
model (nn.Module): The model used in ZeRO DDP.
zero_stage (int, optional): The stage of ZeRO DDP. You can find more information in ZeRO's paper.
https://arxiv.org/abs/1910.02054
gemini_config (dict, optional): The configuration dictionary of `GeminiDDP`. `GeminiDDP` is enabled
when the stage is set to 3. You can set the arguemnts of `GeminiDDP` in the gemini_config.
Here is an example where we set the device of the model, the placement policy of Gemini, and the
size of hidden dimension to help Gemini find out a unified chunk size.
Example:
>>> config_dict = dict(device=torch.cuda.current_device(), hidden_dim=1024, placement_policy='auto')
>>> model = zero_model_wrapper(model, zero_stage=3, gemini_config=config_dict)
"""
setattr
(
model
,
"_colo_zero_stage"
,
zero_stage
)
assert
zero_stage
in
[
1
,
2
,
3
],
"The stage of ZeRO should be 1, 2 or 3"
if
gemini_config
is
None
:
gemini_config
=
dict
()
if
zero_stage
in
[
1
,
2
]:
return
model
else
:
return
GeminiDDP
(
model
,
**
gemini_config
)
def
zero_optim_wrapper
(
model
:
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
,
initial_scale
:
float
=
2
**
16
,
growth_factor
:
float
=
2
,
backoff_factor
:
float
=
0.5
,
growth_interval
:
int
=
1000
,
hysteresis
:
int
=
2
,
min_scale
:
float
=
1
,
max_scale
:
float
=
2
**
32
,
max_norm
:
float
=
0.0
,
norm_type
:
float
=
2.0
,
optim_config
:
Optional
[
Dict
]
=
None
):
"""This wrapper function is used to wrap your training optimizer for ZeRO DDP.
Args:
model (nn.Module): Your model wrapped by `zero_model_wrapper`
optimizer (torch.optim.Optimizer): Your initialized optimizer
initial_scale (float, optional): initial_scale used by DynamicGradScaler.
min_scale (float, optional): min_scale used by DynamicGradScaler.
growth_factor (float, optional): growth_factor used by DynamicGradScaler.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler.
growth_interval (float, optional): growth_interval used by DynamicGradScaler.
hysteresis (float, optional): hysteresis used by DynamicGradScaler.
max_scale (int, optional): max_scale used by DynamicGradScaler.
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
optim_config (dict, optinoal): The configuration used for the ZeRO optimizer.
Example:
>>> zero2_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True)
>>> optim = zero_optim_wrapper(model, optim, optim_config=zero2_config)
"""
assert
hasattr
(
model
,
"_colo_zero_stage"
),
"You should use `zero_ddp_wrapper` first"
zero_stage
=
getattr
(
model
,
"_colo_zero_stage"
)
assert
norm_type
==
2.0
,
"Current ZeRO optimizers only support 'norm_type=2'"
if
optim_config
is
None
:
config_dict
=
dict
()
else
:
config_dict
=
copy
(
optim_config
)
config_dict
[
'initial_scale'
]
=
initial_scale
config_dict
[
'growth_factor'
]
=
growth_factor
config_dict
[
'backoff_factor'
]
=
backoff_factor
config_dict
[
'growth_interval'
]
=
growth_interval
config_dict
[
'hysteresis'
]
=
hysteresis
config_dict
[
'min_scale'
]
=
min_scale
config_dict
[
'max_scale'
]
=
max_scale
if
zero_stage
in
[
1
,
2
]:
from
colossalai.zero.sharded_optim.low_level_optim
import
LowLevelZeroOptimizer
config_dict
[
'partition_grad'
]
=
zero_stage
==
2
config_dict
[
'clip_grad_norm'
]
=
max_norm
return
LowLevelZeroOptimizer
(
optimizer
,
**
config_dict
)
else
:
from
colossalai.nn.optimizer.zero_optimizer
import
ZeroOptimizer
config_dict
[
'clipping_norm'
]
=
max_norm
return
ZeroOptimizer
(
optimizer
,
model
,
**
config_dict
)
colossalai/zero/sharded_optim/low_level_optim.py
View file @
b528eea0
...
@@ -17,7 +17,6 @@ from ._utils import (
...
@@ -17,7 +17,6 @@ from ._utils import (
calculate_global_norm_from_list
,
calculate_global_norm_from_list
,
compute_norm
,
compute_norm
,
flatten
,
flatten
,
get_grad_accumulate_object
,
has_inf_or_nan
,
has_inf_or_nan
,
reduce_tensor_dp_group
,
reduce_tensor_dp_group
,
release_param_grad
,
release_param_grad
,
...
@@ -386,7 +385,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
...
@@ -386,7 +385,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# torch.optim.Optimizer methods
# torch.optim.Optimizer methods
################################
################################
def
backward
(
self
,
loss
,
retain_graph
=
False
):
def
backward
(
self
,
loss
,
retain_graph
=
False
,
sync_grad
=
True
):
loss
=
self
.
loss_scale
*
loss
loss
=
self
.
loss_scale
*
loss
loss
.
backward
(
retain_graph
=
retain_graph
)
loss
.
backward
(
retain_graph
=
retain_graph
)
...
@@ -402,6 +401,10 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
...
@@ -402,6 +401,10 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
self
.
_param_store
.
clear_grads_of_previous_reduced_params
()
self
.
_param_store
.
clear_grads_of_previous_reduced_params
()
# gradient synchronization
if
sync_grad
:
self
.
_sync_grad
()
def
zero_grad
(
self
,
set_to_none
=
True
):
def
zero_grad
(
self
,
set_to_none
=
True
):
"""
"""
Set parameter gradients to zero. If set_to_none = True, gradient
Set parameter gradients to zero. If set_to_none = True, gradient
...
@@ -537,7 +540,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
...
@@ -537,7 +540,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# Gradient Synchronization #
# Gradient Synchronization #
############################
############################
def
sync_grad
(
self
):
def
_
sync_grad
(
self
):
# update param already reduced flag
# update param already reduced flag
reduction_states
=
self
.
_param_store
.
get_param_reduction_states
()
reduction_states
=
self
.
_param_store
.
get_param_reduction_states
()
for
tensor
,
state
in
reduction_states
.
items
():
for
tensor
,
state
in
reduction_states
.
items
():
...
...
tests/test_zero/low_level_zero/test_grad_acc.py
View file @
b528eea0
...
@@ -9,7 +9,6 @@ from torch.nn.parallel import DistributedDataParallel as DDP
...
@@ -9,7 +9,6 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from
torch.testing
import
assert_close
from
torch.testing
import
assert_close
import
colossalai
import
colossalai
from
colossalai.tensor
import
ProcessGroup
from
colossalai.testing.random
import
seed_all
from
colossalai.testing.random
import
seed_all
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.zero
import
LowLevelZeroOptimizer
from
colossalai.zero
import
LowLevelZeroOptimizer
...
@@ -60,16 +59,16 @@ def exam_zero_1_2_grad_acc():
...
@@ -60,16 +59,16 @@ def exam_zero_1_2_grad_acc():
assert
torch
.
equal
(
zero1_output
,
zero2_output
)
assert
torch
.
equal
(
zero1_output
,
zero2_output
)
# zero-dp backward
# zero-dp backward
zero1_optimizer
.
backward
(
zero1_output
.
sum
().
float
())
zero1_optimizer
.
backward
(
zero1_output
.
sum
().
float
()
,
sync_grad
=
False
)
zero2_optimizer
.
backward
(
zero2_output
.
sum
().
float
())
zero2_optimizer
.
backward
(
zero2_output
.
sum
().
float
()
,
sync_grad
=
False
)
for
(
n
,
z1p
),
z2p
in
zip
(
zero1_model
.
named_parameters
(),
zero2_model
.
parameters
()):
for
(
n
,
z1p
),
z2p
in
zip
(
zero1_model
.
named_parameters
(),
zero2_model
.
parameters
()):
if
z2p
.
grad
is
not
None
:
if
z2p
.
grad
is
not
None
:
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
assert
torch
.
equal
(
z1p
.
grad
,
z2p
.
grad
)
assert
torch
.
equal
(
z1p
.
grad
,
z2p
.
grad
)
zero1_optimizer
.
sync_grad
()
zero1_optimizer
.
_
sync_grad
()
zero2_optimizer
.
sync_grad
()
zero2_optimizer
.
_
sync_grad
()
fwd_bwd_func
(
0
,
input_data1
)
fwd_bwd_func
(
0
,
input_data1
)
fwd_bwd_func
(
1
,
input_data2
)
fwd_bwd_func
(
1
,
input_data2
)
...
@@ -124,7 +123,7 @@ def exam_zero_1_grad_acc():
...
@@ -124,7 +123,7 @@ def exam_zero_1_grad_acc():
assert
torch
.
equal
(
zero_output
,
torch_output
)
assert
torch
.
equal
(
zero_output
,
torch_output
)
# zero-dp backward
# zero-dp backward
zero_optimizer
.
backward
(
zero_output
.
sum
().
float
())
zero_optimizer
.
backward
(
zero_output
.
sum
().
float
()
,
sync_grad
=
False
)
# torch-ddp backward
# torch-ddp backward
torch_output
.
sum
().
backward
()
torch_output
.
sum
().
backward
()
...
@@ -135,7 +134,7 @@ def exam_zero_1_grad_acc():
...
@@ -135,7 +134,7 @@ def exam_zero_1_grad_acc():
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
assert
torch
.
equal
(
p
.
grad
,
unscale_grad
)
assert
torch
.
equal
(
p
.
grad
,
unscale_grad
)
zero_optimizer
.
sync_grad
()
zero_optimizer
.
_
sync_grad
()
fwd_bwd_func
(
0
,
input_data1
,
True
)
fwd_bwd_func
(
0
,
input_data1
,
True
)
fwd_bwd_func
(
1
,
input_data2
,
False
)
fwd_bwd_func
(
1
,
input_data2
,
False
)
...
...
tests/test_zero/low_level_zero/test_zero1_2.py
View file @
b528eea0
...
@@ -78,16 +78,16 @@ def exam_zero_1_2():
...
@@ -78,16 +78,16 @@ def exam_zero_1_2():
assert
torch
.
equal
(
zero1_output
,
zero2_output
)
assert
torch
.
equal
(
zero1_output
,
zero2_output
)
# zero-dp backward
# zero-dp backward
zero1_optimizer
.
backward
(
zero1_output
.
mean
().
float
())
zero1_optimizer
.
backward
(
zero1_output
.
mean
().
float
()
,
sync_grad
=
False
)
zero2_optimizer
.
backward
(
zero2_output
.
mean
().
float
())
zero2_optimizer
.
backward
(
zero2_output
.
mean
().
float
()
,
sync_grad
=
False
)
for
(
n
,
z1p
),
z2p
in
zip
(
zero1_model
.
named_parameters
(),
zero2_model
.
parameters
()):
for
(
n
,
z1p
),
z2p
in
zip
(
zero1_model
.
named_parameters
(),
zero2_model
.
parameters
()):
if
z2p
.
grad
is
not
None
:
if
z2p
.
grad
is
not
None
:
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
assert
torch
.
equal
(
z1p
.
grad
,
z2p
.
grad
)
assert
torch
.
equal
(
z1p
.
grad
,
z2p
.
grad
)
zero1_optimizer
.
sync_grad
()
zero1_optimizer
.
_
sync_grad
()
zero2_optimizer
.
sync_grad
()
zero2_optimizer
.
_
sync_grad
()
# step
# step
zero1_optimizer
.
step
()
zero1_optimizer
.
step
()
...
@@ -146,7 +146,7 @@ def exam_zero_1_torch_ddp():
...
@@ -146,7 +146,7 @@ def exam_zero_1_torch_ddp():
half_close
(
zero_output
,
torch_output
,
loose
=
True
)
half_close
(
zero_output
,
torch_output
,
loose
=
True
)
# zero-dp backward
# zero-dp backward
zero_optimizer
.
backward
(
zero_output
.
mean
().
float
())
zero_optimizer
.
backward
(
zero_output
.
mean
().
float
()
,
sync_grad
=
False
)
# torch-ddp backward
# torch-ddp backward
torch_output
.
mean
().
backward
()
torch_output
.
mean
().
backward
()
...
@@ -156,7 +156,7 @@ def exam_zero_1_torch_ddp():
...
@@ -156,7 +156,7 @@ def exam_zero_1_torch_ddp():
half_close
(
p
.
grad
,
z1p
.
grad
,
loose
=
True
)
half_close
(
p
.
grad
,
z1p
.
grad
,
loose
=
True
)
# zero-dp step
# zero-dp step
zero_optimizer
.
sync_grad
()
zero_optimizer
.
_
sync_grad
()
zero_optimizer
.
step
()
zero_optimizer
.
step
()
# torch ddp step
# torch ddp step
...
...
tests/test_zero/low_level_zero/test_zero_tp.py
View file @
b528eea0
...
@@ -74,7 +74,6 @@ def exam_zero_with_tp(overlap_flag, partition_flag):
...
@@ -74,7 +74,6 @@ def exam_zero_with_tp(overlap_flag, partition_flag):
torch_loss
.
backward
()
torch_loss
.
backward
()
torch
.
nn
.
utils
.
clip_grad_norm_
(
torch_model
.
parameters
(),
1.0
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
torch_model
.
parameters
(),
1.0
)
hybrid_optim
.
backward
(
hybrid_loss
)
hybrid_optim
.
backward
(
hybrid_loss
)
hybrid_optim
.
sync_grad
()
torch_optim
.
step
()
torch_optim
.
step
()
hybrid_optim
.
step
()
hybrid_optim
.
step
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment