Commit 0f2a9f19 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

cleaned/commented checkpointing.py.

parent 9c86abd9
...@@ -28,10 +28,6 @@ from megatron import (get_args, ...@@ -28,10 +28,6 @@ from megatron import (get_args,
update_num_microbatches, update_num_microbatches,
utils) utils)
# >>>
from lutil import pax
# <<<
_CHECKPOINT_VERSION = None _CHECKPOINT_VERSION = None
def set_checkpoint_version(value): def set_checkpoint_version(value):
...@@ -85,25 +81,6 @@ def ensure_directory_exists(filename): ...@@ -85,25 +81,6 @@ def ensure_directory_exists(filename):
os.makedirs(dirname) os.makedirs(dirname)
# >>
# def get_checkpoint_name(checkpoints_path, iteration,
# release=False):
# """A unified checkpoint name."""
# if release:
# directory = 'release'
# else:
# directory = 'iter_{:07d}'.format(iteration)
# # Use both the tensor and pipeline MP rank.
# if mpu.get_pipeline_model_parallel_world_size() == 1:
# return os.path.join(checkpoints_path, directory,
# 'mp_rank_{:02d}'.format(
# mpu.get_tensor_model_parallel_rank()),
# 'model_optim_rng.pt')
# return os.path.join(checkpoints_path, directory,
# 'mp_rank_{:02d}_{:03d}'.format(
# mpu.get_tensor_model_parallel_rank(),
# mpu.get_pipeline_model_parallel_rank()),
# 'model_optim_rng.pt')
def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
release=False): release=False):
"""A unified checkpoint name.""" """A unified checkpoint name."""
...@@ -111,7 +88,9 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, ...@@ -111,7 +88,9 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
directory = 'release' directory = 'release'
else: else:
directory = 'iter_{:07d}'.format(iteration) directory = 'iter_{:07d}'.format(iteration)
# Use both the tensor and pipeline MP rank. # Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the
# data parallel rank.
common_path = os.path.join( common_path = os.path.join(
checkpoints_path, checkpoints_path,
directory, directory,
...@@ -126,7 +105,6 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, ...@@ -126,7 +105,6 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
else: else:
optim_name = os.path.join(common_path, "optim.pt") optim_name = os.path.join(common_path, "optim.pt")
return model_name, optim_name return model_name, optim_name
# <<<
def get_checkpoint_tracker_filename(checkpoints_path): def get_checkpoint_tracker_filename(checkpoints_path):
...@@ -212,11 +190,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -212,11 +190,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
model_checkpoint_name, optim_checkpoint_name = \ model_checkpoint_name, optim_checkpoint_name = \
get_checkpoint_names(args.save, iteration, args.use_distributed_optimizer) get_checkpoint_names(args.save, iteration, args.use_distributed_optimizer)
pax(0, {
"model_checkpoint_name" : model_checkpoint_name,
"optim_checkpoint_name" : optim_checkpoint_name,
})
# Save args, model, RNG. # Save args, model, RNG.
if not torch.distributed.is_initialized() \ if not torch.distributed.is_initialized() \
or mpu.get_data_parallel_rank() == 0: or mpu.get_data_parallel_rank() == 0:
...@@ -233,15 +206,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -233,15 +206,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
mpu.set_virtual_pipeline_model_parallel_rank(i) mpu.set_virtual_pipeline_model_parallel_rank(i)
state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint() state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()
# >>>
# # Optimizer stuff.
# if not args.no_save_optim:
# if optimizer is not None:
# state_dict['optimizer'] = optimizer.state_dict()
# if opt_param_scheduler is not None:
# state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict()
# <<<
# RNG states. # RNG states.
if not args.no_save_rng: if not args.no_save_rng:
state_dict["rng_state"] = rng_state state_dict["rng_state"] = rng_state
...@@ -250,8 +214,8 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -250,8 +214,8 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
ensure_directory_exists(model_checkpoint_name) ensure_directory_exists(model_checkpoint_name)
torch.save(state_dict, model_checkpoint_name) torch.save(state_dict, model_checkpoint_name)
# >>> # Save optimizer state. (Optimizer is saved separately from the model, due
# Save optimizer state. # to the conflicting data pattern when using the distributed optimizer.)
if not args.no_save_optim \ if not args.no_save_optim \
and (not torch.distributed.is_initialized() and (not torch.distributed.is_initialized()
or mpu.get_data_parallel_rank() == 0 or mpu.get_data_parallel_rank() == 0
...@@ -267,14 +231,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -267,14 +231,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
# Save. # Save.
ensure_directory_exists(optim_checkpoint_name) ensure_directory_exists(optim_checkpoint_name)
torch.save(state_dict, optim_checkpoint_name) torch.save(state_dict, optim_checkpoint_name)
# >>>
# pax({
# "model_checkpoint_name" : model_checkpoint_name,
# "optim_checkpoint_name" : optim_checkpoint_name,
# "state_dict" : state_dict,
# })
# <<<
# <<<
# Wait so everyone is done (necessary) # Wait so everyone is done (necessary)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
...@@ -415,11 +371,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -415,11 +371,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
print_rank_0(e) print_rank_0(e)
sys.exit() sys.exit()
# >>> # Set checkpoint version.
pax({"hi.": "there."})
# <<<
# set checkpoint version
set_checkpoint_version(model_state_dict.get('checkpoint_version', 0)) set_checkpoint_version(model_state_dict.get('checkpoint_version', 0))
# Set iteration. # Set iteration.
...@@ -464,19 +416,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -464,19 +416,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
print_rank_0(f' checkpoint version {checkpoint_version}') print_rank_0(f' checkpoint version {checkpoint_version}')
fix_query_key_value_ordering(model, checkpoint_version) fix_query_key_value_ordering(model, checkpoint_version)
# >>>
# pax(0, {
# "model_state_dict" : model_state_dict,
# "optim_state_dict" : optim_state_dict,
# })
# <<<
# Optimizer. # Optimizer.
pax({
"release" : release,
"finetune" : args.finetune,
"no_load_optim" : args.no_load_optim,
})
if not release and not args.finetune and not args.no_load_optim: if not release and not args.finetune and not args.no_load_optim:
try: try:
if optimizer is not None: if optimizer is not None:
...@@ -530,7 +470,6 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -530,7 +470,6 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
sys.exit() sys.exit()
# Some utilities want to load a checkpoint without distributed being initialized # Some utilities want to load a checkpoint without distributed being initialized
# pax({"hi.": "there."})
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
torch.distributed.barrier() torch.distributed.barrier()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment