Commit 1b2db724 authored by Jared Casper's avatar Jared Casper
Browse files

Addressing comments.

parent 5a1aa663
......@@ -83,20 +83,20 @@ def ensure_directory_exists(filename):
def get_checkpoint_name(checkpoints_path, iteration, release=False,
pipeline_parallel_size=None, tensor_rank=None, pipeline_rank=None):
"""A unified checkpoint name."""
pipeline_parallel=None, tensor_rank=None, pipeline_rank=None):
"""Determine the directory name for this rank's checkpoint."""
if release:
directory = 'release'
else:
directory = 'iter_{:07d}'.format(iteration)
# Use both the tensor and pipeline MP rank.
if pipeline_parallel_size is None:
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
if pipeline_parallel is None:
pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
if tensor_rank is None:
tensor_rank = mpu.get_tensor_model_parallel_rank()
if pipeline_rank is None:
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
if pipeline_parallel_size == 1:
if not pipeline_parallel:
return os.path.join(checkpoints_path, directory,
f'mp_rank_{tensor_rank:02d}',
'model_optim_rng.pt')
......@@ -116,14 +116,14 @@ def find_checkpoint_rank_0(checkpoints_path, iteration, release=False):
# Look for checkpoint with no pipelining
filename = get_checkpoint_name(checkpoints_path, iteration, release,
pipeline_parallel_size=1,
pipeline_parallel=False,
tensor_rank=0, pipeline_rank=0)
if os.path.isfile(filename):
return filename
# Look for checkpoint with pipelining
filename = get_checkpoint_name(checkpoints_path, iteration, release,
pipeline_parallel_size=2,
pipeline_parallel=True,
tensor_rank=0, pipeline_rank=0)
if os.path.isfile(filename):
return filename
......@@ -404,7 +404,7 @@ def load_args_from_checkpoint(args, load_arg='load'):
if load_dir is None:
return args
state_dict, release = _load_base_checkpoint(load_dir, True)
state_dict, release = _load_base_checkpoint(load_dir, rank0=True)
if not state_dict:
return args
......@@ -460,7 +460,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
model = unwrap_model(model)
state_dict, release = _load_base_checkpoint(load_dir, False)
state_dict, release = _load_base_checkpoint(load_dir, rank0=False)
# set checkpoint version
set_checkpoint_version(state_dict.get('checkpoint_version', 0))
......@@ -587,7 +587,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
checkpoint_name = get_checkpoint_name(load_path, iteration, False)
checkpoint_name = get_checkpoint_name(load_path, iteration, release=False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment