Commit 6cb550fd authored by Rick Ho's avatar Rick Ho
Browse files

diverge gshard gate

parent cdc140f1
......@@ -7,5 +7,6 @@ from .noisy_gate import NoisyGate
from .gshard_gate import GShardGate
from .switch_gate import SwitchGate
from .dc_gate import DCGate
from .swipe_gate import SwipeGate
r"""
Distributed Capacity gate, extended from GShard gate.
Instead of setting capacity based on local batch size and expert count,
the global load of each experts are calculated, and then the experts make
decisions of capacities on each worker.
"""
import math
import torch
import torch.nn.functional as F
from .naive_gate import NaiveGate
from .utils import limit_by_capacity
class DCGate(NaiveGate):
def __init__(self, d_model, num_expert, world_size,
topk=2, capacity=(1.2, 2.4), random_routing=True):
assert topk == 2, 'topk should be 2 in gshard'
super().__init__(d_model, num_expert, world_size, top_k=2)
self.capacity = capacity
self.random_routing = random_routing
def forward(self, x):
naive_outs = super().forward(x, return_all_scores=True)
topk_idx, topk_val, gate_score = naive_outs
S = gate_score.shape[0]
top1_idx = topk_idx.view((-1, self.top_k))[:, 0]
c_e = torch.scatter_add(
torch.zeros(self.tot_expert, device=top1_idx.device),
0,
top1_idx,
torch.ones_like(top1_idx, dtype=torch.float),
) / S
m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0)
loss = torch.mean(c_e * m_e) * (self.num_expert ** 2)
self.set_loss(loss)
cap_rate = self.capacity[0 if self.training else 1]
capacity = math.ceil(cap_rate * x.shape[0])
_new_lec, _new_gec, topk_idx = limit_by_capacity(
topk_idx, self.num_expert, self.world_size, capacity)
if self.random_routing:
rand_routing_prob = torch.rand(gate_score.size(0), device=x.device)
mask = (2 * topk_val[:, 1] < rand_routing_prob)
topk_idx[:, 1].masked_fill_(mask, -1)
return topk_idx, topk_val
......@@ -6,6 +6,7 @@ import torch
import torch.nn.functional as F
from .naive_gate import NaiveGate
from .utils import limit_by_capacity
import fmoe_cuda as fmoe_native
class GShardGate(NaiveGate):
......@@ -33,9 +34,11 @@ class GShardGate(NaiveGate):
self.set_loss(loss)
cap_rate = self.capacity[0 if self.training else 1]
capacity = math.ceil(cap_rate * x.shape[0])
_new_lec, _new_gec, topk_idx = limit_by_capacity(
topk_idx, self.num_expert, self.world_size, capacity)
capacity = math.ceil(cap_rate * x.shape[0]) // self.world_size
capacity = torch.ones(self.num_expert * self.world_size,
dtype=torch.int32, device=topk_idx.device) * capacity
topk_idx = fmoe_native.prune_gate_by_capacity(topk_idx, capacity,
self.num_expert, self.world_size)
if self.random_routing:
rand_routing_prob = torch.rand(gate_score.size(0), device=x.device)
......
......@@ -35,6 +35,7 @@ def _run_distributed(func, world_size, args: Dict, script=__file__, env=dict()):
env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = str(random.randint(50000, 60000))
env["OMPI_COMM_WORLD_SIZE"] = str(world_size)
env["LD_LIBRARY_PATH"] = os.environ.get("LD_LIBRARY_PATH")
for i in range(world_size):
env["OMPI_COMM_WORLD_RANK"] = str(i)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment