Commit 38b334cc authored by Rich Ho's avatar Rich Ho
Browse files

test switch gate

parent ddfaaf49
......@@ -5,8 +5,7 @@ import math
import torch
import torch.nn.functional as F
from .naive_gate import NaiveGate
from fmoe.functions import count_by_gate
import fmoe_cuda as fmoe_native
from .utils import limit_by_capacity
class GShardGate(NaiveGate):
......@@ -32,21 +31,7 @@ class GShardGate(NaiveGate):
self.set_loss(loss)
cap_rate = self.capacity[0 if self.training else 1]
capacity = torch.ones(self.num_expert, dtype=torch.int32,
device=x.device)
capacity *= math.ceil(cap_rate * x.shape[0])
pos, lec, gec = count_by_gate(topk_idx.reshape(-1), self.num_expert,
self.world_size)
new_gec, = fmoe_native.limit_by_capacity(gec, capacity,
self.num_expert, self.world_size)
if self.world_size > 1:
new_lec = fmoe_native.expert_exchange(new_gec,
self.num_expert, self.world_size)
else:
new_lec = new_gec
fmoe_native.prune_gate_by_capacity(topk_idx,
new_lec.to(torch.int32), self.num_expert, self.world_size)
capacity = math.ceil(cap_rate * x.shape[0])
limit_by_capacity(topk_idx, self.num_expert, self.world_size, capacity)
return topk_idx, topk_val
r"""
Balanced gate with Switch Transformer's policy (Google, 2021)
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .naive_gate import NaiveGate
from .utils import limit_by_capacity
class SwitchGate(NaiveGate):
r"""
......@@ -13,7 +17,6 @@ class SwitchGate(NaiveGate):
def __init__(self, d_model, num_expert, world_size,
switch_eps=.1, capacity=(1.2, 2.4)):
super().__init__(d_model, num_expert, world_size, top_k=1)
self.gate = nn.Linear(d_model, num_expert * world_size)
self.switch_eps = switch_eps
self.capacity = capacity
......@@ -21,37 +24,35 @@ class SwitchGate(NaiveGate):
r"""
The switch firstly conduct softmax and then calculates the top-1
"""
gate = super().forward(inp)
score = self.gate(inp)
if self.training:
# random uniform number from [1-eps, 1+eps]
noise = torch.rand_like(gate)
noise = torch.rand_like(score)
noise = noise * 2 * self.switch_eps + 1.0 - self.switch_eps
gate += noise
score += noise
# fp32 softmax for numerical stability
gate_score = F.softmax(gate.float(), dim=-1)
score = F.softmax(score.float(), dim=-1)
gate_score_top1, gate_idx_top1 = torch.topk(
gate_score_clip, k=1, dim=-1, largest=True
top1_score, top1_idx = torch.topk(
score, k=1, dim=-1, largest=True
) # [.. x top_k]
gate_score = gate_score.to(dtype=inp.dtype)
gate_score_top1 = gate_score_top1.to(dtype=inp.dtype)
gate_score_top1 = gate_score_top1.unsqueeze(1)
gate_idx_top1 = gate_idx_top1.view(-1) # (BxLxtop_k)
top1_score = top1_score.to(dtype=inp.dtype)
top1_score = top1_score.to(dtype=inp.dtype)
# TODO: capacity limit
cap_rate = self.capacity[0 if self.training else 1]
capacity = math.ceil(cap_rate * inp.shape[0])
limit_by_capacity(top1_idx, self.num_expert, self.world_size, capacity)
# TODO: not testd, the following code is super dangerous!!!!!!
gate_updated = gate_idx_top1
gate_updated = gate_updated[gate_updated > -1]
valid_idx = top1_idx[top1_idx > -1]
fraction_expert = torch.scatter_add(
torch.zeros(self.tot_expert, device=gate_updated.device),
torch.zeros(self.tot_expert, device=valid_idx.device),
0,
gate_updated,
torch.ones_like(gate_updated, dtype=torch.float),
) / gate_updated.view(-1).size(0)
prob_expert = gate_score.sum(dim=0) / gate_updated.view(-1).size(0)
switch_aux_loss = (fraction_expert * prob_expert).sum() * self.tot_expert
self.set_loss(switch_aux_loss)
return gate_idx_top1, gate_score_top1
valid_idx,
torch.ones_like(valid_idx, dtype=torch.float),
) / valid_idx.numel()
prob_expert = score.sum(dim=0) / valid_idx.numel()
loss = (fraction_expert * prob_expert).sum() * self.tot_expert
self.set_loss(loss)
return top1_idx, top1_score
r"""
Utilities that may be used in the gates
"""
import torch
from fmoe.functions import count_by_gate
import fmoe_cuda as fmoe_native
def limit_by_capacity(topk_idx, num_expert, world_size, capacity):
capacity = torch.ones(num_expert, dtype=torch.int32,
device=topk_idx.device) * capacity
pos, lec, gec = count_by_gate(topk_idx.reshape(-1), num_expert, world_size)
new_gec, = fmoe_native.limit_by_capacity(gec, capacity,
num_expert, world_size)
if world_size > 1:
new_lec = fmoe_native.expert_exchange(new_gec, num_expert, world_size)
else:
new_lec = new_gec
fmoe_native.prune_gate_by_capacity(topk_idx,
new_lec.to(torch.int32), num_expert, world_size)
return new_lec, new_gec
......@@ -3,7 +3,7 @@ import os
import math
import torch
import torch.distributed as dist
from fmoe.gates import GShardGate
from fmoe.gates import GShardGate, SwitchGate
def _ensure_initialized():
......@@ -37,6 +37,26 @@ def test_gshard_gate(d_model, batch_size, n_expert, cap):
assert(i <= real_cap)
@pytest.mark.parametrize("d_model", [8, 1024])
@pytest.mark.parametrize("batch_size", [16, 4096])
@pytest.mark.parametrize("n_expert", [1, 4, 16])
@pytest.mark.parametrize("cap", [.1, .5, 1.1])
def test_switch_gate(d_model, batch_size, n_expert, cap):
_ensure_initialized()
gate = SwitchGate(d_model, n_expert, dist.get_world_size(),
capacity=(cap, cap)).cuda()
x = torch.rand(batch_size, d_model).cuda()
topk_idx, topk_val = gate(x)
counts = [0 for _ in range(n_expert)]
for v in topk_idx.cpu().view(-1).numpy():
if v != -1:
counts[v] += 1
real_cap = math.ceil(cap * batch_size)
for i in counts:
assert(i <= real_cap)
if __name__ == '__main__':
_ensure_initialized()
test_gshard_gate(4096, 1024, 4, .2)
# test_gshard_gate(4096, 1024, 4, .2)
test_switch_gate(4096, 1024, 4, .2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment