"examples/pytorch/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "b2b531e0412f1207a8d17ea24cfbece77490053e"
Commit 6ca840b5 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

load checkpoint for expert parallel

parent 1c69da9c
...@@ -4,6 +4,7 @@ lines of modification. ...@@ -4,6 +4,7 @@ lines of modification.
See `examples/megatron` for usage instructions. See `examples/megatron` for usage instructions.
""" """
import os import os
import sys
import math import math
import random import random
from collections import OrderedDict from collections import OrderedDict
...@@ -207,10 +208,15 @@ class DistributedDataParallel(DistributedGroupedDataParallel): ...@@ -207,10 +208,15 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
""" """
return self.module.load_state_dict(*args, **kwargs) return self.module.load_state_dict(*args, **kwargs)
def get_checkpoint_name(checkpoints_path, iteration, def get_fmoe_checkpoint_name(checkpoints_path, iteration,
release=False): release=False, data_parallel_rank=-1):
"""A unified checkpoint name.""" """A unified checkpoint name, allowing specifying a data parallel rank"""
from megatron import mpu from megatron import mpu
from megatron.checkpointing import get_checkpoint_name
if data_parallel_rank == -1:
data_parallel_rank = mpu.get_data_parallel_rank()
if data_parallel_rank == 0:
return get_checkpoint_name(checkpoints_path, iteration, release)
if release: if release:
directory = 'release' directory = 'release'
...@@ -221,14 +227,14 @@ def get_checkpoint_name(checkpoints_path, iteration, ...@@ -221,14 +227,14 @@ def get_checkpoint_name(checkpoints_path, iteration,
return os.path.join(checkpoints_path, directory, return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}_dp_rank_{:04d}'.format( 'mp_rank_{:02d}_dp_rank_{:04d}'.format(
mpu.get_tensor_model_parallel_rank(), mpu.get_tensor_model_parallel_rank(),
mpu.get_data_parallel_rank() data_parallel_rank
), ),
'model_optim_rng.pt') 'model_optim_rng.pt')
return os.path.join(checkpoints_path, directory, return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}_{:03d}_dp_rank_{:04d}'.format( 'mp_rank_{:02d}_{:03d}_dp_rank_{:04d}'.format(
mpu.get_tensor_model_parallel_rank(), mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank(),
mpu.get_data_parallel_rank() data_parallel_rank
), ),
'model_optim_rng.pt') 'model_optim_rng.pt')
...@@ -238,8 +244,13 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n ...@@ -238,8 +244,13 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
args = get_args() if mpu.get_data_parallel_rank() == 0:
# at dp rank 0, we still follows the native load_checkpoint by megatron
from megatron.checkpointing import save_checkpoint as save_checkpoint_native
save_checkpoint_native(iteration, model, optimizer, lr_scheduler)
return
args = get_args()
# Only rank zero of the data parallel writes to the disk. # Only rank zero of the data parallel writes to the disk.
if isinstance(model, DistributedDataParallel): if isinstance(model, DistributedDataParallel):
...@@ -257,29 +268,26 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n ...@@ -257,29 +268,26 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
state_dict['model'] = model.state_dict_for_save_checkpoint( state_dict['model'] = model.state_dict_for_save_checkpoint(
keep_vars=(mpu.get_data_parallel_rank() > 0)) keep_vars=(mpu.get_data_parallel_rank() > 0))
if mpu.get_data_parallel_rank() > 0: def extract_expert_param(state_dict, expert_dp_comm='none'):
state_dict_new = state_dict.__class__()
def extract_expert_param(state_dict, expert_dp_comm='none'): for k, v in state_dict.items():
state_dict_new = state_dict.__class__() # megatron uses both dict and OrderedDict in its state_dict
for k, v in state_dict.items(): if isinstance(v, (OrderedDict, dict)):
# megatron uses both dict and OrderedDict in its state_dict v_new = extract_expert_param(v, expert_dp_comm)
if isinstance(v, (OrderedDict, dict)): if len(v_new) > 0:
v_new = extract_expert_param(v, expert_dp_comm) state_dict_new[k] = v_new
if len(v_new) > 0: elif hasattr(v, 'dp_comm') and v.dp_comm == expert_dp_comm:
state_dict_new[k] = v_new state_dict_new[k] = v.detach()
elif hasattr(v, 'dp_comm') and v.dp_comm == expert_dp_comm: return state_dict_new
state_dict_new[k] = v.detach()
return state_dict_new state_dict['model'] = extract_expert_param(
state_dict['model'],
state_dict['model'] = extract_expert_param( expert_dp_comm)
state_dict['model'],
expert_dp_comm)
# Optimizer stuff. # Optimizer stuff.
if not args.no_save_optim: if not args.no_save_optim:
if optimizer is not None: if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict() state_dict['optimizer'] = optimizer.state_dict()
if mpu.get_data_parallel_rank() > 0:
index = 0 index = 0
for param_group in optimizer.optimizer.param_groups: for param_group in optimizer.optimizer.param_groups:
for param in param_group['params']: for param in param_group['params']:
...@@ -304,7 +312,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n ...@@ -304,7 +312,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
= mpu.get_cuda_rng_tracker().get_states() = mpu.get_cuda_rng_tracker().get_states()
# Save. # Save.
checkpoint_name = get_checkpoint_name(args.save, iteration) checkpoint_name = get_fmoe_checkpoint_name(args.save, iteration)
from megatron.checkpointing import ensure_directory_exists from megatron.checkpointing import ensure_directory_exists
from megatron.checkpointing import get_checkpoint_tracker_filename from megatron.checkpointing import get_checkpoint_tracker_filename
ensure_directory_exists(checkpoint_name) ensure_directory_exists(checkpoint_name)
...@@ -322,3 +330,182 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n ...@@ -322,3 +330,182 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
f.write(str(iteration)) f.write(str(iteration))
# Wait so everyone is done (not necessary) # Wait so everyone is done (not necessary)
torch.distributed.barrier() torch.distributed.barrier()
def merge_state_dict(state_dict_rank0, state_dict_local):
"""merge two state dicts, one from data parallel rank 0,
another only contains expert states"""
from megatron import print_rank_last
def merge_model(state_dict_rank0, state_dict_local):
for k, v in state_dict_local.items():
# megatron uses both dict and OrderedDict in its state_dict
if isinstance(v, (OrderedDict, dict)):
print_rank_last("[merge model] go recursively to {}".format(k))
merge_model(state_dict_rank0[k], v)
else:
before = state_dict_rank0[k].sum().item()
state_dict_rank0[k] = v
after = state_dict_rank0[k].sum().item()
print_rank_last("[merge model] copy parameter {}, \
before.sum={:7f}, after.sum={:7f}".format(k, before, after))
merge_model(state_dict_rank0['model'], state_dict_local['model'])
for k, v in state_dict_local['optimizer']['state'].items():
before = {kk: vv.sum().item() \
for kk, vv in state_dict_rank0['optimizer']['state'][k].items()}
state_dict_rank0['optimizer']['state'][k] = v
after = {kk: vv.sum().item() \
for kk, vv in state_dict_rank0['optimizer']['state'][k].items()}
print_rank_last("[merge optimizer] copy {}, \
before.sum={}, after.sum={}".format(k, str(before), str(after)))
return state_dict_rank0
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
"""Load a model checkpoint and return the iteration."""
from megatron import get_args
from megatron import mpu
from megatron import print_rank_last
from megatron.checkpointing import get_checkpoint_tracker_filename, set_checkpoint_version, check_checkpoint_args, update_num_microbatches
if mpu.get_data_parallel_rank() == 0:
# at dp rank 0, we still follow the native load_checkpoint by megatron
from megatron.checkpointing import load_checkpoint as load_checkpoint_native
return load_checkpoint_native(model, optimizer, lr_scheduler, load_arg)
args = get_args()
load_dir = getattr(args, load_arg)
if isinstance(model, DistributedDataParallel):
model = model.module
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir)
# If no tracker file, return iretation zero.
if not os.path.isfile(tracker_filename):
print_rank_last('WARNING: could not find the metadata file {} '.format(
tracker_filename))
print_rank_last(' will not load any checkpoints and will start from '
'random')
return 0
# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration = 0
release = False
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == 'release'
if not release:
print_rank_last('ERROR: Invalid metadata file {}. Exiting'.format(
tracker_filename))
sys.exit()
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)
# Checkpoint.
checkpoint_name_rank0 = get_fmoe_checkpoint_name(
load_dir, iteration, release, 0)
checkpoint_name_local = get_fmoe_checkpoint_name(
load_dir, iteration, release, mpu.get_data_parallel_rank())
print_rank_last(' loading checkpoint at rank 0 from {} and rank {} from {} at iteration {}, will merge them later'.format(
checkpoint_name_rank0, mpu.get_data_parallel_rank(),
checkpoint_name_local, iteration))
# Load the checkpoint.
def load_state_dict(checkpoint_name):
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
except ModuleNotFoundError:
from megatron.fp16_deprecated import loss_scaler
# For backward compatibility.
print_rank_last(' > deserializing using the old code structure ...')
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None)
except BaseException:
print_rank_last('could not load the checkpoint')
sys.exit()
return state_dict
state_dict_rank0 = load_state_dict(checkpoint_name_rank0)
state_dict_local = load_state_dict(checkpoint_name_local)
state_dict = merge_state_dict(state_dict_rank0, state_dict_local)
# set checkpoint version
set_checkpoint_version(state_dict.get('checkpoint_version', 0))
# Set iteration.
if args.finetune or release:
iteration = 0
else:
try:
iteration = state_dict['iteration']
except KeyError:
try: # Backward compatible with older checkpoints
iteration = state_dict['total_iters']
except KeyError:
print_rank_last('A metadata file exists but unable to load '
'iteration from checkpoint {}, exiting'.format(
checkpoint_name_local))
sys.exit()
# Check arguments.
assert args.consumed_train_samples == 0
assert args.consumed_valid_samples == 0
if 'args' in state_dict:
checkpoint_args = state_dict['args']
check_checkpoint_args(checkpoint_args)
args.consumed_train_samples = getattr(checkpoint_args,
'consumed_train_samples', 0)
update_num_microbatches(consumed_samples=args.consumed_train_samples)
args.consumed_valid_samples = getattr(checkpoint_args,
'consumed_valid_samples', 0)
else:
print_rank_last('could not find arguments in the checkpoint ...')
# Model.
model.load_state_dict(state_dict['model'])
# Optimizer.
if not release and not args.finetune and not args.no_load_optim:
try:
if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer'])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
except KeyError:
print_rank_last('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'.format(checkpoint_name_local))
sys.exit()
# rng states.
if not release and not args.finetune and not args.no_load_rng:
try:
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
except KeyError:
print_rank_last('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'.format(checkpoint_name_local))
sys.exit()
torch.distributed.barrier()
print_rank_last(' successfully loaded checkpoint (with expert parametes updated) from {} at iteration {}'.format(
args.load, iteration))
return iteration
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment