Commit 6e1fcca1 authored by Rick Ho's avatar Rick Ho
Browse files

faster gate develop

parent 98e24e24
r"""
The example topology-aware gate for two-layer tree-like topology, proposed by
the PPoPP'22 paper, FasterMoE. Limited number of tokens are sent across the
upper-level slow connection, and other ones are re-directed to experts in the
local network.
The number of GPUs to form such a local network is defined by an environment
variable `FMOE_TOPO_GPUS_PER_NODE`, and it is by default `8`.
The fraction of tokens that are allowed to be sent across nodes is defined by
an environement variable `FMOE_TOPO_OUTGOING_FRACTION`, and it is by default
`0.14`. Users are supposed to set the proper value in their own environemnt,
guided by some performance model, to achieve maximum throughput.
"""
from .naive_gate import NaiveGate
import os
import sys
import torch
import torch.nn.functional as F
from .utils import limit_by_capacity
import fmoe_cuda
from fmoe.functions import count_by_gate
nw_per_node = 8
try:
nw_per_node = int(os.environ['FMOE_TOPO_GPUS_PER_NODE'])
except Exception:
pass
class FasterGate(NaiveGate):
def __init__(self, d_model, n_expert, world_size, node_rank):
super().__init__(d_model, n_expert, world_size, top_k=2)
self.ne_per_node = nw_per_node * n_expert
self.ogn_ratio = .14
try:
self.ogn_ratio = float(os.environ['FMOE_TOPO_OUTGOING_FRACTION'])
except Exception:
pass
self.node_rank = node_rank
mask = [1] * world_size * n_expert
for i in range(n_expert * world_size):
if i // self.ne_per_node == self.node_rank:
mask[i] = 0
self.mask = torch.Tensor(mask).bool()
self.policy_fn = None
print('node rank {} mask {}'.format(node_rank, mask))
def forward(self, inp):
if self.mask.device != inp.device:
self.mask = self.mask.to(inp.device)
gate_score = self.gate(inp)
lim_mask = self.mask
top2_val, top2_idx = torch.topk(gate_score, k=2, dim=-1)
S = gate_score.shape[0]
top_k = 2
with torch.no_grad():
top1_idx = top2_idx.view((-1, top_k))[:, 0]
top1_val = top2_val.view((-1, top_k))[:, 0]
c_e = torch.scatter_add(
torch.zeros(self.tot_expert, device=top1_idx.device),
0,
top1_idx,
torch.ones_like(top1_idx, dtype=torch.float),
) / S
m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0)
loss = torch.mean(c_e * m_e) * (self.num_expert ** 2)
self.set_loss(loss)
with torch.no_grad():
if self.policy_fn is None:
stored_models = torch.zeros(self.num_expert * self.world_size,
dtype=torch.bool)
else:
# TODO: Fix this after expert shadowing is ported
_, lec, aec, gec, agec = count_by_gate(top2_idx,
self.num_expert, self.world_size, require_pos=False)
stored_models = self.policy_fn(aec, agec,
self.num_expert, self.world_size, inp.shape[-1], True)
lim_mask = lim_mask & ~stored_models.view(-1).to(lim_mask.device)
ogn_mask = lim_mask[top1_idx]
ogn_thres = int(inp.shape[0] * self.ogn_ratio)
if ogn_mask.sum().item() < ogn_thres:
topk_val, topk_idx = torch.topk(gate_score, k=self.top_k)
topk_val = F.softmax(topk_val, dim=-1)
return topk_idx, topk_val
with torch.no_grad():
top1_val[~ogn_mask] = float('-inf')
_, top_ogn = torch.topk(top1_val.view(-1), k=ogn_thres)
cand = gate_score.clone()
cand[:, lim_mask] = float('-inf')
_, topk_idx = torch.topk(cand, k=self.top_k)
topk_idx[top_ogn, 1] = top1_idx.view(-1)[top_ogn]
idx_x = torch.arange(inp.shape[0], device=inp.device).repeat_interleave(2)
topk_val = gate_score[idx_x, topk_idx.view(-1)].view(-1, self.top_k)
topk_val = F.softmax(topk_val, dim=-1)
return topk_idx, topk_val
def gen_faster_gate(rank):
def _gen(d_model, n_expert, world_size, top_k=2):
assert top_k == 2
return FasterGate(d_model, n_expert, world_size, rank // nw_per_node)
return _gen
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment