"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "766721b1dfd7a92130146a549c4fcca15cc069b2"
Unverified Commit 339cf060 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[feat] moe: add support for multiple experts per device (#161)

parent 95ddbc19
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import TYPE_CHECKING, Any, Optional, Tuple from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
import torch import torch
from torch import Tensor from torch import Tensor
import torch.distributed as dist import torch.distributed as dist
from torch.nn import Module from torch.nn import Module, ModuleList
if TYPE_CHECKING: if TYPE_CHECKING:
Base = Module[Tensor] Base = Module[Tensor]
...@@ -55,13 +55,17 @@ class MOELayer(Base): ...@@ -55,13 +55,17 @@ class MOELayer(Base):
expert network expert network
""" """
def __init__(self, gate: Module, expert: Module, group: Optional[Any] = None) -> None: def __init__(self, gate: Module, experts: Union[Module, ModuleList], group: Optional[Any] = None) -> None:
super().__init__() super().__init__()
self.gate = gate self.gate = gate
self.expert = expert if type(experts) == ModuleList:
self.experts = cast(ModuleList, experts)
else:
self.experts = ModuleList([experts])
self.group = group if group is not None else dist.group.WORLD self.group = group if group is not None else dist.group.WORLD
for p in expert.parameters(): for expert in self.experts:
p.expert = True # type: ignore for p in experts.parameters():
p.expert = True # type: ignore
def all_to_all_dispatch(self, dispatch_mask: Tensor, input: Tensor) -> Tensor: def all_to_all_dispatch(self, dispatch_mask: Tensor, input: Tensor) -> Tensor:
dispatched_input = torch.einsum("gsec,gsm->egcm", dispatch_mask.float(), input) dispatched_input = torch.einsum("gsec,gsm->egcm", dispatch_mask.float(), input)
...@@ -74,6 +78,7 @@ class MOELayer(Base): ...@@ -74,6 +78,7 @@ class MOELayer(Base):
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
assert len(input) == 1, "only single input Tensor supported" assert len(input) == 1, "only single input Tensor supported"
assert len(input[0].shape) == 4, "input Tensor must have dimensions: (g)roup, (s)equence, (t)oken, (m)odel" assert len(input[0].shape) == 4, "input Tensor must have dimensions: (g)roup, (s)equence, (t)oken, (m)odel"
assert input[0].shape[0] == len(self.experts), "group dimension size must be equal to number of local experts"
# Implement Algorithm 2 from GShard paper. # Implement Algorithm 2 from GShard paper.
shape = input[0].shape shape = input[0].shape
...@@ -81,6 +86,11 @@ class MOELayer(Base): ...@@ -81,6 +86,11 @@ class MOELayer(Base):
reshaped_input = input[0].reshape(shape[0], -1, shape[3]) reshaped_input = input[0].reshape(shape[0], -1, shape[3])
self.l_aux, combine_weights, dispatching_mask = self.gate(reshaped_input) self.l_aux, combine_weights, dispatching_mask = self.gate(reshaped_input)
dispatched_input = self.all_to_all_dispatch(dispatching_mask, reshaped_input) dispatched_input = self.all_to_all_dispatch(dispatching_mask, reshaped_input)
expert_output = self.expert(dispatched_input) assert dispatched_input.shape[1] == len(self.experts)
chunks = dispatched_input.chunk(len(self.experts), dim=1)
expert_outputs = []
for chunk, expert in zip(chunks, self.experts):
expert_outputs += [expert(chunk)]
expert_output = torch.cat(expert_outputs, dim=1)
combined_output = self.all_to_all_combine(combine_weights, expert_output) combined_output = self.all_to_all_combine(combine_weights, expert_output)
return combined_output.reshape(shape) return combined_output.reshape(shape)
...@@ -61,7 +61,7 @@ def test_expert_params(device): ...@@ -61,7 +61,7 @@ def test_expert_params(device):
def test_forward(device): def test_forward(device):
model_dim = 8 model_dim = 8
num_experts = dist.get_world_size(dist.group.WORLD) num_experts = dist.get_world_size(dist.group.WORLD)
input = torch.randn(3, 4, 16, model_dim).to(device) input = torch.randn(1, 4, 16, model_dim).to(device)
gate = Top2Gate(model_dim, num_experts) gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False) expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use identity matrix # Use identity matrix
...@@ -73,6 +73,30 @@ def test_forward(device): ...@@ -73,6 +73,30 @@ def test_forward(device):
assert torch.allclose(input, output) assert torch.allclose(input, output)
@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_forward_multi(device):
torch.set_printoptions(threshold=5000)
num_local_experts = 4
model_dim = 4
num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts
input = torch.randn(num_local_experts, 4, 16, model_dim).to(device)
gate = Top2Gate(model_dim, num_experts)
experts = []
for i in range(num_local_experts):
expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use identity matrix
expert.weight = torch.nn.Parameter(torch.eye(model_dim))
experts += [expert]
moe = MOELayer(gate, torch.nn.ModuleList(experts)).to(device)
output = moe(input)
assert output.shape == input.shape
# 90% of the input should have gone to an expert
assert len(output.nonzero(as_tuple=False)) / output.numel() > 0.90
# Except for zeros, re-assembled output should match input due to identity expert.
assert torch.allclose(input, torch.where(output > 0, output, input))
# Test Gate which round-robin routes tokens to experts # Test Gate which round-robin routes tokens to experts
class RoundRobinGate(torch.nn.Module): class RoundRobinGate(torch.nn.Module):
def __init__(self, model_dim, num_experts): def __init__(self, model_dim, num_experts):
...@@ -95,7 +119,7 @@ class RoundRobinGate(torch.nn.Module): ...@@ -95,7 +119,7 @@ class RoundRobinGate(torch.nn.Module):
def test_forward_routing(device): def test_forward_routing(device):
model_dim = 8 model_dim = 8
num_experts = dist.get_world_size() num_experts = dist.get_world_size()
input = torch.randn(3, 4, 16, model_dim).to(device) input = torch.randn(1, 4, 16, model_dim).to(device)
gate = RoundRobinGate(model_dim, num_experts) gate = RoundRobinGate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False) expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use scaling matrix (each rank has a different scale) # Use scaling matrix (each rank has a different scale)
...@@ -117,7 +141,7 @@ def test_backward(device): ...@@ -117,7 +141,7 @@ def test_backward(device):
loss = torch.nn.MSELoss() loss = torch.nn.MSELoss()
model_dim = 8 model_dim = 8
num_experts = dist.get_world_size(dist.group.WORLD) num_experts = dist.get_world_size(dist.group.WORLD)
input = torch.randn(3, 4, 16, model_dim).to(device) input = torch.randn(1, 4, 16, model_dim).to(device)
gate = Top2Gate(model_dim, num_experts) gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False) expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use identity matrix # Use identity matrix
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment