Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
b2c87edf
Unverified
Commit
b2c87edf
authored
May 06, 2020
by
Shaden Smith
Committed by
GitHub
May 06, 2020
Browse files
Fix global_steps checkpoint loading. (#139)
parent
4f42bbb0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
56 additions
and
7 deletions
+56
-7
README.md
README.md
+1
-1
deepspeed/pt/deepspeed_light.py
deepspeed/pt/deepspeed_light.py
+1
-1
tests/unit/test_checkpointing.py
tests/unit/test_checkpointing.py
+54
-5
No files found.
README.md
View file @
b2c87edf
...
...
@@ -8,7 +8,7 @@ efficient, and effective.
<p
align=
"center"
><i><b>
5x Faster Training
</b></i></p>
<p
align=
"center"
><i><b>
Minimal Code Change
</b></i></p>
DeepSpeed can train
DL
models with over a hundred billion parameters on current
DeepSpeed can train
deep learning
models with over a hundred billion parameters on current
generation of GPU clusters, while achieving over 5x in system performance
compared to the state-of-art. Early adopters of DeepSpeed have already produced
a language model (LM) with over 17B parameters called
...
...
deepspeed/pt/deepspeed_light.py
View file @
b2c87edf
...
...
@@ -1021,7 +1021,7 @@ class DeepSpeedLight(Module):
'optimizer'
,
'csr_tensor_module_names'
,
'skipped_steps'
,
'global_step'
'global_step
s
'
]
client_state
=
{
key
:
value
...
...
tests/unit/test_checkpointing.py
View file @
b2c87edf
...
...
@@ -13,7 +13,33 @@ from common import distributed_test
from
simple_model
import
SimpleModel
,
random_dataloader
,
args_from_dict
def
compare_deepspeed_states
(
saved_model
,
loaded_model
):
# These are compared in more depth in other places
assert
hasattr
(
loaded_model
,
'module'
)
assert
saved_model
.
csr_tensor_module_names
==
loaded_model
.
csr_tensor_module_names
assert
saved_model
.
skipped_steps
==
loaded_model
.
skipped_steps
assert
saved_model
.
global_steps
==
loaded_model
.
global_steps
def
compare_lr_scheduler_states
(
saved_model
,
loaded_model
):
if
saved_model
.
lr_scheduler
is
None
:
assert
loaded_model
.
lr_scheduler
is
None
return
saved
=
saved_model
.
lr_scheduler
.
state_dict
()
loaded
=
loaded_model
.
lr_scheduler
.
state_dict
()
assert
sorted
(
saved
.
keys
())
==
sorted
(
loaded
.
keys
())
for
key
in
saved
.
keys
():
if
isinstance
(
saved
[
key
],
torch
.
Tensor
):
assert
torch
.
equal
(
saved
[
key
],
loaded
[
key
])
else
:
assert
saved
[
key
]
==
loaded
[
key
]
def
compare_model_states
(
saved_model
,
loaded_model
):
compare_deepspeed_states
(
saved_model
,
loaded_model
)
for
p0
,
p1
in
zip
(
saved_model
.
module
.
parameters
(),
loaded_model
.
module
.
parameters
()):
assert
torch
.
allclose
(
p0
,
p1
,
atol
=
1e-07
),
f
"FP16 model state
{
p0
}
is not equal to
{
p1
}
"
...
...
@@ -37,6 +63,8 @@ def compare_model_states(saved_model, loaded_model):
def
compare_optimizer_states
(
saved_model
,
loaded_model
,
hidden_dim
):
compare_model_states
(
saved_model
,
loaded_model
)
assert
hasattr
(
loaded_model
,
'optimizer'
)
for
state0
,
state1
in
zip
(
saved_model
.
optimizer
.
optimizer
.
state
.
values
(),
loaded_model
.
optimizer
.
optimizer
.
state
.
values
()):
for
s0
,
s1
in
zip
(
state0
.
values
(),
state1
.
values
()):
...
...
@@ -46,7 +74,8 @@ def compare_optimizer_states(saved_model, loaded_model, hidden_dim):
assert
s0
==
s1
def
checkpoint_correctness_verification
(
args
,
def
checkpoint_correctness_verification
(
save_folder
,
args
,
model
,
hidden_dim
,
load_optimizer_states
=
True
):
...
...
@@ -65,7 +94,6 @@ def checkpoint_correctness_verification(args,
trained_model
=
ds_model
save_folder
=
'saved_checkpoint'
save_tag
=
'1'
trained_model
.
save_checkpoint
(
save_folder
,
save_tag
)
...
...
@@ -78,6 +106,8 @@ def checkpoint_correctness_verification(args,
save_tag
,
load_optimizer_states
=
load_optimizer_states
)
compare_lr_scheduler_states
(
trained_model
,
loaded_model
)
if
load_optimizer_states
:
compare_optimizer_states
(
trained_model
,
loaded_model
,
hidden_dim
)
else
:
...
...
@@ -97,6 +127,22 @@ def test_checkpoint_unfused_optimizer(tmpdir):
},
"fp16"
:
{
"enabled"
:
True
},
"scheduler"
:
{
"type"
:
"OneCycle"
,
"params"
:
{
"cycle_first_step_size"
:
1000
,
"cycle_first_stair_count"
:
500
,
"cycle_second_step_size"
:
1000
,
"cycle_second_stair_count"
:
500
,
"decay_step_size"
:
1000
,
"cycle_min_lr"
:
0.0001
,
"cycle_max_lr"
:
0.0010
,
"decay_lr_rate"
:
0.001
,
"cycle_min_mom"
:
0.85
,
"cycle_max_mom"
:
0.99
,
"decay_mom_rate"
:
0.0
}
}
}
...
...
@@ -110,7 +156,8 @@ def test_checkpoint_unfused_optimizer(tmpdir):
model
,
hidden_dim
,
load_optimizer_states
):
checkpoint_correctness_verification
(
args
,
checkpoint_correctness_verification
(
tmpdir
,
args
,
model
,
hidden_dim
,
load_optimizer_states
=
load_optimizer_states
)
...
...
@@ -151,7 +198,8 @@ def test_checkpoint_fused_optimizer(tmpdir):
@
distributed_test
(
world_size
=
[
2
])
def
_test_checkpoint_fused_optimizer
(
args
,
model
,
hidden_dim
,
load_optimizer_states
):
checkpoint_correctness_verification
(
args
,
checkpoint_correctness_verification
(
tmpdir
,
args
,
model
,
hidden_dim
,
load_optimizer_states
=
load_optimizer_states
)
...
...
@@ -192,7 +240,8 @@ def test_checkpoint_zero_optimizer(tmpdir):
@
distributed_test
(
world_size
=
[
2
])
def
_test_checkpoint_zero_optimizer
(
args
,
model
,
hidden_dim
,
load_optimizer_states
):
checkpoint_correctness_verification
(
args
,
checkpoint_correctness_verification
(
tmpdir
,
args
,
model
,
hidden_dim
,
load_optimizer_states
=
load_optimizer_states
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment