Commit 8a6e56b8 authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Cleaned up load/save checkpoint printing

parent b81cad66
...@@ -110,6 +110,11 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -110,6 +110,11 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Only rank zero of the data parallel writes to the disk. # Only rank zero of the data parallel writes to the disk.
if isinstance(model, torchDDP): if isinstance(model, torchDDP):
model = model.module model = model.module
if torch.distributed.get_rank() == 0:
print('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save), flush=True)
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model. # Arguments, iteration, and model.
...@@ -137,14 +142,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -137,14 +142,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Save. # Save.
checkpoint_name = get_checkpoint_name(args.save, iteration) checkpoint_name = get_checkpoint_name(args.save, iteration)
print('global rank {} is saving checkpoint at iteration {:7d} to {}'.
format(torch.distributed.get_rank(), iteration, checkpoint_name))
ensure_directory_exists(checkpoint_name) ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name) torch.save(state_dict, checkpoint_name)
print(' successfully saved {}'.format(checkpoint_name))
# Wait so everyone is done (necessary) # Wait so everyone is done (necessary)
torch.distributed.barrier() torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' successfully saved checkpoint at iteration {:7d} to {}'.format(
iteration, args.save), flush=True)
# And update the latest iteration # And update the latest iteration
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save) tracker_filename = get_checkpoint_tracker_filename(args.save)
...@@ -192,9 +197,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -192,9 +197,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
# Checkpoint. # Checkpoint.
checkpoint_name = get_checkpoint_name(load_dir, iteration, release) checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
if mpu.get_data_parallel_rank() == 0: if torch.distributed.get_rank() == 0:
print('global rank {} is loading checkpoint {}'.format( print(' loading checkpoint from {} at iteration {}'.format(
torch.distributed.get_rank(), checkpoint_name)) args.load, iteration), flush=True)
# Load the checkpoint. # Load the checkpoint.
try: try:
...@@ -276,8 +281,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -276,8 +281,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
sys.exit() sys.exit()
torch.distributed.barrier() torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0: if torch.distributed.get_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name)) print(' successfully loaded checkpoint from {} at iteration {}'.format(
args.load, iteration), flush=True)
return iteration return iteration
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment