"vscode:/vscode.git/clone" did not exist on "4cc1a6143387f41e2466536abcd6a2620b63a35b"
Unverified Commit 29c01fb1 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[test] Force overflow in top2gating test (#664)

parent 5739930f
......@@ -50,18 +50,18 @@ def test_forward_cuda():
# Verify that top gate is allocated capacity as per Algorithm 1 in GShard paper.
def test_top1s():
def test_expert1_overflow():
num_tokens = 8
num_experts = 4
logits = torch.randn(num_tokens, num_experts)
l_aux, _, dispatch_mask = top2gating(logits)
logits[:, 0] = torch.max(logits, dim=1).values + 1 # Force overflow
top1s = torch.argmax(logits, dim=1)
assert top1s.eq(0).all(), top1s
_, __, dispatch_mask = top2gating(logits)
capacity = 2 * num_tokens // num_experts
ce = [0] * num_experts
locations = [0] * num_tokens
for i, s in enumerate(top1s):
e = s.item()
loc = ce[e]
ce[e] = loc + 1
if ce[e] < capacity:
assert dispatch_mask[i][e][loc]
for i in range(num_tokens):
if i < capacity:
assert dispatch_mask[i][0][i]
else:
assert not dispatch_mask[i][0].any()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment