"src/vscode:/vscode.git/clone" did not exist on "53bc30dd45cf2e4f85bd42c9b69c50347d6bde6d"
Commit 37ca7859 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

started updating checkpointing.py.

parent f6811e28
...@@ -81,7 +81,26 @@ def ensure_directory_exists(filename): ...@@ -81,7 +81,26 @@ def ensure_directory_exists(filename):
os.makedirs(dirname) os.makedirs(dirname)
def get_checkpoint_name(checkpoints_path, iteration, # >>
# def get_checkpoint_name(checkpoints_path, iteration,
# release=False):
# """A unified checkpoint name."""
# if release:
# directory = 'release'
# else:
# directory = 'iter_{:07d}'.format(iteration)
# # Use both the tensor and pipeline MP rank.
# if mpu.get_pipeline_model_parallel_world_size() == 1:
# return os.path.join(checkpoints_path, directory,
# 'mp_rank_{:02d}'.format(
# mpu.get_tensor_model_parallel_rank()),
# 'model_optim_rng.pt')
# return os.path.join(checkpoints_path, directory,
# 'mp_rank_{:02d}_{:03d}'.format(
# mpu.get_tensor_model_parallel_rank(),
# mpu.get_pipeline_model_parallel_rank()),
# 'model_optim_rng.pt')
def get_checkpoint_names(checkpoints_path, iteration,
release=False): release=False):
"""A unified checkpoint name.""" """A unified checkpoint name."""
if release: if release:
...@@ -89,16 +108,17 @@ def get_checkpoint_name(checkpoints_path, iteration, ...@@ -89,16 +108,17 @@ def get_checkpoint_name(checkpoints_path, iteration,
else: else:
directory = 'iter_{:07d}'.format(iteration) directory = 'iter_{:07d}'.format(iteration)
# Use both the tensor and pipeline MP rank. # Use both the tensor and pipeline MP rank.
if mpu.get_pipeline_model_parallel_world_size() == 1: common_path = os.path.join(
return os.path.join(checkpoints_path, directory, checkpoints_path,
'mp_rank_{:02d}'.format( directory,
mpu.get_tensor_model_parallel_rank()), "mp_rank_%02d_%03d_%03d" % (
'model_optim_rng.pt') mpu.get_tensor_model_parallel_rank(),
return os.path.join(checkpoints_path, directory, mpu.get_pipeline_model_parallel_rank(),
'mp_rank_{:02d}_{:03d}'.format( mpu.get_data_parallel_rank()))
mpu.get_tensor_model_parallel_rank(), model_name = os.path.join(common_path, "model_rng.pt")
mpu.get_pipeline_model_parallel_rank()), optim_name = os.path.join(common_path, "optim.pt")
'model_optim_rng.pt') return model_name, optim_name
# <<<
def get_checkpoint_tracker_filename(checkpoints_path): def get_checkpoint_tracker_filename(checkpoints_path):
...@@ -177,10 +197,16 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -177,10 +197,16 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save)) iteration, args.save))
# collect rng state across data parallel ranks # Collect rng state across data parallel ranks.
rng_state = get_rng_state() rng_state = get_rng_state()
if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0: # Checkpoint file names.
model_checkpoint_name, optim_checkpoint_name = \
get_checkpoint_names(args.save, iteration)
# Save args, model, RNG.
if not torch.distributed.is_initialized() \
or mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model. # Arguments, iteration, and model.
state_dict = {} state_dict = {}
...@@ -194,21 +220,49 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -194,21 +220,49 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
mpu.set_virtual_pipeline_model_parallel_rank(i) mpu.set_virtual_pipeline_model_parallel_rank(i)
state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint() state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()
# Optimizer stuff. # >>>
if not args.no_save_optim: # # Optimizer stuff.
if optimizer is not None: # if not args.no_save_optim:
state_dict['optimizer'] = optimizer.state_dict() # if optimizer is not None:
if opt_param_scheduler is not None: # state_dict['optimizer'] = optimizer.state_dict()
state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict() # if opt_param_scheduler is not None:
# state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict()
# <<<
# RNG states. # RNG states.
if not args.no_save_rng: if not args.no_save_rng:
state_dict["rng_state"] = rng_state state_dict["rng_state"] = rng_state
# Save. # Save.
checkpoint_name = get_checkpoint_name(args.save, iteration) ensure_directory_exists(model_checkpoint_name)
ensure_directory_exists(checkpoint_name) torch.save(state_dict, model_checkpoint_name)
torch.save(state_dict, checkpoint_name)
# >>>
# Save optimizer state.
if not args.no_save_optim \
and (not torch.distributed.is_initialized()
or mpu.get_data_parallel_rank() == 0
or args.use_distributed_optimizer):
# Optimizer stuff.
state_dict = {}
if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict()
if opt_param_scheduler is not None:
state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict()
# Save.
ensure_directory_exists(optim_checkpoint_name)
torch.save(state_dict, optim_checkpoint_name)
# >>>
# from lutil import pax
# pax({
# "model_checkpoint_name" : model_checkpoint_name,
# "optim_checkpoint_name" : optim_checkpoint_name,
# "state_dict" : state_dict,
# })
# <<<
# <<<
# Wait so everyone is done (necessary) # Wait so everyone is done (necessary)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
...@@ -322,12 +376,14 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -322,12 +376,14 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
iteration, release = read_metadata(tracker_filename) iteration, release = read_metadata(tracker_filename)
# Checkpoint. # Checkpoint.
checkpoint_name = get_checkpoint_name(load_dir, iteration, release) model_checkpoint_name, optim_checkpoint_name = \
get_checkpoint_names(load_dir, iteration, release)
print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}') print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')
# Load the checkpoint. # Load the checkpoint.
try: try:
state_dict = torch.load(checkpoint_name, map_location='cpu') model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
except ModuleNotFoundError: except ModuleNotFoundError:
from megatron.fp16_deprecated import loss_scaler from megatron.fp16_deprecated import loss_scaler
# For backward compatibility. # For backward compatibility.
...@@ -336,7 +392,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -336,7 +392,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
'megatron.fp16_deprecated.loss_scaler'] 'megatron.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[ sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler'] 'megatron.fp16_deprecated.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu') model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None) sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None) sys.modules.pop('megatron.fp16.loss_scaler', None)
except BaseException as e: except BaseException as e:
......
...@@ -295,12 +295,64 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -295,12 +295,64 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def get_main_grad(self, group_index): def get_main_grad(self, group_index):
return self.get_main_param(group_index).grad return self.get_main_param(group_index).grad
def load_state_dict(self): # def load_state_dict(self):
raise Exception("hi.") # raise Exception("hi.")
def reload_model_params(self): # # def reload_model_params(self): # ... done in MixedPrecisionOptimizer
raise Exception("hi.") # # raise Exception("hi.")
# def state_dict(self):
# raise Exception("hi.")
def state_dict(self): def state_dict(self):
raise Exception("hi.") state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict()
if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
state_dict['params'] = \
[ p for g in self.optimizer.param_groups for p in g["params"] ]
# pax(0, { # ... only called on model rank 0
# # "optimizer" : self.optimizer,
# "state_dict" : state_dict,
# "state_dict / param_groups" : state_dict["optimizer"]["param_groups"],
# "optimizer / groups" : self.optimizer.param_groups,
# "state_dict / params" : [ p.shape for p in state_dict["params"] ],
# "optimizer / params" :
# [ p.shape for g in self.optimizer.param_groups for p in g["params"] ],
# })
return state_dict
def load_state_dict(self, state_dict):
# Optimizer.
optimizer_key = 'optimizer'
if optimizer_key not in state_dict:
optimizer_key = 'optimizer_state_dict'
print_rank_0('***WARNING*** loading optimizer from '
'an old checkpoint ...')
self.optimizer.load_state_dict(state_dict[optimizer_key])
pax(0, {
"state_dict" : state_dict,
"params" : state_dict["params"],
})
# Grad scaler.
if 'grad_scaler' not in state_dict:
print_rank_0('***WARNING*** found an old checkpoint, will not '
'load grad scaler ...')
else:
if self.grad_scaler:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
else:
print_rank_0('***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...')
# Copy data for the main params.
params_key = 'params'
assert params_key in state_dict, "key 'params' not in state_dict."
for current_group, saved_group in zip(
self.fp32_from_float16_groups,
state_dict[fp32_from_float16_params_key]):
for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data)
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment