Commit 4c90e6e8 authored by Rick Ho's avatar Rick Ho
Browse files

add gshard and switch gate with loss

parent 82103edb
...@@ -5,3 +5,5 @@ from .zero_gate import ZeroGate ...@@ -5,3 +5,5 @@ from .zero_gate import ZeroGate
from .naive_gate import NaiveGate from .naive_gate import NaiveGate
from .noisy_gate import NoisyGate from .noisy_gate import NoisyGate
from .gshard_gate import GShardGate
from .switch_gate import SwitchGate
r"""
Balanced gate with GShard's policy (Google, 2020)
"""
import torch
import torch.nn.functional as F
from .naive_gate import NaiveGate
class GShardGate(NaiveGate):
def __init__(self, d_model, num_expert, world_size, capacity=(1.2, 2.4)):
super().__init__(d_model, num_expert, world_size, top_k=2)
self.capacity = capacity
def forward(self, x):
topk_idx, topk_val, gate_score = super().forward(x)
S = gate_score.shape[0]
top_k = topk_idx.shape[0] // gate_score.shape[0]
top1_idx = topk_idx.view((-1, top_k))[:, 0]
c_e = torch.scatter_add(
torch.zeros(self.num_expert, device=gate_top_1_idx.device),
0,
top1_idx,
torch.ones_like(top1_idx, dtype=torch.float),
) / S
m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0)
loss = torch.mean(c_e * m_e) * (self.num_expert ** 2)
self.set_loss(loss)
# TODO: capacity limit
return topk_idx, topk_val
...@@ -19,7 +19,7 @@ class NaiveGate(BaseGate): ...@@ -19,7 +19,7 @@ class NaiveGate(BaseGate):
""" """
def __init__(self, d_model, num_expert, world_size, top_k=2): def __init__(self, d_model, num_expert, world_size, top_k=2):
super().__init__() super().__init__(num_expert, world_size)
self.gate = nn.Linear(d_model, self.tot_expert) self.gate = nn.Linear(d_model, self.tot_expert)
self.top_k = top_k self.top_k = top_k
...@@ -38,5 +38,6 @@ class NaiveGate(BaseGate): ...@@ -38,5 +38,6 @@ class NaiveGate(BaseGate):
gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1) gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1)
gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k) gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)
return gate_top_k_idx, gate_score # TODO: capacity
return gate_top_k_idx, gate_score
r"""
Balanced gate with Switch Transformer's policy (Google, 2021)
"""
import torch
import torch.nn.functional as F
from .naive_gate import NaiveGate
class SwitchGate(NaiveGate):
r"""
A switch gate implementation
"""
def __init__(self, d_model, num_expert, world_size,
switch_eps=.1, capacity=(1.2, 2.4)):
super().__init__(d_model, num_expert, world_size, top_k=1)
self.gate = nn.Linear(d_model, num_expert * world_size)
self.switch_eps = switch_eps
self.capacity = capacity
def forward(self, inp):
r"""
The switch firstly conduct softmax and then calculates the top-1
"""
gate = super().forward(inp)
if self.training:
# random uniform number from [1-eps, 1+eps]
noise = torch.rand_like(gate)
noise = noise * 2 * self.switch_eps + 1.0 - self.switch_eps
gate += noise
# fp32 softmax for numerical stability
gate_score = F.softmax(gate.float(), dim=-1)
gate_score_top1, gate_idx_top1 = torch.topk(
gate_score_clip, k=1, dim=-1, largest=True
) # [.. x top_k]
gate_score = gate_score.to(dtype=inp.dtype)
gate_score_top1 = gate_score_top1.to(dtype=inp.dtype)
gate_score_top1 = gate_score_top1.unsqueeze(1)
gate_idx_top1 = gate_idx_top1.view(-1) # (BxLxtop_k)
# TODO: capacity limit
# TODO: not testd, the following code is super dangerous!!!!!!
gate_updated = gate_idx_top1
gate_updated = gate_updated[gate_updated > -1]
fraction_expert = torch.scatter_add(
torch.zeros(self.tot_expert, device=gate_updated.device),
0,
gate_updated,
torch.ones_like(gate_updated, dtype=torch.float),
) / gate_updated.view(-1).size(0)
prob_expert = gate_score.sum(dim=0) / gate_updated.view(-1).size(0)
switch_aux_loss = (fraction_expert * prob_expert).sum() * self.tot_expert
self.set_loss(switch_aux_loss)
return gate_idx_top1, gate_score_top1
...@@ -38,7 +38,7 @@ def _perform_forward( ...@@ -38,7 +38,7 @@ def _perform_forward(
inp.requires_grad = True inp.requires_grad = True
inp_raw.requires_grad = True inp_raw.requires_grad = True
gate_idx, gate_score, _ = moe.gate(inp_raw) gate_idx, gate_score = moe.gate(inp_raw)
inp_repeated = inp_raw.repeat_interleave(repeats=top_k, dim=0) inp_repeated = inp_raw.repeat_interleave(repeats=top_k, dim=0)
moe_out = moe(inp) moe_out = moe(inp)
raw_out = moe_raw(inp_repeated, gate_idx, gate_score) raw_out = moe_raw(inp_repeated, gate_idx, gate_score)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment