Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
2725dc0b
Commit
2725dc0b
authored
Jun 24, 2022
by
Jared Casper
Browse files
Fixing up checkpointing.
parent
4eb802c4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
20 deletions
+31
-20
megatron/checkpointing.py
megatron/checkpointing.py
+31
-20
No files found.
megatron/checkpointing.py
View file @
2725dc0b
...
...
@@ -117,7 +117,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
model_name
=
optim_name
=
os
.
path
.
join
(
common_path
,
"model_optim_rng.pt"
)
return
model_name
,
optim_name
def
find_checkpoint_rank_0
(
checkpoints_path
,
iteration
,
release
=
False
):
def
find_checkpoint_rank_0
(
checkpoints_path
,
iteration
,
use_distributed_optimizer
,
release
=
False
):
"""Finds the checkpoint for rank 0 without knowing if we are using
pipeline parallelism or not.
...
...
@@ -128,20 +128,20 @@ def find_checkpoint_rank_0(checkpoints_path, iteration, release=False):
"""
# Look for checkpoint with no pipelining
filename
=
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
,
pipeline_parallel
=
False
,
tensor_rank
=
0
,
pipeline_rank
=
0
)
if
os
.
path
.
isfile
(
filename
):
return
filename
filename
s
=
get_checkpoint_name
s
(
checkpoints_path
,
iteration
,
use_distributed_optimizer
,
release
,
pipeline_parallel
=
False
,
tensor_rank
=
0
,
pipeline_rank
=
0
)
if
os
.
path
.
isfile
(
filename
s
[
0
]
):
return
filename
s
# Look for checkpoint with pipelining
filename
=
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
,
pipeline_parallel
=
True
,
tensor_rank
=
0
,
pipeline_rank
=
0
)
if
os
.
path
.
isfile
(
filename
):
return
filename
filename
s
=
get_checkpoint_name
s
(
checkpoints_path
,
iteration
,
use_distributed_optimizer
,
release
,
pipeline_parallel
=
True
,
tensor_rank
=
0
,
pipeline_rank
=
0
)
if
os
.
path
.
isfile
(
filename
s
[
0
]
):
return
filename
s
return
None
return
None
,
None
def
get_checkpoint_tracker_filename
(
checkpoints_path
):
...
...
@@ -370,7 +370,7 @@ def fix_query_key_value_ordering(model, checkpoint_version):
print_rank_0
(
" succesfully fixed query-key-values ordering for"
" checkpoint version {}"
.
format
(
checkpoint_version
))
def
_load_base_checkpoint
(
load_dir
,
rank0
=
False
):
def
_load_base_checkpoint
(
load_dir
,
use_distributed_optimizer
,
rank0
=
False
):
""" Load the base state_dict from the given directory
If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
...
...
@@ -395,11 +395,11 @@ def _load_base_checkpoint(load_dir, rank0=False):
# Checkpoint.
if
rank0
:
checkpoint_names
=
find_checkpoint_rank_0
(
load_dir
,
iteration
,
args
.
use_distributed_optimizer
,
checkpoint_names
=
find_checkpoint_rank_0
(
load_dir
,
iteration
,
use_distributed_optimizer
,
release
)
else
:
checkpoint_names
=
get_checkpoint_name
(
load_dir
,
iteration
,
args
.
use_distributed_optimizer
,
release
)
checkpoint_names
=
get_checkpoint_name
s
(
load_dir
,
iteration
,
use_distributed_optimizer
,
release
)
if
release
:
print_rank_0
(
f
' loading release checkpoint from
{
load_dir
}
'
)
else
:
...
...
@@ -410,7 +410,7 @@ def _load_base_checkpoint(load_dir, rank0=False):
# Load the checkpoint.
try
:
model_state_dict
=
torch
.
load
(
model_checkpoint_name
,
map_location
=
'cpu'
)
if
args
.
use_distributed_optimizer
:
if
use_distributed_optimizer
:
optim_state_dict
=
torch
.
load
(
optim_checkpoint_name
,
map_location
=
'cpu'
)
else
:
optim_state_dict
=
model_state_dict
...
...
@@ -450,18 +450,23 @@ def load_args_from_checkpoint(args, load_arg='load'):
load_dir
=
getattr
(
args
,
load_arg
)
if
load_dir
is
None
:
print_rank_0
(
'No load directory specified, using provided arguments.'
)
return
args
model_state_dict
,
optim_state_dict
,
release
=
_load_base_checkpoint
(
load_dir
,
rank0
=
True
)
model_state_dict
,
optim_state_dict
,
release
=
\
_load_base_checkpoint
(
load_dir
,
use_distributed_optimizer
=
args
.
use_distributed_optimizer
,
rank0
=
True
)
# For args we only care about model state dict
state_dict
=
model_state_dict
if
not
state_dict
:
print_rank_0
(
'Checkpoint not found to provide arguments, using provided arguments.'
)
return
args
if
'args'
not
in
state_dict
:
print
(
'Checkpoint provided does not have arguments saved.'
)
print
_rank_0
(
'Checkpoint provided does not have arguments saved
, using provided arguments
.'
)
return
args
checkpoint_args
=
state_dict
[
'args'
]
...
...
@@ -511,7 +516,13 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
model
=
unwrap_model
(
model
)
model_state_dict
,
optim_state_dict
,
release
=
_load_base_checkpoint
(
load_dir
,
rank0
=
False
)
model_state_dict
,
optim_state_dict
,
release
=
\
_load_base_checkpoint
(
load_dir
,
use_distributed_optimizer
=
args
.
use_distributed_optimizer
,
rank0
=
False
)
if
model_state_dict
is
None
:
return
0
# set checkpoint version
set_checkpoint_version
(
model_state_dict
.
get
(
'checkpoint_version'
,
0
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment