Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
d5b526d5
Commit
d5b526d5
authored
Oct 26, 2020
by
Deepak Narayanan
Browse files
Back compatibility of checkpoints: use `model_parallel_size` when checking for equality of args
parent
318d68c2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
5 deletions
+13
-5
megatron/checkpointing.py
megatron/checkpointing.py
+13
-5
No files found.
megatron/checkpointing.py
View file @
d5b526d5
...
...
@@ -41,11 +41,14 @@ def get_checkpoint_version():
def
check_checkpoint_args
(
checkpoint_args
):
"""Ensure fixed arguments for a model are the same for the input
arguments and the one retr
e
ived frm checkpoint."""
arguments and the one retri
e
ved fr
o
m checkpoint."""
args
=
get_args
()
def
_compare
(
arg_name
):
checkpoint_value
=
getattr
(
checkpoint_args
,
arg_name
)
def
_compare
(
arg_name
,
old_arg_name
=
None
):
if
old_arg_name
is
not
None
:
checkpoint_value
=
getattr
(
checkpoint_args
,
old_arg_name
)
else
:
checkpoint_value
=
getattr
(
checkpoint_args
,
arg_name
)
args_value
=
getattr
(
args
,
arg_name
)
error_message
=
'{} value from checkpoint ({}) is not equal to the '
\
'input argument value ({}).'
.
format
(
...
...
@@ -59,7 +62,12 @@ def check_checkpoint_args(checkpoint_args):
_compare
(
'make_vocab_size_divisible_by'
)
_compare
(
'padded_vocab_size'
)
_compare
(
'tokenizer_type'
)
_compare
(
'tensor_model_parallel_size'
)
if
get_checkpoint_version
()
<
3.0
:
_compare
(
'tensor_model_parallel_size'
,
old_arg_name
=
'model_parallel_size'
)
if
get_checkpoint_version
()
>=
3.0
:
_compare
(
'tensor_model_parallel_size'
)
_compare
(
'pipeline_model_parallel_size'
)
def
ensure_directory_exists
(
filename
):
...
...
@@ -107,7 +115,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Arguments, iteration, and model.
state_dict
=
{}
state_dict
[
'args'
]
=
args
state_dict
[
'checkpoint_version'
]
=
2
.0
state_dict
[
'checkpoint_version'
]
=
3
.0
state_dict
[
'iteration'
]
=
iteration
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment