"vscode:/vscode.git/clone" did not exist on "be6f6c2927dc03b6103af8d48a961562dd5d68d5"
Commit 5a0ba835 authored by Rick Ho's avatar Rick Ho
Browse files

add test but cannot pass

parent 56cb8c15
......@@ -7,6 +7,8 @@
std::vector<torch::Tensor> _limit_by_capacity(
torch::Tensor expert_count, torch::Tensor capacity,
long n_expert, long n_worker) {
CHECK_INPUT(expert_count);
CHECK_INPUT(capacity);
auto expert_count_ack = torch::empty_like(expert_count);
auto smgr = getCudaStreamManager(expert_count.device().index());
fmoe_cuda_limit_by_capacity_impl(
......
......@@ -10,7 +10,7 @@ import fmoe_cuda
from .utils import get_torch_default_comm
def count_by_gate(gate, num_expert, world_size, comm):
def count_by_gate(gate, num_expert, world_size):
# TODO: support -1 in gate, which means ignore this input
with torch.no_grad():
_, pos = torch.sort(gate)
......
r"""
Balanced gate with GShard's policy (Google, 2020)
"""
import math
import torch
import torch.nn.functional as F
from .naive_gate import NaiveGate
......@@ -14,13 +15,13 @@ class GShardGate(NaiveGate):
self.capacity = capacity
def forward(self, x):
topk_idx, topk_val, gate_score = super().forward(x)
topk_idx, gate_score = super().forward(x)
S = gate_score.shape[0]
top_k = topk_idx.shape[0] // gate_score.shape[0]
top1_idx = topk_idx.view((-1, top_k))[:, 0]
c_e = torch.scatter_add(
torch.zeros(self.num_expert, device=gate_top_1_idx.device),
torch.zeros(self.num_expert, device=top1_idx.device),
0,
top1_idx,
torch.ones_like(top1_idx, dtype=torch.float),
......@@ -33,14 +34,19 @@ class GShardGate(NaiveGate):
capacity = torch.ones(self.num_expert, dtype=torch.int32)
capacity *= math.ceil(cap_rate * x.shape[0])
pos, lec, gec = count_by_gate(gate, self.num_expert, self.world_size)
print(topk_idx)
pos, lec, gec = count_by_gate(gate_score, self.num_expert,
self.world_size)
print(topk_idx)
new_gec, = fmoe_native.limit_by_capacity(gec, capacity,
self.num_expert, self.world_size)
print(topk_idx)
if self.world_size > 1:
new_lec = fmoe_native.expert_exchange(new_gec,
self.num_expert, self.world_size)
else:
new_lec = new_gec
print(topk_idx)
fmoe_native.prune_gate_by_capacity(topk_idx,
new_lec.to(torch.int32), self.num_expert, self.world_size)
......
......@@ -35,9 +35,7 @@ class NaiveGate(BaseGate):
gate_top_k_val = gate_top_k_val.view(-1, self.top_k)
# (BxL) x 1 x top_k
gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1)
gate_score = F.softmax(gate_top_k_val, dim=-1)
gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)
# TODO: capacity
return gate_top_k_idx, gate_score
return gate_top_k_idx, gate
import os
import torch
import torch.distributed as dist
from fmoe.gates import GShardGate
def test_gshard_gate(d_model, batch_size, n_expert):
gate = GShardGate(d_model, n_expert, dist.get_world_size()).cuda()
x = torch.rand(batch_size, d_model).cuda()
topk_idx, topk_val = gate(x)
print('rank {} idx {}'.format(dist.get_rank(), topk_idx))
print('rank {} val {}'.format(dist.get_rank(), topk_val))
if __name__ == '__main__':
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
torch.distributed.init_process_group(backend="nccl")
test_gshard_gate(4096, 1024, 4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment