Unverified Commit 527e66af authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #115 from laekov/gate-test-fix

Fix gshard gate test
parents d3f71f21 e01c8b5f
......@@ -9,6 +9,7 @@ import torch
import torch.distributed as dist
import torch.nn.functional as F
from fmoe.gates import GShardGate, SwitchGate
from fmoe.functions import ensure_comm
from test_ddp import _ensure_initialized, _run_distributed
......@@ -33,9 +34,11 @@ def test_gshard_gate(d_model, batch_size, n_expert, cap):
def _test_gshard_gate(d_model, batch_size, n_expert, cap):
_ensure_initialized()
rank = torch.distributed.get_rank()
gate = GShardGate(d_model, n_expert, dist.get_world_size(),
capacity=(cap, cap)).cuda()
x = torch.rand(batch_size, d_model).cuda()
ensure_comm(x, None)
topk_idx, topk_val = gate(x)
counts = [0 for _ in range(n_expert * dist.get_world_size())]
for v in topk_idx.cpu().view(-1).numpy():
......@@ -46,11 +49,17 @@ def _test_gshard_gate(d_model, batch_size, n_expert, cap):
assert(i <= real_cap)
gate_score = gate.gate(x)
gate_top_k_val, gate_top_k_idx = torch.topk(
gate_score, k=gate.top_k, dim=-1, largest=True, sorted=False
)
gate_top_k_val = gate_top_k_val.view(-1, gate.top_k)
gate_score = F.softmax(gate_top_k_val, dim=-1)
for i in range(batch_size):
for j in range(gate.top_k):
v = topk_idx[i, j]
if v != -1:
assert topk_val[i, j] == gate_score[i, v]
assert topk_val[i, j] == gate_score[i, j]
@pytest.mark.parametrize("d_model", [1024])
......@@ -109,7 +118,7 @@ if __name__ == '__main__':
args = json.loads(sys.argv[2])
locals()[sys.argv[1]](**args)
else:
_ensure_initialized()
# _ensure_initialized()
# test_gshard_gate(4096, 1024, 4, .2)
test_switch_gate(8, 16, 4, .1)
_test_gshard_gate(8, 16, 4, .1)
# test_switch_gate(4096, 1024, 4, .2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment