Commit b3380ec2 authored by Rick Ho's avatar Rick Ho
Browse files

support arbitrary module as expert

parent 8328c794
...@@ -2,5 +2,6 @@ r""" ...@@ -2,5 +2,6 @@ r"""
The fmoe package contains MoE Layers only. The fmoe package contains MoE Layers only.
""" """
from .layers import FMoELinear, FMoETransformerMLP from .layers import FMoELinear, FMoE
from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel from .distributed import DistributedGroupedDataParallel
...@@ -41,15 +41,23 @@ class FMoELinear(nn.Module): ...@@ -41,15 +41,23 @@ class FMoELinear(nn.Module):
return MOELinear.apply(inp, self.weight, fwd_expert_count) return MOELinear.apply(inp, self.weight, fwd_expert_count)
def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size): def mark_module_parallel_comm(module, comm):
r'''
Mark all parameters in `module` as doing data parallel in `comm`, where
`comm` may be one of `'world', 'dp', 'none'`.
'''
for p in module.parameters():
setattr(p, 'dp_comm', comm)
def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
r''' r'''
A private function that performs the following steps to complete the MoE A private function that performs the following steps to complete the MoE
computation. computation.
* Count the number of tokens from each worker to each expert. * Count the number of tokens from each worker to each expert.
* Send the features to their target position so that input features to each * Send the features to their target position so that input features to each
expert are contiguous in memory. expert are contiguous in memory.
* Perform the MLP of the experts by applying MoELinear and the activation in * Perform the forward computation of the experts using `expert_fn`
turns.
* Gather the output features of experts back, and reorder them as sentences. * Gather the output features of experts back, and reorder them as sentences.
Intermediate results like expert counts are hidden from users by this Intermediate results like expert counts are hidden from users by this
function. function.
...@@ -62,19 +70,18 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size): ...@@ -62,19 +70,18 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
inp, pos, local_expert_count, global_expert_count, fwd_batch_size, inp, pos, local_expert_count, global_expert_count, fwd_batch_size,
world_size world_size
) )
for i, l in enumerate(linears): x = expert_fn(x, fwd_expert_count)
if i:
x = activation(x)
x = l(x, fwd_expert_count)
x = MOEGather.apply( x = MOEGather.apply(
x, pos, local_expert_count, global_expert_count, inp.shape[0], world_size x, pos, local_expert_count, global_expert_count, inp.shape[0], world_size
) )
return x return x
class FMoETransformerMLP(nn.Module):
class FMoE(nn.Module):
r''' r'''
A complete MoE MLP module in a Transformer block. A general moe implementation that supports an arbitrary module as the expert
Either `expert` or `expert_fn` is required.
* `num_expert` stands for the number of experts on **each** worker. * `num_expert` stands for the number of experts on **each** worker.
* `world_size` stands for the total number of workers that contains * `world_size` stands for the total number of workers that contains
different experts. different experts.
...@@ -83,25 +90,19 @@ class FMoETransformerMLP(nn.Module): ...@@ -83,25 +90,19 @@ class FMoETransformerMLP(nn.Module):
hold the same copy of the input feature, and demands the same copy of the hold the same copy of the input feature, and demands the same copy of the
output. FMoE saves computation by slicing the input in the mp group and output. FMoE saves computation by slicing the input in the mp group and
performing all-gather after the MLP computation. performing all-gather after the MLP computation.
* `activation` is the activation function to be used in MLP in each expert.
* `top_k` stands for the number of experts each token is going to. * `top_k` stands for the number of experts each token is going to.
* `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate
`num_expert` expert modules.
* `expert_fn` is specified as a callable object or a function, it will be
called during forward, giving the input tensor (contiguous) and the array of
the number of input feature to each expert as input.
''' '''
def __init__( def __init__(self, num_expert=32, d_model=1024, world_size=1, mp_group=None,
self, top_k=2, gate=NaiveGate, expert=None, expert_fn=None):
num_expert=32,
d_model=1024,
d_hidden=4096,
world_size=1,
mp_group=None,
activation=torch.nn.functional.gelu,
gate=NaiveGate,
top_k=2,
pre_lnorm=False
):
super().__init__() super().__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.d_model = d_model self.d_model = d_model
self.d_hidden = d_hidden
self.world_size = world_size self.world_size = world_size
self.mp_group = mp_group self.mp_group = mp_group
if mp_group is None: if mp_group is None:
...@@ -110,37 +111,46 @@ class FMoETransformerMLP(nn.Module): ...@@ -110,37 +111,46 @@ class FMoETransformerMLP(nn.Module):
else: else:
self.mp_size = mp_group.size() self.mp_size = mp_group.size()
self.mp_rank = mp_group.rank() self.mp_rank = mp_group.rank()
self.activation = activation
self.pre_lnorm = pre_lnorm
self.top_k = top_k self.top_k = top_k
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model)
if self.world_size > self.mp_size:
for p in self.htoh4.parameters():
setattr(p, 'dp_comm', 'none')
for p in self.h4toh.parameters():
setattr(p, 'dp_comm', 'none')
self.gate = gate(d_model, num_expert, world_size, top_k) self.gate = gate(d_model, num_expert, world_size, top_k)
for p in self.gate.parameters(): if expert_fn is None:
setattr(p, 'dp_comm', 'world') assert expert is not None, 'Either expert or expert_fn should be set'
self.experts = [expert(d_model) for _ in range(num_expert)]
self.layer_norm = nn.LayerNorm(d_model) def expert_fn(self, inp, fwd_expert_count):
self.bias = torch.nn.parameter.Parameter( outputs = []
torch.zeros(d_model, dtype=torch.float32) base_idx = 0
) for i in range(self.num_expert):
batch_size = fwd_expert_count[i].item()
def forward(self, inp: torch.Tensor): inp_slice = inp[base_idx:base_idx + batch_size]
outputs.append(self.experts[i](inp_slice))
base_idx += batch_size
return torch.cat(outputs, dim=0)
self.expert_fn = expert_fn
def mark_parallel_comm(self):
r''' r'''
The FMoETransformerMLP module automatically performs reshape and layer Automatically mark the data parallel comms of the parameters within the
normalization. The score of the selected gate given by the expert is module. This can be typically called at the end of the __init__ function
multiplied to the experts' output tensors as a weight. in child classes.
'''
if self.experts is not None:
if self.world_size > self.mp_size:
comm = 'none'
else:
comm = 'dp'
if isinstance(self.experts, list):
for e in self.experts:
mark_module_parallel_comm(e, comm)
else:
mark_module_parallel_comm(self.experts, comm)
mark_module_parallel_comm(self.gate, 'world')
def forward(self, inp):
r'''
The FMoE module first computes gate output, and then conduct MoE forward
according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight.
''' '''
original_shape = inp.shape
inp = inp.reshape(-1, self.d_model)
if self.mp_size > 1: if self.mp_size > 1:
B: int = inp.shape[0] B: int = inp.shape[0]
local_batch_size = B // self.mp_size local_batch_size = B // self.mp_size
...@@ -148,35 +158,17 @@ class FMoETransformerMLP(nn.Module): ...@@ -148,35 +158,17 @@ class FMoETransformerMLP(nn.Module):
batch_end = min(batch_start + local_batch_size, B) batch_end = min(batch_start + local_batch_size, B)
inp = inp[batch_start:batch_end] inp = inp[batch_start:batch_end]
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
gate_top_k_idx, gate_score = self.gate(inp) gate_top_k_idx, gate_score = self.gate(inp)
# to: (BxLxtop_k) x d_model # to: (BxLxtop_k) x d_model
inp = inp.repeat_interleave(repeats=self.top_k, dim=0) inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
x = _fmoe_general_global_forward(inp, gate_top_k_idx, self.expert_fn,
x = _fmoe_full_forward( self.num_expert, self.world_size)
inp,
gate_top_k_idx,
[self.htoh4, self.h4toh],
self.activation,
self.num_expert,
self.world_size,
)
# to: (BxL) x top_k x d_model # to: (BxL) x top_k x d_model
core_out = x.view(-1, self.top_k, self.d_model) x = x.view(-1, self.top_k, self.d_model)
# to: (BxL) x 1 x d_model # to: (BxL) x d_model
core_out = torch.bmm(gate_score, core_out) x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
output = core_out.reshape(residual.shape) + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
if self.mp_size > 1: if self.mp_size > 1:
output = AllGather.apply(output, x = AllGather.apply(x,
self.mp_rank, self.mp_size, self.mp_group) self.mp_rank, self.mp_size, self.mp_group)
return x
return output.reshape(original_shape), self.bias
...@@ -3,33 +3,35 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two ...@@ -3,33 +3,35 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification. lines of modification.
See `exapmles/megatron` for usage instructions. See `exapmles/megatron` for usage instructions.
''' '''
from .layers import FMoETransformerMLP import torch
from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel from .distributed import DistributedGroupedDataParallel
from .utils import get_torch_default_comm from .utils import get_torch_default_comm
def _create_moe_mlp(args, group): class MegatronMLP(FMoETransformerMLP):
r''' r'''
Make the FMoETransformerMLP layer that distributes experts across Make the FMoETransformerMLP layer that distributes experts across
communication group `group` to replace the original MLP layer in Megatron. communication group `group` to replace the original MLP layer in Megatron.
''' '''
assert (args.seq_length * args.micro_batch_size def __init__(self, args, group):
% args.tensor_model_parallel_size == 0 assert (args.seq_length * args.micro_batch_size
), "Batch size x sequence length should be multiple of mp size" % args.tensor_model_parallel_size == 0
if not args.distributed_experts: ), "Batch size x sequence length should be multiple of mp size"
world_size = 1 if not args.distributed_experts:
else: world_size = 1
world_size = args.world_size else:
fmoe = FMoETransformerMLP( world_size = args.world_size
args.num_experts, super().__init__(args.num_experts,
d_model=args.hidden_size, d_model=args.hidden_size, d_hidden=args.hidden_size * 4,
d_hidden=args.hidden_size * 4, world_size=world_size, mp_group=group)
world_size=world_size, self.bias = torch.nn.parameter.Parameter(
mp_group=group torch.zeros(args.hidden_size, dtype=torch.float32)
) )
for p in fmoe.gate.parameters():
setattr(p, 'shared', True) def forward(self, inp):
return fmoe return super().forward(inp), self.bias
def fmoefy(model, num_experts=None, distributed_experts=True): def fmoefy(model, num_experts=None, distributed_experts=True):
...@@ -60,7 +62,7 @@ def fmoefy(model, num_experts=None, distributed_experts=True): ...@@ -60,7 +62,7 @@ def fmoefy(model, num_experts=None, distributed_experts=True):
args.distributed_experts = distributed_experts args.distributed_experts = distributed_experts
for l in model.language_model.transformer.layers: for l in model.language_model.transformer.layers:
l.mlp = _create_moe_mlp(args, get_torch_default_comm()) l.mlp = MegatronMLP(args, get_torch_default_comm())
return model return model
......
r'''
Adaption to act as the MLP layer using an MoE MLP layer in transformer.
'''
import torch
import torch.nn as nn
from .gates import NaiveGate
from .layers import FMoE, FMoELinear
class _Expert(nn.Module):
r'''
An expert using 2 FMoELinear modules to speed up the computation of experts
within one worker.
'''
def __init__(self, num_expert, d_model, d_hidden, activation):
super().__init__()
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model)
self.activation = activation
def forward(self, inp, fwd_expert_count):
r'''
First expand input to 4h (the hidden size is variable, but is called h4
for convenience). Then perform activation. Finally shirink back to h.
'''
x = self.htoh4(inp, fwd_expert_count)
x = self.activation(x)
x = self.h4toh(x, fwd_expert_count)
return x
class FMoETransformerMLP(FMoE):
r'''
A complete MoE MLP module in a Transformer block.
* `activation` is the activation function to be used in MLP in each expert.
* `d_hidden` is the dimension of the MLP layer.
'''
def __init__(
self,
num_expert=32,
d_model=1024,
d_hidden=4096,
world_size=1,
mp_group=None,
activation=torch.nn.functional.gelu,
gate=NaiveGate,
top_k=2,
pre_lnorm=False
):
def expert_fn(inp, gate):
return self.experts(inp, gate)
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
world_size=world_size, mp_group=mp_group, expert_fn=expert_fn)
self.experts = _Expert(num_expert, d_model, d_hidden, activation)
self.pre_lnorm = pre_lnorm
self.layer_norm = nn.LayerNorm(d_model)
self.mark_parallel_comm()
def forward(self, inp: torch.Tensor):
r'''
This module wraps up the FMoE module with reshape, residual and layer
normalization.
'''
original_shape = inp.shape
inp = inp.reshape(-1, self.d_model)
if self.pre_lnorm:
inp = self.layer_norm(inp)
output = super().forward(inp) + inp
if not self.pre_lnorm:
output = self.layer_norm(output)
return output.reshape(original_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment