Commit 82103edb authored by Rick Ho's avatar Rick Ho
Browse files

reconstruct gate structure

parent 26824495
r"""
Different implementations of the Gate are located in separate files here.
"""
from .zero_gate import ZeroGate
from .naive_gate import NaiveGate
from .noisy_gate import NoisyGate
r"""
Base gate with standard interface
"""
import torch.nn as nn
class BaseGate(nn.Module):
def __init__(self, num_expert, world_size):
super().__init__()
self.world_size = world_size
self.num_expert = num_expert
self.tot_expert = world_size * num_expert
self.loss = None
def forward(self, x):
raise NotImplementedError('Base gate cannot be directly used for fwd')
def set_loss(self, loss):
self.loss = loss
def get_loss(self, clear=True):
loss = self.loss
if clear:
self.loss = None
return loss
r"""
Naive gate
"""
from .base_gate import BaseGate
import torch
import torch.nn as nn
import torch.nn.functional as F
class NaiveGate(BaseGate):
r"""
A naive gate implementation that defines the standard behavior of the gate
which determines which experts the tokens are going to.
Both the indecies and the score, or confidence, are output to the parent
module.
The load-balance strategies are also designed to be implemented within the
`Gate` module.
"""
def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__()
self.gate = nn.Linear(d_model, self.tot_expert)
self.top_k = top_k
def forward(self, inp):
r"""
The naive implementation simply calculates the top-k of a linear layer's
output.
"""
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(
gate, k=self.top_k, dim=-1, largest=True, sorted=False
) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
# (BxL) x 1 x top_k
gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1)
gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)
return gate_top_k_idx, gate_score
r"""
Different implementations of the Gate are located here.
The `NaiveGate` is the reference to implement any other gate.
Noisy gate for gshard and switch
"""
from .base_gate import BaseGate
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
class ZeroGate(nn.Module):
r"""
Guide all input samples to gate 0.
"""
def __init__(self, _1, num_expert, _3, top_k=2):
super().__init__()
self.num_expert = num_expert
self.top_k = top_k
def forward(self, inp):
r"""
All output to expert 1
"""
idx = torch.zeros(
inp.shape[0] * self.top_k, dtype=torch.int64, device=inp.device
)
gate_score = (
torch.ones(inp.shape[0] * self.top_k, device=inp.device) / self.top_k
)
gate_score_all = torch.zeros(inp.shape[0], self.num_expert, device=inp.device)
gate_score_all[:, 0] = 1
return idx, gate_score.reshape(-1, 1, self.top_k), gate_score_all
class NaiveGate(nn.Module):
r"""
A naive gate implementation that defines the standard behavior of the gate
which determines which experts the tokens are going to.
Both the indecies and the score, or confidence, are output to the parent
module.
The load-balance strategies are also designed to be implemented within the
`Gate` module.
"""
def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__()
self.gate = nn.Linear(d_model, num_expert * world_size)
self.top_k = top_k
def forward(self, inp):
r"""
The naive implementation simply calculates the top-k of a linear layer's
output.
"""
gate = self.gate(inp)
gate_top_k_val, gate_top_k_idx = torch.topk(
gate, k=self.top_k, dim=-1, largest=True, sorted=False
) # [.. x top_k]
gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
# (BxL) x 1 x top_k
gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1)
gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)
return gate_top_k_idx, gate_score, gate
class NoisyGate(nn.Module):
class NoisyGate(BaseGate):
def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__()
self.num_expert = num_expert * world_size
super().__init__(num_expert, world_size)
self.w_gate = nn.Parameter(
torch.zeros(d_model, num_expert * world_size), requires_grad=True
torch.zeros(d_model, self.tot_expert), requires_grad=True
)
self.w_noise = nn.Parameter(
torch.zeros(d_model, num_expert * world_size), requires_grad=True
torch.zeros(d_model, self.tot_expert), requires_grad=True
)
self.top_k = top_k
self.softplus = nn.Softplus()
......@@ -163,7 +105,7 @@ class NoisyGate(nn.Module):
# calculate topk + 1 that will be needed for the noisy gates
top_logits, top_indices = logits.topk(
min(self.top_k + 1, self.num_expert), dim=1
min(self.top_k + 1, self.tot_expert), dim=1
)
top_k_logits = top_logits[:, : self.top_k]
top_k_indices = top_indices[:, : self.top_k]
......@@ -172,7 +114,7 @@ class NoisyGate(nn.Module):
zeros = torch.zeros_like(logits, requires_grad=True)
gates = zeros.scatter(1, top_k_indices, top_k_gates)
if self.top_k < self.num_expert:
if self.top_k < self.tot_expert:
load = (
self._prob_in_top_k(
clean_logits, noisy_logits, noise_stddev, top_logits
......@@ -183,9 +125,9 @@ class NoisyGate(nn.Module):
importance = gates.sum(0)
loss = self.cv_squared(importance) + self.cv_squared(load)
self.set_loss(loss)
return (
top_k_indices.contiguous().view(-1),
top_k_gates.contiguous().unsqueeze(1),
loss,
)
r"""
Zero gate that direct all input to gate 0
"""
from .base_gate import BaseGate
import torch
import torch.nn as nn
import torch.nn.functional as F
class ZeroGate(BaseGate):
r"""
Guide all input samples to gate 0.
"""
def __init__(self, _1, num_expert, world_size, top_k=2):
super().__init__(num_expert, world_size)
self.top_k = top_k
def forward(self, inp):
r"""
All output to expert 1
"""
idx = torch.zeros(
inp.shape[0] * self.top_k, dtype=torch.int64, device=inp.device
)
gate_score = (
torch.ones(inp.shape[0] * self.top_k, device=inp.device) / self.top_k
)
gate_score_all = torch.zeros(inp.shape[0], self.num_expert, device=inp.device)
gate_score_all[:, 0] = 1
return idx, gate_score.reshape(-1, 1, self.top_k), gate_score_all
......@@ -214,10 +214,10 @@ class FMoE(nn.Module):
if self.mp_size > 1:
inp = Slice.apply(inp, self.mp_rank, self.mp_size, self.mp_group)
gate_top_k_idx, gate_score, gate_state_dict = self.gate(inp)
if self.gate_hook:
self.gate_hook(gate_top_k_idx, gate_score, gate_state_dict)
gate_top_k_idx, gate_score = self.gate(inp)
# to: (BxLxtop_k) x d_model
# TODO: remove repeat_interleave
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
x = _fmoe_general_global_forward(
inp, gate_top_k_idx, self.expert_fn, self.num_expert, self.world_size
......
......@@ -20,7 +20,7 @@ if __name__ == '__main__':
author_email='hja20@mails.tsinghua.edu.cn',
license='Apache-2',
url='https://github.com/laekov/fastmoe',
packages=['fmoe', 'fmoe.megatron'],
packages=['fmoe', 'fmoe.megatron', 'fmoe.gates'],
ext_modules=[
CUDAExtension(
name='fmoe_cuda',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment