Unverified Commit 53ecd24b authored by zms1999's avatar zms1999 Committed by GitHub
Browse files

Merge pull request #101 from laekov/faster-topo-gate

Faster topo gate
parents 14cb8545 49b5b5d6
r"""
The example topology-aware gate for two-layer tree-like topology, proposed by
the PPoPP'22 paper, FasterMoE. Limited number of tokens are sent across the
upper-level slow connection, and other ones are re-directed to experts in the
local network.
The number of GPUs to form such a local network is defined by an environment
variable `FMOE_TOPO_GPUS_PER_NODE`, and it is by default `8`.
The fraction of tokens that are allowed to be sent across nodes is defined by
an environement variable `FMOE_TOPO_OUTGOING_FRACTION`, and it is by default
`0.14`. Users are supposed to set the proper value in their own environemnt,
guided by some performance model, to achieve maximum throughput.
"""
from .naive_gate import NaiveGate
import os
import sys
import torch
import torch.nn.functional as F
from .utils import limit_by_capacity
import fmoe_cuda
from fmoe.functions import count_by_gate
nw_per_node = 8
try:
nw_per_node = int(os.environ['FMOE_TOPO_GPUS_PER_NODE'])
except Exception:
pass
class FasterGate(NaiveGate):
def __init__(self, d_model, n_expert, world_size, node_rank):
super().__init__(d_model, n_expert, world_size, top_k=2)
self.ne_per_node = nw_per_node * n_expert
self.ogn_ratio = .14
try:
self.ogn_ratio = float(os.environ['FMOE_TOPO_OUTGOING_FRACTION'])
except Exception:
pass
self.node_rank = node_rank
mask = [1] * world_size * n_expert
for i in range(n_expert * world_size):
if i // self.ne_per_node == self.node_rank:
mask[i] = 0
self.mask = torch.Tensor(mask).bool()
self.policy_fn = None
print('node rank {} mask {}'.format(node_rank, mask))
def forward(self, inp):
if self.mask.device != inp.device:
self.mask = self.mask.to(inp.device)
gate_score = self.gate(inp)
lim_mask = self.mask
top2_val, top2_idx = torch.topk(gate_score, k=2, dim=-1)
S = gate_score.shape[0]
top_k = 2
with torch.no_grad():
top1_idx = top2_idx.view((-1, top_k))[:, 0]
top1_val = top2_val.view((-1, top_k))[:, 0]
c_e = torch.scatter_add(
torch.zeros(self.tot_expert, device=top1_idx.device),
0,
top1_idx,
torch.ones_like(top1_idx, dtype=torch.float),
) / S
m_e = torch.mean(F.softmax(gate_score, dim=1), dim=0)
loss = torch.mean(c_e * m_e) * (self.num_expert ** 2)
self.set_loss(loss)
with torch.no_grad():
if self.policy_fn is None:
stored_models = torch.zeros(self.num_expert * self.world_size,
dtype=torch.bool)
else:
# TODO: Fix this after expert shadowing is ported
_, lec, aec, gec, agec = count_by_gate(top2_idx,
self.num_expert, self.world_size, require_pos=False)
stored_models = self.policy_fn(aec, agec,
self.num_expert, self.world_size, inp.shape[-1], True)
lim_mask = lim_mask & ~stored_models.view(-1).to(lim_mask.device)
ogn_mask = lim_mask[top1_idx]
ogn_thres = int(inp.shape[0] * self.ogn_ratio)
if ogn_mask.sum().item() < ogn_thres:
topk_val, topk_idx = torch.topk(gate_score, k=self.top_k)
topk_val = F.softmax(topk_val, dim=-1)
return topk_idx, topk_val
with torch.no_grad():
top1_val[~ogn_mask] = float('-inf')
_, top_ogn = torch.topk(top1_val.view(-1), k=ogn_thres)
cand = gate_score.clone()
cand[:, lim_mask] = float('-inf')
_, topk_idx = torch.topk(cand, k=self.top_k)
topk_idx[top_ogn, 1] = top1_idx.view(-1)[top_ogn]
idx_x = torch.arange(inp.shape[0], device=inp.device).repeat_interleave(2)
topk_val = gate_score[idx_x, topk_idx.view(-1)].view(-1, self.top_k)
topk_val = F.softmax(topk_val, dim=-1)
return topk_idx, topk_val
def gen_faster_gate(rank):
def _gen(d_model, n_expert, world_size, top_k=2):
assert top_k == 2
return FasterGate(d_model, n_expert, world_size, rank // nw_per_node)
return _gen
import json
import random
import os
import sys
from typing import Dict
......@@ -13,30 +14,34 @@ from test_numerical import _test_fmoe_local_ddp
def _ensure_initialized():
if not dist.is_initialized():
if 'RANK' not in os.environ:
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12211")
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
def _run_distributed(func, world_size, args: Dict, script=__file__):
if torch.cuda.device_count() < world_size:
pytest.skip("No enough GPU")
def _run_distributed(func, world_size, args: Dict, script=__file__, env=dict()):
device_count = torch.cuda.device_count()
if device_count < world_size:
pytest.skip("No enough GPU, only {} found".format(device_count))
import subprocess
import os
ps = []
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "36666"
os.environ["OMPI_COMM_WORLD_SIZE"] = str(world_size)
env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = str(random.randint(50000, 60000))
env["OMPI_COMM_WORLD_SIZE"] = str(world_size)
for i in range(world_size):
os.environ["OMPI_COMM_WORLD_RANK"] = str(i)
env["OMPI_COMM_WORLD_RANK"] = str(i)
p = subprocess.Popen(
[sys.executable, script, func, json.dumps(args)], stdout=subprocess.PIPE
[sys.executable, script, func, json.dumps(args)],
stdout=subprocess.PIPE,
env=env
)
ps.append(p)
......
import pytest
import os
import sys
import json
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
from fmoe.gates.faster_gate import FasterGate
from test_ddp import _ensure_initialized, _run_distributed
@pytest.mark.parametrize("n_process", [8])
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1, 4])
@pytest.mark.parametrize("gpu_per_node", [2, 4, 8])
@pytest.mark.parametrize("frac", [.2])
def test_faster_gate(n_process, d_model, batch_size, n_expert, gpu_per_node, frac):
_run_distributed('_test_faster_gate',
n_process,
{
'd_model': d_model,
'batch_size': batch_size,
'n_expert': n_expert,
'gpu_per_node': gpu_per_node,
'frac': frac
},
script=__file__,
env=dict(
FMOE_TOPO_GPUS_PER_NODE=str(gpu_per_node),
FMOE_TOPO_OUTGOING_FRACTION=str(frac)
)
)
def _test_faster_gate(d_model, batch_size, n_expert, gpu_per_node, frac):
_ensure_initialized()
rank = dist.get_rank()
node_rank = rank // gpu_per_node
gate = FasterGate(d_model, n_expert, dist.get_world_size(), node_rank).cuda()
x = torch.rand(batch_size, d_model).cuda()
topk_idx, topk_val = gate(x)
cnto = 0
idxs = topk_idx[:, 0].cpu().view(-1).numpy()
for v in idxs:
assert(v != -1)
if v // n_expert // gpu_per_node != rank // gpu_per_node:
cnto += 1
assert(cnto <= math.ceil(batch_size * frac))
if __name__ == '__main__':
if len(sys.argv) >= 3:
args = json.loads(sys.argv[2])
locals()[sys.argv[1]](**args)
else:
test_faster_gate(8, 1024, 16, 1, 2, .2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment