Commit 49a4678c authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

save experts separately in each data parallel rank

parent 89d6c794
......@@ -3,8 +3,11 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `examples/megatron` for usage instructions.
"""
import os
import math
import numpy as np
import random
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -361,3 +364,105 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
Keep consitency with Megatron
"""
return self.module.load_state_dict(*args, **kwargs)
def get_checkpoint_name(checkpoints_path, iteration,
release=False):
"""A unified checkpoint name."""
from megatron import mpu
if release:
directory = 'release'
else:
directory = 'iter_{:07d}'.format(iteration)
# Use both the tensor and pipeline MP rank.
if mpu.get_pipeline_model_parallel_world_size() == 1:
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}_dp_rank_{:04d}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_data_parallel_rank()
),
'model_optim_rng.pt')
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}_{:03d}_dp_rank_{:04d}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
mpu.get_data_parallel_rank()
),
'model_optim_rng.pt')
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint with expert parallel """
from megatron import get_args
from megatron import mpu
args = get_args()
# Only rank zero of the data parallel writes to the disk.
if isinstance(model, DistributedDataParallel):
model = model.module
if torch.distributed.get_rank() == 0:
print('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save), flush=True)
data_parallel_rank = mpu.get_data_parallel_rank()
# Arguments, iteration, and model.
state_dict = {}
state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration
keep_vars = False if mpu.get_data_parallel_rank() == 0 else True
state_dict['model'] = model.state_dict_for_save_checkpoint(keep_vars=keep_vars)
if mpu.get_data_parallel_rank() != 0:
def extract_expert_param(state_dict, expert_dp_comm='none'):
state_dict_new = state_dict.__class__()
for k, v in state_dict.items():
# megatron uses both dict and OrderedDict in its state_dict
if isinstance(v, OrderedDict) or isinstance(v, dict):
v_new = extract_expert_param(v, expert_dp_comm)
if len(v_new):
state_dict_new[k] = v_new
elif hasattr(v, 'dp_comm') and v.dp_comm == expert_dp_comm:
state_dict_new[k] = v.detach()
return state_dict_new
state_dict['model'] = extract_expert_param(state_dict['model'], 'none')
# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict()
if lr_scheduler is not None:
state_dict['lr_scheduler'] = lr_scheduler.state_dict()
# RNG states.
if not args.no_save_rng:
state_dict['random_rng_state'] = random.getstate()
state_dict['np_rng_state'] = np.random.get_state()
state_dict['torch_rng_state'] = torch.get_rng_state()
state_dict['cuda_rng_state'] = torch.cuda.get_rng_state()
state_dict['rng_tracker_states'] \
= mpu.get_cuda_rng_tracker().get_states()
# Save.
checkpoint_name = get_checkpoint_name(args.save, iteration)
from megatron.checkpointing import ensure_directory_exists
from megatron.checkpointing import get_checkpoint_tracker_filename
ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name)
# Wait so everyone is done (necessary)
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' successfully saved checkpoint at iteration {:7d} to {}'.format(
iteration, args.save), flush=True)
# And update the latest iteration
if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))
# Wait so everyone is done (not necessary)
torch.distributed.barrier()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment