"src/graph/vscode:/vscode.git/clone" did not exist on "44089c8b4d4db4ca71e816e0de50dca972dbabdb"
Commit ddfaaf49 authored by Rich Ho's avatar Rich Ho
Browse files

gshard gate test

parent 5a0ba835
...@@ -15,7 +15,8 @@ class GShardGate(NaiveGate): ...@@ -15,7 +15,8 @@ class GShardGate(NaiveGate):
self.capacity = capacity self.capacity = capacity
def forward(self, x): def forward(self, x):
topk_idx, gate_score = super().forward(x) naive_outs = super().forward(x, return_all_scores=True)
topk_idx, topk_val, gate_score = naive_outs
S = gate_score.shape[0] S = gate_score.shape[0]
top_k = topk_idx.shape[0] // gate_score.shape[0] top_k = topk_idx.shape[0] // gate_score.shape[0]
...@@ -31,22 +32,19 @@ class GShardGate(NaiveGate): ...@@ -31,22 +32,19 @@ class GShardGate(NaiveGate):
self.set_loss(loss) self.set_loss(loss)
cap_rate = self.capacity[0 if self.training else 1] cap_rate = self.capacity[0 if self.training else 1]
capacity = torch.ones(self.num_expert, dtype=torch.int32) capacity = torch.ones(self.num_expert, dtype=torch.int32,
device=x.device)
capacity *= math.ceil(cap_rate * x.shape[0]) capacity *= math.ceil(cap_rate * x.shape[0])
print(topk_idx) pos, lec, gec = count_by_gate(topk_idx.reshape(-1), self.num_expert,
pos, lec, gec = count_by_gate(gate_score, self.num_expert,
self.world_size) self.world_size)
print(topk_idx)
new_gec, = fmoe_native.limit_by_capacity(gec, capacity, new_gec, = fmoe_native.limit_by_capacity(gec, capacity,
self.num_expert, self.world_size) self.num_expert, self.world_size)
print(topk_idx)
if self.world_size > 1: if self.world_size > 1:
new_lec = fmoe_native.expert_exchange(new_gec, new_lec = fmoe_native.expert_exchange(new_gec,
self.num_expert, self.world_size) self.num_expert, self.world_size)
else: else:
new_lec = new_gec new_lec = new_gec
print(topk_idx)
fmoe_native.prune_gate_by_capacity(topk_idx, fmoe_native.prune_gate_by_capacity(topk_idx,
new_lec.to(torch.int32), self.num_expert, self.world_size) new_lec.to(torch.int32), self.num_expert, self.world_size)
......
...@@ -23,7 +23,7 @@ class NaiveGate(BaseGate): ...@@ -23,7 +23,7 @@ class NaiveGate(BaseGate):
self.gate = nn.Linear(d_model, self.tot_expert) self.gate = nn.Linear(d_model, self.tot_expert)
self.top_k = top_k self.top_k = top_k
def forward(self, inp): def forward(self, inp, return_all_scores=False):
r""" r"""
The naive implementation simply calculates the top-k of a linear layer's The naive implementation simply calculates the top-k of a linear layer's
output. output.
...@@ -38,4 +38,6 @@ class NaiveGate(BaseGate): ...@@ -38,4 +38,6 @@ class NaiveGate(BaseGate):
gate_score = F.softmax(gate_top_k_val, dim=-1) gate_score = F.softmax(gate_top_k_val, dim=-1)
gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k) gate_top_k_idx = gate_top_k_idx.view(-1) # (BxLxtop_k)
return gate_top_k_idx, gate if return_all_scores:
return gate_top_k_idx, gate_top_k_val, gate
return gate_top_k_idx, gate_top_k_val
...@@ -225,6 +225,7 @@ class FMoE(nn.Module): ...@@ -225,6 +225,7 @@ class FMoE(nn.Module):
# to: (BxL) x top_k x d_model # to: (BxL) x top_k x d_model
x = x.view(-1, self.top_k, self.d_model) x = x.view(-1, self.top_k, self.d_model)
# to: (BxL) x d_model # to: (BxL) x d_model
gate_score = gate_score.unsqueeze(1)
x = torch.bmm(gate_score, x).reshape(-1, self.d_model) x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
if self.mp_size > 1: if self.mp_size > 1:
......
...@@ -40,6 +40,7 @@ class BruteForceMoELinear(nn.Module): ...@@ -40,6 +40,7 @@ class BruteForceMoELinear(nn.Module):
x = x @ self.weight_h4toh[i].t() x = x @ self.weight_h4toh[i].t()
x = x + self.bias_h4toh[i] x = x + self.bias_h4toh[i]
o[idx] = x o[idx] = x
gate_score = gate_score.unsqueeze(1)
x = torch.bmm(gate_score, o.view(-1, self.top_k, self.d_model)).reshape( x = torch.bmm(gate_score, o.view(-1, self.top_k, self.d_model)).reshape(
-1, self.d_model -1, self.d_model
) )
...@@ -60,6 +61,7 @@ class BruteForceMoE(nn.Module): ...@@ -60,6 +61,7 @@ class BruteForceMoE(nn.Module):
x = inp.new_zeros((batch_size, self.d_model)) x = inp.new_zeros((batch_size, self.d_model))
for i in range(batch_size): for i in range(batch_size):
x[i] = self.experts[gate_long[i]](inp[i]) x[i] = self.experts[gate_long[i]](inp[i])
gate_score = gate_score.unsqueeze(1)
x = torch.bmm(gate_score, x.view(-1, self.top_k, self.d_model)).reshape( x = torch.bmm(gate_score, x.view(-1, self.top_k, self.d_model)).reshape(
-1, self.d_model -1, self.d_model
) )
......
import pytest
import os import os
import math
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from fmoe.gates import GShardGate from fmoe.gates import GShardGate
def test_gshard_gate(d_model, batch_size, n_expert): def _ensure_initialized():
gate = GShardGate(d_model, n_expert, dist.get_world_size()).cuda() if not dist.is_initialized():
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12211")
dist.init_process_group(backend="nccl")
@pytest.mark.parametrize("d_model", [8, 1024])
@pytest.mark.parametrize("batch_size", [16, 4096])
@pytest.mark.parametrize("n_expert", [1, 4, 16])
@pytest.mark.parametrize("cap", [.1, .5, 1.1])
def test_gshard_gate(d_model, batch_size, n_expert, cap):
_ensure_initialized()
if dist.get_world_size() * n_expert < 2:
pytest.skip("No enough experts")
gate = GShardGate(d_model, n_expert, dist.get_world_size(),
capacity=(cap, cap)).cuda()
x = torch.rand(batch_size, d_model).cuda() x = torch.rand(batch_size, d_model).cuda()
topk_idx, topk_val = gate(x) topk_idx, topk_val = gate(x)
print('rank {} idx {}'.format(dist.get_rank(), topk_idx)) counts = [0 for _ in range(n_expert)]
print('rank {} val {}'.format(dist.get_rank(), topk_val)) for v in topk_idx.cpu().view(-1).numpy():
if v != -1:
counts[v] += 1
real_cap = math.ceil(cap * batch_size)
for i in counts:
assert(i <= real_cap)
if __name__ == '__main__': if __name__ == '__main__':
os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0") _ensure_initialized()
os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1") test_gshard_gate(4096, 1024, 4, .2)
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
torch.distributed.init_process_group(backend="nccl")
test_gshard_gate(4096, 1024, 4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment