Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
fa87a73a
Unverified
Commit
fa87a73a
authored
Mar 16, 2021
by
Olatunji Ruwase
Committed by
GitHub
Mar 16, 2021
Browse files
Fix ZeRO3 save_checkpoint (#857)
Co-authored-by:
Jeff Rasley
<
jerasley@microsoft.com
>
parent
871f3048
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
27 deletions
+29
-27
deepspeed/runtime/zero/stage3.py
deepspeed/runtime/zero/stage3.py
+5
-7
tests/unit/test_checkpointing.py
tests/unit/test_checkpointing.py
+24
-20
No files found.
deepspeed/runtime/zero/stage3.py
View file @
fa87a73a
...
...
@@ -2269,7 +2269,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
assert
single_grad_partition
.
numel
()
==
self
.
fp32_partitioned_groups_flat
[
sub_group_id
].
numel
(),
\
"averaged gradients have different number of elements that partition size {} {} {} {}"
.
format
(
single_grad_partition
.
numel
(),
self
.
partition
_size
[
sub_group_id
],
sub_group_id
,
partition_id
)
single_grad_partition
.
numel
(),
self
.
fp32_
partition
ed_groups_flat
[
sub_group_id
]
.
numel
()
,
sub_group_id
,
partition_id
)
self
.
fp32_partitioned_groups_flat
[
sub_group_id
].
grad
=
single_grad_partition
...
...
@@ -2638,14 +2638,12 @@ class FP16_DeepSpeedZeroOptimizer_Stage3(object):
def
_set_fp32_optimizer_param_groups
(
self
):
for
sub_group_id
,
_
in
enumerate
(
self
.
fp16_groups
):
param_group_id
=
self
.
sub_group_to_group_id
[
sub_group_id
]
self
.
optimizer
.
param_groups
[
param_group_id
][
'params'
]
=
[
self
.
fp32_partitioned_groups_flat
[
sub_group_id
]
]
self
.
optimizer
.
param_groups
[
param_group_id
][
'params'
].
append
(
self
.
fp32_partitioned_groups_flat
[
sub_group_id
])
def
_clear_fp32_optimizer_param_groups
(
self
):
for
sub_group_id
,
_
in
enumerate
(
self
.
fp16_groups
):
param_group_id
=
self
.
sub_group_to_group_id
[
sub_group_id
]
self
.
optimizer
.
param_groups
[
param_group_id
][
'params'
]
=
[]
for
param_group
in
self
.
optimizer
.
param_groups
:
param_group
[
'params'
]
=
[]
def
_rigid_state_dict
(
self
):
state_dict
=
{}
...
...
tests/unit/test_checkpointing.py
View file @
fa87a73a
...
...
@@ -47,7 +47,7 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True):
if
FP16_DeepSpeedZeroOptimizer_Stage3
is
not
None
and
isinstance
(
saved_model
.
optimizer
,
FP16_DeepSpeedZeroOptimizer_Stage3
):
for
p0
,
p1
in
zip
(
saved_model
.
optimizer
.
fp32_groups_flat
,
loaded_model
.
optimizer
.
fp32_groups_flat
):
for
p0
,
p1
in
zip
(
saved_model
.
optimizer
.
fp32_
partitioned_
groups_flat
,
loaded_model
.
optimizer
.
fp32_
partitioned_
groups_flat
):
assert
torch
.
allclose
(
p0
,
p1
,
atol
=
1e-07
),
f
"Fp32 model states
{
p0
}
is not equal to
{
p1
}
"
elif
isinstance
(
saved_model
.
optimizer
,
FP16_DeepSpeedZeroOptimizer
):
...
...
@@ -303,12 +303,13 @@ def test_checkpoint_fused_optimizer(tmpdir):
'deepspeed_adam'
),
(
3
,
False
,
'Adam'
)])
'Adam'
),
(
3
,
True
,
'deepspeed_adam'
)])
def
test_checkpoint_zero_optimizer
(
tmpdir
,
zero_stage
,
use_cpu_offload
,
adam_optimizer
):
if
use_cpu_offload
and
not
deepspeed
.
ops
.
__compatible_ops__
[
CPUAdamBuilder
.
NAME
]:
pytest
.
skip
(
"cpu-adam is not compatible"
)
if
zero_stage
==
3
:
pytest
.
skip
(
'Skip checkpointing tests for ZeRO3'
)
config_dict
=
{
"train_batch_size"
:
2
,
...
...
@@ -324,8 +325,10 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_opt
}
},
"fp16"
:
{
"enabled"
:
True
"enabled"
:
True
,
"initial_scale_power"
:
8
},
"wall_clock_breakdown"
:
True
,
"zero_optimization"
:
{
"stage"
:
zero_stage
,
"cpu_offload"
:
use_cpu_offload
...
...
@@ -340,9 +343,7 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_opt
hidden_dim
,
load_optimizer_states
):
if
zero_stage
==
3
:
global
FP16_DeepSpeedZeroOptimizer_Stage3
from
deepspeed.runtime.zero.stage3
import
FP16_DeepSpeedZeroOptimizer_Stage3
with
deepspeed
.
ScatteredParameters
(
zero_modules
=
True
):
with
deepspeed
.
zero
.
Init
():
models
=
[
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
for
_
in
range
(
2
)]
else
:
models
=
[
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
for
_
in
range
(
2
)]
...
...
@@ -371,15 +372,16 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_opt
'deepspeed_adam'
),
(
3
,
False
,
'Adam'
)])
'Adam'
),
(
3
,
True
,
'deepspeed_adam'
)])
def
test_checkpoint_zero_no_optimizer
(
tmpdir
,
zero_stage
,
use_cpu_offload
,
adam_optimizer
):
if
use_cpu_offload
and
not
deepspeed
.
ops
.
__compatible_ops__
[
CPUAdamBuilder
.
NAME
]:
pytest
.
skip
(
"cpu-adam is not compatible"
)
if
zero_stage
==
3
:
pytest
.
skip
(
'Skip checkpointing tests for ZeRO3'
)
config_dict
=
{
"train_batch_size"
:
2
,
...
...
@@ -413,7 +415,7 @@ def test_checkpoint_zero_no_optimizer(tmpdir,
if
zero_stage
==
3
:
global
FP16_DeepSpeedZeroOptimizer_Stage3
from
deepspeed.runtime.zero.stage3
import
FP16_DeepSpeedZeroOptimizer_Stage3
with
deepspeed
.
ScatteredParameters
(
zero_modules
=
True
):
with
deepspeed
.
zero
.
Init
(
):
models
=
[
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
for
_
in
range
(
2
)]
else
:
models
=
[
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
for
_
in
range
(
2
)]
...
...
@@ -445,12 +447,13 @@ def test_checkpoint_zero_no_optimizer(tmpdir,
'deepspeed_adam'
),
(
3
,
False
,
'Adam'
)])
'Adam'
),
(
3
,
True
,
'deepspeed_adam'
)])
def
test_checkpoint_lr_scheduler
(
tmpdir
,
zero_stage
,
use_cpu_offload
,
adam_optimizer
):
if
use_cpu_offload
and
not
deepspeed
.
ops
.
__compatible_ops__
[
CPUAdamBuilder
.
NAME
]:
pytest
.
skip
(
"cpu-adam is not compatible"
)
if
zero_stage
==
3
:
pytest
.
skip
(
'Skip checkpointing tests for ZeRO3'
)
config_dict
=
{
"train_batch_size"
:
2
,
...
...
@@ -493,7 +496,7 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optim
if
zero_stage
==
3
:
global
FP16_DeepSpeedZeroOptimizer_Stage3
from
deepspeed.runtime.zero.stage3
import
FP16_DeepSpeedZeroOptimizer_Stage3
with
deepspeed
.
ScatteredParameters
(
zero_modules
=
True
):
with
deepspeed
.
zero
.
Init
(
):
models
=
[
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
for
_
in
range
(
2
)]
else
:
models
=
[
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
for
_
in
range
(
2
)]
...
...
@@ -526,14 +529,15 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optim
(
2
,
True
,
'deepspeed_adam'
),
(
3
,
False
,
'Adam'
),
(
3
,
True
,
'
A
dam'
)])
'
deepspeed_a
dam'
)])
def
test_checkpoint_no_lr_scheduler
(
tmpdir
,
zero_stage
,
use_cpu_offload
,
adam_optimizer
):
if
use_cpu_offload
and
not
deepspeed
.
ops
.
__compatible_ops__
[
CPUAdamBuilder
.
NAME
]:
pytest
.
skip
(
"cpu-adam is not compatible"
)
if
zero_stage
==
3
:
pytest
.
skip
(
'Skip checkpointing tests for ZeRO3'
)
config_dict
=
{
"train_batch_size"
:
2
,
...
...
@@ -570,7 +574,7 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_op
load_optimizer_states
,
load_lr_scheduler_states
):
if
zero_stage
==
3
:
with
deepspeed
.
ScatteredParameters
(
zero_modules
=
True
):
with
deepspeed
.
zero
.
Init
(
):
models
=
[
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
for
_
in
range
(
2
)]
else
:
models
=
[
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
for
_
in
range
(
2
)]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment