Unverified Commit 527e66af authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #115 from laekov/gate-test-fix

Fix gshard gate test
parents d3f71f21 e01c8b5f
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from fmoe.gates import GShardGate, SwitchGate from fmoe.gates import GShardGate, SwitchGate
from fmoe.functions import ensure_comm
from test_ddp import _ensure_initialized, _run_distributed from test_ddp import _ensure_initialized, _run_distributed
...@@ -33,9 +34,11 @@ def test_gshard_gate(d_model, batch_size, n_expert, cap): ...@@ -33,9 +34,11 @@ def test_gshard_gate(d_model, batch_size, n_expert, cap):
def _test_gshard_gate(d_model, batch_size, n_expert, cap): def _test_gshard_gate(d_model, batch_size, n_expert, cap):
_ensure_initialized() _ensure_initialized()
rank = torch.distributed.get_rank()
gate = GShardGate(d_model, n_expert, dist.get_world_size(), gate = GShardGate(d_model, n_expert, dist.get_world_size(),
capacity=(cap, cap)).cuda() capacity=(cap, cap)).cuda()
x = torch.rand(batch_size, d_model).cuda() x = torch.rand(batch_size, d_model).cuda()
ensure_comm(x, None)
topk_idx, topk_val = gate(x) topk_idx, topk_val = gate(x)
counts = [0 for _ in range(n_expert * dist.get_world_size())] counts = [0 for _ in range(n_expert * dist.get_world_size())]
for v in topk_idx.cpu().view(-1).numpy(): for v in topk_idx.cpu().view(-1).numpy():
...@@ -46,11 +49,17 @@ def _test_gshard_gate(d_model, batch_size, n_expert, cap): ...@@ -46,11 +49,17 @@ def _test_gshard_gate(d_model, batch_size, n_expert, cap):
assert(i <= real_cap) assert(i <= real_cap)
gate_score = gate.gate(x) gate_score = gate.gate(x)
gate_top_k_val, gate_top_k_idx = torch.topk(
gate_score, k=gate.top_k, dim=-1, largest=True, sorted=False
)
gate_top_k_val = gate_top_k_val.view(-1, gate.top_k)
gate_score = F.softmax(gate_top_k_val, dim=-1)
for i in range(batch_size): for i in range(batch_size):
for j in range(gate.top_k): for j in range(gate.top_k):
v = topk_idx[i, j] v = topk_idx[i, j]
if v != -1: if v != -1:
assert topk_val[i, j] == gate_score[i, v] assert topk_val[i, j] == gate_score[i, j]
@pytest.mark.parametrize("d_model", [1024]) @pytest.mark.parametrize("d_model", [1024])
...@@ -109,7 +118,7 @@ if __name__ == '__main__': ...@@ -109,7 +118,7 @@ if __name__ == '__main__':
args = json.loads(sys.argv[2]) args = json.loads(sys.argv[2])
locals()[sys.argv[1]](**args) locals()[sys.argv[1]](**args)
else: else:
_ensure_initialized() # _ensure_initialized()
# test_gshard_gate(4096, 1024, 4, .2) # test_gshard_gate(4096, 1024, 4, .2)
test_switch_gate(8, 16, 4, .1) _test_gshard_gate(8, 16, 4, .1)
# test_switch_gate(4096, 1024, 4, .2) # test_switch_gate(4096, 1024, 4, .2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment