Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
08a848c7
Commit
08a848c7
authored
Feb 08, 2021
by
Jared Casper
Browse files
Improve handling of rng states in checkpoints.
parent
8863af8c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
5 deletions
+12
-5
megatron/arguments.py
megatron/arguments.py
+2
-2
megatron/checkpointing.py
megatron/checkpointing.py
+5
-2
tools/generate_samples_gpt.py
tools/generate_samples_gpt.py
+3
-1
tools/merge_mp_partitions.py
tools/merge_mp_partitions.py
+2
-0
No files found.
megatron/arguments.py
View file @
08a848c7
...
...
@@ -494,9 +494,9 @@ def _add_checkpointing_args(parser):
help
=
'Output directory to save checkpoints to.'
)
group
.
add_argument
(
'--save-interval'
,
type
=
int
,
default
=
None
,
help
=
'Number of iterations between checkpoint saves.'
)
group
.
add_argument
(
'--no-save-optim'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-save-optim'
,
action
=
'store_true'
,
default
=
None
,
help
=
'Do not save current optimizer.'
)
group
.
add_argument
(
'--no-save-rng'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-save-rng'
,
action
=
'store_true'
,
default
=
None
,
help
=
'Do not save current rng state.'
)
group
.
add_argument
(
'--load'
,
type
=
str
,
default
=
None
,
help
=
'Directory containing a model checkpoint.'
)
...
...
megatron/checkpointing.py
View file @
08a848c7
...
...
@@ -343,12 +343,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
np
.
random
.
set_state
(
state_dict
[
'np_rng_state'
])
torch
.
set_rng_state
(
state_dict
[
'torch_rng_state'
])
torch
.
cuda
.
set_rng_state
(
state_dict
[
'cuda_rng_state'
])
# Check for empty states array
if
not
state_dict
[
'rng_tracker_states'
]:
raise
KeyError
mpu
.
get_cuda_rng_tracker
().
set_states
(
state_dict
[
'rng_tracker_states'
])
except
KeyError
:
print_rank_0
(
'Unable to load
optimizer
from checkpoint {}. '
print_rank_0
(
'Unable to load
rng state
from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
'attempting to load the
optimizer
state, '
'attempting to load the
rng
state, '
'exiting ...'
.
format
(
checkpoint_name
))
sys
.
exit
()
...
...
tools/generate_samples_gpt.py
View file @
08a848c7
...
...
@@ -92,7 +92,9 @@ def main():
"""Main program."""
initialize_megatron
(
extra_args_provider
=
add_text_generate_args
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
,
'no_load_rng'
:
True
,
'no_load_optim'
:
True
})
# Set up model and load checkpoint.
model
=
get_model
(
model_provider
)
...
...
tools/merge_mp_partitions.py
View file @
08a848c7
...
...
@@ -200,6 +200,8 @@ def main():
'micro_batch_size'
:
1
,
'no_load_optim'
:
True
,
'no_load_rng'
:
True
,
'no_save_optim'
:
True
,
'no_save_rng'
:
True
,
'save_interval'
:
1
})
args
=
get_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment