Commit 1b2db724 authored by Jared Casper's avatar Jared Casper
Browse files

Addressing comments.

parent 5a1aa663
...@@ -83,20 +83,20 @@ def ensure_directory_exists(filename): ...@@ -83,20 +83,20 @@ def ensure_directory_exists(filename):
def get_checkpoint_name(checkpoints_path, iteration, release=False, def get_checkpoint_name(checkpoints_path, iteration, release=False,
pipeline_parallel_size=None, tensor_rank=None, pipeline_rank=None): pipeline_parallel=None, tensor_rank=None, pipeline_rank=None):
"""A unified checkpoint name.""" """Determine the directory name for this rank's checkpoint."""
if release: if release:
directory = 'release' directory = 'release'
else: else:
directory = 'iter_{:07d}'.format(iteration) directory = 'iter_{:07d}'.format(iteration)
# Use both the tensor and pipeline MP rank. # Use both the tensor and pipeline MP rank.
if pipeline_parallel_size is None: if pipeline_parallel is None:
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size() pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
if tensor_rank is None: if tensor_rank is None:
tensor_rank = mpu.get_tensor_model_parallel_rank() tensor_rank = mpu.get_tensor_model_parallel_rank()
if pipeline_rank is None: if pipeline_rank is None:
pipeline_rank = mpu.get_pipeline_model_parallel_rank() pipeline_rank = mpu.get_pipeline_model_parallel_rank()
if pipeline_parallel_size == 1: if not pipeline_parallel:
return os.path.join(checkpoints_path, directory, return os.path.join(checkpoints_path, directory,
f'mp_rank_{tensor_rank:02d}', f'mp_rank_{tensor_rank:02d}',
'model_optim_rng.pt') 'model_optim_rng.pt')
...@@ -116,14 +116,14 @@ def find_checkpoint_rank_0(checkpoints_path, iteration, release=False): ...@@ -116,14 +116,14 @@ def find_checkpoint_rank_0(checkpoints_path, iteration, release=False):
# Look for checkpoint with no pipelining # Look for checkpoint with no pipelining
filename = get_checkpoint_name(checkpoints_path, iteration, release, filename = get_checkpoint_name(checkpoints_path, iteration, release,
pipeline_parallel_size=1, pipeline_parallel=False,
tensor_rank=0, pipeline_rank=0) tensor_rank=0, pipeline_rank=0)
if os.path.isfile(filename): if os.path.isfile(filename):
return filename return filename
# Look for checkpoint with pipelining # Look for checkpoint with pipelining
filename = get_checkpoint_name(checkpoints_path, iteration, release, filename = get_checkpoint_name(checkpoints_path, iteration, release,
pipeline_parallel_size=2, pipeline_parallel=True,
tensor_rank=0, pipeline_rank=0) tensor_rank=0, pipeline_rank=0)
if os.path.isfile(filename): if os.path.isfile(filename):
return filename return filename
...@@ -404,7 +404,7 @@ def load_args_from_checkpoint(args, load_arg='load'): ...@@ -404,7 +404,7 @@ def load_args_from_checkpoint(args, load_arg='load'):
if load_dir is None: if load_dir is None:
return args return args
state_dict, release = _load_base_checkpoint(load_dir, True) state_dict, release = _load_base_checkpoint(load_dir, rank0=True)
if not state_dict: if not state_dict:
return args return args
...@@ -460,7 +460,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -460,7 +460,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
model = unwrap_model(model) model = unwrap_model(model)
state_dict, release = _load_base_checkpoint(load_dir, False) state_dict, release = _load_base_checkpoint(load_dir, rank0=False)
# set checkpoint version # set checkpoint version
set_checkpoint_version(state_dict.get('checkpoint_version', 0)) set_checkpoint_version(state_dict.get('checkpoint_version', 0))
...@@ -587,7 +587,7 @@ def load_biencoder_checkpoint(model, only_query_model=False, ...@@ -587,7 +587,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
with open(tracker_filename, 'r') as f: with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip()) iteration = int(f.read().strip())
checkpoint_name = get_checkpoint_name(load_path, iteration, False) checkpoint_name = get_checkpoint_name(load_path, iteration, release=False)
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format( print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name)) torch.distributed.get_rank(), checkpoint_name))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment