Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
9c86abd9
Commit
9c86abd9
authored
Mar 14, 2022
by
Lawrence McAfee
Browse files
more specific formatting of model/optim checkpoint paths.
parent
af2b136f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
61 additions
and
31 deletions
+61
-31
megatron/checkpointing.py
megatron/checkpointing.py
+54
-21
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+7
-10
No files found.
megatron/checkpointing.py
View file @
9c86abd9
...
@@ -28,6 +28,10 @@ from megatron import (get_args,
...
@@ -28,6 +28,10 @@ from megatron import (get_args,
update_num_microbatches
,
update_num_microbatches
,
utils
)
utils
)
# >>>
from
lutil
import
pax
# <<<
_CHECKPOINT_VERSION
=
None
_CHECKPOINT_VERSION
=
None
def
set_checkpoint_version
(
value
):
def
set_checkpoint_version
(
value
):
...
@@ -100,8 +104,8 @@ def ensure_directory_exists(filename):
...
@@ -100,8 +104,8 @@ def ensure_directory_exists(filename):
# mpu.get_tensor_model_parallel_rank(),
# mpu.get_tensor_model_parallel_rank(),
# mpu.get_pipeline_model_parallel_rank()),
# mpu.get_pipeline_model_parallel_rank()),
# 'model_optim_rng.pt')
# 'model_optim_rng.pt')
def
get_checkpoint_names
(
checkpoints_path
,
iteration
,
def
get_checkpoint_names
(
checkpoints_path
,
iteration
,
use_distributed_optimizer
,
release
=
False
):
release
=
False
):
"""A unified checkpoint name."""
"""A unified checkpoint name."""
if
release
:
if
release
:
directory
=
'release'
directory
=
'release'
...
@@ -111,12 +115,16 @@ def get_checkpoint_names(checkpoints_path, iteration,
...
@@ -111,12 +115,16 @@ def get_checkpoint_names(checkpoints_path, iteration,
common_path
=
os
.
path
.
join
(
common_path
=
os
.
path
.
join
(
checkpoints_path
,
checkpoints_path
,
directory
,
directory
,
"mp_rank_%02d_%03d
_%03d
"
%
(
"mp_rank_%02d_%03d"
%
(
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_pipeline_model_parallel_rank
(),
mpu
.
get_pipeline_model_parallel_rank
()))
mpu
.
get_data_parallel_rank
()))
model_name
=
os
.
path
.
join
(
common_path
,
"model_rng.pt"
)
model_name
=
os
.
path
.
join
(
common_path
,
"model_rng.pt"
)
optim_name
=
os
.
path
.
join
(
common_path
,
"optim.pt"
)
if
use_distributed_optimizer
:
optim_name
=
os
.
path
.
join
(
common_path
+
"_%03d"
%
mpu
.
get_data_parallel_rank
(),
"optim.pt"
)
else
:
optim_name
=
os
.
path
.
join
(
common_path
,
"optim.pt"
)
return
model_name
,
optim_name
return
model_name
,
optim_name
# <<<
# <<<
...
@@ -202,7 +210,12 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
...
@@ -202,7 +210,12 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
# Checkpoint file names.
# Checkpoint file names.
model_checkpoint_name
,
optim_checkpoint_name
=
\
model_checkpoint_name
,
optim_checkpoint_name
=
\
get_checkpoint_names
(
args
.
save
,
iteration
)
get_checkpoint_names
(
args
.
save
,
iteration
,
args
.
use_distributed_optimizer
)
pax
(
0
,
{
"model_checkpoint_name"
:
model_checkpoint_name
,
"optim_checkpoint_name"
:
optim_checkpoint_name
,
})
# Save args, model, RNG.
# Save args, model, RNG.
if
not
torch
.
distributed
.
is_initialized
()
\
if
not
torch
.
distributed
.
is_initialized
()
\
...
@@ -255,7 +268,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
...
@@ -255,7 +268,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
ensure_directory_exists
(
optim_checkpoint_name
)
ensure_directory_exists
(
optim_checkpoint_name
)
torch
.
save
(
state_dict
,
optim_checkpoint_name
)
torch
.
save
(
state_dict
,
optim_checkpoint_name
)
# >>>
# >>>
# from lutil import pax
# pax({
# pax({
# "model_checkpoint_name" : model_checkpoint_name,
# "model_checkpoint_name" : model_checkpoint_name,
# "optim_checkpoint_name" : optim_checkpoint_name,
# "optim_checkpoint_name" : optim_checkpoint_name,
...
@@ -377,7 +389,9 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
...
@@ -377,7 +389,9 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Checkpoint.
# Checkpoint.
model_checkpoint_name
,
optim_checkpoint_name
=
\
model_checkpoint_name
,
optim_checkpoint_name
=
\
get_checkpoint_names
(
load_dir
,
iteration
,
release
)
get_checkpoint_names
(
load_dir
,
iteration
,
args
.
use_distributed_optimizer
,
release
)
print_rank_0
(
f
' loading checkpoint from
{
args
.
load
}
at iteration
{
iteration
}
'
)
print_rank_0
(
f
' loading checkpoint from
{
args
.
load
}
at iteration
{
iteration
}
'
)
# Load the checkpoint.
# Load the checkpoint.
...
@@ -401,6 +415,10 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
...
@@ -401,6 +415,10 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
print_rank_0
(
e
)
print_rank_0
(
e
)
sys
.
exit
()
sys
.
exit
()
# >>>
pax
({
"hi."
:
"there."
})
# <<<
# set checkpoint version
# set checkpoint version
set_checkpoint_version
(
model_state_dict
.
get
(
'checkpoint_version'
,
0
))
set_checkpoint_version
(
model_state_dict
.
get
(
'checkpoint_version'
,
0
))
...
@@ -446,13 +464,25 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
...
@@ -446,13 +464,25 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
print_rank_0
(
f
' checkpoint version
{
checkpoint_version
}
'
)
print_rank_0
(
f
' checkpoint version
{
checkpoint_version
}
'
)
fix_query_key_value_ordering
(
model
,
checkpoint_version
)
fix_query_key_value_ordering
(
model
,
checkpoint_version
)
# >>>
# pax(0, {
# "model_state_dict" : model_state_dict,
# "optim_state_dict" : optim_state_dict,
# })
# <<<
# Optimizer.
# Optimizer.
pax
({
"release"
:
release
,
"finetune"
:
args
.
finetune
,
"no_load_optim"
:
args
.
no_load_optim
,
})
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_optim
:
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_optim
:
try
:
try
:
if
optimizer
is
not
None
:
if
optimizer
is
not
None
:
optimizer
.
load_state_dict
(
optim_state_dict
[
'optimizer'
])
optimizer
.
load_state_dict
(
optim_state_dict
[
'optimizer'
])
if
opt_param_scheduler
is
not
None
:
if
opt_param_scheduler
is
not
None
:
if
'lr_scheduler'
in
state_dict
:
# backward compatbility
if
'lr_scheduler'
in
optim_
state_dict
:
# backward compatbility
opt_param_scheduler
.
load_state_dict
(
optim_state_dict
[
'lr_scheduler'
])
opt_param_scheduler
.
load_state_dict
(
optim_state_dict
[
'lr_scheduler'
])
else
:
else
:
opt_param_scheduler
.
load_state_dict
(
optim_state_dict
[
'opt_param_scheduler'
])
opt_param_scheduler
.
load_state_dict
(
optim_state_dict
[
'opt_param_scheduler'
])
...
@@ -466,13 +496,13 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
...
@@ -466,13 +496,13 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# rng states.
# rng states.
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_rng
:
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_rng
:
try
:
try
:
if
'rng_state'
in
state_dict
:
if
'rng_state'
in
model_
state_dict
:
# access rng_state for data parallel rank
# access rng_state for data parallel rank
if
args
.
data_parallel_random_init
:
if
args
.
data_parallel_random_init
:
rng_state
=
state_dict
[
'rng_state'
][
mpu
.
get_data_parallel_rank
()]
rng_state
=
model_
state_dict
[
'rng_state'
][
mpu
.
get_data_parallel_rank
()]
else
:
else
:
rng_state
=
state_dict
[
'rng_state'
][
0
]
rng_state
=
model_
state_dict
[
'rng_state'
][
0
]
random
.
setstate
(
rng_state
[
'random_rng_state'
])
random
.
setstate
(
rng_state
[
'random_rng_state'
])
np
.
random
.
set_state
(
rng_state
[
'np_rng_state'
])
np
.
random
.
set_state
(
rng_state
[
'np_rng_state'
])
torch
.
set_rng_state
(
rng_state
[
'torch_rng_state'
])
torch
.
set_rng_state
(
rng_state
[
'torch_rng_state'
])
...
@@ -483,15 +513,15 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
...
@@ -483,15 +513,15 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
mpu
.
get_cuda_rng_tracker
().
set_states
(
mpu
.
get_cuda_rng_tracker
().
set_states
(
rng_state
[
'rng_tracker_states'
])
rng_state
[
'rng_tracker_states'
])
else
:
# backward compatability
else
:
# backward compatability
random
.
setstate
(
state_dict
[
'random_rng_state'
])
random
.
setstate
(
model_
state_dict
[
'random_rng_state'
])
np
.
random
.
set_state
(
state_dict
[
'np_rng_state'
])
np
.
random
.
set_state
(
model_
state_dict
[
'np_rng_state'
])
torch
.
set_rng_state
(
state_dict
[
'torch_rng_state'
])
torch
.
set_rng_state
(
model_
state_dict
[
'torch_rng_state'
])
torch
.
cuda
.
set_rng_state
(
state_dict
[
'cuda_rng_state'
])
torch
.
cuda
.
set_rng_state
(
model_
state_dict
[
'cuda_rng_state'
])
# Check for empty states array
# Check for empty states array
if
not
state_dict
[
'rng_tracker_states'
]:
if
not
model_
state_dict
[
'rng_tracker_states'
]:
raise
KeyError
raise
KeyError
mpu
.
get_cuda_rng_tracker
().
set_states
(
mpu
.
get_cuda_rng_tracker
().
set_states
(
state_dict
[
'rng_tracker_states'
])
model_
state_dict
[
'rng_tracker_states'
])
except
KeyError
:
except
KeyError
:
print_rank_0
(
'Unable to load rng state from checkpoint {}. '
print_rank_0
(
'Unable to load rng state from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
'Specify --no-load-rng or --finetune to prevent '
...
@@ -500,6 +530,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
...
@@ -500,6 +530,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
sys
.
exit
()
sys
.
exit
()
# Some utilities want to load a checkpoint without distributed being initialized
# Some utilities want to load a checkpoint without distributed being initialized
# pax({"hi.": "there."})
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
...
@@ -526,12 +557,14 @@ def load_biencoder_checkpoint(model, only_query_model=False,
...
@@ -526,12 +557,14 @@ def load_biencoder_checkpoint(model, only_query_model=False,
with
open
(
tracker_filename
,
'r'
)
as
f
:
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
iteration
=
int
(
f
.
read
().
strip
())
checkpoint_name
=
get_checkpoint_name
(
load_path
,
iteration
,
False
)
checkpoint_name
,
_
=
get_checkpoint_names
(
load_path
,
iteration
,
args
.
use_distributed_optimizer
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
model_
checkpoint_name
,
map_location
=
'cpu'
)
ret_state_dict
=
state_dict
[
'model'
]
ret_state_dict
=
state_dict
[
'model'
]
if
only_query_model
:
if
only_query_model
:
...
...
megatron/optimizer/distrib_optimizer.py
View file @
9c86abd9
...
@@ -308,7 +308,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -308,7 +308,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
# state_dict['params'] = \
# state_dict['params'] = \
# [ p for g in self.optimizer.param_groups for p in g["params"] ]
# [ p for g in self.optimizer.param_groups for p in g["params"] ]
state_dict
[
'groups'
]
=
[
g
[
"
params
"
]
for
g
in
self
.
optimizer
.
param_groups
]
state_dict
[
'groups'
]
=
[
g
[
'
params
'
]
for
g
in
self
.
optimizer
.
param_groups
]
# pax(0, { # ... only called on model rank 0
# pax(0, { # ... only called on model rank 0
# # "optimizer" : self.optimizer,
# # "optimizer" : self.optimizer,
# "state_dict" : state_dict,
# "state_dict" : state_dict,
...
@@ -348,20 +348,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -348,20 +348,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Copy data for the main params.
# Copy data for the main params.
current_groups
=
[
g
[
"params"
]
for
g
in
self
.
optimizer
.
param_groups
]
current_groups
=
[
g
[
"params"
]
for
g
in
self
.
optimizer
.
param_groups
]
params_key
=
'params'
assert
"groups"
in
state_dict
,
"key 'groups' not in state_dict."
assert
params_key
in
state_dict
,
"key 'params' not in state_dict."
# pax(0, {
# pax(0, {
# "state_dict" : state_dict,
# "state_dict" : state_dict,
# "current_groups" : current_groups,
# "current_groups" : current_groups,
# "saved_groups" : state_dict[params_key],
# "saved_groups" : state_dict[params_key],
# })
# })
for
current_group
,
saved_group
in
zip
(
for
current_group
,
saved_group
in
zip
(
current_groups
,
state_dict
[
"groups"
]):
current_groups
,
# pax(0, {
state_dict
[
params_key
]):
# "current_group" : current_group,
pax
(
0
,
{
# "saved_group" : saved_group,
"current_group"
:
current_group
,
# })
"saved_group"
:
saved_group
,
})
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
current_param
.
data
.
copy_
(
saved_param
.
data
)
current_param
.
data
.
copy_
(
saved_param
.
data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment