Commit d9dba929 authored by Rick Ho's avatar Rick Ho
Browse files

tide up fmoe python code

parent d2392de2
...@@ -2,6 +2,7 @@ r""" ...@@ -2,6 +2,7 @@ r"""
The fmoe package contains MoE Layers only. The fmoe package contains MoE Layers only.
""" """
from .layers import FMoELinear, FMoE from .layers import FMoE
from .linear import FMoELinear
from .transformer import FMoETransformerMLP from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel from .distributed import DistributedGroupedDataParallel
...@@ -25,11 +25,8 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -25,11 +25,8 @@ class DistributedGroupedDataParallel(nn.Module):
def __init__( def __init__(
self, self,
module, module,
mp_group=None,
dp_group=None,
moe_group=None,
world_group=None,
auto_allreduce=False, auto_allreduce=False,
**kwargs
): ):
assert not auto_allreduce, "Automatic all-reduce is not implemented yet" assert not auto_allreduce, "Automatic all-reduce is not implemented yet"
...@@ -37,20 +34,12 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -37,20 +34,12 @@ class DistributedGroupedDataParallel(nn.Module):
self.module = module self.module = module
self.comms = dict() self.comms = dict()
if mp_group is not None: for k in kwargs:
self.comms["mp"] = mp_group if k.endswith('_group'):
if dp_group is not None: self.comms[k[:-6]] = kwargs[k]
self.comms["dp"] = dp_group for k in ['dp', 'gate', 'moe', 'world']:
else: if k not in self.comms:
self.comms["dp"] = get_torch_default_comm() self.comms[k] = get_torch_default_comm()
if moe_group is not None:
self.comms["moe"] = moe_group
else:
self.comms["moe"] = get_torch_default_comm()
if world_group is None:
self.comms["world"] = get_torch_default_comm()
else:
self.comms["world"] = world_group
def allreduce_params(no_scale=False, def allreduce_params(no_scale=False,
reduce_after=False, fp32_allreduce=False): reduce_after=False, fp32_allreduce=False):
......
...@@ -132,34 +132,6 @@ class MOEScatter(Function): ...@@ -132,34 +132,6 @@ class MOEScatter(Function):
grad_in = _local_gather(local_grad_in, pos, inp_batch_size) grad_in = _local_gather(local_grad_in, pos, inp_batch_size)
return grad_in, None, None, None, None, None return grad_in, None, None, None, None, None
class MOELinear(Function):
r"""
Computes linear operators within one GPU on different experts simutaneously.
"""
@staticmethod
def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
global_output_buf = fmoe_cuda.linear_forward(
global_input_buf, fwd_expert_count, weight, bias
)
variables = (global_input_buf, fwd_expert_count, weight, bias)
ctx.save_for_backward(*variables)
return global_output_buf
@staticmethod
def backward(ctx, grad_out):
(input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors
grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.linear_backward(
grad_out, input_buf, fwd_expert_count, weight, bias
)
if not torch.is_tensor(bias):
grad_bias = None
return grad_inp_buf, None, grad_weight, grad_bias
class MOEGather(Function): class MOEGather(Function):
r""" r"""
Gather output samples from contiguous alone experts back to [batch x Gather output samples from contiguous alone experts back to [batch x
......
r""" r"""
Layers that FMoE provides to users FMoE core layer
""" """
import torch import torch
import torch.nn as nn import torch.nn as nn
import math
from .functions import prepare_forward, ensure_comm from .functions import prepare_forward, ensure_comm
from .functions import MOEScatter, MOEGather, MOELinear from .functions import MOEScatter, MOEGather
from .functions import AllGather, Slice from .functions import AllGather, Slice
from .gates import NaiveGate from .gates import NaiveGate
class FMoELinear(nn.Module):
r"""
A linear layer that contains multiple experts.
As multiple experts can be placed on the same worker, the computation can be
performed in parallel to increase the performance.
The FMoELinear module provides such function.
"""
def __init__(
self,
num_expert: int,
in_feat: int,
out_feat: int,
bias: bool = True,
rank: int = 0,
):
super().__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.rank = rank
self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
if bias:
self.bias = nn.Parameter(torch.zeros(num_expert, out_feat))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def forward(self, inp, fwd_expert_count):
r"""
Call MOE function
"""
x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias)
return x
def extra_repr(self) -> str:
return "num_expert={}, in_features={}, \
out_features={}, bias={}, rank={}".format(
self.num_expert,
self.in_feat,
self.out_feat,
self.bias is not None,
self.rank,
)
def reset_parameters(self):
# Approach is the same as in torch.nn.Linear
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
# bias is left to zero, similar as megatron
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def mark_module_parallel_comm(module, comm): def mark_module_parallel_comm(module, comm):
r""" r"""
...@@ -121,11 +67,12 @@ class FMoE(nn.Module): ...@@ -121,11 +67,12 @@ class FMoE(nn.Module):
* `num_expert` stands for the number of experts on **each** worker. * `num_expert` stands for the number of experts on **each** worker.
* `world_size` stands for the total number of workers that contains * `world_size` stands for the total number of workers that contains
different experts. different experts.
* `mp_group` can be a torch's communication group, indicating that model * `slice_group` can be a torch's communication group, indicating that
parallel is applied across the group, which means that workers in the group specific model parallel is applied across the group, and workers in the
hold the same copy of the input feature, and demands the same copy of the group hold the same copy of input feature, and requires the same copy of
output. FMoE saves computation by slicing the input in the mp group and the output. For each worker, FMoE only computes the output of a certain
performing all-gather after the MLP computation. slice of the input batch, and will all-gather the outputs after
computation.
* `top_k` stands for the number of experts each token is going to. * `top_k` stands for the number of experts each token is going to.
* `gate` is a gate class which can found in `fmoe.gates`. * `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate * `expert` can be specified as a module class, it is used to generate
...@@ -137,7 +84,8 @@ class FMoE(nn.Module): ...@@ -137,7 +84,8 @@ class FMoE(nn.Module):
num_expert=32, num_expert=32,
d_model=1024, d_model=1024,
world_size=1, world_size=1,
mp_group=None, mp_group=None, # being deprecated
slice_group=None,
moe_group=None, moe_group=None,
top_k=2, top_k=2,
gate=NaiveGate, gate=NaiveGate,
...@@ -150,13 +98,18 @@ class FMoE(nn.Module): ...@@ -150,13 +98,18 @@ class FMoE(nn.Module):
self.num_expert = num_expert self.num_expert = num_expert
self.d_model = d_model self.d_model = d_model
self.world_size = world_size self.world_size = world_size
self.mp_group = mp_group
if mp_group is None: self.slice_group = slice_group
self.mp_size = 1 if mp_group is not None:
self.mp_rank = 0 print('[Warning] mp_group is being deprecated')
self.slice_group = mp_group
if self.slice_group is None:
self.slice_size = 1
self.slice_rank = 0
else: else:
self.mp_size = mp_group.size() self.slice_size = slice_group.size()
self.mp_rank = mp_group.rank() self.slice_rank = slice_group.rank()
self.top_k = top_k self.top_k = top_k
if type(expert) is list: if type(expert) is list:
self.experts = nn.ModuleList([e(d_model) for e in expert]) self.experts = nn.ModuleList([e(d_model) for e in expert])
...@@ -168,6 +121,7 @@ class FMoE(nn.Module): ...@@ -168,6 +121,7 @@ class FMoE(nn.Module):
self.experts_fused = False self.experts_fused = False
else: else:
self.experts_fused = True self.experts_fused = True
self.gate = gate(d_model, num_expert, world_size, top_k) self.gate = gate(d_model, num_expert, world_size, top_k)
self.gate_hook = gate_hook self.gate_hook = gate_hook
self.mask = mask self.mask = mask
...@@ -203,7 +157,7 @@ class FMoE(nn.Module): ...@@ -203,7 +157,7 @@ class FMoE(nn.Module):
mark_module_parallel_comm(e, comm) mark_module_parallel_comm(e, comm)
else: else:
mark_module_parallel_comm(self.experts, comm) mark_module_parallel_comm(self.experts, comm)
mark_module_parallel_comm(self.gate, "moe") mark_module_parallel_comm(self.gate, "gate")
def forward(self, inp): def forward(self, inp):
r""" r"""
...@@ -213,8 +167,9 @@ class FMoE(nn.Module): ...@@ -213,8 +167,9 @@ class FMoE(nn.Module):
""" """
if self.world_size > 1: if self.world_size > 1:
ensure_comm(inp, self.moe_group) ensure_comm(inp, self.moe_group)
if self.mp_size > 1: if self.slice_size > 1:
inp = Slice.apply(inp, self.mp_rank, self.mp_size, self.mp_group) inp = Slice.apply(inp, self.slice_rank,
self.slice_size, self.slice_group)
gate_top_k_idx, gate_score = self.gate(inp) gate_top_k_idx, gate_score = self.gate(inp)
...@@ -249,6 +204,7 @@ class FMoE(nn.Module): ...@@ -249,6 +204,7 @@ class FMoE(nn.Module):
gate_score = gate_score.view(x.shape[0], 1, self.top_k) gate_score = gate_score.view(x.shape[0], 1, self.top_k)
x = torch.bmm(gate_score, x).reshape(-1, self.d_model) x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
if self.mp_size > 1: if self.slice_size > 1:
x = AllGather.apply(x, self.mp_rank, self.mp_size, self.mp_group) x = AllGather.apply(x, self.slice_rank,
self.slice_size, self.slice_group)
return x return x
r"""
FMoE's parallel linear layer
"""
import torch
import torch.nn as nn
from torch.autograd import Function
import math
import fmoe_cuda
class MOELinear(Function):
r"""
Computes linear operators within one GPU on different experts simutaneously.
"""
@staticmethod
def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
global_output_buf = fmoe_cuda.linear_forward(
global_input_buf, fwd_expert_count, weight, bias
)
variables = (global_input_buf, fwd_expert_count, weight, bias)
ctx.save_for_backward(*variables)
return global_output_buf
@staticmethod
def backward(ctx, grad_out):
(input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors
grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.linear_backward(
grad_out, input_buf, fwd_expert_count, weight, bias
)
if not torch.is_tensor(bias):
grad_bias = None
return grad_inp_buf, None, grad_weight, grad_bias
class FMoELinear(nn.Module):
r"""
A linear layer that contains multiple experts.
As multiple experts can be placed on the same worker, the computation can be
performed in parallel to increase the performance.
The FMoELinear module provides such function.
"""
def __init__(
self,
num_expert: int,
in_feat: int,
out_feat: int,
bias: bool = True,
rank: int = 0,
):
super().__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.rank = rank
self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
if bias:
self.bias = nn.Parameter(torch.zeros(num_expert, out_feat))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def forward(self, inp, fwd_expert_count):
r"""
Call MOE function
"""
x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias)
return x
def extra_repr(self) -> str:
return "num_expert={}, in_features={}, \
out_features={}, bias={}, rank={}".format(
self.num_expert,
self.in_feat,
self.out_feat,
self.bias is not None,
self.rank,
)
def reset_parameters(self):
# Approach is the same as in torch.nn.Linear
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
# bias is left to zero, similar as megatron
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
...@@ -3,8 +3,8 @@ Adaption to act as the MLP layer using an MoE MLP layer in transformer. ...@@ -3,8 +3,8 @@ Adaption to act as the MLP layer using an MoE MLP layer in transformer.
""" """
import torch import torch
import torch.nn as nn import torch.nn as nn
from .gates import NaiveGate from .layers import FMoE
from .layers import FMoE, FMoELinear from .linear import FMoELinear
class _Expert(nn.Module): class _Expert(nn.Module):
...@@ -42,31 +42,14 @@ class FMoETransformerMLP(FMoE): ...@@ -42,31 +42,14 @@ class FMoETransformerMLP(FMoE):
num_expert=32, num_expert=32,
d_model=1024, d_model=1024,
d_hidden=4096, d_hidden=4096,
world_size=1,
mp_group=None,
moe_group=None,
activation=torch.nn.GELU(), activation=torch.nn.GELU(),
gate=NaiveGate,
top_k=2,
expert_dp_comm="none", expert_dp_comm="none",
gate_hook=None, expert_rank=0,
mask=None, **kwargs
mask_dict=None,
): ):
super().__init__( super().__init__(num_expert=num_expert, d_model=d_model, **kwargs)
num_expert=num_expert,
d_model=d_model,
gate=gate,
top_k=top_k,
world_size=world_size,
mp_group=mp_group,
moe_group=moe_group,
gate_hook=gate_hook,
mask=mask,
mask_dict=mask_dict
)
self.experts = _Expert( self.experts = _Expert(
num_expert, d_model, d_hidden, activation, rank=self.mp_rank num_expert, d_model, d_hidden, activation, rank=expert_rank
) )
self.mark_parallel_comm(expert_dp_comm) self.mark_parallel_comm(expert_dp_comm)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment