Commit 9c86abd9 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

more specific formatting of model/optim checkpoint paths.

parent af2b136f
......@@ -28,6 +28,10 @@ from megatron import (get_args,
update_num_microbatches,
utils)
# >>>
from lutil import pax
# <<<
_CHECKPOINT_VERSION = None
def set_checkpoint_version(value):
......@@ -100,8 +104,8 @@ def ensure_directory_exists(filename):
# mpu.get_tensor_model_parallel_rank(),
# mpu.get_pipeline_model_parallel_rank()),
# 'model_optim_rng.pt')
def get_checkpoint_names(checkpoints_path, iteration,
release=False):
def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
release=False):
"""A unified checkpoint name."""
if release:
directory = 'release'
......@@ -111,12 +115,16 @@ def get_checkpoint_names(checkpoints_path, iteration,
common_path = os.path.join(
checkpoints_path,
directory,
"mp_rank_%02d_%03d_%03d" % (
"mp_rank_%02d_%03d" % (
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
mpu.get_data_parallel_rank()))
mpu.get_pipeline_model_parallel_rank()))
model_name = os.path.join(common_path, "model_rng.pt")
optim_name = os.path.join(common_path, "optim.pt")
if use_distributed_optimizer:
optim_name = os.path.join(
common_path + "_%03d" % mpu.get_data_parallel_rank(),
"optim.pt")
else:
optim_name = os.path.join(common_path, "optim.pt")
return model_name, optim_name
# <<<
......@@ -202,7 +210,12 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
# Checkpoint file names.
model_checkpoint_name, optim_checkpoint_name = \
get_checkpoint_names(args.save, iteration)
get_checkpoint_names(args.save, iteration, args.use_distributed_optimizer)
pax(0, {
"model_checkpoint_name" : model_checkpoint_name,
"optim_checkpoint_name" : optim_checkpoint_name,
})
# Save args, model, RNG.
if not torch.distributed.is_initialized() \
......@@ -255,7 +268,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
ensure_directory_exists(optim_checkpoint_name)
torch.save(state_dict, optim_checkpoint_name)
# >>>
# from lutil import pax
# pax({
# "model_checkpoint_name" : model_checkpoint_name,
# "optim_checkpoint_name" : optim_checkpoint_name,
......@@ -377,7 +389,9 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Checkpoint.
model_checkpoint_name, optim_checkpoint_name = \
get_checkpoint_names(load_dir, iteration, release)
get_checkpoint_names(load_dir, iteration,
args.use_distributed_optimizer,
release)
print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')
# Load the checkpoint.
......@@ -401,6 +415,10 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
print_rank_0(e)
sys.exit()
# >>>
pax({"hi.": "there."})
# <<<
# set checkpoint version
set_checkpoint_version(model_state_dict.get('checkpoint_version', 0))
......@@ -446,13 +464,25 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
print_rank_0(f' checkpoint version {checkpoint_version}')
fix_query_key_value_ordering(model, checkpoint_version)
# >>>
# pax(0, {
# "model_state_dict" : model_state_dict,
# "optim_state_dict" : optim_state_dict,
# })
# <<<
# Optimizer.
pax({
"release" : release,
"finetune" : args.finetune,
"no_load_optim" : args.no_load_optim,
})
if not release and not args.finetune and not args.no_load_optim:
try:
if optimizer is not None:
optimizer.load_state_dict(optim_state_dict['optimizer'])
if opt_param_scheduler is not None:
if 'lr_scheduler' in state_dict: # backward compatbility
if 'lr_scheduler' in optim_state_dict: # backward compatbility
opt_param_scheduler.load_state_dict(optim_state_dict['lr_scheduler'])
else:
opt_param_scheduler.load_state_dict(optim_state_dict['opt_param_scheduler'])
......@@ -466,13 +496,13 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# rng states.
if not release and not args.finetune and not args.no_load_rng:
try:
if 'rng_state' in state_dict:
if 'rng_state' in model_state_dict:
# access rng_state for data parallel rank
if args.data_parallel_random_init:
rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()]
else:
rng_state = state_dict['rng_state'][0]
rng_state = model_state_dict['rng_state'][0]
random.setstate(rng_state['random_rng_state'])
np.random.set_state(rng_state['np_rng_state'])
torch.set_rng_state(rng_state['torch_rng_state'])
......@@ -483,15 +513,15 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
mpu.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states'])
else: # backward compatability
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
random.setstate(model_state_dict['random_rng_state'])
np.random.set_state(model_state_dict['np_rng_state'])
torch.set_rng_state(model_state_dict['torch_rng_state'])
torch.cuda.set_rng_state(model_state_dict['cuda_rng_state'])
# Check for empty states array
if not state_dict['rng_tracker_states']:
if not model_state_dict['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
model_state_dict['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
......@@ -500,6 +530,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
sys.exit()
# Some utilities want to load a checkpoint without distributed being initialized
# pax({"hi.": "there."})
if torch.distributed.is_initialized():
torch.distributed.barrier()
......@@ -526,12 +557,14 @@ def load_biencoder_checkpoint(model, only_query_model=False,
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
checkpoint_name = get_checkpoint_name(load_path, iteration, False)
checkpoint_name, _ = get_checkpoint_names(load_path, iteration,
args.use_distributed_optimizer,
False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu')
state_dict = torch.load(model_checkpoint_name, map_location='cpu')
ret_state_dict = state_dict['model']
if only_query_model:
......
......@@ -308,7 +308,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
# state_dict['params'] = \
# [ p for g in self.optimizer.param_groups for p in g["params"] ]
state_dict['groups'] = [g["params"] for g in self.optimizer.param_groups]
state_dict['groups'] = [g['params'] for g in self.optimizer.param_groups]
# pax(0, { # ... only called on model rank 0
# # "optimizer" : self.optimizer,
# "state_dict" : state_dict,
......@@ -348,20 +348,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Copy data for the main params.
current_groups = [ g["params"] for g in self.optimizer.param_groups ]
params_key = 'params'
assert params_key in state_dict, "key 'params' not in state_dict."
assert "groups" in state_dict, "key 'groups' not in state_dict."
# pax(0, {
# "state_dict" : state_dict,
# "current_groups" : current_groups,
# "saved_groups" : state_dict[params_key],
# })
for current_group, saved_group in zip(
current_groups,
state_dict[params_key]):
pax(0, {
"current_group" : current_group,
"saved_group" : saved_group,
})
for current_group, saved_group in zip(current_groups, state_dict["groups"]):
# pax(0, {
# "current_group" : current_group,
# "saved_group" : saved_group,
# })
for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment