Commit 38b334cc authored by Rich Ho's avatar Rich Ho
Browse files

test switch gate

parent ddfaaf49
...@@ -5,8 +5,7 @@ import math ...@@ -5,8 +5,7 @@ import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from .naive_gate import NaiveGate from .naive_gate import NaiveGate
from fmoe.functions import count_by_gate from .utils import limit_by_capacity
import fmoe_cuda as fmoe_native
class GShardGate(NaiveGate): class GShardGate(NaiveGate):
...@@ -32,21 +31,7 @@ class GShardGate(NaiveGate): ...@@ -32,21 +31,7 @@ class GShardGate(NaiveGate):
self.set_loss(loss) self.set_loss(loss)
cap_rate = self.capacity[0 if self.training else 1] cap_rate = self.capacity[0 if self.training else 1]
capacity = torch.ones(self.num_expert, dtype=torch.int32, capacity = math.ceil(cap_rate * x.shape[0])
device=x.device) limit_by_capacity(topk_idx, self.num_expert, self.world_size, capacity)
capacity *= math.ceil(cap_rate * x.shape[0])
pos, lec, gec = count_by_gate(topk_idx.reshape(-1), self.num_expert,
self.world_size)
new_gec, = fmoe_native.limit_by_capacity(gec, capacity,
self.num_expert, self.world_size)
if self.world_size > 1:
new_lec = fmoe_native.expert_exchange(new_gec,
self.num_expert, self.world_size)
else:
new_lec = new_gec
fmoe_native.prune_gate_by_capacity(topk_idx,
new_lec.to(torch.int32), self.num_expert, self.world_size)
return topk_idx, topk_val return topk_idx, topk_val
r""" r"""
Balanced gate with Switch Transformer's policy (Google, 2021) Balanced gate with Switch Transformer's policy (Google, 2021)
""" """
import math
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .naive_gate import NaiveGate from .naive_gate import NaiveGate
from .utils import limit_by_capacity
class SwitchGate(NaiveGate): class SwitchGate(NaiveGate):
r""" r"""
...@@ -13,7 +17,6 @@ class SwitchGate(NaiveGate): ...@@ -13,7 +17,6 @@ class SwitchGate(NaiveGate):
def __init__(self, d_model, num_expert, world_size, def __init__(self, d_model, num_expert, world_size,
switch_eps=.1, capacity=(1.2, 2.4)): switch_eps=.1, capacity=(1.2, 2.4)):
super().__init__(d_model, num_expert, world_size, top_k=1) super().__init__(d_model, num_expert, world_size, top_k=1)
self.gate = nn.Linear(d_model, num_expert * world_size)
self.switch_eps = switch_eps self.switch_eps = switch_eps
self.capacity = capacity self.capacity = capacity
...@@ -21,37 +24,35 @@ class SwitchGate(NaiveGate): ...@@ -21,37 +24,35 @@ class SwitchGate(NaiveGate):
r""" r"""
The switch firstly conduct softmax and then calculates the top-1 The switch firstly conduct softmax and then calculates the top-1
""" """
gate = super().forward(inp) score = self.gate(inp)
if self.training: if self.training:
# random uniform number from [1-eps, 1+eps] # random uniform number from [1-eps, 1+eps]
noise = torch.rand_like(gate) noise = torch.rand_like(score)
noise = noise * 2 * self.switch_eps + 1.0 - self.switch_eps noise = noise * 2 * self.switch_eps + 1.0 - self.switch_eps
gate += noise score += noise
# fp32 softmax for numerical stability # fp32 softmax for numerical stability
gate_score = F.softmax(gate.float(), dim=-1) score = F.softmax(score.float(), dim=-1)
gate_score_top1, gate_idx_top1 = torch.topk( top1_score, top1_idx = torch.topk(
gate_score_clip, k=1, dim=-1, largest=True score, k=1, dim=-1, largest=True
) # [.. x top_k] ) # [.. x top_k]
gate_score = gate_score.to(dtype=inp.dtype) top1_score = top1_score.to(dtype=inp.dtype)
gate_score_top1 = gate_score_top1.to(dtype=inp.dtype) top1_score = top1_score.to(dtype=inp.dtype)
gate_score_top1 = gate_score_top1.unsqueeze(1)
gate_idx_top1 = gate_idx_top1.view(-1) # (BxLxtop_k)
# TODO: capacity limit cap_rate = self.capacity[0 if self.training else 1]
capacity = math.ceil(cap_rate * inp.shape[0])
limit_by_capacity(top1_idx, self.num_expert, self.world_size, capacity)
# TODO: not testd, the following code is super dangerous!!!!!! valid_idx = top1_idx[top1_idx > -1]
gate_updated = gate_idx_top1
gate_updated = gate_updated[gate_updated > -1]
fraction_expert = torch.scatter_add( fraction_expert = torch.scatter_add(
torch.zeros(self.tot_expert, device=gate_updated.device), torch.zeros(self.tot_expert, device=valid_idx.device),
0, 0,
gate_updated, valid_idx,
torch.ones_like(gate_updated, dtype=torch.float), torch.ones_like(valid_idx, dtype=torch.float),
) / gate_updated.view(-1).size(0) ) / valid_idx.numel()
prob_expert = gate_score.sum(dim=0) / gate_updated.view(-1).size(0) prob_expert = score.sum(dim=0) / valid_idx.numel()
switch_aux_loss = (fraction_expert * prob_expert).sum() * self.tot_expert loss = (fraction_expert * prob_expert).sum() * self.tot_expert
self.set_loss(switch_aux_loss) self.set_loss(loss)
return gate_idx_top1, gate_score_top1 return top1_idx, top1_score
r"""
Utilities that may be used in the gates
"""
import torch
from fmoe.functions import count_by_gate
import fmoe_cuda as fmoe_native
def limit_by_capacity(topk_idx, num_expert, world_size, capacity):
capacity = torch.ones(num_expert, dtype=torch.int32,
device=topk_idx.device) * capacity
pos, lec, gec = count_by_gate(topk_idx.reshape(-1), num_expert, world_size)
new_gec, = fmoe_native.limit_by_capacity(gec, capacity,
num_expert, world_size)
if world_size > 1:
new_lec = fmoe_native.expert_exchange(new_gec, num_expert, world_size)
else:
new_lec = new_gec
fmoe_native.prune_gate_by_capacity(topk_idx,
new_lec.to(torch.int32), num_expert, world_size)
return new_lec, new_gec
...@@ -3,7 +3,7 @@ import os ...@@ -3,7 +3,7 @@ import os
import math import math
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from fmoe.gates import GShardGate from fmoe.gates import GShardGate, SwitchGate
def _ensure_initialized(): def _ensure_initialized():
...@@ -37,6 +37,26 @@ def test_gshard_gate(d_model, batch_size, n_expert, cap): ...@@ -37,6 +37,26 @@ def test_gshard_gate(d_model, batch_size, n_expert, cap):
assert(i <= real_cap) assert(i <= real_cap)
@pytest.mark.parametrize("d_model", [8, 1024])
@pytest.mark.parametrize("batch_size", [16, 4096])
@pytest.mark.parametrize("n_expert", [1, 4, 16])
@pytest.mark.parametrize("cap", [.1, .5, 1.1])
def test_switch_gate(d_model, batch_size, n_expert, cap):
_ensure_initialized()
gate = SwitchGate(d_model, n_expert, dist.get_world_size(),
capacity=(cap, cap)).cuda()
x = torch.rand(batch_size, d_model).cuda()
topk_idx, topk_val = gate(x)
counts = [0 for _ in range(n_expert)]
for v in topk_idx.cpu().view(-1).numpy():
if v != -1:
counts[v] += 1
real_cap = math.ceil(cap * batch_size)
for i in counts:
assert(i <= real_cap)
if __name__ == '__main__': if __name__ == '__main__':
_ensure_initialized() _ensure_initialized()
test_gshard_gate(4096, 1024, 4, .2) # test_gshard_gate(4096, 1024, 4, .2)
test_switch_gate(4096, 1024, 4, .2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment