Commit ddfaaf49 authored by Rich Ho's avatar Rich Ho
Browse files

gshard gate test

parent 5a0ba835
......@@ -15,7 +15,8 @@ class GShardGate(NaiveGate):
self.capacity = capacity
def forward(self, x):
topk_idx, gate_score = super().forward(x)
naive_outs = super().forward(x, return_all_scores=True)
topk_idx, topk_val, gate_score = naive_outs
S = gate_score.shape[0]
top_k = topk_idx.shape[0] // gate_score.shape[0]
......@@ -31,22 +32,19 @@ class GShardGate(NaiveGate):
self.set_loss(loss)
cap_rate = self.capacity[0 if self.training else 1]
capacity = torch.ones(self.num_expert, dtype=torch.int32)
capacity = torch.ones(self.num_expert, dtype=torch.int32,
device=x.device)
capacity *= math.ceil(cap_rate * x.shape[0])
print(topk_idx)
pos, lec, gec = count_by_gate(gate_score, self.num_expert,
pos, lec, gec = count_by_gate(topk_idx.reshape(-1), self.num_expert,
self.world_size)
print(topk_idx)
new_gec, = fmoe_native.limit_by_capacity(gec, capacity,
self.num_expert, self.world_size)
print(topk_idx)
if self.world_size > 1:
new_lec = fmoe_native.expert_exchange(new_gec,
self.num_expert, self.world_size)
else:
new_lec = new_gec
print(topk_idx)
fmoe_native.prune_gate_by_capacity(topk_idx,
new_lec.to(torch.int32), self.num_expert, self.world_size)
......
......@@ -23,7 +23,7 @@ class NaiveGate(BaseGate):
self.gate = nn.Linear(d_model, self.tot_expert)
self.top_k = top_k
def forward(self, inp):
def forward(self, inp, return_all_scores=False):
r"""
The naive implementation simply calculates the top-k of a linear layer's
output.
......@@ -38,4 +38,6 @@ class NaiveGate(BaseGate):
gate_score = F.softmax(gate_top_k_val, dim=-1)
gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)
return gate_top_k_idx, gate
if return_all_scores:
return gate_top_k_idx, gate_top_k_val, gate
return gate_top_k_idx, gate_top_k_val
......@@ -225,6 +225,7 @@ class FMoE(nn.Module):
# to: (BxL) x top_k x d_model
x = x.view(-1, self.top_k, self.d_model)
# to: (BxL) x d_model
gate_score = gate_score.unsqueeze(1)
x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
if self.mp_size > 1:
......
......@@ -40,6 +40,7 @@ class BruteForceMoELinear(nn.Module):
x = x @ self.weight_h4toh[i].t()
x = x + self.bias_h4toh[i]
o[idx] = x
gate_score = gate_score.unsqueeze(1)
x = torch.bmm(gate_score, o.view(-1, self.top_k, self.d_model)).reshape(
-1, self.d_model
)
......@@ -60,6 +61,7 @@ class BruteForceMoE(nn.Module):
x = inp.new_zeros((batch_size, self.d_model))
for i in range(batch_size):
x[i] = self.experts[gate_long[i]](inp[i])
gate_score = gate_score.unsqueeze(1)
x = torch.bmm(gate_score, x.view(-1, self.top_k, self.d_model)).reshape(
-1, self.d_model
)
......
import pytest
import os
import math
import torch
import torch.distributed as dist
from fmoe.gates import GShardGate
def test_gshard_gate(d_model, batch_size, n_expert):
gate = GShardGate(d_model, n_expert, dist.get_world_size()).cuda()
def _ensure_initialized():
if not dist.is_initialized():
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12211")
dist.init_process_group(backend="nccl")
@pytest.mark.parametrize("d_model", [8, 1024])
@pytest.mark.parametrize("batch_size", [16, 4096])
@pytest.mark.parametrize("n_expert", [1, 4, 16])
@pytest.mark.parametrize("cap", [.1, .5, 1.1])
def test_gshard_gate(d_model, batch_size, n_expert, cap):
_ensure_initialized()
if dist.get_world_size() * n_expert < 2:
pytest.skip("No enough experts")
gate = GShardGate(d_model, n_expert, dist.get_world_size(),
capacity=(cap, cap)).cuda()
x = torch.rand(batch_size, d_model).cuda()
topk_idx, topk_val = gate(x)
print('rank {} idx {}'.format(dist.get_rank(), topk_idx))
print('rank {} val {}'.format(dist.get_rank(), topk_val))
counts = [0 for _ in range(n_expert)]
for v in topk_idx.cpu().view(-1).numpy():
if v != -1:
counts[v] += 1
real_cap = math.ceil(cap * batch_size)
for i in counts:
assert(i <= real_cap)
if __name__ == '__main__':
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
torch.distributed.init_process_group(backend="nccl")
test_gshard_gate(4096, 1024, 4)
_ensure_initialized()
test_gshard_gate(4096, 1024, 4, .2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment