Commit b3380ec2 authored by Rick Ho's avatar Rick Ho
Browse files

support arbitrary module as expert

parent 8328c794
......@@ -2,5 +2,6 @@ r"""
The fmoe package contains MoE Layers only.
"""
from .layers import FMoELinear, FMoETransformerMLP
from .layers import FMoELinear, FMoE
from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel
......@@ -41,15 +41,23 @@ class FMoELinear(nn.Module):
return MOELinear.apply(inp, self.weight, fwd_expert_count)
def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
def mark_module_parallel_comm(module, comm):
r'''
Mark all parameters in `module` as doing data parallel in `comm`, where
`comm` may be one of `'world', 'dp', 'none'`.
'''
for p in module.parameters():
setattr(p, 'dp_comm', comm)
def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
r'''
A private function that performs the following steps to complete the MoE
computation.
* Count the number of tokens from each worker to each expert.
* Send the features to their target position so that input features to each
expert are contiguous in memory.
* Perform the MLP of the experts by applying MoELinear and the activation in
turns.
* Perform the forward computation of the experts using `expert_fn`
* Gather the output features of experts back, and reorder them as sentences.
Intermediate results like expert counts are hidden from users by this
function.
......@@ -62,19 +70,18 @@ def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
inp, pos, local_expert_count, global_expert_count, fwd_batch_size,
world_size
)
for i, l in enumerate(linears):
if i:
x = activation(x)
x = l(x, fwd_expert_count)
x = expert_fn(x, fwd_expert_count)
x = MOEGather.apply(
x, pos, local_expert_count, global_expert_count, inp.shape[0], world_size
)
return x
class FMoETransformerMLP(nn.Module):
class FMoE(nn.Module):
r'''
A complete MoE MLP module in a Transformer block.
A general moe implementation that supports an arbitrary module as the expert
Either `expert` or `expert_fn` is required.
* `num_expert` stands for the number of experts on **each** worker.
* `world_size` stands for the total number of workers that contains
different experts.
......@@ -83,25 +90,19 @@ class FMoETransformerMLP(nn.Module):
hold the same copy of the input feature, and demands the same copy of the
output. FMoE saves computation by slicing the input in the mp group and
performing all-gather after the MLP computation.
* `activation` is the activation function to be used in MLP in each expert.
* `top_k` stands for the number of experts each token is going to.
* `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate
`num_expert` expert modules.
* `expert_fn` is specified as a callable object or a function, it will be
called during forward, giving the input tensor (contiguous) and the array of
the number of input feature to each expert as input.
'''
def __init__(
self,
num_expert=32,
d_model=1024,
d_hidden=4096,
world_size=1,
mp_group=None,
activation=torch.nn.functional.gelu,
gate=NaiveGate,
top_k=2,
pre_lnorm=False
):
def __init__(self, num_expert=32, d_model=1024, world_size=1, mp_group=None,
top_k=2, gate=NaiveGate, expert=None, expert_fn=None):
super().__init__()
self.num_expert = num_expert
self.d_model = d_model
self.d_hidden = d_hidden
self.world_size = world_size
self.mp_group = mp_group
if mp_group is None:
......@@ -110,37 +111,46 @@ class FMoETransformerMLP(nn.Module):
else:
self.mp_size = mp_group.size()
self.mp_rank = mp_group.rank()
self.activation = activation
self.pre_lnorm = pre_lnorm
self.top_k = top_k
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model)
if self.world_size > self.mp_size:
for p in self.htoh4.parameters():
setattr(p, 'dp_comm', 'none')
for p in self.h4toh.parameters():
setattr(p, 'dp_comm', 'none')
self.gate = gate(d_model, num_expert, world_size, top_k)
for p in self.gate.parameters():
setattr(p, 'dp_comm', 'world')
self.layer_norm = nn.LayerNorm(d_model)
self.bias = torch.nn.parameter.Parameter(
torch.zeros(d_model, dtype=torch.float32)
)
def forward(self, inp: torch.Tensor):
if expert_fn is None:
assert expert is not None, 'Either expert or expert_fn should be set'
self.experts = [expert(d_model) for _ in range(num_expert)]
def expert_fn(self, inp, fwd_expert_count):
outputs = []
base_idx = 0
for i in range(self.num_expert):
batch_size = fwd_expert_count[i].item()
inp_slice = inp[base_idx:base_idx + batch_size]
outputs.append(self.experts[i](inp_slice))
base_idx += batch_size
return torch.cat(outputs, dim=0)
self.expert_fn = expert_fn
def mark_parallel_comm(self):
r'''
The FMoETransformerMLP module automatically performs reshape and layer
normalization. The score of the selected gate given by the expert is
multiplied to the experts' output tensors as a weight.
Automatically mark the data parallel comms of the parameters within the
module. This can be typically called at the end of the __init__ function
in child classes.
'''
original_shape = inp.shape
inp = inp.reshape(-1, self.d_model)
if self.experts is not None:
if self.world_size > self.mp_size:
comm = 'none'
else:
comm = 'dp'
if isinstance(self.experts, list):
for e in self.experts:
mark_module_parallel_comm(e, comm)
else:
mark_module_parallel_comm(self.experts, comm)
mark_module_parallel_comm(self.gate, 'world')
def forward(self, inp):
r'''
The FMoE module first computes gate output, and then conduct MoE forward
according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight.
'''
if self.mp_size > 1:
B: int = inp.shape[0]
local_batch_size = B // self.mp_size
......@@ -148,35 +158,17 @@ class FMoETransformerMLP(nn.Module):
batch_end = min(batch_start + local_batch_size, B)
inp = inp[batch_start:batch_end]
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
gate_top_k_idx, gate_score = self.gate(inp)
# to: (BxLxtop_k) x d_model
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
x = _fmoe_full_forward(
inp,
gate_top_k_idx,
[self.htoh4, self.h4toh],
self.activation,
self.num_expert,
self.world_size,
)
x = _fmoe_general_global_forward(inp, gate_top_k_idx, self.expert_fn,
self.num_expert, self.world_size)
# to: (BxL) x top_k x d_model
core_out = x.view(-1, self.top_k, self.d_model)
# to: (BxL) x 1 x d_model
core_out = torch.bmm(gate_score, core_out)
output = core_out.reshape(residual.shape) + residual
if not self.pre_lnorm:
output = self.layer_norm(output)
x = x.view(-1, self.top_k, self.d_model)
# to: (BxL) x d_model
x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
if self.mp_size > 1:
output = AllGather.apply(output,
x = AllGather.apply(x,
self.mp_rank, self.mp_size, self.mp_group)
return output.reshape(original_shape), self.bias
return x
......@@ -3,16 +3,19 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `exapmles/megatron` for usage instructions.
'''
from .layers import FMoETransformerMLP
import torch
from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel
from .utils import get_torch_default_comm
def _create_moe_mlp(args, group):
class MegatronMLP(FMoETransformerMLP):
r'''
Make the FMoETransformerMLP layer that distributes experts across
communication group `group` to replace the original MLP layer in Megatron.
'''
def __init__(self, args, group):
assert (args.seq_length * args.micro_batch_size
% args.tensor_model_parallel_size == 0
), "Batch size x sequence length should be multiple of mp size"
......@@ -20,16 +23,15 @@ def _create_moe_mlp(args, group):
world_size = 1
else:
world_size = args.world_size
fmoe = FMoETransformerMLP(
args.num_experts,
d_model=args.hidden_size,
d_hidden=args.hidden_size * 4,
world_size=world_size,
mp_group=group
super().__init__(args.num_experts,
d_model=args.hidden_size, d_hidden=args.hidden_size * 4,
world_size=world_size, mp_group=group)
self.bias = torch.nn.parameter.Parameter(
torch.zeros(args.hidden_size, dtype=torch.float32)
)
for p in fmoe.gate.parameters():
setattr(p, 'shared', True)
return fmoe
def forward(self, inp):
return super().forward(inp), self.bias
def fmoefy(model, num_experts=None, distributed_experts=True):
......@@ -60,7 +62,7 @@ def fmoefy(model, num_experts=None, distributed_experts=True):
args.distributed_experts = distributed_experts
for l in model.language_model.transformer.layers:
l.mlp = _create_moe_mlp(args, get_torch_default_comm())
l.mlp = MegatronMLP(args, get_torch_default_comm())
return model
......
r'''
Adaption to act as the MLP layer using an MoE MLP layer in transformer.
'''
import torch
import torch.nn as nn
from .gates import NaiveGate
from .layers import FMoE, FMoELinear
class _Expert(nn.Module):
r'''
An expert using 2 FMoELinear modules to speed up the computation of experts
within one worker.
'''
def __init__(self, num_expert, d_model, d_hidden, activation):
super().__init__()
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model)
self.activation = activation
def forward(self, inp, fwd_expert_count):
r'''
First expand input to 4h (the hidden size is variable, but is called h4
for convenience). Then perform activation. Finally shirink back to h.
'''
x = self.htoh4(inp, fwd_expert_count)
x = self.activation(x)
x = self.h4toh(x, fwd_expert_count)
return x
class FMoETransformerMLP(FMoE):
r'''
A complete MoE MLP module in a Transformer block.
* `activation` is the activation function to be used in MLP in each expert.
* `d_hidden` is the dimension of the MLP layer.
'''
def __init__(
self,
num_expert=32,
d_model=1024,
d_hidden=4096,
world_size=1,
mp_group=None,
activation=torch.nn.functional.gelu,
gate=NaiveGate,
top_k=2,
pre_lnorm=False
):
def expert_fn(inp, gate):
return self.experts(inp, gate)
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
world_size=world_size, mp_group=mp_group, expert_fn=expert_fn)
self.experts = _Expert(num_expert, d_model, d_hidden, activation)
self.pre_lnorm = pre_lnorm
self.layer_norm = nn.LayerNorm(d_model)
self.mark_parallel_comm()
def forward(self, inp: torch.Tensor):
r'''
This module wraps up the FMoE module with reshape, residual and layer
normalization.
'''
original_shape = inp.shape
inp = inp.reshape(-1, self.d_model)
if self.pre_lnorm:
inp = self.layer_norm(inp)
output = super().forward(inp) + inp
if not self.pre_lnorm:
output = self.layer_norm(output)
return output.reshape(original_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment