Unverified Commit c6d9be79 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[test] moe: add a more thorough MOELayer routing test (#151)

parent 49a3d9bc
...@@ -73,6 +73,44 @@ def test_forward(device): ...@@ -73,6 +73,44 @@ def test_forward(device):
assert torch.allclose(input, output) assert torch.allclose(input, output)
# Test Gate which round-robin routes tokens to experts
class RoundRobinGate(torch.nn.Module):
def __init__(self, model_dim, num_experts):
super().__init__()
self.model_dim = model_dim
self.num_experts = num_experts
def forward(self, input):
g, s, _ = input.shape
assert s % self.num_experts == 0
capacity = 2 * s // self.num_experts
output = torch.zeros(g, s, self.num_experts, capacity, dtype=input.dtype, device=input.device)
for i in range(s):
output[:, i, i % self.num_experts, i // self.num_experts] = 1.0
return 0.0, output, output.bool()
@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_forward_routing(device):
model_dim = 8
num_experts = dist.get_world_size()
input = torch.randn(3, 4, 16, model_dim).to(device)
gate = RoundRobinGate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use scaling matrix (each rank has a different scale)
scale = dist.get_rank() + 1
expert.weight = torch.nn.Parameter(torch.eye(model_dim) * scale)
moe = MOELayer(gate, expert).to(device)
output = moe(input)
assert output.shape == input.shape
# Verify that each token was sent to the correct expert by checking its scale.
t = input.shape[2]
for i in range(t):
expert = i % num_experts
assert torch.allclose(input[:, :, i] * (expert + 1), output[:, :, i])
@pytest.mark.mpi @pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"]) @pytest.mark.parametrize("device", ["cpu"])
def test_backward(device): def test_backward(device):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment