Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
839847b7
Unverified
Commit
839847b7
authored
Aug 25, 2023
by
LuGY
Committed by
GitHub
Aug 25, 2023
Browse files
[zero]support zero2 with gradient accumulation (#4511)
* support gradient accumulation with zero2 * fix type
parent
c0efc3eb
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
61 additions
and
28 deletions
+61
-28
colossalai/zero/low_level/bookkeeping/gradient_store.py
colossalai/zero/low_level/bookkeeping/gradient_store.py
+2
-2
colossalai/zero/low_level/low_level_optim.py
colossalai/zero/low_level/low_level_optim.py
+10
-3
colossalai/zero/low_level/readme.md
colossalai/zero/low_level/readme.md
+40
-4
tests/test_zero/test_low_level/test_grad_acc.py
tests/test_zero/test_low_level/test_grad_acc.py
+9
-19
No files found.
colossalai/zero/low_level/bookkeeping/gradient_store.py
View file @
839847b7
...
...
@@ -57,8 +57,8 @@ class GradientStore(BaseStore):
self
.
_grads_of_params
[
group_id
][
param_id
].
append
(
grad
)
def
add_gradients_by_param_id
(
self
,
grad
:
Tensor
,
grad_idx
:
int
,
group_id
:
int
,
param_id
:
int
):
"""
For old
gradient
accumulation, not in use now.
Add a gradient slice on an existing slice of the parameter's gradient
"""
Add a
gradient
slice on an existing slice of the parameter's gradient
Used when no_sync is not activated.
Args:
grad (Tensor): The split gradient to append to list
...
...
colossalai/zero/low_level/low_level_optim.py
View file @
839847b7
...
...
@@ -277,7 +277,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
sync_tensor
(
flat_grads_per_rank
[
rank
],
grad_list
)
for
grad
in
grad_list
:
param_id
=
self
.
_bucket_store
.
get_param_id_of_grad
(
grad
)
if
len
(
self
.
_grad_store
.
get_partitioned_gradients_by_param_id
(
group_id
,
param_id
))
<
self
.
_world_size
:
self
.
_grad_store
.
append_gradients_by_param_id
(
grad
,
group_id
,
param_id
)
else
:
self
.
_grad_store
.
add_gradients_by_param_id
(
grad
,
rank
,
group_id
,
param_id
)
else
:
flat_grads_list
=
list
(
flat_grads
.
split
(
len
(
flat_grads
)
//
self
.
_world_size
))
...
...
@@ -291,7 +295,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
sync_tensor
(
recieved_grad
,
grad_in_bucket_current_rank
)
for
grad
in
grad_in_bucket_current_rank
:
param_id
=
self
.
_bucket_store
.
get_param_id_of_grad
(
grad
)
if
len
(
self
.
_grad_store
.
get_partitioned_gradients_by_param_id
(
group_id
,
param_id
))
<
1
:
self
.
_grad_store
.
append_gradients_by_param_id
(
grad
,
group_id
,
param_id
)
else
:
self
.
_grad_store
.
add_gradients_by_param_id
(
grad
,
0
,
group_id
,
param_id
)
self
.
_bucket_store
.
reset
()
...
...
@@ -315,7 +322,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def
backward
(
self
,
loss
,
retain_graph
=
False
):
assert
not
(
self
.
_partition_grads
and
not
self
.
require_grad_sync
),
\
"ZeRO2(partition_grads) and
gradient accumulation(
no_sync
)
are not compatible"
"ZeRO2(partition_grads) and no_sync are not compatible"
if
self
.
mixed_precision_mixin
is
not
None
:
loss
=
self
.
mixed_precision_mixin
.
pre_backward
(
loss
)
...
...
colossalai/zero/low_level/readme.md
View file @
839847b7
# Low Level ZeRO
>Low Level ZeRO == ZeRO-DP stage 1 and 2, we would denote it as ZeRO.
## Examples of ZeRO and gradient accumulation
The code below only shows a typical gradient accumulation process, and it drops a lot of details, such as the processing of loss.
```
python
# examples of ZeRO1 with gradient accumulation
...
outputs
=
model
(
input
)
loss
=
SomeLoss
(
outputs
)
if
(
idx
+
1
)
%
ACCUMULATE_STEP
!=
0
:
with
booster
.
no_sync
(
model
,
optimizer
):
# under this context, the gradient would not sync when backward,
# left each rank having different gradient.
# It saves the backward time
booster
.
backward
(
loss
,
optimizer
)
continue
else
:
# need to sync all the accumulated gradient
booster
.
backward
(
loss
,
optimizer
):
optimizer
.
step
()
...
```
```
python
# example of ZeRO2 with gradient accumulation
...
outputs
=
model
(
input
)
loss
=
SomeLoss
(
outputs
)
# ZeRO2 split the gradients and can NOT accumulate gradient with syncing.
booster
.
backward
(
loss
,
optimizer
)
if
(
idx
+
1
)
%
ACCUMULATE_STEP
==
0
:
optimizer
.
step
()
...
```
## Design:
### Notion
...
...
@@ -25,11 +61,11 @@ The data structure looks like this:
```
After that, the gradients would be flattened by rank, and the data structure looks like this:
```
# g-0 means flatten([g-00, g-10])
# g-
X
0 means flatten([g-00, g-10])
{
0: [g-0],
1: [g-1],
2: [g-2]
0: [g-
X
0],
1: [g-
X
1],
2: [g-
X
2]
}
```
For zero1, we iterate the dictionary and do
`all_reduce`
. For zero2, we can just do
`reduce-scatter`
.
...
...
tests/test_zero/test_low_level/test_grad_acc.py
View file @
839847b7
...
...
@@ -58,18 +58,9 @@ def exam_zero_1_2_grad_acc():
assert
torch
.
equal
(
zero1_output
,
zero2_output
)
# zero-dp backward
no_sync
=
number
==
0
with
conditional_context
(
zero1_optimizer
.
no_sync
(),
no_sync
):
zero1_optimizer
.
backward
(
zero1_output
.
sum
().
float
())
with
conditional_context
(
zero2_optimizer
.
no_sync
(),
no_sync
):
zero2_optimizer
.
backward
(
zero2_output
.
sum
().
float
())
if
check_flag
:
for
(
n
,
z1p
),
z2p
in
zip
(
zero1_model
.
named_parameters
(),
zero2_model
.
parameters
()):
if
z2p
.
grad
is
not
None
:
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
assert
torch
.
equal
(
z1p
.
grad
,
z2p
.
grad
)
fwd_bwd_func
(
0
,
input_data1
,
True
)
fwd_bwd_func
(
1
,
input_data2
,
False
)
...
...
@@ -82,7 +73,7 @@ def exam_zero_1_2_grad_acc():
assert
torch
.
equal
(
z1p
.
data
,
z2p
.
data
)
def
exam_zero_1_grad_acc
():
def
exam_zero_1_grad_acc
(
sync
):
local_rank
=
torch
.
distributed
.
get_rank
()
seed_all
(
2008
)
...
...
@@ -112,9 +103,8 @@ def exam_zero_1_grad_acc():
input_data1
=
torch
.
randn
(
32
,
128
).
cuda
()
input_data2
=
torch
.
randn
(
32
,
128
).
cuda
()
def
fwd_bwd_func
(
n
umber
,
cur_data
,
check_flag
):
def
fwd_bwd_func
(
n
o_sync
,
cur_data
,
check_flag
):
no_sync
=
number
==
0
# zero1 fwd and bwd
with
conditional_context
(
zero_optimizer
.
no_sync
(),
no_sync
):
zero_output
=
zero_model
(
cur_data
)
...
...
@@ -131,8 +121,8 @@ def exam_zero_1_grad_acc():
for
(
n
,
p
),
z1p
in
zip
(
torch_model
.
named_parameters
(),
zero_model
.
parameters
()):
assert
torch
.
equal
(
p
.
grad
,
z1p
.
grad
)
fwd_bwd_func
(
0
,
input_data1
,
True
)
fwd_bwd_func
(
1
,
input_data2
,
False
)
fwd_bwd_func
(
sync
,
input_data1
,
sync
)
fwd_bwd_func
(
False
,
input_data2
,
False
)
zero_optimizer
.
step
()
torch
.
nn
.
utils
.
clip_grad_norm_
(
torch_model
.
parameters
(),
1.0
)
...
...
@@ -147,9 +137,9 @@ def exam_zero_1_grad_acc():
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
port
=
port
,
host
=
'localhost'
)
exam_zero_1_grad_acc
()
# gradient accumulation is not compatible with ZeRO-2
#
exam_zero_1_2_grad_acc()
exam_zero_1_grad_acc
(
sync
=
True
)
exam_zero_1_grad_acc
(
sync
=
False
)
exam_zero_1_2_grad_acc
()
@
pytest
.
mark
.
dist
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment