Commit 09d38884 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

checkpointing rng_state of all data parallel ranks

parent 343dc97a
...@@ -140,6 +140,32 @@ def read_metadata(tracker_filename): ...@@ -140,6 +140,32 @@ def read_metadata(tracker_filename):
return max_iter, release return max_iter, release
def get_rng_state():
""" collect rng state across data parallel ranks """
rng_state = {
'random_rng_state': random.getstate(),
'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state(),
'rng_tracker_states': mpu.get_cuda_rng_tracker().get_states()}
rng_state_list = None
if torch.distributed.is_initialized() and \
mpu.get_data_parallel_world_size() > 1:
if mpu.get_data_parallel_rank() == 0:
rng_state_list = \
[None for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.gather_object(
rng_state,
rng_state_list,
dst=mpu.get_data_parallel_src_rank(),
group=mpu.get_data_parallel_group())
else:
rng_state_list = [rng_state]
return rng_state_list
def save_checkpoint(iteration, model, optimizer, lr_scheduler): def save_checkpoint(iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint.""" """Save a model checkpoint."""
args = get_args() args = get_args()
...@@ -150,6 +176,9 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -150,6 +176,9 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save)) iteration, args.save))
# collect rng state across data parallel ranks
rng_state = get_rng_state()
if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0: if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model. # Arguments, iteration, and model.
...@@ -173,12 +202,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -173,12 +202,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# RNG states. # RNG states.
if not args.no_save_rng: if not args.no_save_rng:
state_dict['random_rng_state'] = random.getstate() state_dict["rng_state"] = rng_state
state_dict['np_rng_state'] = np.random.get_state()
state_dict['torch_rng_state'] = torch.get_rng_state()
state_dict['cuda_rng_state'] = torch.cuda.get_rng_state()
state_dict['rng_tracker_states'] \
= mpu.get_cuda_rng_tracker().get_states()
# Save. # Save.
checkpoint_name = get_checkpoint_name(args.save, iteration) checkpoint_name = get_checkpoint_name(args.save, iteration)
...@@ -381,15 +405,28 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -381,15 +405,28 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
# rng states. # rng states.
if not release and not args.finetune and not args.no_load_rng: if not release and not args.finetune and not args.no_load_rng:
try: try:
random.setstate(state_dict['random_rng_state']) if 'rng_state' in state_dict:
np.random.set_state(state_dict['np_rng_state']) # access rng_state for data parallel rank
torch.set_rng_state(state_dict['torch_rng_state']) rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
torch.cuda.set_rng_state(state_dict['cuda_rng_state']) random.setstate(rng_state['random_rng_state'])
# Check for empty states array np.random.set_state(rng_state['np_rng_state'])
if not state_dict['rng_tracker_states']: torch.set_rng_state(rng_state['torch_rng_state'])
raise KeyError torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
mpu.get_cuda_rng_tracker().set_states( # Check for empty states array
state_dict['rng_tracker_states']) if not rng_state['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
else: # backward compatability
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
# Check for empty states array
if not state_dict['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
except KeyError: except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. ' print_rank_0('Unable to load rng state from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent ' 'Specify --no-load-rng or --finetune to prevent '
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment