"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "ded3db164bb3c090871647f30ff9988c9c17fd83"
Unverified Commit 29c01fb1 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[test] Force overflow in top2gating test (#664)

parent 5739930f
...@@ -50,18 +50,18 @@ def test_forward_cuda(): ...@@ -50,18 +50,18 @@ def test_forward_cuda():
# Verify that top gate is allocated capacity as per Algorithm 1 in GShard paper. # Verify that top gate is allocated capacity as per Algorithm 1 in GShard paper.
def test_top1s(): def test_expert1_overflow():
num_tokens = 8 num_tokens = 8
num_experts = 4 num_experts = 4
logits = torch.randn(num_tokens, num_experts) logits = torch.randn(num_tokens, num_experts)
l_aux, _, dispatch_mask = top2gating(logits) logits[:, 0] = torch.max(logits, dim=1).values + 1 # Force overflow
top1s = torch.argmax(logits, dim=1) top1s = torch.argmax(logits, dim=1)
assert top1s.eq(0).all(), top1s
_, __, dispatch_mask = top2gating(logits)
capacity = 2 * num_tokens // num_experts capacity = 2 * num_tokens // num_experts
ce = [0] * num_experts
locations = [0] * num_tokens for i in range(num_tokens):
for i, s in enumerate(top1s): if i < capacity:
e = s.item() assert dispatch_mask[i][0][i]
loc = ce[e] else:
ce[e] = loc + 1 assert not dispatch_mask[i][0].any()
if ce[e] < capacity:
assert dispatch_mask[i][e][loc]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment