Commit 2725dc0b authored by Jared Casper's avatar Jared Casper
Browse files

Fixing up checkpointing.

parent 4eb802c4
......@@ -117,7 +117,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
return model_name, optim_name
def find_checkpoint_rank_0(checkpoints_path, iteration, release=False):
def find_checkpoint_rank_0(checkpoints_path, iteration, use_distributed_optimizer, release=False):
"""Finds the checkpoint for rank 0 without knowing if we are using
pipeline parallelism or not.
......@@ -128,20 +128,20 @@ def find_checkpoint_rank_0(checkpoints_path, iteration, release=False):
"""
# Look for checkpoint with no pipelining
filename = get_checkpoint_name(checkpoints_path, iteration, release,
pipeline_parallel=False,
tensor_rank=0, pipeline_rank=0)
if os.path.isfile(filename):
return filename
filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
pipeline_parallel=False,
tensor_rank=0, pipeline_rank=0)
if os.path.isfile(filenames[0]):
return filenames
# Look for checkpoint with pipelining
filename = get_checkpoint_name(checkpoints_path, iteration, release,
pipeline_parallel=True,
tensor_rank=0, pipeline_rank=0)
if os.path.isfile(filename):
return filename
filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
pipeline_parallel=True,
tensor_rank=0, pipeline_rank=0)
if os.path.isfile(filenames[0]):
return filenames
return None
return None, None
def get_checkpoint_tracker_filename(checkpoints_path):
......@@ -370,7 +370,7 @@ def fix_query_key_value_ordering(model, checkpoint_version):
print_rank_0(" succesfully fixed query-key-values ordering for"
" checkpoint version {}".format(checkpoint_version))
def _load_base_checkpoint(load_dir, rank0=False):
def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False):
""" Load the base state_dict from the given directory
If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
......@@ -395,11 +395,11 @@ def _load_base_checkpoint(load_dir, rank0=False):
# Checkpoint.
if rank0:
checkpoint_names = find_checkpoint_rank_0(load_dir, iteration, args.use_distributed_optimizer,
checkpoint_names = find_checkpoint_rank_0(load_dir, iteration, use_distributed_optimizer,
release)
else:
checkpoint_names = get_checkpoint_name(load_dir, iteration, args.use_distributed_optimizer,
release)
checkpoint_names = get_checkpoint_names(load_dir, iteration, use_distributed_optimizer,
release)
if release:
print_rank_0(f' loading release checkpoint from {load_dir}')
else:
......@@ -410,7 +410,7 @@ def _load_base_checkpoint(load_dir, rank0=False):
# Load the checkpoint.
try:
model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
if args.use_distributed_optimizer:
if use_distributed_optimizer:
optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
else:
optim_state_dict = model_state_dict
......@@ -450,18 +450,23 @@ def load_args_from_checkpoint(args, load_arg='load'):
load_dir = getattr(args, load_arg)
if load_dir is None:
print_rank_0('No load directory specified, using provided arguments.')
return args
model_state_dict, optim_state_dict, release = _load_base_checkpoint(load_dir, rank0=True)
model_state_dict, optim_state_dict, release = \
_load_base_checkpoint(load_dir,
use_distributed_optimizer=args.use_distributed_optimizer,
rank0=True)
# For args we only care about model state dict
state_dict = model_state_dict
if not state_dict:
print_rank_0('Checkpoint not found to provide arguments, using provided arguments.')
return args
if 'args' not in state_dict:
print('Checkpoint provided does not have arguments saved.')
print_rank_0('Checkpoint provided does not have arguments saved, using provided arguments.')
return args
checkpoint_args = state_dict['args']
......@@ -511,7 +516,13 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
model = unwrap_model(model)
model_state_dict, optim_state_dict, release = _load_base_checkpoint(load_dir, rank0=False)
model_state_dict, optim_state_dict, release = \
_load_base_checkpoint(load_dir,
use_distributed_optimizer=args.use_distributed_optimizer,
rank0=False)
if model_state_dict is None:
return 0
# set checkpoint version
set_checkpoint_version(model_state_dict.get('checkpoint_version', 0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment