Commit 0f2a9f19 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

cleaned/commented checkpointing.py.

parent 9c86abd9
......@@ -28,10 +28,6 @@ from megatron import (get_args,
update_num_microbatches,
utils)
# >>>
from lutil import pax
# <<<
_CHECKPOINT_VERSION = None
def set_checkpoint_version(value):
......@@ -85,25 +81,6 @@ def ensure_directory_exists(filename):
os.makedirs(dirname)
# >>
# def get_checkpoint_name(checkpoints_path, iteration,
# release=False):
# """A unified checkpoint name."""
# if release:
# directory = 'release'
# else:
# directory = 'iter_{:07d}'.format(iteration)
# # Use both the tensor and pipeline MP rank.
# if mpu.get_pipeline_model_parallel_world_size() == 1:
# return os.path.join(checkpoints_path, directory,
# 'mp_rank_{:02d}'.format(
# mpu.get_tensor_model_parallel_rank()),
# 'model_optim_rng.pt')
# return os.path.join(checkpoints_path, directory,
# 'mp_rank_{:02d}_{:03d}'.format(
# mpu.get_tensor_model_parallel_rank(),
# mpu.get_pipeline_model_parallel_rank()),
# 'model_optim_rng.pt')
def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
release=False):
"""A unified checkpoint name."""
......@@ -111,7 +88,9 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
directory = 'release'
else:
directory = 'iter_{:07d}'.format(iteration)
# Use both the tensor and pipeline MP rank.
# Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the
# data parallel rank.
common_path = os.path.join(
checkpoints_path,
directory,
......@@ -126,7 +105,6 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
else:
optim_name = os.path.join(common_path, "optim.pt")
return model_name, optim_name
# <<<
def get_checkpoint_tracker_filename(checkpoints_path):
......@@ -212,11 +190,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
model_checkpoint_name, optim_checkpoint_name = \
get_checkpoint_names(args.save, iteration, args.use_distributed_optimizer)
pax(0, {
"model_checkpoint_name" : model_checkpoint_name,
"optim_checkpoint_name" : optim_checkpoint_name,
})
# Save args, model, RNG.
if not torch.distributed.is_initialized() \
or mpu.get_data_parallel_rank() == 0:
......@@ -233,15 +206,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
mpu.set_virtual_pipeline_model_parallel_rank(i)
state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()
# >>>
# # Optimizer stuff.
# if not args.no_save_optim:
# if optimizer is not None:
# state_dict['optimizer'] = optimizer.state_dict()
# if opt_param_scheduler is not None:
# state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict()
# <<<
# RNG states.
if not args.no_save_rng:
state_dict["rng_state"] = rng_state
......@@ -250,8 +214,8 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
ensure_directory_exists(model_checkpoint_name)
torch.save(state_dict, model_checkpoint_name)
# >>>
# Save optimizer state.
# Save optimizer state. (Optimizer is saved separately from the model, due
# to the conflicting data pattern when using the distributed optimizer.)
if not args.no_save_optim \
and (not torch.distributed.is_initialized()
or mpu.get_data_parallel_rank() == 0
......@@ -267,14 +231,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
# Save.
ensure_directory_exists(optim_checkpoint_name)
torch.save(state_dict, optim_checkpoint_name)
# >>>
# pax({
# "model_checkpoint_name" : model_checkpoint_name,
# "optim_checkpoint_name" : optim_checkpoint_name,
# "state_dict" : state_dict,
# })
# <<<
# <<<
# Wait so everyone is done (necessary)
if torch.distributed.is_initialized():
......@@ -415,11 +371,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
print_rank_0(e)
sys.exit()
# >>>
pax({"hi.": "there."})
# <<<
# set checkpoint version
# Set checkpoint version.
set_checkpoint_version(model_state_dict.get('checkpoint_version', 0))
# Set iteration.
......@@ -464,19 +416,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
print_rank_0(f' checkpoint version {checkpoint_version}')
fix_query_key_value_ordering(model, checkpoint_version)
# >>>
# pax(0, {
# "model_state_dict" : model_state_dict,
# "optim_state_dict" : optim_state_dict,
# })
# <<<
# Optimizer.
pax({
"release" : release,
"finetune" : args.finetune,
"no_load_optim" : args.no_load_optim,
})
if not release and not args.finetune and not args.no_load_optim:
try:
if optimizer is not None:
......@@ -530,7 +470,6 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
sys.exit()
# Some utilities want to load a checkpoint without distributed being initialized
# pax({"hi.": "there."})
if torch.distributed.is_initialized():
torch.distributed.barrier()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment