Commit 24541d6a authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

load checkpoint for expert parallel

parent f5a5d31a
......@@ -4,6 +4,7 @@ lines of modification.
See `examples/megatron` for usage instructions.
"""
import os
import sys
import math
import random
from collections import OrderedDict
......@@ -365,10 +366,15 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
"""
return self.module.load_state_dict(*args, **kwargs)
def get_checkpoint_name(checkpoints_path, iteration,
release=False):
"""A unified checkpoint name."""
def get_fmoe_checkpoint_name(checkpoints_path, iteration,
release=False, data_parallel_rank=-1):
"""A unified checkpoint name, allowing specifying a data parallel rank"""
from megatron import mpu
from megatron.checkpointing import get_checkpoint_name
if data_parallel_rank == -1:
data_parallel_rank = mpu.get_data_parallel_rank()
if data_parallel_rank == 0:
return get_checkpoint_name(checkpoints_path, iteration, release)
if release:
directory = 'release'
......@@ -379,14 +385,14 @@ def get_checkpoint_name(checkpoints_path, iteration,
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}_dp_rank_{:04d}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_data_parallel_rank()
data_parallel_rank
),
'model_optim_rng.pt')
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}_{:03d}_dp_rank_{:04d}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
mpu.get_data_parallel_rank()
data_parallel_rank
),
'model_optim_rng.pt')
......@@ -396,8 +402,13 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
from megatron import get_args
from megatron import mpu
args = get_args()
if mpu.get_data_parallel_rank() == 0:
# at dp rank 0, we still follows the native load_checkpoint by megatron
from megatron.checkpointing import save_checkpoint as save_checkpoint_native
save_checkpoint_native(iteration, model, optimizer, lr_scheduler)
return
args = get_args()
# Only rank zero of the data parallel writes to the disk.
if isinstance(model, DistributedDataParallel):
......@@ -415,29 +426,26 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
state_dict['model'] = model.state_dict_for_save_checkpoint(
keep_vars=(mpu.get_data_parallel_rank() > 0))
if mpu.get_data_parallel_rank() > 0:
def extract_expert_param(state_dict, expert_dp_comm='none'):
state_dict_new = state_dict.__class__()
for k, v in state_dict.items():
# megatron uses both dict and OrderedDict in its state_dict
if isinstance(v, (OrderedDict, dict)):
v_new = extract_expert_param(v, expert_dp_comm)
if len(v_new) > 0:
state_dict_new[k] = v_new
elif hasattr(v, 'dp_comm') and v.dp_comm == expert_dp_comm:
state_dict_new[k] = v.detach()
return state_dict_new
state_dict['model'] = extract_expert_param(
state_dict['model'],
expert_dp_comm)
def extract_expert_param(state_dict, expert_dp_comm='none'):
state_dict_new = state_dict.__class__()
for k, v in state_dict.items():
# megatron uses both dict and OrderedDict in its state_dict
if isinstance(v, (OrderedDict, dict)):
v_new = extract_expert_param(v, expert_dp_comm)
if len(v_new) > 0:
state_dict_new[k] = v_new
elif hasattr(v, 'dp_comm') and v.dp_comm == expert_dp_comm:
state_dict_new[k] = v.detach()
return state_dict_new
state_dict['model'] = extract_expert_param(
state_dict['model'],
expert_dp_comm)
# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict()
if mpu.get_data_parallel_rank() > 0:
index = 0
for param_group in optimizer.optimizer.param_groups:
for param in param_group['params']:
......@@ -462,7 +470,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
= mpu.get_cuda_rng_tracker().get_states()
# Save.
checkpoint_name = get_checkpoint_name(args.save, iteration)
checkpoint_name = get_fmoe_checkpoint_name(args.save, iteration)
from megatron.checkpointing import ensure_directory_exists
from megatron.checkpointing import get_checkpoint_tracker_filename
ensure_directory_exists(checkpoint_name)
......@@ -480,3 +488,182 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
f.write(str(iteration))
# Wait so everyone is done (not necessary)
torch.distributed.barrier()
def merge_state_dict(state_dict_rank0, state_dict_local):
"""merge two state dicts, one from data parallel rank 0,
another only contains expert states"""
from megatron import print_rank_last
def merge_model(state_dict_rank0, state_dict_local):
for k, v in state_dict_local.items():
# megatron uses both dict and OrderedDict in its state_dict
if isinstance(v, (OrderedDict, dict)):
print_rank_last("[merge model] go recursively to {}".format(k))
merge_model(state_dict_rank0[k], v)
else:
before = state_dict_rank0[k].sum().item()
state_dict_rank0[k] = v
after = state_dict_rank0[k].sum().item()
print_rank_last("[merge model] copy parameter {}, \
before.sum={:7f}, after.sum={:7f}".format(k, before, after))
merge_model(state_dict_rank0['model'], state_dict_local['model'])
for k, v in state_dict_local['optimizer']['state'].items():
before = {kk: vv.sum().item() \
for kk, vv in state_dict_rank0['optimizer']['state'][k].items()}
state_dict_rank0['optimizer']['state'][k] = v
after = {kk: vv.sum().item() \
for kk, vv in state_dict_rank0['optimizer']['state'][k].items()}
print_rank_last("[merge optimizer] copy {}, \
before.sum={}, after.sum={}".format(k, str(before), str(after)))
return state_dict_rank0
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
"""Load a model checkpoint and return the iteration."""
from megatron import get_args
from megatron import mpu
from megatron import print_rank_last
from megatron.checkpointing import get_checkpoint_tracker_filename, set_checkpoint_version, check_checkpoint_args, update_num_microbatches
if mpu.get_data_parallel_rank() == 0:
# at dp rank 0, we still follow the native load_checkpoint by megatron
from megatron.checkpointing import load_checkpoint as load_checkpoint_native
return load_checkpoint_native(model, optimizer, lr_scheduler, load_arg)
args = get_args()
load_dir = getattr(args, load_arg)
if isinstance(model, DistributedDataParallel):
model = model.module
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir)
# If no tracker file, return iretation zero.
if not os.path.isfile(tracker_filename):
print_rank_last('WARNING: could not find the metadata file {} '.format(
tracker_filename))
print_rank_last(' will not load any checkpoints and will start from '
'random')
return 0
# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration = 0
release = False
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == 'release'
if not release:
print_rank_last('ERROR: Invalid metadata file {}. Exiting'.format(
tracker_filename))
sys.exit()
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)
# Checkpoint.
checkpoint_name_rank0 = get_fmoe_checkpoint_name(
load_dir, iteration, release, 0)
checkpoint_name_local = get_fmoe_checkpoint_name(
load_dir, iteration, release, mpu.get_data_parallel_rank())
print_rank_last(' loading checkpoint at rank 0 from {} and rank {} from {} at iteration {}, will merge them later'.format(
checkpoint_name_rank0, mpu.get_data_parallel_rank(),
checkpoint_name_local, iteration))
# Load the checkpoint.
def load_state_dict(checkpoint_name):
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
except ModuleNotFoundError:
from megatron.fp16_deprecated import loss_scaler
# For backward compatibility.
print_rank_last(' > deserializing using the old code structure ...')
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None)
except BaseException:
print_rank_last('could not load the checkpoint')
sys.exit()
return state_dict
state_dict_rank0 = load_state_dict(checkpoint_name_rank0)
state_dict_local = load_state_dict(checkpoint_name_local)
state_dict = merge_state_dict(state_dict_rank0, state_dict_local)
# set checkpoint version
set_checkpoint_version(state_dict.get('checkpoint_version', 0))
# Set iteration.
if args.finetune or release:
iteration = 0
else:
try:
iteration = state_dict['iteration']
except KeyError:
try: # Backward compatible with older checkpoints
iteration = state_dict['total_iters']
except KeyError:
print_rank_last('A metadata file exists but unable to load '
'iteration from checkpoint {}, exiting'.format(
checkpoint_name_local))
sys.exit()
# Check arguments.
assert args.consumed_train_samples == 0
assert args.consumed_valid_samples == 0
if 'args' in state_dict:
checkpoint_args = state_dict['args']
check_checkpoint_args(checkpoint_args)
args.consumed_train_samples = getattr(checkpoint_args,
'consumed_train_samples', 0)
update_num_microbatches(consumed_samples=args.consumed_train_samples)
args.consumed_valid_samples = getattr(checkpoint_args,
'consumed_valid_samples', 0)
else:
print_rank_last('could not find arguments in the checkpoint ...')
# Model.
model.load_state_dict(state_dict['model'])
# Optimizer.
if not release and not args.finetune and not args.no_load_optim:
try:
if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer'])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
except KeyError:
print_rank_last('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'.format(checkpoint_name_local))
sys.exit()
# rng states.
if not release and not args.finetune and not args.no_load_rng:
try:
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
except KeyError:
print_rank_last('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'.format(checkpoint_name_local))
sys.exit()
torch.distributed.barrier()
print_rank_last(' successfully loaded checkpoint (with expert parametes updated) from {} at iteration {}'.format(
args.load, iteration))
return iteration
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment