Commit 49a4678c authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

save experts separately in each data parallel rank

parent 89d6c794
...@@ -3,8 +3,11 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two ...@@ -3,8 +3,11 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification. lines of modification.
See `examples/megatron` for usage instructions. See `examples/megatron` for usage instructions.
""" """
import os
import math import math
import numpy as np import numpy as np
import random
from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -361,3 +364,105 @@ class DistributedDataParallel(DistributedGroupedDataParallel): ...@@ -361,3 +364,105 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
Keep consitency with Megatron Keep consitency with Megatron
""" """
return self.module.load_state_dict(*args, **kwargs) return self.module.load_state_dict(*args, **kwargs)
def get_checkpoint_name(checkpoints_path, iteration,
release=False):
"""A unified checkpoint name."""
from megatron import mpu
if release:
directory = 'release'
else:
directory = 'iter_{:07d}'.format(iteration)
# Use both the tensor and pipeline MP rank.
if mpu.get_pipeline_model_parallel_world_size() == 1:
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}_dp_rank_{:04d}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_data_parallel_rank()
),
'model_optim_rng.pt')
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}_{:03d}_dp_rank_{:04d}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
mpu.get_data_parallel_rank()
),
'model_optim_rng.pt')
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint with expert parallel """
from megatron import get_args
from megatron import mpu
args = get_args()
# Only rank zero of the data parallel writes to the disk.
if isinstance(model, DistributedDataParallel):
model = model.module
if torch.distributed.get_rank() == 0:
print('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save), flush=True)
data_parallel_rank = mpu.get_data_parallel_rank()
# Arguments, iteration, and model.
state_dict = {}
state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration
keep_vars = False if mpu.get_data_parallel_rank() == 0 else True
state_dict['model'] = model.state_dict_for_save_checkpoint(keep_vars=keep_vars)
if mpu.get_data_parallel_rank() != 0:
def extract_expert_param(state_dict, expert_dp_comm='none'):
state_dict_new = state_dict.__class__()
for k, v in state_dict.items():
# megatron uses both dict and OrderedDict in its state_dict
if isinstance(v, OrderedDict) or isinstance(v, dict):
v_new = extract_expert_param(v, expert_dp_comm)
if len(v_new):
state_dict_new[k] = v_new
elif hasattr(v, 'dp_comm') and v.dp_comm == expert_dp_comm:
state_dict_new[k] = v.detach()
return state_dict_new
state_dict['model'] = extract_expert_param(state_dict['model'], 'none')
# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict()
if lr_scheduler is not None:
state_dict['lr_scheduler'] = lr_scheduler.state_dict()
# RNG states.
if not args.no_save_rng:
state_dict['random_rng_state'] = random.getstate()
state_dict['np_rng_state'] = np.random.get_state()
state_dict['torch_rng_state'] = torch.get_rng_state()
state_dict['cuda_rng_state'] = torch.cuda.get_rng_state()
state_dict['rng_tracker_states'] \
= mpu.get_cuda_rng_tracker().get_states()
# Save.
checkpoint_name = get_checkpoint_name(args.save, iteration)
from megatron.checkpointing import ensure_directory_exists
from megatron.checkpointing import get_checkpoint_tracker_filename
ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name)
# Wait so everyone is done (necessary)
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' successfully saved checkpoint at iteration {:7d} to {}'.format(
iteration, args.save), flush=True)
# And update the latest iteration
if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))
# Wait so everyone is done (not necessary)
torch.distributed.barrier()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment