Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
845921b3
Unverified
Commit
845921b3
authored
Dec 02, 2020
by
Jeff Rasley
Committed by
GitHub
Dec 02, 2020
Browse files
Add 'latest' checkpoint save/load support (#569)
parent
7a75f8b3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
85 additions
and
9 deletions
+85
-9
deepspeed/runtime/engine.py
deepspeed/runtime/engine.py
+24
-5
tests/unit/test_checkpointing.py
tests/unit/test_checkpointing.py
+61
-4
No files found.
deepspeed/runtime/engine.py
View file @
845921b3
...
@@ -1300,15 +1300,15 @@ class DeepSpeedEngine(Module):
...
@@ -1300,15 +1300,15 @@ class DeepSpeedEngine(Module):
def
load_checkpoint
(
self
,
def
load_checkpoint
(
self
,
load_dir
,
load_dir
,
tag
,
tag
=
None
,
load_module_strict
=
True
,
load_module_strict
=
True
,
load_optimizer_states
=
True
,
load_optimizer_states
=
True
,
load_lr_scheduler_states
=
True
):
load_lr_scheduler_states
=
True
):
r
"""Load training checkpoint
"""Load training checkpoint
Arguments:
Arguments:
load_dir: Required. Directory to load the checkpoint from
load_dir: Required. Directory to load the checkpoint from
tag:
Required.
Checkpoint tag used as a unique identifier for
the
checkpoint
. Ex. Global Step.
tag: Checkpoint tag used as a unique identifier for checkpoint
, if not provided will attempt to load tag in 'latest' file
load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance
load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance
load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
...
@@ -1317,6 +1317,13 @@ class DeepSpeedEngine(Module):
...
@@ -1317,6 +1317,13 @@ class DeepSpeedEngine(Module):
client_state: State dictionary used for loading required training states in the client code.
client_state: State dictionary used for loading required training states in the client code.
"""
"""
if
tag
is
None
:
latest_path
=
os
.
path
.
join
(
load_dir
,
'latest'
)
assert
os
.
path
.
isfile
(
latest_path
),
f
"Unable to find latest file at
{
latest_path
}
, if trying to load latest "
\
"checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint."
with
open
(
latest_path
,
'r'
)
as
fd
:
tag
=
fd
.
read
().
strip
()
load_path
,
client_states
=
self
.
_load_checkpoint
(
load_dir
,
load_path
,
client_states
=
self
.
_load_checkpoint
(
load_dir
,
tag
,
tag
,
load_module_strict
=
load_module_strict
,
load_module_strict
=
load_module_strict
,
...
@@ -1454,18 +1461,25 @@ class DeepSpeedEngine(Module):
...
@@ -1454,18 +1461,25 @@ class DeepSpeedEngine(Module):
)
)
return
zero_optimizer_sd
return
zero_optimizer_sd
def
save_checkpoint
(
self
,
save_dir
,
tag
,
client_state
=
{}):
def
save_checkpoint
(
self
,
save_dir
,
tag
=
None
,
client_state
=
{}
,
save_latest
=
True
):
r
"""Save training checkpoint
r
"""Save training checkpoint
Arguments:
Arguments:
save_dir: Required. Directory for saving the checkpoint
save_dir: Required. Directory for saving the checkpoint
tag:
Required
. Checkpoint tag used as a unique identifier for the checkpoint
. Ex. G
lobal
S
tep.
tag:
Optional
. Checkpoint tag used as a unique identifier for the checkpoint
, g
lobal
s
tep
is used if not provided
.
client_state: Optional. State dictionary used for saving required training states in the client code.
client_state: Optional. State dictionary used for saving required training states in the client code.
save_latest: Optional. Save a file 'latest' pointing to the latest saved checkpoint.
"""
"""
# This is to make sure the checkpoint names are created without collision
# This is to make sure the checkpoint names are created without collision
# There seems to be issue creating them in parallel
# There seems to be issue creating them in parallel
# Ensure save_dir directory exists
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
if
tag
is
None
:
tag
=
f
"global_step
{
self
.
global_steps
}
"
if
self
.
save_non_zero_checkpoint
:
if
self
.
save_non_zero_checkpoint
:
self
.
_create_checkpoint_file
(
save_dir
,
tag
,
False
)
self
.
_create_checkpoint_file
(
save_dir
,
tag
,
False
)
self
.
_save_checkpoint
(
save_dir
,
tag
,
client_state
=
client_state
)
self
.
_save_checkpoint
(
save_dir
,
tag
,
client_state
=
client_state
)
...
@@ -1474,6 +1488,11 @@ class DeepSpeedEngine(Module):
...
@@ -1474,6 +1488,11 @@ class DeepSpeedEngine(Module):
self
.
_create_zero_checkpoint_files
(
save_dir
,
tag
)
self
.
_create_zero_checkpoint_files
(
save_dir
,
tag
)
self
.
_save_zero_checkpoint
(
save_dir
,
tag
)
self
.
_save_zero_checkpoint
(
save_dir
,
tag
)
# Save latest checkpoint tag
if
save_latest
:
with
open
(
os
.
path
.
join
(
save_dir
,
'latest'
),
'w'
)
as
fd
:
fd
.
write
(
tag
)
return
True
return
True
def
_create_checkpoint_file
(
self
,
save_dir
,
tag
,
zero_checkpoint
):
def
_create_checkpoint_file
(
self
,
save_dir
,
tag
,
zero_checkpoint
):
...
...
tests/unit/test_checkpointing.py
View file @
845921b3
...
@@ -128,7 +128,8 @@ def checkpoint_correctness_verification(args,
...
@@ -128,7 +128,8 @@ def checkpoint_correctness_verification(args,
fp16
=
True
,
fp16
=
True
,
train_batch
=
False
,
train_batch
=
False
,
base_optimizers
=
[
None
,
base_optimizers
=
[
None
,
None
]):
None
],
empty_tag
=
False
):
dtype
=
torch
.
half
if
fp16
else
torch
.
float32
dtype
=
torch
.
half
if
fp16
else
torch
.
float32
ds_model
=
create_deepspeed_model
(
args
=
args
,
ds_model
=
create_deepspeed_model
(
args
=
args
,
model
=
models
[
0
],
model
=
models
[
0
],
...
@@ -153,16 +154,16 @@ def checkpoint_correctness_verification(args,
...
@@ -153,16 +154,16 @@ def checkpoint_correctness_verification(args,
trained_model
=
ds_model
trained_model
=
ds_model
save_folder
=
os
.
path
.
join
(
tmpdir
,
'saved_checkpoint'
)
save_folder
=
os
.
path
.
join
(
tmpdir
,
'saved_checkpoint'
)
save_tag
=
'1'
save_tag
=
None
if
empty_tag
else
'1'
trained_model
.
save_checkpoint
(
save_folder
,
save_tag
)
trained_model
.
save_checkpoint
(
save_folder
,
tag
=
save_tag
)
loaded_model
=
create_deepspeed_model
(
args
=
args
,
loaded_model
=
create_deepspeed_model
(
args
=
args
,
model
=
models
[
1
],
model
=
models
[
1
],
base_optimizer
=
base_optimizers
[
1
])
base_optimizer
=
base_optimizers
[
1
])
loaded_model
.
load_checkpoint
(
save_folder
,
loaded_model
.
load_checkpoint
(
save_folder
,
save_tag
,
tag
=
save_tag
,
load_optimizer_states
=
load_optimizer_states
,
load_optimizer_states
=
load_optimizer_states
,
load_lr_scheduler_states
=
load_lr_scheduler_states
)
load_lr_scheduler_states
=
load_lr_scheduler_states
)
...
@@ -704,3 +705,59 @@ def test_checkpoint_zero_hybrid_optimizer_state(tmpdir, zero_stage):
...
@@ -704,3 +705,59 @@ def test_checkpoint_zero_hybrid_optimizer_state(tmpdir, zero_stage):
models
=
models
,
models
=
models
,
optimizers
=
optimizers
,
optimizers
=
optimizers
,
hidden_dim
=
hidden_dim
)
hidden_dim
=
hidden_dim
)
def
test_checkpoint_latest
(
tmpdir
):
config_dict
=
{
"train_batch_size"
:
2
,
"steps_per_print"
:
1
,
"optimizer"
:
{
"type"
:
"Adam"
,
"params"
:
{
"lr"
:
0.00015
}
}
}
hidden_dim
=
10
args
=
args_from_dict
(
tmpdir
,
config_dict
)
models
=
[
SimpleModel
(
hidden_dim
=
hidden_dim
)
for
_
in
range
(
2
)]
@
distributed_test
(
world_size
=
[
1
])
def
_helper
(
args
,
models
):
checkpoint_correctness_verification
(
args
,
models
=
models
,
hidden_dim
=
hidden_dim
,
tmpdir
=
tmpdir
,
load_optimizer_states
=
True
,
load_lr_scheduler_states
=
False
,
fp16
=
False
,
empty_tag
=
True
)
_helper
(
args
,
models
)
def
test_checkpoint_missing_latest
(
tmpdir
):
config_dict
=
{
"train_batch_size"
:
2
,
"steps_per_print"
:
1
,
"optimizer"
:
{
"type"
:
"Adam"
,
"params"
:
{
"lr"
:
0.00015
}
}
}
hidden_dim
=
10
args
=
args_from_dict
(
tmpdir
,
config_dict
)
model
=
SimpleModel
(
hidden_dim
,
rank
=
args
.
local_rank
)
@
distributed_test
(
world_size
=
[
1
])
def
_helper
(
args
,
model
,
hidden_dim
):
model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
model
=
model
,
model_parameters
=
model
.
parameters
())
with
pytest
.
raises
(
AssertionError
):
model
.
load_checkpoint
(
tmpdir
)
_helper
(
args
=
args
,
model
=
model
,
hidden_dim
=
hidden_dim
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment