Unverified Commit 339cf060 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[feat] moe: add support for multiple experts per device (#161)

parent 95ddbc19
......@@ -3,12 +3,12 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import TYPE_CHECKING, Any, Optional, Tuple
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
import torch
from torch import Tensor
import torch.distributed as dist
from torch.nn import Module
from torch.nn import Module, ModuleList
if TYPE_CHECKING:
Base = Module[Tensor]
......@@ -55,13 +55,17 @@ class MOELayer(Base):
expert network
"""
def __init__(self, gate: Module, expert: Module, group: Optional[Any] = None) -> None:
def __init__(self, gate: Module, experts: Union[Module, ModuleList], group: Optional[Any] = None) -> None:
super().__init__()
self.gate = gate
self.expert = expert
if type(experts) == ModuleList:
self.experts = cast(ModuleList, experts)
else:
self.experts = ModuleList([experts])
self.group = group if group is not None else dist.group.WORLD
for p in expert.parameters():
p.expert = True # type: ignore
for expert in self.experts:
for p in experts.parameters():
p.expert = True # type: ignore
def all_to_all_dispatch(self, dispatch_mask: Tensor, input: Tensor) -> Tensor:
dispatched_input = torch.einsum("gsec,gsm->egcm", dispatch_mask.float(), input)
......@@ -74,6 +78,7 @@ class MOELayer(Base):
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
assert len(input) == 1, "only single input Tensor supported"
assert len(input[0].shape) == 4, "input Tensor must have dimensions: (g)roup, (s)equence, (t)oken, (m)odel"
assert input[0].shape[0] == len(self.experts), "group dimension size must be equal to number of local experts"
# Implement Algorithm 2 from GShard paper.
shape = input[0].shape
......@@ -81,6 +86,11 @@ class MOELayer(Base):
reshaped_input = input[0].reshape(shape[0], -1, shape[3])
self.l_aux, combine_weights, dispatching_mask = self.gate(reshaped_input)
dispatched_input = self.all_to_all_dispatch(dispatching_mask, reshaped_input)
expert_output = self.expert(dispatched_input)
assert dispatched_input.shape[1] == len(self.experts)
chunks = dispatched_input.chunk(len(self.experts), dim=1)
expert_outputs = []
for chunk, expert in zip(chunks, self.experts):
expert_outputs += [expert(chunk)]
expert_output = torch.cat(expert_outputs, dim=1)
combined_output = self.all_to_all_combine(combine_weights, expert_output)
return combined_output.reshape(shape)
......@@ -61,7 +61,7 @@ def test_expert_params(device):
def test_forward(device):
model_dim = 8
num_experts = dist.get_world_size(dist.group.WORLD)
input = torch.randn(3, 4, 16, model_dim).to(device)
input = torch.randn(1, 4, 16, model_dim).to(device)
gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use identity matrix
......@@ -73,6 +73,30 @@ def test_forward(device):
assert torch.allclose(input, output)
@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_forward_multi(device):
torch.set_printoptions(threshold=5000)
num_local_experts = 4
model_dim = 4
num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts
input = torch.randn(num_local_experts, 4, 16, model_dim).to(device)
gate = Top2Gate(model_dim, num_experts)
experts = []
for i in range(num_local_experts):
expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use identity matrix
expert.weight = torch.nn.Parameter(torch.eye(model_dim))
experts += [expert]
moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device)
output = moe(input)
assert output.shape == input.shape
# 90% of the input should have gone to an expert
assert len(output.nonzero(as_tuple=False)) / output.numel() > 0.90
# Except for zeros, re-assembled output should match input due to identity expert.
assert torch.allclose(input, torch.where(output > 0, output, input))
# Test Gate which round-robin routes tokens to experts
class RoundRobinGate(torch.nn.Module):
def __init__(self, model_dim, num_experts):
......@@ -95,7 +119,7 @@ class RoundRobinGate(torch.nn.Module):
def test_forward_routing(device):
model_dim = 8
num_experts = dist.get_world_size()
input = torch.randn(3, 4, 16, model_dim).to(device)
input = torch.randn(1, 4, 16, model_dim).to(device)
gate = RoundRobinGate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use scaling matrix (each rank has a different scale)
......@@ -117,7 +141,7 @@ def test_backward(device):
loss = torch.nn.MSELoss()
model_dim = 8
num_experts = dist.get_world_size(dist.group.WORLD)
input = torch.randn(3, 4, 16, model_dim).to(device)
input = torch.randn(1, 4, 16, model_dim).to(device)
gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use identity matrix
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment