Commit 08a848c7 authored by Jared Casper's avatar Jared Casper
Browse files

Improve handling of rng states in checkpoints.

parent 8863af8c
......@@ -494,9 +494,9 @@ def _add_checkpointing_args(parser):
help='Output directory to save checkpoints to.')
group.add_argument('--save-interval', type=int, default=None,
help='Number of iterations between checkpoint saves.')
group.add_argument('--no-save-optim', action='store_true',
group.add_argument('--no-save-optim', action='store_true', default=None,
help='Do not save current optimizer.')
group.add_argument('--no-save-rng', action='store_true',
group.add_argument('--no-save-rng', action='store_true', default=None,
help='Do not save current rng state.')
group.add_argument('--load', type=str, default=None,
help='Directory containing a model checkpoint.')
......
......@@ -343,12 +343,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
# Check for empty states array
if not state_dict['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. '
print_rank_0('Unable to load rng state from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
'attempting to load the optimizer state, '
'attempting to load the rng state, '
'exiting ...'.format(checkpoint_name))
sys.exit()
......
......@@ -92,7 +92,9 @@ def main():
"""Main program."""
initialize_megatron(extra_args_provider=add_text_generate_args,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
'no_load_rng': True,
'no_load_optim': True})
# Set up model and load checkpoint.
model = get_model(model_provider)
......
......@@ -200,6 +200,8 @@ def main():
'micro_batch_size': 1,
'no_load_optim': True,
'no_load_rng': True,
'no_save_optim': True,
'no_save_rng': True,
'save_interval': 1})
args = get_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment