Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
af2b136f
Commit
af2b136f
authored
Mar 14, 2022
by
Lawrence McAfee
Browse files
optimizer saves list(group), not list(param).
parent
37ca7859
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
18 deletions
+29
-18
megatron/checkpointing.py
megatron/checkpointing.py
+10
-10
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+19
-8
No files found.
megatron/checkpointing.py
View file @
af2b136f
...
...
@@ -402,17 +402,17 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
sys
.
exit
()
# set checkpoint version
set_checkpoint_version
(
state_dict
.
get
(
'checkpoint_version'
,
0
))
set_checkpoint_version
(
model_
state_dict
.
get
(
'checkpoint_version'
,
0
))
# Set iteration.
if
args
.
finetune
or
release
:
iteration
=
0
else
:
try
:
iteration
=
state_dict
[
'iteration'
]
iteration
=
model_
state_dict
[
'iteration'
]
except
KeyError
:
try
:
# Backward compatible with older checkpoints
iteration
=
state_dict
[
'total_iters'
]
iteration
=
model_
state_dict
[
'total_iters'
]
except
KeyError
:
print_rank_0
(
'A metadata file exists but unable to load '
'iteration from checkpoint {}, exiting'
.
format
(
...
...
@@ -422,8 +422,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check arguments.
assert
args
.
consumed_train_samples
==
0
assert
args
.
consumed_valid_samples
==
0
if
'args'
in
state_dict
:
checkpoint_args
=
state_dict
[
'args'
]
if
'args'
in
model_
state_dict
:
checkpoint_args
=
model_
state_dict
[
'args'
]
check_checkpoint_args
(
checkpoint_args
)
args
.
consumed_train_samples
=
getattr
(
checkpoint_args
,
'consumed_train_samples'
,
0
)
...
...
@@ -435,11 +435,11 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Model.
if
len
(
model
)
==
1
:
model
[
0
].
load_state_dict
(
state_dict
[
'model'
],
strict
=
strict
)
model
[
0
].
load_state_dict
(
model_
state_dict
[
'model'
],
strict
=
strict
)
else
:
for
i
in
range
(
len
(
model
)):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
model
[
i
].
load_state_dict
(
state_dict
[
'model%d'
%
i
],
strict
=
strict
)
model
[
i
].
load_state_dict
(
model_
state_dict
[
'model%d'
%
i
],
strict
=
strict
)
# Fix up query/key/value matrix ordering if needed
checkpoint_version
=
get_checkpoint_version
()
...
...
@@ -450,12 +450,12 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_optim
:
try
:
if
optimizer
is
not
None
:
optimizer
.
load_state_dict
(
state_dict
[
'optimizer'
])
optimizer
.
load_state_dict
(
optim_
state_dict
[
'optimizer'
])
if
opt_param_scheduler
is
not
None
:
if
'lr_scheduler'
in
state_dict
:
# backward compatbility
opt_param_scheduler
.
load_state_dict
(
state_dict
[
'lr_scheduler'
])
opt_param_scheduler
.
load_state_dict
(
optim_
state_dict
[
'lr_scheduler'
])
else
:
opt_param_scheduler
.
load_state_dict
(
state_dict
[
'opt_param_scheduler'
])
opt_param_scheduler
.
load_state_dict
(
optim_
state_dict
[
'opt_param_scheduler'
])
except
KeyError
:
print_rank_0
(
'Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
...
...
megatron/optimizer/distrib_optimizer.py
View file @
af2b136f
...
...
@@ -306,8 +306,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
if
self
.
grad_scaler
:
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'params'
]
=
\
[
p
for
g
in
self
.
optimizer
.
param_groups
for
p
in
g
[
"params"
]
]
# state_dict['params'] = \
# [ p for g in self.optimizer.param_groups for p in g["params"] ]
state_dict
[
'groups'
]
=
[
g
[
"params"
]
for
g
in
self
.
optimizer
.
param_groups
]
# pax(0, { # ... only called on model rank 0
# # "optimizer" : self.optimizer,
# "state_dict" : state_dict,
...
...
@@ -329,10 +330,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
'an old checkpoint ...'
)
self
.
optimizer
.
load_state_dict
(
state_dict
[
optimizer_key
])
pax
(
0
,
{
"state_dict"
:
state_dict
,
"params"
:
state_dict
[
"params"
],
})
#
pax(0, {
#
"state_dict" : state_dict,
#
"params" : state_dict["params"],
#
})
# Grad scaler.
if
'grad_scaler'
not
in
state_dict
:
print_rank_0
(
'***WARNING*** found an old checkpoint, will not '
...
...
@@ -346,11 +347,21 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
'Skipping loading grad scaler ...'
)
# Copy data for the main params.
current_groups
=
[
g
[
"params"
]
for
g
in
self
.
optimizer
.
param_groups
]
params_key
=
'params'
assert
params_key
in
state_dict
,
"key 'params' not in state_dict."
# pax(0, {
# "state_dict" : state_dict,
# "current_groups" : current_groups,
# "saved_groups" : state_dict[params_key],
# })
for
current_group
,
saved_group
in
zip
(
self
.
fp32_from_float16_groups
,
state_dict
[
fp32_from_float16_params_key
]):
current_groups
,
state_dict
[
params_key
]):
pax
(
0
,
{
"current_group"
:
current_group
,
"saved_group"
:
saved_group
,
})
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
current_param
.
data
.
copy_
(
saved_param
.
data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment