Commit 9c86abd9 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

more specific formatting of model/optim checkpoint paths.

parent af2b136f
...@@ -28,6 +28,10 @@ from megatron import (get_args, ...@@ -28,6 +28,10 @@ from megatron import (get_args,
update_num_microbatches, update_num_microbatches,
utils) utils)
# >>>
from lutil import pax
# <<<
_CHECKPOINT_VERSION = None _CHECKPOINT_VERSION = None
def set_checkpoint_version(value): def set_checkpoint_version(value):
...@@ -100,8 +104,8 @@ def ensure_directory_exists(filename): ...@@ -100,8 +104,8 @@ def ensure_directory_exists(filename):
# mpu.get_tensor_model_parallel_rank(), # mpu.get_tensor_model_parallel_rank(),
# mpu.get_pipeline_model_parallel_rank()), # mpu.get_pipeline_model_parallel_rank()),
# 'model_optim_rng.pt') # 'model_optim_rng.pt')
def get_checkpoint_names(checkpoints_path, iteration, def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
release=False): release=False):
"""A unified checkpoint name.""" """A unified checkpoint name."""
if release: if release:
directory = 'release' directory = 'release'
...@@ -111,12 +115,16 @@ def get_checkpoint_names(checkpoints_path, iteration, ...@@ -111,12 +115,16 @@ def get_checkpoint_names(checkpoints_path, iteration,
common_path = os.path.join( common_path = os.path.join(
checkpoints_path, checkpoints_path,
directory, directory,
"mp_rank_%02d_%03d_%03d" % ( "mp_rank_%02d_%03d" % (
mpu.get_tensor_model_parallel_rank(), mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank()))
mpu.get_data_parallel_rank()))
model_name = os.path.join(common_path, "model_rng.pt") model_name = os.path.join(common_path, "model_rng.pt")
optim_name = os.path.join(common_path, "optim.pt") if use_distributed_optimizer:
optim_name = os.path.join(
common_path + "_%03d" % mpu.get_data_parallel_rank(),
"optim.pt")
else:
optim_name = os.path.join(common_path, "optim.pt")
return model_name, optim_name return model_name, optim_name
# <<< # <<<
...@@ -202,7 +210,12 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -202,7 +210,12 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
# Checkpoint file names. # Checkpoint file names.
model_checkpoint_name, optim_checkpoint_name = \ model_checkpoint_name, optim_checkpoint_name = \
get_checkpoint_names(args.save, iteration) get_checkpoint_names(args.save, iteration, args.use_distributed_optimizer)
pax(0, {
"model_checkpoint_name" : model_checkpoint_name,
"optim_checkpoint_name" : optim_checkpoint_name,
})
# Save args, model, RNG. # Save args, model, RNG.
if not torch.distributed.is_initialized() \ if not torch.distributed.is_initialized() \
...@@ -255,7 +268,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -255,7 +268,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
ensure_directory_exists(optim_checkpoint_name) ensure_directory_exists(optim_checkpoint_name)
torch.save(state_dict, optim_checkpoint_name) torch.save(state_dict, optim_checkpoint_name)
# >>> # >>>
# from lutil import pax
# pax({ # pax({
# "model_checkpoint_name" : model_checkpoint_name, # "model_checkpoint_name" : model_checkpoint_name,
# "optim_checkpoint_name" : optim_checkpoint_name, # "optim_checkpoint_name" : optim_checkpoint_name,
...@@ -377,7 +389,9 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -377,7 +389,9 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Checkpoint. # Checkpoint.
model_checkpoint_name, optim_checkpoint_name = \ model_checkpoint_name, optim_checkpoint_name = \
get_checkpoint_names(load_dir, iteration, release) get_checkpoint_names(load_dir, iteration,
args.use_distributed_optimizer,
release)
print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}') print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')
# Load the checkpoint. # Load the checkpoint.
...@@ -401,6 +415,10 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -401,6 +415,10 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
print_rank_0(e) print_rank_0(e)
sys.exit() sys.exit()
# >>>
pax({"hi.": "there."})
# <<<
# set checkpoint version # set checkpoint version
set_checkpoint_version(model_state_dict.get('checkpoint_version', 0)) set_checkpoint_version(model_state_dict.get('checkpoint_version', 0))
...@@ -446,13 +464,25 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -446,13 +464,25 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
print_rank_0(f' checkpoint version {checkpoint_version}') print_rank_0(f' checkpoint version {checkpoint_version}')
fix_query_key_value_ordering(model, checkpoint_version) fix_query_key_value_ordering(model, checkpoint_version)
# >>>
# pax(0, {
# "model_state_dict" : model_state_dict,
# "optim_state_dict" : optim_state_dict,
# })
# <<<
# Optimizer. # Optimizer.
pax({
"release" : release,
"finetune" : args.finetune,
"no_load_optim" : args.no_load_optim,
})
if not release and not args.finetune and not args.no_load_optim: if not release and not args.finetune and not args.no_load_optim:
try: try:
if optimizer is not None: if optimizer is not None:
optimizer.load_state_dict(optim_state_dict['optimizer']) optimizer.load_state_dict(optim_state_dict['optimizer'])
if opt_param_scheduler is not None: if opt_param_scheduler is not None:
if 'lr_scheduler' in state_dict: # backward compatbility if 'lr_scheduler' in optim_state_dict: # backward compatbility
opt_param_scheduler.load_state_dict(optim_state_dict['lr_scheduler']) opt_param_scheduler.load_state_dict(optim_state_dict['lr_scheduler'])
else: else:
opt_param_scheduler.load_state_dict(optim_state_dict['opt_param_scheduler']) opt_param_scheduler.load_state_dict(optim_state_dict['opt_param_scheduler'])
...@@ -466,13 +496,13 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -466,13 +496,13 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# rng states. # rng states.
if not release and not args.finetune and not args.no_load_rng: if not release and not args.finetune and not args.no_load_rng:
try: try:
if 'rng_state' in state_dict: if 'rng_state' in model_state_dict:
# access rng_state for data parallel rank # access rng_state for data parallel rank
if args.data_parallel_random_init: if args.data_parallel_random_init:
rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()] rng_state = model_state_dict['rng_state'][mpu.get_data_parallel_rank()]
else: else:
rng_state = state_dict['rng_state'][0] rng_state = model_state_dict['rng_state'][0]
random.setstate(rng_state['random_rng_state']) random.setstate(rng_state['random_rng_state'])
np.random.set_state(rng_state['np_rng_state']) np.random.set_state(rng_state['np_rng_state'])
torch.set_rng_state(rng_state['torch_rng_state']) torch.set_rng_state(rng_state['torch_rng_state'])
...@@ -483,15 +513,15 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -483,15 +513,15 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
mpu.get_cuda_rng_tracker().set_states( mpu.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states']) rng_state['rng_tracker_states'])
else: # backward compatability else: # backward compatability
random.setstate(state_dict['random_rng_state']) random.setstate(model_state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state']) np.random.set_state(model_state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state']) torch.set_rng_state(model_state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state']) torch.cuda.set_rng_state(model_state_dict['cuda_rng_state'])
# Check for empty states array # Check for empty states array
if not state_dict['rng_tracker_states']: if not model_state_dict['rng_tracker_states']:
raise KeyError raise KeyError
mpu.get_cuda_rng_tracker().set_states( mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states']) model_state_dict['rng_tracker_states'])
except KeyError: except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. ' print_rank_0('Unable to load rng state from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent ' 'Specify --no-load-rng or --finetune to prevent '
...@@ -500,6 +530,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -500,6 +530,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
sys.exit() sys.exit()
# Some utilities want to load a checkpoint without distributed being initialized # Some utilities want to load a checkpoint without distributed being initialized
# pax({"hi.": "there."})
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
torch.distributed.barrier() torch.distributed.barrier()
...@@ -526,12 +557,14 @@ def load_biencoder_checkpoint(model, only_query_model=False, ...@@ -526,12 +557,14 @@ def load_biencoder_checkpoint(model, only_query_model=False,
with open(tracker_filename, 'r') as f: with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip()) iteration = int(f.read().strip())
checkpoint_name = get_checkpoint_name(load_path, iteration, False) checkpoint_name, _ = get_checkpoint_names(load_path, iteration,
args.use_distributed_optimizer,
False)
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format( print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name)) torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu') state_dict = torch.load(model_checkpoint_name, map_location='cpu')
ret_state_dict = state_dict['model'] ret_state_dict = state_dict['model']
if only_query_model: if only_query_model:
......
...@@ -308,7 +308,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -308,7 +308,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
state_dict['grad_scaler'] = self.grad_scaler.state_dict() state_dict['grad_scaler'] = self.grad_scaler.state_dict()
# state_dict['params'] = \ # state_dict['params'] = \
# [ p for g in self.optimizer.param_groups for p in g["params"] ] # [ p for g in self.optimizer.param_groups for p in g["params"] ]
state_dict['groups'] = [g["params"] for g in self.optimizer.param_groups] state_dict['groups'] = [g['params'] for g in self.optimizer.param_groups]
# pax(0, { # ... only called on model rank 0 # pax(0, { # ... only called on model rank 0
# # "optimizer" : self.optimizer, # # "optimizer" : self.optimizer,
# "state_dict" : state_dict, # "state_dict" : state_dict,
...@@ -348,20 +348,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -348,20 +348,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Copy data for the main params. # Copy data for the main params.
current_groups = [ g["params"] for g in self.optimizer.param_groups ] current_groups = [ g["params"] for g in self.optimizer.param_groups ]
params_key = 'params' assert "groups" in state_dict, "key 'groups' not in state_dict."
assert params_key in state_dict, "key 'params' not in state_dict."
# pax(0, { # pax(0, {
# "state_dict" : state_dict, # "state_dict" : state_dict,
# "current_groups" : current_groups, # "current_groups" : current_groups,
# "saved_groups" : state_dict[params_key], # "saved_groups" : state_dict[params_key],
# }) # })
for current_group, saved_group in zip( for current_group, saved_group in zip(current_groups, state_dict["groups"]):
current_groups, # pax(0, {
state_dict[params_key]): # "current_group" : current_group,
pax(0, { # "saved_group" : saved_group,
"current_group" : current_group, # })
"saved_group" : saved_group,
})
for current_param, saved_param in zip(current_group, saved_group): for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data) current_param.data.copy_(saved_param.data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment