Commit 6cb550fd authored by Rick Ho's avatar Rick Ho
Browse files

diverge gshard gate

parent cdc140f1
...@@ -7,5 +7,6 @@ from .noisy_gate import NoisyGate ...@@ -7,5 +7,6 @@ from .noisy_gate import NoisyGate
from .gshard_gate import GShardGate from .gshard_gate import GShardGate
from .switch_gate import SwitchGate from .switch_gate import SwitchGate
from .dc_gate import DCGate
from .swipe_gate import SwipeGate from .swipe_gate import SwipeGate
r"""
Distributed Capacity gate, extended from GShard gate.
Instead of setting capacity based on local batch size and expert count,
the global load of each experts are calculated, and then the experts make
decisions of capacities on each worker.
"""
import math
import torch
import torch.nn.functional as F
from .naive_gate import NaiveGate
from .utils import limit_by_capacity
class DCGate(NaiveGate):
def __init__(self, d_model, num_expert, world_size,
topk=2, capacity=(1.2, 2.4), random_routing=True):
assert topk == 2, 'topk should be 2 in gshard'
super().__init__(d_model, num_expert, world_size, top_k=2)
self.capacity = capacity
self.random_routing = random_routing
def forward(self, x):
naive_outs = super().forward(x, return_all_scores=True)
topk_idx, topk_val, gate_score = naive_outs
S = gate_score.shape[0]
top1_idx = topk_idx.view((-1, self.top_k))[:, 0]
c_e = torch.scatter_add(
torch.zeros(self.tot_expert, device=top1_idx.device),
0,
top1_idx,
torch.ones_like(top1_idx, dtype=torch.float),
) / S
m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0)
loss = torch.mean(c_e * m_e) * (self.num_expert ** 2)
self.set_loss(loss)
cap_rate = self.capacity[0 if self.training else 1]
capacity = math.ceil(cap_rate * x.shape[0])
_new_lec, _new_gec, topk_idx = limit_by_capacity(
topk_idx, self.num_expert, self.world_size, capacity)
if self.random_routing:
rand_routing_prob = torch.rand(gate_score.size(0), device=x.device)
mask = (2 * topk_val[:, 1] < rand_routing_prob)
topk_idx[:, 1].masked_fill_(mask, -1)
return topk_idx, topk_val
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from .naive_gate import NaiveGate from .naive_gate import NaiveGate
from .utils import limit_by_capacity from .utils import limit_by_capacity
import fmoe_cuda as fmoe_native
class GShardGate(NaiveGate): class GShardGate(NaiveGate):
...@@ -33,9 +34,11 @@ class GShardGate(NaiveGate): ...@@ -33,9 +34,11 @@ class GShardGate(NaiveGate):
self.set_loss(loss) self.set_loss(loss)
cap_rate = self.capacity[0 if self.training else 1] cap_rate = self.capacity[0 if self.training else 1]
capacity = math.ceil(cap_rate * x.shape[0]) capacity = math.ceil(cap_rate * x.shape[0]) // self.world_size
_new_lec, _new_gec, topk_idx = limit_by_capacity( capacity = torch.ones(self.num_expert * self.world_size,
topk_idx, self.num_expert, self.world_size, capacity) dtype=torch.int32, device=topk_idx.device) * capacity
topk_idx = fmoe_native.prune_gate_by_capacity(topk_idx, capacity,
self.num_expert, self.world_size)
if self.random_routing: if self.random_routing:
rand_routing_prob = torch.rand(gate_score.size(0), device=x.device) rand_routing_prob = torch.rand(gate_score.size(0), device=x.device)
......
...@@ -35,6 +35,7 @@ def _run_distributed(func, world_size, args: Dict, script=__file__, env=dict()): ...@@ -35,6 +35,7 @@ def _run_distributed(func, world_size, args: Dict, script=__file__, env=dict()):
env["MASTER_ADDR"] = "localhost" env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = str(random.randint(50000, 60000)) env["MASTER_PORT"] = str(random.randint(50000, 60000))
env["OMPI_COMM_WORLD_SIZE"] = str(world_size) env["OMPI_COMM_WORLD_SIZE"] = str(world_size)
env["LD_LIBRARY_PATH"] = os.environ.get("LD_LIBRARY_PATH")
for i in range(world_size): for i in range(world_size):
env["OMPI_COMM_WORLD_RANK"] = str(i) env["OMPI_COMM_WORLD_RANK"] = str(i)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment