Commit 0f091a1d authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add fastmoe project

parents
Pipeline #263 failed with stages
in 0 seconds
r"""
Supportive modules to conduct distributed training
"""
import torch
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .utils import get_torch_default_comm
class DistributedGroupedDataParallel(nn.Module):
r"""
A customized DDP module to support different all-reduce regions in the
model. The all-reduce region is defined as an attribution `dp_comm` in the
weight object.
The grads of the weights are identified to be reduced in different groups
according to the weigths' `dp_comm` attribute.
If it is set to `dp`, it will only be reduced across the data-parallel
groups, which means that in the model parallel group, they are not
synchronized.
If it is set to `world`, the gradients is synchronized across all workers,
regardless their model or data parallel group. This is extremely useful for
shared layers like the gate.
"""
def __init__(
self,
module,
auto_allreduce=False,
**kwargs
):
assert not auto_allreduce, "Automatic all-reduce is not implemented yet"
super().__init__()
self.module = module
self.comms = dict()
for k in kwargs:
if k.endswith('_group'):
self.comms[k[:-6]] = kwargs[k]
for k in ['dp', 'gate', 'moe', 'world']:
if k not in self.comms:
self.comms[k] = get_torch_default_comm()
def allreduce_params(no_scale=False,
reduce_after=False, fp32_allreduce=False):
groups = dict()
for p in self.module.parameters():
if not p.requires_grad or p.grad is None:
continue
if hasattr(p, "dp_comm"):
dp_comm = p.dp_comm
else:
dp_comm = "dp"
group_key = (dp_comm, p.dtype)
if group_key not in groups:
groups[group_key] = [p]
else:
groups[group_key].append(p)
for (dp_comm, dtype), group in groups.items():
if dp_comm not in self.comms:
continue
comm = self.comms[dp_comm]
grads = [p.grad.data for p in group]
coalesced = _flatten_dense_tensors(grads)
if fp32_allreduce and dtype != torch.float32:
coalesced = coalesced.float()
if not no_scale and not reduce_after:
coalesced /= comm.size()
torch.distributed.all_reduce(coalesced, group=comm)
torch.cuda.synchronize()
if not no_scale and reduce_after:
coalesced /= comm.size()
synced = _unflatten_dense_tensors(coalesced, grads)
for g, s in zip(grads, synced):
g.copy_(s)
self.allreduce_params = allreduce_params
self._sync_params()
def _sync_params(self):
groups = dict()
for p in self.module.parameters():
if not p.requires_grad or p.grad is None:
continue
if hasattr(p, "dp_comm"):
dp_comm = p.dp_comm
else:
dp_comm = "dp"
group_key = (dp_comm, p.dtype)
if group_key not in groups:
groups[group_key] = [p]
else:
groups[group_key].append(p)
for (dp_comm, _), group in groups.items():
if dp_comm not in self.comms:
continue
comm = self.comms[dp_comm]
datas = [p.data for p in group]
coalesced = _flatten_dense_tensors(datas)
torch.distributed.broadcast(coalesced, 0, group=comm)
torch.cuda.synchronize()
synced = _unflatten_dense_tensors(coalesced, datas)
for d, s in zip(datas, synced):
d.copy_(s)
def forward(self, *args, **kwargs):
r"""
Directly call the module's forward function.
"""
return self.module(*args, **kwargs)
r"""
The fmoe.functions module contains functions that are directly warped up from
C/CUDA functions to complete distributed communication, computation and gradient
computation.
"""
import torch
from torch.autograd import Function
import fmoe_cuda
from .utils import get_torch_default_comm
def ensure_comm(t, comm):
if comm is None:
comm = get_torch_default_comm()
fmoe_cuda.ensure_nccl(comm, t)
def count_by_gate(gate, num_expert, world_size, require_pos=True):
with torch.no_grad():
local_expert_count = torch.zeros(
num_expert * world_size, device=gate.device, dtype=torch.int32
)
fmoe_cuda.expert_count(gate, local_expert_count)
local_expert_count = local_expert_count.long()
if world_size > 1:
global_expert_count = fmoe_cuda.expert_exchange(
local_expert_count, num_expert, world_size
)
else:
global_expert_count = local_expert_count
if not require_pos:
pos = None
else:
lec_cum = torch.cumsum(local_expert_count, dim=0).int()
pos_size = lec_cum[-1].item()
pos = torch.empty((pos_size,), device=gate.device, dtype=torch.long)
fmoe_cuda.assign_pos(lec_cum, gate, pos)
return pos, local_expert_count, global_expert_count
def prepare_forward(gate, num_expert, world_size):
r"""
Prepare necessary information from gate output for MoE computation.
Args:
gate: a 1-d Long Tensor representing the target expert of each input
sample.
num_expert: number of experts on each worker.
world_size: number of workers that hold different experts.
comm: the communicator of all workers in the expert-parallel group.
"""
pos, local_expert_count, global_expert_count = count_by_gate(gate,
num_expert, world_size)
with torch.no_grad():
fwd_expert_count = global_expert_count.view(world_size,
num_expert).sum(dim=0)
fwd_batch_size = int(fwd_expert_count.sum().item())
return (
pos,
local_expert_count.cpu(),
global_expert_count.cpu(),
fwd_expert_count.cpu(),
fwd_batch_size,
)
def _local_scatter(inp, pos):
inp_buf = torch.index_select(inp, 0, pos)
return inp_buf
def _local_gather(inp, pos, out_batch_size, maybe_overlap=True):
inp_buf = torch.zeros(out_batch_size, inp.shape[-1],
dtype=inp.dtype, device=inp.device)
if maybe_overlap:
inp_buf.index_add_(0, pos, inp)
else:
inp_buf.index_copy_(0, pos, inp)
return inp_buf
class MOEScatter(Function):
r"""
Scatter input samples from [batch x sequences] to contiguous alone experts.
If `world_size` is greater than 1, the samples will first be locally
scattered, and then exchanged across workers.
"""
@staticmethod
def forward(
ctx,
inp,
pos,
local_expert_count,
global_expert_count,
fwd_batch_size,
world_size,
):
local_input_buf = _local_scatter(inp, pos)
if world_size > 1:
global_input_buf = fmoe_cuda.global_scatter(
local_input_buf,
local_expert_count,
global_expert_count,
fwd_batch_size,
world_size,
)
else:
global_input_buf = local_input_buf
ctx.moe_args = inp.shape[0], pos.shape[0], world_size
variables = (pos, local_expert_count, global_expert_count)
ctx.save_for_backward(*variables)
return global_input_buf
@staticmethod
def backward(ctx, global_grad_in):
(pos, local_expert_count, global_expert_count) = ctx.saved_tensors
(inp_batch_size, buf_batch_size, world_size) = ctx.moe_args
if world_size > 1:
local_grad_in = fmoe_cuda.global_gather(
global_grad_in,
local_expert_count,
global_expert_count,
buf_batch_size,
world_size,
)
else:
local_grad_in = global_grad_in
grad_in = _local_gather(local_grad_in, pos, inp_batch_size)
return grad_in, None, None, None, None, None
class MOEGather(Function):
r"""
Gather output samples from contiguous alone experts back to [batch x
sequences]. Works symmetrically with MOEScatter.
"""
@staticmethod
def forward(
ctx,
global_output_buf,
pos,
local_expert_count,
global_expert_count,
local_batch_size,
world_size,
):
if world_size > 1:
local_output_buf = fmoe_cuda.global_gather(
global_output_buf,
local_expert_count,
global_expert_count,
pos.shape[0],
world_size,
)
else:
local_output_buf = global_output_buf
output = _local_gather(local_output_buf, pos, local_batch_size,
maybe_overlap=False)
ctx.moe_args = (global_output_buf.shape[0], world_size)
variables = (pos, local_expert_count, global_expert_count)
ctx.save_for_backward(*variables)
return output
@staticmethod
def backward(ctx, grad_out):
pos, local_expert_count, global_expert_count = ctx.saved_tensors
fwd_batch_size, world_size = ctx.moe_args
grad_out_buf = _local_scatter(grad_out.contiguous(), pos)
if world_size > 1:
global_grad_out_buf = fmoe_cuda.global_scatter(
grad_out_buf,
local_expert_count,
global_expert_count,
fwd_batch_size,
world_size,
)
else:
global_grad_out_buf = grad_out_buf
return global_grad_out_buf, None, None, None, None, None
class AllGather(Function):
r"""
A wrapper for the All-Gather function to support auto-differentiation.
"""
@staticmethod
def forward(ctx, inp, rank, world_size, group):
tensor_list = [torch.empty_like(inp) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, inp, group=group)
torch.cuda.synchronize()
output = torch.cat(tensor_list, dim=0)
ctx.args = rank, inp.shape[0]
return output
@staticmethod
def backward(ctx, grad_out):
rank, dim0 = ctx.args
return grad_out[rank * dim0 : (rank + 1) * dim0], None, None, None
class Slice(Function):
r"""
A wrapper for the Slice function to support auto-differentiation.
"""
@staticmethod
def forward(ctx, inp, rank, world_size, group):
B: int = inp.shape[0]
local_batch_size = B // world_size
batch_start = local_batch_size * rank
batch_end = min(batch_start + local_batch_size, B)
inp = inp[batch_start:batch_end]
ctx.args = world_size, group
return inp
@staticmethod
def backward(ctx, grad_out):
world_size, group = ctx.args
tensor_list = [torch.empty_like(grad_out) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, grad_out, group=group)
torch.cuda.synchronize()
grad_out = torch.cat(tensor_list, dim=0)
return grad_out, None, None, None
r"""
Different implementations of the Gate are located in separate files here.
"""
from .zero_gate import ZeroGate
from .naive_gate import NaiveGate
from .noisy_gate import NoisyGate
from .gshard_gate import GShardGate
from .switch_gate import SwitchGate
from .swipe_gate import SwipeGate
r"""
Base gate with standard interface
"""
import torch.nn as nn
class BaseGate(nn.Module):
def __init__(self, num_expert, world_size):
super().__init__()
self.world_size = world_size
self.num_expert = num_expert
self.tot_expert = world_size * num_expert
self.loss = None
def forward(self, x):
raise NotImplementedError('Base gate cannot be directly used for fwd')
def set_loss(self, loss):
self.loss = loss
def get_loss(self, clear=True):
loss = self.loss
if clear:
self.loss = None
return loss
@property
def has_loss(self):
return self.loss is not None
r"""
Balanced gate with GShard's policy (Google, 2020)
"""
import math
import torch
import torch.nn.functional as F
from .naive_gate import NaiveGate
from .utils import limit_by_capacity
class GShardGate(NaiveGate):
def __init__(self, d_model, num_expert, world_size,
topk=2, capacity=(1.2, 2.4), random_routing=True):
assert topk == 2, 'topk should be 2 in gshard'
super().__init__(d_model, num_expert, world_size, top_k=2)
self.capacity = capacity
self.random_routing = True
def forward(self, x):
naive_outs = super().forward(x, return_all_scores=True)
topk_idx, topk_val, gate_score = naive_outs
S = gate_score.shape[0]
top_k = topk_idx.shape[0] // gate_score.shape[0]
top1_idx = topk_idx.view((-1, top_k))[:, 0]
c_e = torch.scatter_add(
torch.zeros(self.tot_expert, device=top1_idx.device),
0,
top1_idx,
torch.ones_like(top1_idx, dtype=torch.float),
) / S
m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0)
loss = torch.mean(c_e * m_e) * (self.num_expert ** 2)
self.set_loss(loss)
cap_rate = self.capacity[0 if self.training else 1]
capacity = math.ceil(cap_rate * x.shape[0])
_new_lec, _new_gec, topk_idx = limit_by_capacity(
topk_idx, self.num_expert, self.world_size, capacity)
if self.random_routing:
rand_routing_prob = torch.rand(gate_score.size(0), device=x.device)
mask = (2 * topk_val[:, 1] < rand_routing_prob)
topk_idx[:, 1].masked_fill_(mask, -1)
return topk_idx, topk_val
r"""
Naive gate
"""
from .base_gate import BaseGate
import torch
import torch.nn as nn
import torch.nn.functional as F
class NaiveGate(BaseGate):
r"""
A naive gate implementation that defines the standard behavior of the gate
which determines which experts the tokens are going to.
Both the indicies and the score, or confidence, are output to the parent
module.
The load-balance strategies are also designed to be implemented within the
`Gate` module.
"""
def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__(num_expert, world_size)
self.gate = nn.Linear(d_model, self.tot_expert)
self.top_k = top_k
def forward(self, inp, return_all_scores=False):
r"""
The naive implementation simply calculates the top-k of a linear layer's
output.
"""
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(
gate, k=self.top_k, dim=-1, largest=True, sorted=False
) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
# (BxL) x 1 x top_k
gate_score = F.softmax(gate_top_k_val, dim=-1)
if return_all_scores:
return gate_top_k_idx, gate_top_k_val, gate
return gate_top_k_idx, gate_top_k_val
r"""
Noisy gate for gshard and switch
"""
from .base_gate import BaseGate
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
import math
class NoisyGate(BaseGate):
def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__(num_expert, world_size)
self.w_gate = nn.Parameter(
torch.zeros(d_model, self.tot_expert), requires_grad=True
)
self.w_noise = nn.Parameter(
torch.zeros(d_model, self.tot_expert), requires_grad=True
)
self.top_k = top_k
self.softplus = nn.Softplus()
self.softmax = nn.Softmax(1)
self.noise_epsilon = 1e-2
self.reset_parameters()
def reset_parameters(self):
# Approach is the same as in torch.nn.Linear
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
torch.nn.init.kaiming_uniform_(self.w_gate, a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.w_noise, a=math.sqrt(5))
def _gates_to_load(self, gates):
"""Compute the true load per expert, given the gates.
The load is the number of examples for which the corresponding gate is >0.
Args:
gates: a `Tensor` of shape [batch_size, n]
Returns:
a float32 `Tensor` of shape [n]
"""
return (gates > 0).sum(0)
def _prob_in_top_k(
self, clean_values, noisy_values, noise_stddev, noisy_top_values
):
"""Helper function to NoisyTopKGating.
Computes the probability that value is in top k, given different random noise.
This gives us a way of backpropagating from a loss that balances the number
of times each expert is in the top k experts per example.
In the case of no noise, pass in None for noise_stddev, and the result will
not be differentiable.
Args:
clean_values: a `Tensor` of shape [batch, n].
noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus
normally distributed noise with standard deviation noise_stddev.
noise_stddev: a `Tensor` of shape [batch, n], or None
noisy_top_values: a `Tensor` of shape [batch, m].
"values" Output of tf.top_k(noisy_top_values, m). m >= k+1
Returns:
a `Tensor` of shape [batch, n].
"""
batch = clean_values.size(0)
m = noisy_top_values.size(1)
top_values_flat = noisy_top_values.flatten()
threshold_positions_if_in = (
torch.arange(batch, device=clean_values.device) * m + self.top_k
)
threshold_if_in = torch.unsqueeze(
torch.gather(top_values_flat, 0, threshold_positions_if_in), 1
)
is_in = torch.gt(noisy_values, threshold_if_in)
threshold_positions_if_out = threshold_positions_if_in - 1
threshold_if_out = torch.unsqueeze(
torch.gather(top_values_flat, 0, threshold_positions_if_out), 1
)
# is each value currently in the top k.
normal = Normal(
torch.tensor([0.0], device=clean_values.device),
torch.tensor([1.0], device=clean_values.device),
)
prob_if_in = normal.cdf((clean_values - threshold_if_in) / noise_stddev)
prob_if_out = normal.cdf((clean_values - threshold_if_out) / noise_stddev)
prob = torch.where(is_in, prob_if_in, prob_if_out)
return prob
def cv_squared(self, x):
"""The squared coefficient of variation of a sample.
Useful as a loss to encourage a positive distribution to be more uniform.
Epsilons added for numerical stability.
Returns 0 for an empty Tensor.
Args:
x: a `Tensor`.
Returns:
a `Scalar`.
"""
eps = 1e-10
# if only num_expert = 1
if x.shape[0] == 1:
return torch.Tensor([0])
return x.float().var() / (x.float().mean() ** 2 + eps)
def forward(self, inp):
clean_logits = inp @ self.w_gate
raw_noise_stddev = inp @ self.w_noise
noise_stddev = (
self.softplus(raw_noise_stddev) + self.noise_epsilon
) * self.training
noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev)
logits = noisy_logits
# calculate topk + 1 that will be needed for the noisy gates
top_logits, top_indices = logits.topk(
min(self.top_k + 1, self.tot_expert), dim=1
)
top_k_logits = top_logits[:, : self.top_k]
top_k_indices = top_indices[:, : self.top_k]
top_k_gates = self.softmax(top_k_logits)
zeros = torch.zeros_like(logits, requires_grad=True)
gates = zeros.scatter(1, top_k_indices, top_k_gates)
if self.top_k < self.tot_expert:
load = (
self._prob_in_top_k(
clean_logits, noisy_logits, noise_stddev, top_logits
)
).sum(0)
else:
load = self._gates_to_load(gates)
importance = gates.sum(0)
loss = self.cv_squared(importance) + self.cv_squared(load)
self.set_loss(loss)
return (
top_k_indices.contiguous().view(-1),
top_k_gates.contiguous().unsqueeze(1),
)
r"""
Balanced gate using SWIPE algorithm
"""
import math
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from .naive_gate import NaiveGate
from fmoe.functions import count_by_gate
import fmoe_cuda as fmoe_native
class SwipeGate(NaiveGate):
def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__(d_model, num_expert, world_size, top_k)
def swipe_once(self, idx, capacity, bias):
with torch.no_grad():
idx_new, capacity = fmoe_native.swipe_once(idx, capacity,
self.num_expert, self.world_size, bias)
idx_new = idx_new.to(idx.device)
return idx_new, capacity
def forward(self, inp):
score = self.gate(inp)
orig_score, orig_idx = torch.topk(score, k=self.top_k, dim=-1)
if not self.training:
topk_val = F.softmax(orig_score, dim=-1)
return orig_idx, topk_val
capacity = torch.scalar_tensor(inp.shape[0] * self.top_k,
dtype=torch.long)
topk_idxs = []
topk_vals = []
idx_x = torch.arange(inp.shape[0], device=inp.device)
for k in range(self.top_k):
idx, capacity = self.swipe_once(orig_idx[:, k], capacity,
k % self.num_expert)
topk_vals.append(score[idx_x, idx])
topk_idxs.append(idx)
topk_idx = torch.stack(topk_idxs).transpose(0, 1)
topk_val = torch.stack(topk_vals).transpose(0, 1)
topk_val = F.softmax(topk_val, dim=-1)
return topk_idx, topk_val
r"""
Balanced gate with Switch Transformer's policy (Google, 2021)
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .naive_gate import NaiveGate
from .utils import limit_by_capacity
class SwitchGate(NaiveGate):
r"""
A switch gate implementation
"""
def __init__(self, d_model, num_expert, world_size, topk=1,
switch_eps=.1, capacity=(1.2, 2.4)):
assert topk == 1, 'topk should be 1 in switch'
super().__init__(d_model, num_expert, world_size, top_k=1)
self.switch_eps = switch_eps
self.capacity = capacity
def forward(self, inp):
r"""
The switch firstly conduct softmax and then calculates the top-1
"""
score = self.gate(inp)
if self.training:
# random uniform number from [1-eps, 1+eps]
noise = torch.rand_like(score)
noise = noise * 2 * self.switch_eps + 1.0 - self.switch_eps
score += noise
# fp32 softmax for numerical stability
score = F.softmax(score.float(), dim=-1)
top1_score, top1_idx = torch.topk(
score, k=1, dim=-1, largest=True
) # [.. x top_k]
top1_score = top1_score.to(dtype=inp.dtype)
cap_rate = self.capacity[0 if self.training else 1]
capacity = math.ceil(cap_rate * inp.shape[0])
_new_lec, _new_gec, top1_idx = limit_by_capacity(
top1_idx, self.num_expert, self.world_size, capacity)
valid_idx = top1_idx[top1_idx > -1]
fraction_expert = torch.scatter_add(
torch.zeros(self.tot_expert, device=valid_idx.device),
0,
valid_idx,
torch.ones_like(valid_idx, dtype=torch.float),
) / valid_idx.numel()
prob_expert = score.sum(dim=0) / valid_idx.numel()
loss = (fraction_expert * prob_expert).sum() * self.tot_expert
self.set_loss(loss)
return top1_idx, top1_score
r"""
Utilities that may be used in the gates
"""
import torch
from fmoe.functions import count_by_gate
import fmoe_cuda as fmoe_native
def limit_by_capacity(topk_idx, num_expert, world_size, capacity):
with torch.no_grad():
capacity = torch.ones(num_expert, dtype=torch.int32,
device=topk_idx.device) * capacity
pos, lec, gec = count_by_gate(topk_idx, num_expert, world_size,
require_pos=False)
new_gec = fmoe_native.limit_by_capacity(gec, capacity,
num_expert, world_size)
if world_size > 1:
new_lec = fmoe_native.expert_exchange(new_gec, num_expert,
world_size)
else:
new_lec = new_gec
topk_idx = fmoe_native.prune_gate_by_capacity(topk_idx,
new_lec.to(torch.int32), num_expert, world_size)
return new_lec, new_gec, topk_idx
r"""
Zero gate that direct all input to gate 0
"""
from .base_gate import BaseGate
import torch
import torch.nn as nn
import torch.nn.functional as F
class ZeroGate(BaseGate):
r"""
Guide all input samples to gate 0.
"""
def __init__(self, _1, num_expert, world_size, top_k=2):
super().__init__(num_expert, world_size)
self.top_k = top_k
def forward(self, inp):
r"""
All output to expert 1
"""
idx = torch.zeros(
inp.shape[0] * self.top_k, dtype=torch.int64, device=inp.device
)
gate_score = (
torch.ones(inp.shape[0] * self.top_k, device=inp.device) / self.top_k
)
return idx, gate_score.reshape(-1, 1, self.top_k)
r"""
FMoE core layer
"""
import torch
import torch.nn as nn
from .functions import prepare_forward, ensure_comm
from .functions import MOEScatter, MOEGather
from .functions import AllGather, Slice
from .gates import NaiveGate
def mark_module_parallel_comm(module, comm):
r"""
Mark all parameters in `module` as doing data parallel in `comm`, where
`comm` may be one of `'world', 'dp', 'none'`.
"""
for p in module.parameters():
setattr(p, "dp_comm", comm)
def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
r"""
A private function that performs the following steps to complete the MoE
computation.
* Count the number of tokens from each worker to each expert.
* Send the features to their target position so that input features to each
expert are contiguous in memory.
* Perform the forward computation of the experts using `expert_fn`
* Gather the output features of experts back, and reorder them as sentences.
Intermediate results like expert counts are hidden from users by this
function.
"""
(
pos,
local_expert_count,
global_expert_count,
fwd_expert_count,
fwd_batch_size,
) = prepare_forward(gate, num_expert, world_size)
topk = 1
if len(gate.shape) == 2:
topk = gate.shape[1]
x = MOEScatter.apply(
inp, pos // topk,
local_expert_count, global_expert_count, fwd_batch_size, world_size
)
x = expert_fn(x, fwd_expert_count)
out_batch_size = inp.shape[0]
if len(gate.shape) == 2:
out_batch_size *= gate.shape[1]
x = MOEGather.apply(
x, pos,
local_expert_count, global_expert_count,
out_batch_size, world_size
)
return x
class FMoE(nn.Module):
r"""
A general moe implementation that supports an arbitrary module as the
expert.
* `num_expert` stands for the number of experts on **each** worker.
* `world_size` stands for the total number of workers that contains
different experts.
* `slice_group` can be a torch's communication group, indicating that
specific model parallel is applied across the group, and workers in the
group hold the same copy of input feature, and requires the same copy of
the output. For each worker, FMoE only computes the output of a certain
slice of the input batch, and will all-gather the outputs after
computation.
* `top_k` stands for the number of experts each token is going to.
* `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate
`num_expert` expert modules.
"""
def __init__(
self,
num_expert=32,
d_model=1024,
world_size=1,
mp_group=None, # being deprecated
slice_group=None,
moe_group=None,
top_k=2,
gate=NaiveGate,
expert=None,
gate_hook=None,
mask=None,
mask_dict=None,
):
super().__init__()
self.num_expert = num_expert
self.d_model = d_model
self.world_size = world_size
self.slice_group = slice_group
if mp_group is not None:
print('[Warning] mp_group is being deprecated')
self.slice_group = mp_group
if self.slice_group is None:
self.slice_size = 1
self.slice_rank = 0
else:
self.slice_size = self.slice_group.size()
self.slice_rank = self.slice_group.rank()
self.top_k = top_k
if type(expert) is list:
self.experts = nn.ModuleList([e(d_model) for e in expert])
self.experts_fused = False
self.num_expert = num_expert = len(expert)
elif expert is not None:
self.experts = nn.ModuleList([expert(d_model)
for _ in range(num_expert)])
self.experts_fused = False
else:
self.experts_fused = True
self.gate = gate(d_model, num_expert, world_size, top_k)
self.gate_hook = gate_hook
self.mask = mask
self.mask_dict = mask_dict
self.moe_group = moe_group
def expert_fn(self, inp, fwd_expert_count):
r"""
The default expert function which either calls the experts as a whole
or as separate experts.
"""
if self.experts_fused:
return self.experts(inp, fwd_expert_count)
outputs = []
base_idx = 0
for i in range(self.num_expert):
batch_size = fwd_expert_count[i].item()
inp_slice = inp[base_idx : base_idx + batch_size]
outputs.append(self.experts[i](inp_slice))
base_idx += batch_size
return torch.cat(outputs, dim=0)
def mark_parallel_comm(self, expert_dp_comm="none"):
r"""
Automatically mark the data parallel comms of the parameters within the
module. This can be typically called at the end of the __init__ function
in child classes.
"""
if self.experts is not None:
comm = expert_dp_comm
if isinstance(self.experts, list):
for e in self.experts:
mark_module_parallel_comm(e, comm)
else:
mark_module_parallel_comm(self.experts, comm)
mark_module_parallel_comm(self.gate, "gate")
def forward(self, inp):
r"""
The FMoE module first computes gate output, and then conduct MoE forward
according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight.
"""
if self.world_size > 1:
ensure_comm(inp, self.moe_group)
if self.slice_size > 1:
inp = Slice.apply(inp, self.slice_rank,
self.slice_size, self.slice_group)
gate_top_k_idx, gate_score = self.gate(inp)
if self.gate_hook is not None:
self.gate_hook(gate_top_k_idx, gate_score, None)
# delete masked tensors
if self.mask is not None and self.mask_dict is not None:
mask = self.mask.view(-1)
# to: (BxL') x d_model
inp = inp[mask == 0, :]
gate_top_k_idx = gate_top_k_idx[mask == 0, :]
fwd = _fmoe_general_global_forward(
inp, gate_top_k_idx,
self.expert_fn, self.num_expert, self.world_size
)
# recover deleted tensors
if self.mask is not None and self.mask_dict is not None:
# to: (BxL') x top_k x d_model
fwd = fwd.view(-1, self.top_k, self.d_model)
# to: (BxL) x top_k x d_model
x = torch.zeros(mask.shape[0], self.top_k, self.d_model, device=fwd.device, dtype=fwd.dtype)
# recover
x[mask == 0] = fwd
for k, v in self.mask_dict.items():
x[mask == k] = v
else:
x = fwd.view(-1, self.top_k, self.d_model)
gate_score = gate_score.view(x.shape[0], 1, self.top_k)
x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
if self.slice_size > 1:
x = AllGather.apply(x, self.slice_rank,
self.slice_size, self.slice_group)
return x
r"""
FMoE's parallel linear layer
"""
import torch
import torch.nn as nn
from torch.autograd import Function
import math
import fmoe_cuda
class MOELinear(Function):
r"""
Computes linear operators within one GPU on different experts simutaneously.
"""
@staticmethod
def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
global_output_buf = fmoe_cuda.linear_forward(
global_input_buf, fwd_expert_count, weight, bias
)
variables = (global_input_buf, fwd_expert_count, weight, bias)
ctx.save_for_backward(*variables)
return global_output_buf
@staticmethod
def backward(ctx, grad_out):
(input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors
grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.linear_backward(
grad_out, input_buf, fwd_expert_count, weight, bias
)
if not torch.is_tensor(bias):
grad_bias = None
return grad_inp_buf, None, grad_weight, grad_bias
class FMoELinear(nn.Module):
r"""
A linear layer that contains multiple experts.
As multiple experts can be placed on the same worker, the computation can be
performed in parallel to increase the performance.
The FMoELinear module provides such function.
"""
def __init__(
self,
num_expert: int,
in_feat: int,
out_feat: int,
bias: bool = True,
rank: int = 0,
):
super().__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.rank = rank
self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
if bias:
self.bias = nn.Parameter(torch.zeros(num_expert, out_feat))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def forward(self, inp, fwd_expert_count):
r"""
Call MOE function
"""
x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias)
return x
def extra_repr(self) -> str:
return "num_expert={}, in_features={}, \
out_features={}, bias={}, rank={}".format(
self.num_expert,
self.in_feat,
self.out_feat,
self.bias is not None,
self.rank,
)
def reset_parameters(self):
# Approach is the same as in torch.nn.Linear
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
# bias is left to zero, similar as megatron
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
Part of our code in megatron.py is copied from NVIDIA's Megatron-LM
codebase with modification.
------------- LICENSE FOR NVIDIA Megatron-LM --------------
The following applies to all files unless otherwise noted:
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--
This repository also contains code from Hugging Face Inc., Google Research,
and Facebook (from their Fairseq project). Files from these
organizations have notices at the top of each file. Below are licenses
used in those files, as indicated.
------------- LICENSE FOR huggingface and Google Research code --------------
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
------------- LICENSE FOR Facebook Fairseq code --------------
MIT License
Copyright (c) Facebook, Inc. and its affiliates.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
r"""
A set of modules to plugin into Megatron-LM with FastMoE
"""
from .utils import add_fmoe_args
from .layers import MegatronMLP
from .layers import fmoefy
from .checkpoint import save_checkpoint
from .checkpoint import load_checkpoint
from .distributed import DistributedDataParallel
from .balance import reset_gate_hook
from .balance import get_balance_profile
from .balance import generate_megatron_gate_hook
from .balance import add_balance_log
from .patch import patch_forward_step
from .patch import patch_model_provider
r"""
Support for monitoring loss in Megatron
"""
import torch
from fmoe.balance import reset_balance_profile
from fmoe.balance import update_balance_profile
from fmoe.utils import get_torch_default_comm
balance_dict = {}
num_layers = 0
def reset_gate_hook(_num_layers=None):
from megatron import get_args
global balance_dict, num_layers
if _num_layers is not None:
num_layers = _num_layers
reset_balance_profile(balance_dict, num_layers, get_args().balance_strategy)
def get_balance_profile():
global balance_dict
return balance_dict
def generate_megatron_gate_hook(layer_idx, num_expert_global):
from megatron import get_args
balance_strategy = get_args().balance_strategy
def megatron_gate_hook(gate_top_k_idx, gate_score_top_k, gate_context):
global balance_dict
update_balance_profile(
balance_dict,
gate_top_k_idx,
gate_score_top_k,
gate_context,
layer_idx,
num_expert_global,
balance_strategy,
)
return megatron_gate_hook
def add_balance_log(model, writer, iteration):
from megatron import is_last_rank
while hasattr(model, 'module'):
model = model.module
losses = [l.mlp.gate.get_loss(clear=True)
for l in model.language_model.transformer.layers
if l.mlp.gate.has_loss]
if len(losses) == 0:
return
balance_dict_tensor = torch.vstack(losses).detach()
world_group = get_torch_default_comm()
world_size = torch.distributed.get_world_size(group=world_group)
torch.distributed.all_reduce(balance_dict_tensor, group=world_group)
balance_dict_tensor /= world_size
if writer and is_last_rank():
for idx, metric_name in enumerate(balance_dict):
for layer_id, val in enumerate(balance_dict_tensor[idx]):
writer.add_scalar(
f"balance-{metric_name}/layer-{layer_id}", val.item(), iteration
)
writer.add_scalar(
f"balance-{metric_name}/all",
balance_dict_tensor[idx].mean().item(),
iteration,
)
r"""
Support for Megatron to enable saving parameters of different experts on
different ranks.
"""
import os
import sys
import random
from collections import OrderedDict
import numpy as np
import torch
def get_fmoe_checkpoint_name(
checkpoints_path, iteration, release=False, data_parallel_rank=-1
):
"""A unified checkpoint name, allowing specifying a data parallel rank"""
from megatron import mpu
from megatron.checkpointing import get_checkpoint_name
if data_parallel_rank == -1:
data_parallel_rank = mpu.get_data_parallel_rank()
if data_parallel_rank == 0:
return get_checkpoint_name(checkpoints_path, iteration, release)
if release:
directory = "release"
else:
directory = "iter_{:07d}".format(iteration)
# Use both the tensor and pipeline MP rank.
if mpu.get_pipeline_model_parallel_world_size() == 1:
return os.path.join(
checkpoints_path,
directory,
"mp_rank_{:02d}_dp_rank_{:04d}".format(
mpu.get_tensor_model_parallel_rank(), data_parallel_rank
),
"model_optim_rng.pt",
)
return os.path.join(
checkpoints_path,
directory,
"mp_rank_{:02d}_{:03d}_dp_rank_{:04d}".format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
data_parallel_rank,
),
"model_optim_rng.pt",
)
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint with expert parallel """
# TODO: update patch
from megatron import get_args
from megatron import mpu
from megatron import print_rank_last
expert_dp_comm = "none"
if mpu.get_data_parallel_rank() == 0:
# at dp rank 0, we still follows the native load_checkpoint by megatron
from megatron.checkpointing import save_checkpoint as save_checkpoint_native
save_checkpoint_native(iteration, model, optimizer, lr_scheduler)
return
args = get_args()
# Only rank zero of the data parallel writes to the disk.
if hasattr(model, 'module'):
model = model.module
print_rank_last(
"saving checkpoint at iteration {:7d} to {}".format(iteration, args.save)
)
# Arguments, iteration, and model.
state_dict = {}
state_dict["model"] = model.state_dict_for_save_checkpoint(
keep_vars=(mpu.get_data_parallel_rank() > 0)
)
def extract_expert_param(state_dict, expert_dp_comm="none"):
state_dict_new = state_dict.__class__()
for k, v in state_dict.items():
# megatron uses both dict and OrderedDict in its state_dict
if isinstance(v, (OrderedDict, dict)):
v_new = extract_expert_param(v, expert_dp_comm)
if len(v_new) > 0:
state_dict_new[k] = v_new
elif hasattr(v, "dp_comm") and v.dp_comm == expert_dp_comm:
state_dict_new[k] = v.detach()
return state_dict_new
state_dict["model"] = extract_expert_param(state_dict["model"], expert_dp_comm)
# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
state_dict["optimizer"] = optimizer.state_dict()
param_global_idx = 0
for param_group in optimizer.optimizer.param_groups:
for param in param_group["params"]:
if not (
hasattr(param, "dp_comm") and param.dp_comm == expert_dp_comm
):
# this parameter is not an expert parameter
# thus there is no need to save its state in current rank
# since it has been saved by data parallel rank 0
if args.fp16:
# fp16 optimizer may have empty state due to overflow
state_dict["optimizer"]["optimizer"]["state"].pop(
param_global_idx, None
)
else:
state_dict["optimizer"]["state"].pop(param_global_idx)
param_global_idx += 1
if args.fp16:
state_dict["optimizer"]["optimizer"].pop("param_groups")
# fp32_from_fp16_params in state_dict is not a copy
# but a reference to optimizer.fp32_from_fp16_params,
# changing it in state_dict will change
# optimizer.fp32_from_fp16_params as well
# thus we create an empty fp32_from_fp16_params in state_dict
# and only insert expert parameters.
fp32_from_fp16_params = state_dict["optimizer"]["fp32_from_fp16_params"]
state_dict["optimizer"]["fp32_from_fp16_params"] = []
for param_group in fp32_from_fp16_params:
param_group_copy = []
for param in param_group:
param_copy = (
param
if hasattr(param, "dp_comm")
and param.dp_comm == expert_dp_comm
else None
)
param_group_copy.append(param_copy)
state_dict["optimizer"]["fp32_from_fp16_params"].append(
param_group_copy
)
else:
state_dict["optimizer"].pop("param_groups")
# Save.
checkpoint_name = get_fmoe_checkpoint_name(args.save, iteration)
from megatron.checkpointing import ensure_directory_exists
from megatron.checkpointing import get_checkpoint_tracker_filename
ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name)
# Wait so everyone is done (necessary)
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(
" successfully saved checkpoint at iteration {:7d} to {}".format(
iteration, args.save
),
flush=True,
)
# And update the latest iteration
if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, "w") as f:
f.write(str(iteration))
# Wait so everyone is done (not necessary)
torch.distributed.barrier()
def merge_state_dict(state_dict_rank0, state_dict_local, fp16):
"""merge two state dicts, one from data parallel rank 0,
another only contains expert states"""
# from megatron import print_rank_last
def merge_model(state_dict_rank0, state_dict_local):
for k, v in state_dict_local.items():
# megatron uses both dict and OrderedDict in its state_dict
if isinstance(v, (OrderedDict, dict)):
merge_model(state_dict_rank0[k], v)
else:
state_dict_rank0[k] = v
merge_model(state_dict_rank0["model"], state_dict_local["model"])
optimizer_rank0 = (
state_dict_rank0["optimizer"]["optimizer"]
if fp16
else state_dict_rank0["optimizer"]
)
optimizer_local = (
state_dict_local["optimizer"]["optimizer"]
if fp16
else state_dict_local["optimizer"]
)
for k, v in optimizer_local["state"].items():
optimizer_rank0["state"][k] = v
if fp16:
for group_idx, param_group in enumerate(
state_dict_local["optimizer"]["fp32_from_fp16_params"]
):
for param_in_group_idx, param in enumerate(param_group):
if param is not None:
state_dict_rank0["optimizer"]["fp32_from_fp16_params"][group_idx][
param_in_group_idx
] = param
return state_dict_rank0
def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
"""Load a model checkpoint and return the iteration."""
from megatron import get_args
from megatron import mpu
from megatron import print_rank_last
from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.checkpointing import set_checkpoint_version
from megatron.checkpointing import check_checkpoint_args
from megatron.checkpointing import update_num_microbatches
if mpu.get_data_parallel_rank() == 0:
# at dp rank 0, we still follow the native load_checkpoint by megatron
from megatron.checkpointing import load_checkpoint as load_checkpoint_native
return load_checkpoint_native(model, optimizer, lr_scheduler, load_arg)
args = get_args()
load_dir = getattr(args, load_arg)
if hasattr(model, 'module'):
model = model.module
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir)
# If no tracker file, return iretation zero.
if not os.path.isfile(tracker_filename):
print_rank_last(
"WARNING: could not find the metadata file {} ".format(tracker_filename)
)
print_rank_last(
" will not load any checkpoints and will start from " "random"
)
return 0
# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration = 0
release = False
with open(tracker_filename, "r") as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == "release"
if not release:
print_rank_last(
"ERROR: Invalid metadata file {}. Exiting".format(tracker_filename)
)
sys.exit()
assert iteration > 0 or release, "error parsing metadata file {}".format(
tracker_filename
)
# Checkpoint.
checkpoint_name_rank0 = get_fmoe_checkpoint_name(load_dir, iteration, release, 0)
checkpoint_name_local = get_fmoe_checkpoint_name(
load_dir, iteration, release, mpu.get_data_parallel_rank()
)
print_rank_last(
" loading checkpoint at rank 0 from {} and rank {} from {} at iteration {}, will merge them later".format(
checkpoint_name_rank0,
mpu.get_data_parallel_rank(),
checkpoint_name_local,
iteration,
)
)
# Load the checkpoint.
def load_state_dict(checkpoint_name):
try:
state_dict = torch.load(checkpoint_name, map_location="cpu")
except ModuleNotFoundError:
from megatron.fp16_deprecated import loss_scaler
# For backward compatibility.
print_rank_last(" > deserializing using the old code structure ...")
sys.modules["fp16.loss_scaler"] = sys.modules[
"megatron.fp16_deprecated.loss_scaler"
]
sys.modules["megatron.fp16.loss_scaler"] = sys.modules[
"megatron.fp16_deprecated.loss_scaler"
]
state_dict = torch.load(checkpoint_name, map_location="cpu")
sys.modules.pop("fp16.loss_scaler", None)
sys.modules.pop("megatron.fp16.loss_scaler", None)
return state_dict
state_dict_rank0 = load_state_dict(checkpoint_name_rank0)
state_dict_local = load_state_dict(checkpoint_name_local)
state_dict = merge_state_dict(state_dict_rank0, state_dict_local, args.fp16)
# set checkpoint version
set_checkpoint_version(state_dict.get("checkpoint_version", 0))
# Set iteration.
if args.finetune or release:
iteration = 0
else:
try:
iteration = state_dict["iteration"]
except KeyError:
try: # Backward compatible with older checkpoints
iteration = state_dict["total_iters"]
except KeyError:
print_rank_last(
"A metadata file exists but unable to load "
"iteration from checkpoint {}, exiting".format(
checkpoint_name_local
)
)
sys.exit()
# Check arguments.
assert args.consumed_train_samples == 0
assert args.consumed_valid_samples == 0
if "args" in state_dict:
checkpoint_args = state_dict["args"]
check_checkpoint_args(checkpoint_args)
args.consumed_train_samples = getattr(
checkpoint_args, "consumed_train_samples", 0
)
update_num_microbatches(consumed_samples=args.consumed_train_samples)
args.consumed_valid_samples = getattr(
checkpoint_args, "consumed_valid_samples", 0
)
else:
print_rank_last("could not find arguments in the checkpoint ...")
# Model.
model.load_state_dict(state_dict["model"])
# Optimizer.
if not release and not args.finetune and not args.no_load_optim:
try:
if optimizer is not None:
optimizer.load_state_dict(state_dict["optimizer"])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
except KeyError:
print_rank_last(
"Unable to load optimizer from checkpoint {}. "
"Specify --no-load-optim or --finetune to prevent "
"attempting to load the optimizer state, "
"exiting ...".format(checkpoint_name_local)
)
sys.exit()
# rng states.
if not release and not args.finetune and not args.no_load_rng:
try:
random.setstate(state_dict["random_rng_state"])
np.random.set_state(state_dict["np_rng_state"])
torch.set_rng_state(state_dict["torch_rng_state"])
torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"])
except KeyError:
print_rank_last(
"Unable to load optimizer from checkpoint {}. "
"Specify --no-load-rng or --finetune to prevent "
"attempting to load the optimizer state, "
"exiting ...".format(checkpoint_name_local)
)
sys.exit()
torch.distributed.barrier()
print_rank_last(
" successfully loaded checkpoint (with expert parametes updated) from {} at iteration {}".format(
args.load, iteration
)
)
return iteration
r"""
distributed support for Megatron
"""
import torch
from fmoe.distributed import DistributedGroupedDataParallel
_groups = None
def _set_groups(**kwargs):
global _groups
_groups = kwargs
def _init():
from megatron import get_args
from megatron import mpu
args = get_args()
# Create a comm prependicular to the pipeline group as gate group
stage_size = args.world_size // args.pipeline_model_parallel_size
for i in range(0, args.world_size, stage_size):
ranks = range(i, i + stage_size)
group = torch.distributed.new_group(ranks)
if args.rank in ranks:
gate_group = group
_set_groups(
dp_group=mpu.get_data_parallel_group(),
moe_group=mpu.get_data_parallel_group(),
gate_group=gate_group)
class DistributedDataParallel(DistributedGroupedDataParallel):
r"""
A wrapper that is used to replace the DDP module provided by Megatron, which
is adapted to enable the sophiscated parallel and reduction strategies in
Fast MoE.
"""
def __init__(self, module):
if _groups is None:
_init()
super().__init__(module, **_groups)
def state_dict(self, *args, **kwargs):
r"""
Keep consitency with Megatron
"""
return self.module.state_dict(*args, **kwargs)
def state_dict_for_save_checkpoint(self, *args, **kwargs):
r"""
Keep consitency with Megatron
"""
return self.module.state_dict_for_save_checkpoint(*args, **kwargs)
def load_state_dict(self, *args, **kwargs):
r"""
Keep consitency with Megatron
"""
return self.module.load_state_dict(*args, **kwargs)
r"""
nn modules to replace Megatron's native ones
"""
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fmoe.transformer import FMoETransformerMLP
from .balance import reset_gate_hook
from .balance import generate_megatron_gate_hook
class _FakeMegatronMLP(nn.Module):
r"""
A fake mlp without model parallelism for correctness testing
"""
def __init__(self, args, _):
super().__init__()
self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size)
self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size)
def forward(self, x):
r"""
Directly use GeLU
"""
x = self.fc1(x)
x = F.gelu(x)
x = self.fc2(x)
return x, torch.zeros_like(x)
def _megatron_init_method(self, rng, sigma):
r"""
Init method based on N(0, sigma).
Copied from Megatron-LM
"""
device = self.weight.device
dtype = self.weight.dtype
weight = rng.normal(loc=0.0, scale=sigma, size=tuple(self.weight.size()))
self.weight.data = torch.from_numpy(weight).to(dtype=dtype, device=device)
if self.bias is not None:
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
def _random_init_weight(self, rng):
r"""
Copied from torch.nn.init.kaiming_uniform_
"""
fan = nn.init._calculate_correct_fan(self.weight[0], "fan_in")
gain = nn.init.calculate_gain("leaky_relu", math.sqrt(5))
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std
device = self.weight.device
dtype = self.weight.dtype
weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
self.weight.data = torch.from_numpy(weight).to(dtype=dtype, device=device)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
bound = 1 / math.sqrt(fan_in)
bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
self.bias.data = torch.from_numpy(bias).to(dtype=dtype, device=device)
class MegatronMLP(FMoETransformerMLP):
r"""
Make the FMoETransformerMLP layer that distributes experts across
communication group `group` to replace the original MLP layer in Megatron.
"""
def __init__(self, args, layer_idx, gate=None):
if not args.distributed_experts:
world_size = 1
moe_group = None
else:
world_size = args.data_parallel_size
from megatron.mpu import get_data_parallel_group
moe_group = get_data_parallel_group()
if not args.balance_strategy or args.balance_strategy == "naive":
from fmoe.gates import NaiveGate
gate = NaiveGate
elif args.balance_strategy == "noisy":
from fmoe.gates import NoisyGate
gate = NoisyGate
elif args.balance_strategy == "gshard":
from fmoe.gates import GShardGate
gate = GShardGate
elif args.balance_strategy == "switch":
from fmoe.gates import SwitchGate
gate = SwitchGate
elif args.balance_strategy == "swipe":
from fmoe.gates import SwipeGate
gate = SwipeGate
elif gate is None:
assert False, "Undefined balance strategy {}" % (args.balance_strategy)
super().__init__(
args.num_experts,
top_k=args.top_k,
d_model=args.hidden_size,
d_hidden=args.hidden_hidden_size,
world_size=world_size,
moe_group=moe_group,
expert_dp_comm="none" if args.distributed_experts else "dp",
gate_hook=generate_megatron_gate_hook(
layer_idx, args.num_experts * world_size
),
gate=gate,
)
self.hidden_size = args.hidden_size
if args.distributed_experts:
self.rank = args.rank
else:
self.rank = 0
self.sigma = args.init_method_std
self.num_layers = args.num_layers
self.reset_parameters()
def reset_parameters(self):
r"""
Initialize the weight as linear layers.
As megatron is using fixed random seed for some nasty stuff, an
additional numpy rng is used.
"""
rng = np.random.default_rng(np.random.randint(2048) + self.rank)
_megatron_init_method(self.experts.htoh4, rng, self.sigma)
std = self.sigma / math.sqrt(2.0 * self.num_layers)
_megatron_init_method(self.experts.h4toh, rng, std)
def forward(self, inp):
from megatron import mpu
x = super().forward(inp)
x = mpu.reduce_from_tensor_model_parallel_region(x)
return (
x,
torch.zeros(self.hidden_size, dtype=inp.dtype, device=inp.device),
)
def fmoefy(
model,
num_experts=None,
distributed_experts=True,
hidden_hidden_size=None,
top_k=None,
gate=None,
):
r"""
Replace MLP layers in a transformer-based model in Megatron by MoE.
* `model` should be a standard Megatron model that has
`model.language_model.transformer.layers` as transformer layers, which is an
array of transformer blocks that contain an `mlp` member.
* `distributed_expert` is set to True if different experts are located in
different workers. Otherwise, the experts on the workers are identical, and
they are trained in data-parallel mode. This can be useful when testing on
small models that do not require high training throughput or large parameter
capacity.
Note that pipeline parallel is not supported yet. When distributed experts
are enabled, their communicator should be Megatron's
tensor_model_parall_comm x data_parallel_comm, which is not created.
"""
from megatron import get_args
args = get_args()
# Set distributed_experts to None to use default setting in args
if distributed_experts is not None:
args.distributed_experts = distributed_experts
if num_experts is not None:
args.num_experts = num_experts
assert (
"num_experts" in args
), "num_experts should be specified in arguments or fmoefy function"
if top_k is not None:
args.top_k = top_k
elif not hasattr(args, "top_k"):
args.top_k = 2
args.hidden_hidden_size = hidden_hidden_size
for idx, l in enumerate(model.language_model.transformer.layers):
l.mlp = MegatronMLP(args, idx, gate=gate)
# initialize gate hook
num_layers = len(model.language_model.transformer.layers)
reset_gate_hook(num_layers)
return model
r"""
Patching some of Megatron-LM's functions to create an MoE model
"""
import torch
def patch_forward_step(forward_step_func):
r"""
Patch model's forward_step_func to support balance loss
"""
from megatron.mpu import is_pipeline_last_stage
from megatron.mpu import get_tensor_model_parallel_group
from megatron import get_args
if not get_args().balance_strategy:
return forward_step_func
def forward_step_with_balance_loss(data_iterator, model, input_tensor):
args = get_args()
output = forward_step_func(data_iterator, model, input_tensor)
if not is_pipeline_last_stage() or not args.balance_strategy:
return output
while hasattr(model, 'module'):
model = model.module
loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
for l in model.language_model.transformer.layers
if l.mlp.gate.has_loss]
if len(loss_list) == 0:
return output
loss_name = args.balance_strategy + "_loss"
(loss, state_dict), bal_loss = (
output,
torch.cat(loss_list).mean() * args.balance_loss_weight
)
# avarage across moe group
moe_group = get_tensor_model_parallel_group()
world_size = torch.distributed.get_world_size(group=moe_group)
averaged_bal_loss = bal_loss.clone().detach()
torch.distributed.all_reduce(averaged_bal_loss, group=moe_group)
averaged_bal_loss /= world_size
loss += bal_loss
state_dict[loss_name] = averaged_bal_loss
return loss, state_dict
return forward_step_with_balance_loss
def patch_model_provider(model_provider, gate=None):
from megatron import get_args
def fmoefied_model_provider():
from .layers import fmoefy
args = get_args()
hhs = args.hidden_size * 4
assert hhs % args.top_k == 0
hhs = hhs // args.top_k
assert hhs % args.tensor_model_parallel_size == 0
hhs = hhs // args.tensor_model_parallel_size
return fmoefy(
model_provider(),
num_experts=args.num_experts,
hidden_hidden_size=hhs,
top_k=args.top_k,
gate=gate
)
return fmoefied_model_provider
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment