Commit 2725dc0b authored by Jared Casper's avatar Jared Casper
Browse files

Fixing up checkpointing.

parent 4eb802c4
...@@ -117,7 +117,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, ...@@ -117,7 +117,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt") model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
return model_name, optim_name return model_name, optim_name
def find_checkpoint_rank_0(checkpoints_path, iteration, release=False): def find_checkpoint_rank_0(checkpoints_path, iteration, use_distributed_optimizer, release=False):
"""Finds the checkpoint for rank 0 without knowing if we are using """Finds the checkpoint for rank 0 without knowing if we are using
pipeline parallelism or not. pipeline parallelism or not.
...@@ -128,20 +128,20 @@ def find_checkpoint_rank_0(checkpoints_path, iteration, release=False): ...@@ -128,20 +128,20 @@ def find_checkpoint_rank_0(checkpoints_path, iteration, release=False):
""" """
# Look for checkpoint with no pipelining # Look for checkpoint with no pipelining
filename = get_checkpoint_name(checkpoints_path, iteration, release, filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
pipeline_parallel=False, pipeline_parallel=False,
tensor_rank=0, pipeline_rank=0) tensor_rank=0, pipeline_rank=0)
if os.path.isfile(filename): if os.path.isfile(filenames[0]):
return filename return filenames
# Look for checkpoint with pipelining # Look for checkpoint with pipelining
filename = get_checkpoint_name(checkpoints_path, iteration, release, filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
pipeline_parallel=True, pipeline_parallel=True,
tensor_rank=0, pipeline_rank=0) tensor_rank=0, pipeline_rank=0)
if os.path.isfile(filename): if os.path.isfile(filenames[0]):
return filename return filenames
return None return None, None
def get_checkpoint_tracker_filename(checkpoints_path): def get_checkpoint_tracker_filename(checkpoints_path):
...@@ -370,7 +370,7 @@ def fix_query_key_value_ordering(model, checkpoint_version): ...@@ -370,7 +370,7 @@ def fix_query_key_value_ordering(model, checkpoint_version):
print_rank_0(" succesfully fixed query-key-values ordering for" print_rank_0(" succesfully fixed query-key-values ordering for"
" checkpoint version {}".format(checkpoint_version)) " checkpoint version {}".format(checkpoint_version))
def _load_base_checkpoint(load_dir, rank0=False): def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False):
""" Load the base state_dict from the given directory """ Load the base state_dict from the given directory
If rank0 is true, just loads rank 0 checkpoint, ignoring arguments. If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
...@@ -395,11 +395,11 @@ def _load_base_checkpoint(load_dir, rank0=False): ...@@ -395,11 +395,11 @@ def _load_base_checkpoint(load_dir, rank0=False):
# Checkpoint. # Checkpoint.
if rank0: if rank0:
checkpoint_names = find_checkpoint_rank_0(load_dir, iteration, args.use_distributed_optimizer, checkpoint_names = find_checkpoint_rank_0(load_dir, iteration, use_distributed_optimizer,
release) release)
else: else:
checkpoint_names = get_checkpoint_name(load_dir, iteration, args.use_distributed_optimizer, checkpoint_names = get_checkpoint_names(load_dir, iteration, use_distributed_optimizer,
release) release)
if release: if release:
print_rank_0(f' loading release checkpoint from {load_dir}') print_rank_0(f' loading release checkpoint from {load_dir}')
else: else:
...@@ -410,7 +410,7 @@ def _load_base_checkpoint(load_dir, rank0=False): ...@@ -410,7 +410,7 @@ def _load_base_checkpoint(load_dir, rank0=False):
# Load the checkpoint. # Load the checkpoint.
try: try:
model_state_dict = torch.load(model_checkpoint_name, map_location='cpu') model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
if args.use_distributed_optimizer: if use_distributed_optimizer:
optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu') optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
else: else:
optim_state_dict = model_state_dict optim_state_dict = model_state_dict
...@@ -450,18 +450,23 @@ def load_args_from_checkpoint(args, load_arg='load'): ...@@ -450,18 +450,23 @@ def load_args_from_checkpoint(args, load_arg='load'):
load_dir = getattr(args, load_arg) load_dir = getattr(args, load_arg)
if load_dir is None: if load_dir is None:
print_rank_0('No load directory specified, using provided arguments.')
return args return args
model_state_dict, optim_state_dict, release = _load_base_checkpoint(load_dir, rank0=True) model_state_dict, optim_state_dict, release = \
_load_base_checkpoint(load_dir,
use_distributed_optimizer=args.use_distributed_optimizer,
rank0=True)
# For args we only care about model state dict # For args we only care about model state dict
state_dict = model_state_dict state_dict = model_state_dict
if not state_dict: if not state_dict:
print_rank_0('Checkpoint not found to provide arguments, using provided arguments.')
return args return args
if 'args' not in state_dict: if 'args' not in state_dict:
print('Checkpoint provided does not have arguments saved.') print_rank_0('Checkpoint provided does not have arguments saved, using provided arguments.')
return args return args
checkpoint_args = state_dict['args'] checkpoint_args = state_dict['args']
...@@ -511,7 +516,13 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -511,7 +516,13 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
model = unwrap_model(model) model = unwrap_model(model)
model_state_dict, optim_state_dict, release = _load_base_checkpoint(load_dir, rank0=False) model_state_dict, optim_state_dict, release = \
_load_base_checkpoint(load_dir,
use_distributed_optimizer=args.use_distributed_optimizer,
rank0=False)
if model_state_dict is None:
return 0
# set checkpoint version # set checkpoint version
set_checkpoint_version(model_state_dict.get('checkpoint_version', 0)) set_checkpoint_version(model_state_dict.get('checkpoint_version', 0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment