Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
0f2a9f19
Commit
0f2a9f19
authored
Mar 14, 2022
by
Lawrence McAfee
Browse files
cleaned/commented checkpointing.py.
parent
9c86abd9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
67 deletions
+6
-67
megatron/checkpointing.py
megatron/checkpointing.py
+6
-67
No files found.
megatron/checkpointing.py
View file @
0f2a9f19
...
...
@@ -28,10 +28,6 @@ from megatron import (get_args,
update_num_microbatches
,
utils
)
# >>>
from
lutil
import
pax
# <<<
_CHECKPOINT_VERSION
=
None
def
set_checkpoint_version
(
value
):
...
...
@@ -85,25 +81,6 @@ def ensure_directory_exists(filename):
os
.
makedirs
(
dirname
)
# >>
# def get_checkpoint_name(checkpoints_path, iteration,
# release=False):
# """A unified checkpoint name."""
# if release:
# directory = 'release'
# else:
# directory = 'iter_{:07d}'.format(iteration)
# # Use both the tensor and pipeline MP rank.
# if mpu.get_pipeline_model_parallel_world_size() == 1:
# return os.path.join(checkpoints_path, directory,
# 'mp_rank_{:02d}'.format(
# mpu.get_tensor_model_parallel_rank()),
# 'model_optim_rng.pt')
# return os.path.join(checkpoints_path, directory,
# 'mp_rank_{:02d}_{:03d}'.format(
# mpu.get_tensor_model_parallel_rank(),
# mpu.get_pipeline_model_parallel_rank()),
# 'model_optim_rng.pt')
def
get_checkpoint_names
(
checkpoints_path
,
iteration
,
use_distributed_optimizer
,
release
=
False
):
"""A unified checkpoint name."""
...
...
@@ -111,7 +88,9 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
directory
=
'release'
else
:
directory
=
'iter_{:07d}'
.
format
(
iteration
)
# Use both the tensor and pipeline MP rank.
# Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the
# data parallel rank.
common_path
=
os
.
path
.
join
(
checkpoints_path
,
directory
,
...
...
@@ -126,7 +105,6 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
else
:
optim_name
=
os
.
path
.
join
(
common_path
,
"optim.pt"
)
return
model_name
,
optim_name
# <<<
def
get_checkpoint_tracker_filename
(
checkpoints_path
):
...
...
@@ -212,11 +190,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
model_checkpoint_name
,
optim_checkpoint_name
=
\
get_checkpoint_names
(
args
.
save
,
iteration
,
args
.
use_distributed_optimizer
)
pax
(
0
,
{
"model_checkpoint_name"
:
model_checkpoint_name
,
"optim_checkpoint_name"
:
optim_checkpoint_name
,
})
# Save args, model, RNG.
if
not
torch
.
distributed
.
is_initialized
()
\
or
mpu
.
get_data_parallel_rank
()
==
0
:
...
...
@@ -233,15 +206,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
state_dict
[
'model%d'
%
i
]
=
model
[
i
].
state_dict_for_save_checkpoint
()
# >>>
# # Optimizer stuff.
# if not args.no_save_optim:
# if optimizer is not None:
# state_dict['optimizer'] = optimizer.state_dict()
# if opt_param_scheduler is not None:
# state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict()
# <<<
# RNG states.
if
not
args
.
no_save_rng
:
state_dict
[
"rng_state"
]
=
rng_state
...
...
@@ -250,8 +214,8 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
ensure_directory_exists
(
model_checkpoint_name
)
torch
.
save
(
state_dict
,
model_checkpoint_name
)
#
>>>
#
Save
optimizer
state
.
#
Save optimizer state. (Optimizer is saved separately from the model, due
#
to the conflicting data pattern when using the distributed
optimizer.
)
if
not
args
.
no_save_optim
\
and
(
not
torch
.
distributed
.
is_initialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
...
...
@@ -267,14 +231,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
# Save.
ensure_directory_exists
(
optim_checkpoint_name
)
torch
.
save
(
state_dict
,
optim_checkpoint_name
)
# >>>
# pax({
# "model_checkpoint_name" : model_checkpoint_name,
# "optim_checkpoint_name" : optim_checkpoint_name,
# "state_dict" : state_dict,
# })
# <<<
# <<<
# Wait so everyone is done (necessary)
if
torch
.
distributed
.
is_initialized
():
...
...
@@ -415,11 +371,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
print_rank_0
(
e
)
sys
.
exit
()
# >>>
pax
({
"hi."
:
"there."
})
# <<<
# set checkpoint version
# Set checkpoint version.
set_checkpoint_version
(
model_state_dict
.
get
(
'checkpoint_version'
,
0
))
# Set iteration.
...
...
@@ -464,19 +416,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
print_rank_0
(
f
' checkpoint version
{
checkpoint_version
}
'
)
fix_query_key_value_ordering
(
model
,
checkpoint_version
)
# >>>
# pax(0, {
# "model_state_dict" : model_state_dict,
# "optim_state_dict" : optim_state_dict,
# })
# <<<
# Optimizer.
pax
({
"release"
:
release
,
"finetune"
:
args
.
finetune
,
"no_load_optim"
:
args
.
no_load_optim
,
})
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_optim
:
try
:
if
optimizer
is
not
None
:
...
...
@@ -530,7 +470,6 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
sys
.
exit
()
# Some utilities want to load a checkpoint without distributed being initialized
# pax({"hi.": "there."})
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
barrier
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment