Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
607814fe
Unverified
Commit
607814fe
authored
Jul 15, 2020
by
Olatunji Ruwase
Committed by
GitHub
Jul 15, 2020
Browse files
Fix bug in fp32 optimizer state loading (#289)
parent
7ccc9daf
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
13 deletions
+54
-13
deepspeed/pt/deepspeed_light.py
deepspeed/pt/deepspeed_light.py
+6
-2
tests/unit/simple_model.py
tests/unit/simple_model.py
+2
-2
tests/unit/test_checkpointing.py
tests/unit/test_checkpointing.py
+46
-9
No files found.
deepspeed/pt/deepspeed_light.py
View file @
607814fe
...
@@ -1140,8 +1140,12 @@ class DeepSpeedLight(Module):
...
@@ -1140,8 +1140,12 @@ class DeepSpeedLight(Module):
self
.
load_module_state_dict
(
state_dict
=
checkpoint
[
'module'
],
self
.
load_module_state_dict
(
state_dict
=
checkpoint
[
'module'
],
strict
=
load_module_strict
)
strict
=
load_module_strict
)
if
not
self
.
zero_optimization
():
if
not
self
.
zero_optimization
():
self
.
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
],
if
self
.
fp16_enabled
():
load_optimizer_states
=
load_optimizer_states
)
self
.
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
],
load_optimizer_states
=
load_optimizer_states
)
else
:
self
.
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
if
load_lr_scheduler_states
and
self
.
lr_scheduler
is
not
None
:
if
load_lr_scheduler_states
and
self
.
lr_scheduler
is
not
None
:
self
.
lr_scheduler
.
load_state_dict
(
checkpoint
[
'lr_scheduler'
])
self
.
lr_scheduler
.
load_state_dict
(
checkpoint
[
'lr_scheduler'
])
...
...
tests/unit/simple_model.py
View file @
607814fe
...
@@ -41,9 +41,9 @@ class SimpleOptimizer(torch.optim.Optimizer):
...
@@ -41,9 +41,9 @@ class SimpleOptimizer(torch.optim.Optimizer):
return
loss
return
loss
def
random_dataloader
(
model
,
total_samples
,
hidden_dim
,
device
):
def
random_dataloader
(
model
,
total_samples
,
hidden_dim
,
device
,
dtype
=
torch
.
half
):
batch_size
=
model
.
train_micro_batch_size_per_gpu
()
batch_size
=
model
.
train_micro_batch_size_per_gpu
()
train_data
=
torch
.
randn
(
total_samples
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
half
)
train_data
=
torch
.
randn
(
total_samples
,
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
train_label
=
torch
.
empty
(
total_samples
,
train_label
=
torch
.
empty
(
total_samples
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
).
random_
(
hidden_dim
)
device
=
device
).
random_
(
hidden_dim
)
...
...
tests/unit/test_checkpointing.py
View file @
607814fe
...
@@ -47,14 +47,18 @@ def compare_model_states(saved_model, loaded_model):
...
@@ -47,14 +47,18 @@ def compare_model_states(saved_model, loaded_model):
for
params0
,
params1
in
zip
(
saved_model
.
optimizer
.
fp32_groups
,
loaded_model
.
optimizer
.
fp32_groups
):
for
params0
,
params1
in
zip
(
saved_model
.
optimizer
.
fp32_groups
,
loaded_model
.
optimizer
.
fp32_groups
):
for
p0
,
p1
in
zip
(
params0
,
params1
):
for
p0
,
p1
in
zip
(
params0
,
params1
):
assert
torch
.
allclose
(
p0
,
p1
,
atol
=
1e-07
),
f
"FP32 model states
{
p0
}
is not equal to
{
p1
}
"
assert
torch
.
allclose
(
p0
,
p1
,
atol
=
1e-07
),
f
"FP32 model states
{
p0
}
is not equal to
{
p1
}
"
elif
isinstance
(
saved_model
.
optimizer
,
torch
.
optim
.
Optimizer
):
pass
else
:
else
:
assert
False
,
'Unexpected Optimizer Type'
assert
False
,
f
'Unexpected Optimizer Type:
{
saved_model
.
optimizer
}
'
def
compare_optimizer_states
(
saved_model
,
loaded_model
,
hidden_dim
,
fp16
=
True
):
saved_optimizer
=
saved_model
.
optimizer
.
optimizer
if
fp16
else
saved_model
.
optimizer
loaded_optimizer
=
loaded_model
.
optimizer
.
optimizer
if
fp16
else
loaded_model
.
optimizer
def
compare_optimizer_states
(
saved_model
,
loaded_model
,
hidden_dim
):
for
state0
,
state1
in
zip
(
saved_optimizer
.
state
.
values
(),
for
state0
,
state1
in
zip
(
saved_model
.
optimizer
.
optimizer
.
state
.
values
(),
loaded_optimizer
.
state
.
values
()):
loaded_model
.
optimizer
.
optimizer
.
state
.
values
()):
for
s0
,
s1
in
zip
(
state0
.
values
(),
state1
.
values
()):
for
s0
,
s1
in
zip
(
state0
.
values
(),
state1
.
values
()):
if
isinstance
(
s0
,
torch
.
Tensor
)
and
isinstance
(
s1
,
torch
.
Tensor
):
if
isinstance
(
s0
,
torch
.
Tensor
)
and
isinstance
(
s1
,
torch
.
Tensor
):
assert
torch
.
equal
(
s0
,
s1
)
assert
torch
.
equal
(
s0
,
s1
)
...
@@ -90,15 +94,17 @@ def checkpoint_correctness_verification(args,
...
@@ -90,15 +94,17 @@ def checkpoint_correctness_verification(args,
hidden_dim
,
hidden_dim
,
tmpdir
,
tmpdir
,
load_optimizer_states
=
False
,
load_optimizer_states
=
False
,
load_lr_scheduler_states
=
False
):
load_lr_scheduler_states
=
False
,
fp16
=
True
):
dtype
=
torch
.
half
if
fp16
else
torch
.
float32
ds_model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
ds_model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
model
=
model
,
model
=
model
,
model_parameters
=
model
.
parameters
())
model_parameters
=
model
.
parameters
())
data_loader
=
random_dataloader
(
model
=
ds_model
,
data_loader
=
random_dataloader
(
model
=
ds_model
,
total_samples
=
50
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
hidden_dim
=
hidden_dim
,
device
=
ds_model
.
device
)
device
=
ds_model
.
device
,
dtype
=
dtype
)
for
n
,
batch
in
enumerate
(
data_loader
):
for
n
,
batch
in
enumerate
(
data_loader
):
loss
=
ds_model
(
batch
[
0
],
batch
[
1
])
loss
=
ds_model
(
batch
[
0
],
batch
[
1
])
ds_model
.
backward
(
loss
)
ds_model
.
backward
(
loss
)
...
@@ -123,7 +129,7 @@ def checkpoint_correctness_verification(args,
...
@@ -123,7 +129,7 @@ def checkpoint_correctness_verification(args,
compare_model_states
(
trained_model
,
loaded_model
)
compare_model_states
(
trained_model
,
loaded_model
)
if
load_optimizer_states
:
if
load_optimizer_states
:
compare_optimizer_states
(
trained_model
,
loaded_model
,
hidden_dim
)
compare_optimizer_states
(
trained_model
,
loaded_model
,
hidden_dim
,
fp16
)
if
load_lr_scheduler_states
:
if
load_lr_scheduler_states
:
compare_lr_scheduler_states
(
trained_model
,
loaded_model
)
compare_lr_scheduler_states
(
trained_model
,
loaded_model
)
...
@@ -420,3 +426,34 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage):
...
@@ -420,3 +426,34 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage):
hidden_dim
=
hidden_dim
,
hidden_dim
=
hidden_dim
,
load_optimizer_states
=
False
,
load_optimizer_states
=
False
,
load_lr_scheduler_states
=
False
)
load_lr_scheduler_states
=
False
)
def
test_checkpoint_fp32_optimizer
(
tmpdir
):
config_dict
=
{
"train_batch_size"
:
2
,
"steps_per_print"
:
1
,
"optimizer"
:
{
"type"
:
"Adam"
,
"params"
:
{
"lr"
:
0.00015
,
"betas"
:
[
0.8
,
0.999
],
"eps"
:
1e-8
,
"weight_decay"
:
3e-7
}
},
"fp16"
:
{
"enabled"
:
False
}
}
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
@
distributed_test
(
world_size
=
[
2
])
def
_test_checkpoint_fp32_optimizer
(
args
,
model
,
hidden_dim
):
checkpoint_correctness_verification
(
args
,
model
,
hidden_dim
,
tmpdir
,
fp16
=
False
)
_test_checkpoint_fp32_optimizer
(
args
=
args
,
model
=
model
,
hidden_dim
=
hidden_dim
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment