Commit 8328c794 authored by Rick Ho's avatar Rick Ho
Browse files

separate gates file

parent 9c92be55
...@@ -2,5 +2,5 @@ r""" ...@@ -2,5 +2,5 @@ r"""
The fmoe package contains MoE Layers only. The fmoe package contains MoE Layers only.
""" """
from .layers import FMoELinear, FMoENaiveGate, FMoETransformerMLP from .layers import FMoELinear, FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel from .distributed import DistributedGroupedDataParallel
r'''
Different implementations of the Gate are located here.
The `NaiveGate` is the reference to implement any other gate.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
class NaiveGate(nn.Module):
r'''
A naive gate implementation that defines the standard behavior of the gate
which determines which experts the tokens are going to.
Both the indecies and the score, or confidence, are output to the parent
module.
The load-balance strategies are also designed to be implemented within the
`Gate` module.
'''
def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__()
self.gate = nn.Linear(d_model, num_expert * world_size)
self.top_k = top_k
def forward(self, inp):
r'''
The naive implementation simply calculates the top-k of a linear layer's
output.
'''
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(
gate, k=self.top_k, dim=-1, largest=True, sorted=False
) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
# (BxL) x 1 x top_k
gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1)
gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)
return gate_top_k_idx, gate_score
...@@ -3,11 +3,11 @@ Layers that FMoE provides to users ...@@ -3,11 +3,11 @@ Layers that FMoE provides to users
''' '''
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from .functions import moe_prepare_forward from .functions import moe_prepare_forward
from .functions import MOEScatter, MOEGather, MOELinear from .functions import MOEScatter, MOEGather, MOELinear
from .functions import AllGather from .functions import AllGather
from .gates import NaiveGate
class FMoELinear(nn.Module): class FMoELinear(nn.Module):
...@@ -41,38 +41,6 @@ class FMoELinear(nn.Module): ...@@ -41,38 +41,6 @@ class FMoELinear(nn.Module):
return MOELinear.apply(inp, self.weight, fwd_expert_count) return MOELinear.apply(inp, self.weight, fwd_expert_count)
class FMoENaiveGate(nn.Module):
r'''
A naive gate implementation that defines the standard behavior of the gate
which determines which experts the tokens are going to.
Both the indecies and the score, or confidence, are output to the parent
module.
The load-balance strategies are also designed to be implemented within the
`Gate` module.
'''
def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__()
self.gate = nn.Linear(d_model, num_expert * world_size)
self.top_k = top_k
def forward(self, inp):
r'''
The naive implementation simply calculates the top-k of a linear layer's
output.
'''
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(
gate, k=self.top_k, dim=-1, largest=True, sorted=False
) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
# (BxL) x 1 x top_k
gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1)
gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)
return gate_top_k_idx, gate_score
def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size): def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
r''' r'''
A private function that performs the following steps to complete the MoE A private function that performs the following steps to complete the MoE
...@@ -126,6 +94,7 @@ class FMoETransformerMLP(nn.Module): ...@@ -126,6 +94,7 @@ class FMoETransformerMLP(nn.Module):
world_size=1, world_size=1,
mp_group=None, mp_group=None,
activation=torch.nn.functional.gelu, activation=torch.nn.functional.gelu,
gate=NaiveGate,
top_k=2, top_k=2,
pre_lnorm=False pre_lnorm=False
): ):
...@@ -154,7 +123,7 @@ class FMoETransformerMLP(nn.Module): ...@@ -154,7 +123,7 @@ class FMoETransformerMLP(nn.Module):
for p in self.h4toh.parameters(): for p in self.h4toh.parameters():
setattr(p, 'dp_comm', 'none') setattr(p, 'dp_comm', 'none')
self.gate = FMoENaiveGate(d_model, num_expert, world_size, top_k) self.gate = gate(d_model, num_expert, world_size, top_k)
for p in self.gate.parameters(): for p in self.gate.parameters():
setattr(p, 'dp_comm', 'world') setattr(p, 'dp_comm', 'world')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment