Unverified Commit d99c445a authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[feat] moe: add all_to_all backward support (#137)

parent 1e6c547a
......@@ -42,7 +42,7 @@ install_dep_15: &install_dep_15
- run:
name: Install Dependencies
command: |
sudo apt-get install -y mpi-default-dev
sudo apt-get install -y libopenmpi-dev
pip install --progress-bar off torch==1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
......@@ -52,7 +52,7 @@ install_dep_16: &install_dep_16
- run:
name: Install Dependencies
command: |
sudo apt-get install -y mpi-default-dev
sudo apt-get install -y libopenmpi-dev
pip install --progress-bar off torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
......
......@@ -3,7 +3,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Tuple
import torch
from torch import Tensor
......@@ -19,6 +19,24 @@ else:
# See https://arxiv.org/pdf/2006.16668.pdf for details.
# Based on https://github.com/pytorch/pytorch/pull/40762
class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore
ctx.group = group
world_size = dist.get_world_size(group)
input = input.contiguous()
output = torch.empty_like(input)
input_chunks = list(input.chunk(world_size))
output_chunks = list(output.chunk(world_size))
dist.all_to_all(output_chunks, input_chunks, group=group)
return output
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
return (None, _AllToAll.apply(ctx.group, *grad_output))
class MOELayer(Base):
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
::
......@@ -42,21 +60,14 @@ class MOELayer(Base):
self.gate = gate
self.expert = expert
self.group = group if group is not None else dist.group.WORLD
self.world_size = dist.get_world_size(self.group)
def all_to_all_dispatch(self, dispatch_mask: Tensor, input: Tensor) -> Tensor:
dispatched_input = torch.einsum("gsec,gsm->egcm", dispatch_mask.float(), input)
dispatched_input = dispatched_input.contiguous()
chunks = list(dispatched_input.chunk(self.world_size))
dist.all_to_all(chunks, chunks, self.group)
return dispatched_input
return _AllToAll.apply(self.group, dispatched_input)
def all_to_all_combine(self, combine_weights: Tensor, input: Tensor) -> Tensor:
expert_output = input.contiguous()
chunks = list(expert_output.chunk(self.world_size))
dist.all_to_all(chunks, chunks, self.group)
output = torch.einsum("gsec,egcm->gsm", combine_weights, expert_output)
return output
expert_output = _AllToAll.apply(self.group, input)
return torch.einsum("gsec,egcm->gsm", combine_weights, expert_output)
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
assert len(input) == 1, "only single input Tensor supported"
......
......@@ -60,3 +60,22 @@ def test_forward(device):
assert output.shape == input.shape
# Re-assembled output should match input due to identity expert.
assert torch.allclose(input, output)
@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_backward(device):
loss = torch.nn.MSELoss()
model_dim = 8
num_experts = dist.get_world_size(dist.group.WORLD)
input = torch.randn(3, 4, 16, model_dim).to(device)
gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use identity matrix
expert.weight = torch.nn.Parameter(torch.eye(model_dim))
moe = MOELayer(gate, expert).to(device)
output = moe(input)
assert output.shape == input.shape
output = loss(output, input)
output.backward()
assert torch.allclose(expert.weight.grad, torch.zeros_like(expert.weight))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment