"git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "942aa4ead5c02ddcfb154dbd82742da3ea00c6d0"
Unverified Commit 2cb96b40 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #141 from laekov/gate-fix

Diverge gshard gate
parents cdc140f1 a762d33c
......@@ -7,5 +7,6 @@ from .noisy_gate import NoisyGate
from .gshard_gate import GShardGate
from .switch_gate import SwitchGate
from .dc_gate import DCGate
from .swipe_gate import SwipeGate
r"""
Distributed Capacity gate, extended from GShard gate.
Instead of setting capacity based on local batch size and expert count,
the global load of each experts are calculated, and then the experts make
decisions of capacities on each worker.
"""
import math
import torch
import torch.nn.functional as F
from .naive_gate import NaiveGate
from .utils import limit_by_capacity
class DCGate(NaiveGate):
def __init__(self, d_model, num_expert, world_size,
topk=2, capacity=(1.2, 2.4), random_routing=True):
assert topk == 2, 'topk should be 2 in gshard'
super().__init__(d_model, num_expert, world_size, top_k=2)
self.capacity = capacity
self.random_routing = random_routing
def forward(self, x):
naive_outs = super().forward(x, return_all_scores=True)
topk_idx, topk_val, gate_score = naive_outs
S = gate_score.shape[0]
top1_idx = topk_idx.view((-1, self.top_k))[:, 0]
c_e = torch.scatter_add(
torch.zeros(self.tot_expert, device=top1_idx.device),
0,
top1_idx,
torch.ones_like(top1_idx, dtype=torch.float),
) / S
m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0)
loss = torch.mean(c_e * m_e) * (self.num_expert ** 2)
self.set_loss(loss)
cap_rate = self.capacity[0 if self.training else 1]
capacity = math.ceil(cap_rate * x.shape[0])
_new_lec, _new_gec, topk_idx = limit_by_capacity(
topk_idx, self.num_expert, self.world_size, capacity)
if self.random_routing:
rand_routing_prob = torch.rand(gate_score.size(0), device=x.device)
mask = (2 * topk_val[:, 1] < rand_routing_prob)
topk_idx[:, 1].masked_fill_(mask, -1)
return topk_idx, topk_val
......@@ -6,6 +6,7 @@ import torch
import torch.nn.functional as F
from .naive_gate import NaiveGate
from .utils import limit_by_capacity
import fmoe_cuda as fmoe_native
class GShardGate(NaiveGate):
......@@ -34,8 +35,11 @@ class GShardGate(NaiveGate):
cap_rate = self.capacity[0 if self.training else 1]
capacity = math.ceil(cap_rate * x.shape[0])
_new_lec, _new_gec, topk_idx = limit_by_capacity(
topk_idx, self.num_expert, self.world_size, capacity)
capacity = capacity * self.top_k // (self.world_size * self.num_expert)
capacity = torch.ones(self.num_expert * self.world_size,
dtype=torch.int32, device=topk_idx.device) * capacity
topk_idx = fmoe_native.prune_gate_by_capacity(topk_idx, capacity,
self.num_expert, self.world_size)
if self.random_routing:
rand_routing_prob = torch.rand(gate_score.size(0), device=x.device)
......
......@@ -35,6 +35,7 @@ def _run_distributed(func, world_size, args: Dict, script=__file__, env=dict()):
env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = str(random.randint(50000, 60000))
env["OMPI_COMM_WORLD_SIZE"] = str(world_size)
env["LD_LIBRARY_PATH"] = os.environ.get("LD_LIBRARY_PATH")
for i in range(world_size):
env["OMPI_COMM_WORLD_RANK"] = str(i)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment