Unverified Commit 7f6463f0 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #38 from GODVIX/gates

Update test_gates.py
parents ddaac5eb 26cc37cb
...@@ -7,6 +7,7 @@ import math ...@@ -7,6 +7,7 @@ import math
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F
from fmoe.gates import GShardGate, SwitchGate from fmoe.gates import GShardGate, SwitchGate
from test_ddp import _run_distributed from test_ddp import _run_distributed
...@@ -54,6 +55,13 @@ def _test_gshard_gate(d_model, batch_size, n_expert, cap): ...@@ -54,6 +55,13 @@ def _test_gshard_gate(d_model, batch_size, n_expert, cap):
for i in counts: for i in counts:
assert(i <= real_cap) assert(i <= real_cap)
gate_score = gate.gate(x)
for i in range(batch_size):
for j in range(gate.top_k):
v = topk_idx[i, j]
if v != -1:
assert topk_val[i, j] == gate_score[i, v]
@pytest.mark.parametrize("d_model", [1024]) @pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [4096]) @pytest.mark.parametrize("batch_size", [4096])
...@@ -77,6 +85,7 @@ def _test_switch_gate(d_model, batch_size, n_expert, cap): ...@@ -77,6 +85,7 @@ def _test_switch_gate(d_model, batch_size, n_expert, cap):
gate = SwitchGate(d_model, n_expert, dist.get_world_size(), gate = SwitchGate(d_model, n_expert, dist.get_world_size(),
capacity=(cap, cap)).cuda() capacity=(cap, cap)).cuda()
x = torch.rand(batch_size, d_model).cuda() x = torch.rand(batch_size, d_model).cuda()
rng = torch.cuda.get_rng_state() # save rng state
topk_idx, topk_val = gate(x) topk_idx, topk_val = gate(x)
counts = [0 for _ in range(n_expert * dist.get_world_size())] counts = [0 for _ in range(n_expert * dist.get_world_size())]
for v in topk_idx.cpu().view(-1).numpy(): for v in topk_idx.cpu().view(-1).numpy():
...@@ -86,6 +95,24 @@ def _test_switch_gate(d_model, batch_size, n_expert, cap): ...@@ -86,6 +95,24 @@ def _test_switch_gate(d_model, batch_size, n_expert, cap):
for i in counts: for i in counts:
assert(i <= real_cap) assert(i <= real_cap)
score = gate.gate(x)
if gate.training:
# reset rng state to make sure noise is the same as in gate.forward()
torch.cuda.set_rng_state(rng)
# random uniform number from [1-eps, 1+eps]
noise = torch.rand_like(score)
noise = noise * 2 * gate.switch_eps + 1.0 - gate.switch_eps
score += noise
# fp32 softmax for numerical stability
score = F.softmax(score.float(), dim=-1)
for i in range(batch_size):
v = topk_idx[i]
if v != -1:
assert topk_val[i] == score[i, topk_idx[i]]
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) >= 3: if len(sys.argv) >= 3:
...@@ -94,5 +121,5 @@ if __name__ == '__main__': ...@@ -94,5 +121,5 @@ if __name__ == '__main__':
else: else:
_ensure_initialized() _ensure_initialized()
# test_gshard_gate(4096, 1024, 4, .2) # test_gshard_gate(4096, 1024, 4, .2)
test_gshard_gate(8, 16, 1, .1) test_switch_gate(8, 16, 4, .1)
# test_switch_gate(4096, 1024, 4, .2) # test_switch_gate(4096, 1024, 4, .2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment