Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
b2c87edf
"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "06608f847c1e58b16a2c3eb876b1afaa7cc7a7f7"
Unverified
Commit
b2c87edf
authored
May 06, 2020
by
Shaden Smith
Committed by
GitHub
May 06, 2020
Browse files
Fix global_steps checkpoint loading. (#139)
parent
4f42bbb0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
56 additions
and
7 deletions
+56
-7
README.md
README.md
+1
-1
deepspeed/pt/deepspeed_light.py
deepspeed/pt/deepspeed_light.py
+1
-1
tests/unit/test_checkpointing.py
tests/unit/test_checkpointing.py
+54
-5
No files found.
README.md
View file @
b2c87edf
...
@@ -8,7 +8,7 @@ efficient, and effective.
...
@@ -8,7 +8,7 @@ efficient, and effective.
<p
align=
"center"
><i><b>
5x Faster Training
</b></i></p>
<p
align=
"center"
><i><b>
5x Faster Training
</b></i></p>
<p
align=
"center"
><i><b>
Minimal Code Change
</b></i></p>
<p
align=
"center"
><i><b>
Minimal Code Change
</b></i></p>
DeepSpeed can train
DL
models with over a hundred billion parameters on current
DeepSpeed can train
deep learning
models with over a hundred billion parameters on current
generation of GPU clusters, while achieving over 5x in system performance
generation of GPU clusters, while achieving over 5x in system performance
compared to the state-of-art. Early adopters of DeepSpeed have already produced
compared to the state-of-art. Early adopters of DeepSpeed have already produced
a language model (LM) with over 17B parameters called
a language model (LM) with over 17B parameters called
...
...
deepspeed/pt/deepspeed_light.py
View file @
b2c87edf
...
@@ -1021,7 +1021,7 @@ class DeepSpeedLight(Module):
...
@@ -1021,7 +1021,7 @@ class DeepSpeedLight(Module):
'optimizer'
,
'optimizer'
,
'csr_tensor_module_names'
,
'csr_tensor_module_names'
,
'skipped_steps'
,
'skipped_steps'
,
'global_step'
'global_step
s
'
]
]
client_state
=
{
client_state
=
{
key
:
value
key
:
value
...
...
tests/unit/test_checkpointing.py
View file @
b2c87edf
...
@@ -13,7 +13,33 @@ from common import distributed_test
...
@@ -13,7 +13,33 @@ from common import distributed_test
from
simple_model
import
SimpleModel
,
random_dataloader
,
args_from_dict
from
simple_model
import
SimpleModel
,
random_dataloader
,
args_from_dict
def
compare_deepspeed_states
(
saved_model
,
loaded_model
):
# These are compared in more depth in other places
assert
hasattr
(
loaded_model
,
'module'
)
assert
saved_model
.
csr_tensor_module_names
==
loaded_model
.
csr_tensor_module_names
assert
saved_model
.
skipped_steps
==
loaded_model
.
skipped_steps
assert
saved_model
.
global_steps
==
loaded_model
.
global_steps
def
compare_lr_scheduler_states
(
saved_model
,
loaded_model
):
if
saved_model
.
lr_scheduler
is
None
:
assert
loaded_model
.
lr_scheduler
is
None
return
saved
=
saved_model
.
lr_scheduler
.
state_dict
()
loaded
=
loaded_model
.
lr_scheduler
.
state_dict
()
assert
sorted
(
saved
.
keys
())
==
sorted
(
loaded
.
keys
())
for
key
in
saved
.
keys
():
if
isinstance
(
saved
[
key
],
torch
.
Tensor
):
assert
torch
.
equal
(
saved
[
key
],
loaded
[
key
])
else
:
assert
saved
[
key
]
==
loaded
[
key
]
def
compare_model_states
(
saved_model
,
loaded_model
):
def
compare_model_states
(
saved_model
,
loaded_model
):
compare_deepspeed_states
(
saved_model
,
loaded_model
)
for
p0
,
p1
in
zip
(
saved_model
.
module
.
parameters
(),
loaded_model
.
module
.
parameters
()):
for
p0
,
p1
in
zip
(
saved_model
.
module
.
parameters
(),
loaded_model
.
module
.
parameters
()):
assert
torch
.
allclose
(
p0
,
p1
,
atol
=
1e-07
),
f
"FP16 model state
{
p0
}
is not equal to
{
p1
}
"
assert
torch
.
allclose
(
p0
,
p1
,
atol
=
1e-07
),
f
"FP16 model state
{
p0
}
is not equal to
{
p1
}
"
...
@@ -37,6 +63,8 @@ def compare_model_states(saved_model, loaded_model):
...
@@ -37,6 +63,8 @@ def compare_model_states(saved_model, loaded_model):
def
compare_optimizer_states
(
saved_model
,
loaded_model
,
hidden_dim
):
def
compare_optimizer_states
(
saved_model
,
loaded_model
,
hidden_dim
):
compare_model_states
(
saved_model
,
loaded_model
)
compare_model_states
(
saved_model
,
loaded_model
)
assert
hasattr
(
loaded_model
,
'optimizer'
)
for
state0
,
state1
in
zip
(
saved_model
.
optimizer
.
optimizer
.
state
.
values
(),
for
state0
,
state1
in
zip
(
saved_model
.
optimizer
.
optimizer
.
state
.
values
(),
loaded_model
.
optimizer
.
optimizer
.
state
.
values
()):
loaded_model
.
optimizer
.
optimizer
.
state
.
values
()):
for
s0
,
s1
in
zip
(
state0
.
values
(),
state1
.
values
()):
for
s0
,
s1
in
zip
(
state0
.
values
(),
state1
.
values
()):
...
@@ -46,7 +74,8 @@ def compare_optimizer_states(saved_model, loaded_model, hidden_dim):
...
@@ -46,7 +74,8 @@ def compare_optimizer_states(saved_model, loaded_model, hidden_dim):
assert
s0
==
s1
assert
s0
==
s1
def
checkpoint_correctness_verification
(
args
,
def
checkpoint_correctness_verification
(
save_folder
,
args
,
model
,
model
,
hidden_dim
,
hidden_dim
,
load_optimizer_states
=
True
):
load_optimizer_states
=
True
):
...
@@ -65,7 +94,6 @@ def checkpoint_correctness_verification(args,
...
@@ -65,7 +94,6 @@ def checkpoint_correctness_verification(args,
trained_model
=
ds_model
trained_model
=
ds_model
save_folder
=
'saved_checkpoint'
save_tag
=
'1'
save_tag
=
'1'
trained_model
.
save_checkpoint
(
save_folder
,
save_tag
)
trained_model
.
save_checkpoint
(
save_folder
,
save_tag
)
...
@@ -78,6 +106,8 @@ def checkpoint_correctness_verification(args,
...
@@ -78,6 +106,8 @@ def checkpoint_correctness_verification(args,
save_tag
,
save_tag
,
load_optimizer_states
=
load_optimizer_states
)
load_optimizer_states
=
load_optimizer_states
)
compare_lr_scheduler_states
(
trained_model
,
loaded_model
)
if
load_optimizer_states
:
if
load_optimizer_states
:
compare_optimizer_states
(
trained_model
,
loaded_model
,
hidden_dim
)
compare_optimizer_states
(
trained_model
,
loaded_model
,
hidden_dim
)
else
:
else
:
...
@@ -97,6 +127,22 @@ def test_checkpoint_unfused_optimizer(tmpdir):
...
@@ -97,6 +127,22 @@ def test_checkpoint_unfused_optimizer(tmpdir):
},
},
"fp16"
:
{
"fp16"
:
{
"enabled"
:
True
"enabled"
:
True
},
"scheduler"
:
{
"type"
:
"OneCycle"
,
"params"
:
{
"cycle_first_step_size"
:
1000
,
"cycle_first_stair_count"
:
500
,
"cycle_second_step_size"
:
1000
,
"cycle_second_stair_count"
:
500
,
"decay_step_size"
:
1000
,
"cycle_min_lr"
:
0.0001
,
"cycle_max_lr"
:
0.0010
,
"decay_lr_rate"
:
0.001
,
"cycle_min_mom"
:
0.85
,
"cycle_max_mom"
:
0.99
,
"decay_mom_rate"
:
0.0
}
}
}
}
}
...
@@ -110,7 +156,8 @@ def test_checkpoint_unfused_optimizer(tmpdir):
...
@@ -110,7 +156,8 @@ def test_checkpoint_unfused_optimizer(tmpdir):
model
,
model
,
hidden_dim
,
hidden_dim
,
load_optimizer_states
):
load_optimizer_states
):
checkpoint_correctness_verification
(
args
,
checkpoint_correctness_verification
(
tmpdir
,
args
,
model
,
model
,
hidden_dim
,
hidden_dim
,
load_optimizer_states
=
load_optimizer_states
)
load_optimizer_states
=
load_optimizer_states
)
...
@@ -151,7 +198,8 @@ def test_checkpoint_fused_optimizer(tmpdir):
...
@@ -151,7 +198,8 @@ def test_checkpoint_fused_optimizer(tmpdir):
@
distributed_test
(
world_size
=
[
2
])
@
distributed_test
(
world_size
=
[
2
])
def
_test_checkpoint_fused_optimizer
(
args
,
model
,
hidden_dim
,
load_optimizer_states
):
def
_test_checkpoint_fused_optimizer
(
args
,
model
,
hidden_dim
,
load_optimizer_states
):
checkpoint_correctness_verification
(
args
,
checkpoint_correctness_verification
(
tmpdir
,
args
,
model
,
model
,
hidden_dim
,
hidden_dim
,
load_optimizer_states
=
load_optimizer_states
)
load_optimizer_states
=
load_optimizer_states
)
...
@@ -192,7 +240,8 @@ def test_checkpoint_zero_optimizer(tmpdir):
...
@@ -192,7 +240,8 @@ def test_checkpoint_zero_optimizer(tmpdir):
@
distributed_test
(
world_size
=
[
2
])
@
distributed_test
(
world_size
=
[
2
])
def
_test_checkpoint_zero_optimizer
(
args
,
model
,
hidden_dim
,
load_optimizer_states
):
def
_test_checkpoint_zero_optimizer
(
args
,
model
,
hidden_dim
,
load_optimizer_states
):
checkpoint_correctness_verification
(
args
,
checkpoint_correctness_verification
(
tmpdir
,
args
,
model
,
model
,
hidden_dim
,
hidden_dim
,
load_optimizer_states
=
load_optimizer_states
)
load_optimizer_states
=
load_optimizer_states
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment