Unverified Commit 317c0945 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[fix] moe: fix bug for multiple experts per-gpu case (#184)

parent 89176e34
......@@ -24,7 +24,6 @@ class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore
ctx.group = group
world_size = dist.get_world_size(group)
input = input.contiguous()
output = torch.empty_like(input)
dist.all_to_all_single(output, input, group=group)
......@@ -64,6 +63,8 @@ class MOELayer(Base):
for expert in self.experts:
for p in experts.parameters():
p.expert = True # type: ignore
self.world_size = dist.get_world_size(self.group)
self.num_local_experts = len(self.experts)
def all_to_all_dispatch(self, dispatch_mask: Tensor, input: Tensor) -> Tensor:
dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.float(), input)
......@@ -79,15 +80,19 @@ class MOELayer(Base):
assert input[0].shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts"
# Implement Algorithm 2 from GShard paper.
shape = input[0].shape
d_model = input[0].shape[2]
# Reshape into S tokens by dropping sequence dimension.
reshaped_input = input[0].reshape(-1, shape[2])
reshaped_input = input[0].reshape(-1, d_model)
self.l_aux, combine_weights, dispatching_mask = self.gate(reshaped_input)
dispatched_input = self.all_to_all_dispatch(dispatching_mask, reshaped_input)
chunks = dispatched_input.chunk(len(self.experts), dim=0)
# Re-shape after all-to-all: ecm -> gecm
dispatched_input = dispatched_input.reshape(self.world_size, self.num_local_experts, -1, d_model)
chunks = dispatched_input.chunk(self.num_local_experts, dim=1)
expert_outputs = []
for chunk, expert in zip(chunks, self.experts):
expert_outputs += [expert(chunk)]
expert_output = torch.cat(expert_outputs, dim=0)
expert_output = torch.cat(expert_outputs, dim=1)
# Re-shape back: gecm -> ecm
expert_output = expert_output.reshape(self.world_size * self.num_local_experts, -1, d_model)
combined_output = self.all_to_all_combine(combine_weights, expert_output)
return combined_output.reshape(shape)
return combined_output.reshape(input[0].shape)
......@@ -135,6 +135,31 @@ def test_forward_routing(device):
assert torch.allclose(input[:, i] * (expert + 1), output[:, i])
@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_forward_routing_multi(device):
model_dim = 8
num_local_experts = 4
num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts
input = torch.randn(4 * num_local_experts, 16, model_dim).to(device)
gate = RoundRobinGate(model_dim, num_experts)
experts = []
for i in range(num_local_experts):
expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use scaling matrix (each rank has a different scale)
scale = dist.get_rank() * num_local_experts + i + 1
expert.weight = torch.nn.Parameter(torch.eye(model_dim) * scale)
experts += [expert]
moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device)
output = moe(input)
assert output.shape == input.shape
# Verify that each token was sent to the correct expert by checking its scale.
t = input.shape[1]
for i in range(t):
expert = i % num_experts
assert torch.allclose(input[:, i] * (expert + 1), output[:, i])
@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_backward(device):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment