Commit d9dba929 authored by Rick Ho's avatar Rick Ho
Browse files

tide up fmoe python code

parent d2392de2
......@@ -2,6 +2,7 @@ r"""
The fmoe package contains MoE Layers only.
"""
from .layers import FMoELinear, FMoE
from .layers import FMoE
from .linear import FMoELinear
from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel
......@@ -25,11 +25,8 @@ class DistributedGroupedDataParallel(nn.Module):
def __init__(
self,
module,
mp_group=None,
dp_group=None,
moe_group=None,
world_group=None,
auto_allreduce=False,
**kwargs
):
assert not auto_allreduce, "Automatic all-reduce is not implemented yet"
......@@ -37,20 +34,12 @@ class DistributedGroupedDataParallel(nn.Module):
self.module = module
self.comms = dict()
if mp_group is not None:
self.comms["mp"] = mp_group
if dp_group is not None:
self.comms["dp"] = dp_group
else:
self.comms["dp"] = get_torch_default_comm()
if moe_group is not None:
self.comms["moe"] = moe_group
else:
self.comms["moe"] = get_torch_default_comm()
if world_group is None:
self.comms["world"] = get_torch_default_comm()
else:
self.comms["world"] = world_group
for k in kwargs:
if k.endswith('_group'):
self.comms[k[:-6]] = kwargs[k]
for k in ['dp', 'gate', 'moe', 'world']:
if k not in self.comms:
self.comms[k] = get_torch_default_comm()
def allreduce_params(no_scale=False,
reduce_after=False, fp32_allreduce=False):
......
......@@ -132,34 +132,6 @@ class MOEScatter(Function):
grad_in = _local_gather(local_grad_in, pos, inp_batch_size)
return grad_in, None, None, None, None, None
class MOELinear(Function):
r"""
Computes linear operators within one GPU on different experts simutaneously.
"""
@staticmethod
def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
global_output_buf = fmoe_cuda.linear_forward(
global_input_buf, fwd_expert_count, weight, bias
)
variables = (global_input_buf, fwd_expert_count, weight, bias)
ctx.save_for_backward(*variables)
return global_output_buf
@staticmethod
def backward(ctx, grad_out):
(input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors
grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.linear_backward(
grad_out, input_buf, fwd_expert_count, weight, bias
)
if not torch.is_tensor(bias):
grad_bias = None
return grad_inp_buf, None, grad_weight, grad_bias
class MOEGather(Function):
r"""
Gather output samples from contiguous alone experts back to [batch x
......
r"""
Layers that FMoE provides to users
FMoE core layer
"""
import torch
import torch.nn as nn
import math
from .functions import prepare_forward, ensure_comm
from .functions import MOEScatter, MOEGather, MOELinear
from .functions import MOEScatter, MOEGather
from .functions import AllGather, Slice
from .gates import NaiveGate
class FMoELinear(nn.Module):
r"""
A linear layer that contains multiple experts.
As multiple experts can be placed on the same worker, the computation can be
performed in parallel to increase the performance.
The FMoELinear module provides such function.
"""
def __init__(
self,
num_expert: int,
in_feat: int,
out_feat: int,
bias: bool = True,
rank: int = 0,
):
super().__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.rank = rank
self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
if bias:
self.bias = nn.Parameter(torch.zeros(num_expert, out_feat))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def forward(self, inp, fwd_expert_count):
r"""
Call MOE function
"""
x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias)
return x
def extra_repr(self) -> str:
return "num_expert={}, in_features={}, \
out_features={}, bias={}, rank={}".format(
self.num_expert,
self.in_feat,
self.out_feat,
self.bias is not None,
self.rank,
)
def reset_parameters(self):
# Approach is the same as in torch.nn.Linear
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
# bias is left to zero, similar as megatron
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def mark_module_parallel_comm(module, comm):
r"""
......@@ -121,11 +67,12 @@ class FMoE(nn.Module):
* `num_expert` stands for the number of experts on **each** worker.
* `world_size` stands for the total number of workers that contains
different experts.
* `mp_group` can be a torch's communication group, indicating that model
parallel is applied across the group, which means that workers in the group
hold the same copy of the input feature, and demands the same copy of the
output. FMoE saves computation by slicing the input in the mp group and
performing all-gather after the MLP computation.
* `slice_group` can be a torch's communication group, indicating that
specific model parallel is applied across the group, and workers in the
group hold the same copy of input feature, and requires the same copy of
the output. For each worker, FMoE only computes the output of a certain
slice of the input batch, and will all-gather the outputs after
computation.
* `top_k` stands for the number of experts each token is going to.
* `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate
......@@ -137,7 +84,8 @@ class FMoE(nn.Module):
num_expert=32,
d_model=1024,
world_size=1,
mp_group=None,
mp_group=None, # being deprecated
slice_group=None,
moe_group=None,
top_k=2,
gate=NaiveGate,
......@@ -150,13 +98,18 @@ class FMoE(nn.Module):
self.num_expert = num_expert
self.d_model = d_model
self.world_size = world_size
self.mp_group = mp_group
if mp_group is None:
self.mp_size = 1
self.mp_rank = 0
self.slice_group = slice_group
if mp_group is not None:
print('[Warning] mp_group is being deprecated')
self.slice_group = mp_group
if self.slice_group is None:
self.slice_size = 1
self.slice_rank = 0
else:
self.mp_size = mp_group.size()
self.mp_rank = mp_group.rank()
self.slice_size = slice_group.size()
self.slice_rank = slice_group.rank()
self.top_k = top_k
if type(expert) is list:
self.experts = nn.ModuleList([e(d_model) for e in expert])
......@@ -168,6 +121,7 @@ class FMoE(nn.Module):
self.experts_fused = False
else:
self.experts_fused = True
self.gate = gate(d_model, num_expert, world_size, top_k)
self.gate_hook = gate_hook
self.mask = mask
......@@ -203,7 +157,7 @@ class FMoE(nn.Module):
mark_module_parallel_comm(e, comm)
else:
mark_module_parallel_comm(self.experts, comm)
mark_module_parallel_comm(self.gate, "moe")
mark_module_parallel_comm(self.gate, "gate")
def forward(self, inp):
r"""
......@@ -213,8 +167,9 @@ class FMoE(nn.Module):
"""
if self.world_size > 1:
ensure_comm(inp, self.moe_group)
if self.mp_size > 1:
inp = Slice.apply(inp, self.mp_rank, self.mp_size, self.mp_group)
if self.slice_size > 1:
inp = Slice.apply(inp, self.slice_rank,
self.slice_size, self.slice_group)
gate_top_k_idx, gate_score = self.gate(inp)
......@@ -249,6 +204,7 @@ class FMoE(nn.Module):
gate_score = gate_score.view(x.shape[0], 1, self.top_k)
x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
if self.mp_size > 1:
x = AllGather.apply(x, self.mp_rank, self.mp_size, self.mp_group)
if self.slice_size > 1:
x = AllGather.apply(x, self.slice_rank,
self.slice_size, self.slice_group)
return x
r"""
FMoE's parallel linear layer
"""
import torch
import torch.nn as nn
from torch.autograd import Function
import math
import fmoe_cuda
class MOELinear(Function):
r"""
Computes linear operators within one GPU on different experts simutaneously.
"""
@staticmethod
def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
global_output_buf = fmoe_cuda.linear_forward(
global_input_buf, fwd_expert_count, weight, bias
)
variables = (global_input_buf, fwd_expert_count, weight, bias)
ctx.save_for_backward(*variables)
return global_output_buf
@staticmethod
def backward(ctx, grad_out):
(input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors
grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.linear_backward(
grad_out, input_buf, fwd_expert_count, weight, bias
)
if not torch.is_tensor(bias):
grad_bias = None
return grad_inp_buf, None, grad_weight, grad_bias
class FMoELinear(nn.Module):
r"""
A linear layer that contains multiple experts.
As multiple experts can be placed on the same worker, the computation can be
performed in parallel to increase the performance.
The FMoELinear module provides such function.
"""
def __init__(
self,
num_expert: int,
in_feat: int,
out_feat: int,
bias: bool = True,
rank: int = 0,
):
super().__init__()
self.num_expert = num_expert
self.in_feat = in_feat
self.out_feat = out_feat
self.rank = rank
self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
if bias:
self.bias = nn.Parameter(torch.zeros(num_expert, out_feat))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def forward(self, inp, fwd_expert_count):
r"""
Call MOE function
"""
x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias)
return x
def extra_repr(self) -> str:
return "num_expert={}, in_features={}, \
out_features={}, bias={}, rank={}".format(
self.num_expert,
self.in_feat,
self.out_feat,
self.bias is not None,
self.rank,
)
def reset_parameters(self):
# Approach is the same as in torch.nn.Linear
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
# bias is left to zero, similar as megatron
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
......@@ -3,8 +3,8 @@ Adaption to act as the MLP layer using an MoE MLP layer in transformer.
"""
import torch
import torch.nn as nn
from .gates import NaiveGate
from .layers import FMoE, FMoELinear
from .layers import FMoE
from .linear import FMoELinear
class _Expert(nn.Module):
......@@ -42,31 +42,14 @@ class FMoETransformerMLP(FMoE):
num_expert=32,
d_model=1024,
d_hidden=4096,
world_size=1,
mp_group=None,
moe_group=None,
activation=torch.nn.GELU(),
gate=NaiveGate,
top_k=2,
expert_dp_comm="none",
gate_hook=None,
mask=None,
mask_dict=None,
expert_rank=0,
**kwargs
):
super().__init__(
num_expert=num_expert,
d_model=d_model,
gate=gate,
top_k=top_k,
world_size=world_size,
mp_group=mp_group,
moe_group=moe_group,
gate_hook=gate_hook,
mask=mask,
mask_dict=mask_dict
)
super().__init__(num_expert=num_expert, d_model=d_model, **kwargs)
self.experts = _Expert(
num_expert, d_model, d_hidden, activation, rank=self.mp_rank
num_expert, d_model, d_hidden, activation, rank=expert_rank
)
self.mark_parallel_comm(expert_dp_comm)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment