Unverified Commit d99c445a authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[feat] moe: add all_to_all backward support (#137)

parent 1e6c547a
...@@ -42,7 +42,7 @@ install_dep_15: &install_dep_15 ...@@ -42,7 +42,7 @@ install_dep_15: &install_dep_15
- run: - run:
name: Install Dependencies name: Install Dependencies
command: | command: |
sudo apt-get install -y mpi-default-dev sudo apt-get install -y libopenmpi-dev
pip install --progress-bar off torch==1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html pip install --progress-bar off torch==1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt pip install --progress-bar off -r requirements-test.txt
python -c 'import torch; print("Torch version:", torch.__version__)' python -c 'import torch; print("Torch version:", torch.__version__)'
...@@ -52,7 +52,7 @@ install_dep_16: &install_dep_16 ...@@ -52,7 +52,7 @@ install_dep_16: &install_dep_16
- run: - run:
name: Install Dependencies name: Install Dependencies
command: | command: |
sudo apt-get install -y mpi-default-dev sudo apt-get install -y libopenmpi-dev
pip install --progress-bar off torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html pip install --progress-bar off torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt pip install --progress-bar off -r requirements-test.txt
python -c 'import torch; print("Torch version:", torch.__version__)' python -c 'import torch; print("Torch version:", torch.__version__)'
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -19,6 +19,24 @@ else: ...@@ -19,6 +19,24 @@ else:
# See https://arxiv.org/pdf/2006.16668.pdf for details. # See https://arxiv.org/pdf/2006.16668.pdf for details.
# Based on https://github.com/pytorch/pytorch/pull/40762
class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore
ctx.group = group
world_size = dist.get_world_size(group)
input = input.contiguous()
output = torch.empty_like(input)
input_chunks = list(input.chunk(world_size))
output_chunks = list(output.chunk(world_size))
dist.all_to_all(output_chunks, input_chunks, group=group)
return output
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
return (None, _AllToAll.apply(ctx.group, *grad_output))
class MOELayer(Base): class MOELayer(Base):
"""MOELayer module which implements MixtureOfExperts as described in Gshard_. """MOELayer module which implements MixtureOfExperts as described in Gshard_.
:: ::
...@@ -42,21 +60,14 @@ class MOELayer(Base): ...@@ -42,21 +60,14 @@ class MOELayer(Base):
self.gate = gate self.gate = gate
self.expert = expert self.expert = expert
self.group = group if group is not None else dist.group.WORLD self.group = group if group is not None else dist.group.WORLD
self.world_size = dist.get_world_size(self.group)
def all_to_all_dispatch(self, dispatch_mask: Tensor, input: Tensor) -> Tensor: def all_to_all_dispatch(self, dispatch_mask: Tensor, input: Tensor) -> Tensor:
dispatched_input = torch.einsum("gsec,gsm->egcm", dispatch_mask.float(), input) dispatched_input = torch.einsum("gsec,gsm->egcm", dispatch_mask.float(), input)
dispatched_input = dispatched_input.contiguous() return _AllToAll.apply(self.group, dispatched_input)
chunks = list(dispatched_input.chunk(self.world_size))
dist.all_to_all(chunks, chunks, self.group)
return dispatched_input
def all_to_all_combine(self, combine_weights: Tensor, input: Tensor) -> Tensor: def all_to_all_combine(self, combine_weights: Tensor, input: Tensor) -> Tensor:
expert_output = input.contiguous() expert_output = _AllToAll.apply(self.group, input)
chunks = list(expert_output.chunk(self.world_size)) return torch.einsum("gsec,egcm->gsm", combine_weights, expert_output)
dist.all_to_all(chunks, chunks, self.group)
output = torch.einsum("gsec,egcm->gsm", combine_weights, expert_output)
return output
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
assert len(input) == 1, "only single input Tensor supported" assert len(input) == 1, "only single input Tensor supported"
......
...@@ -60,3 +60,22 @@ def test_forward(device): ...@@ -60,3 +60,22 @@ def test_forward(device):
assert output.shape == input.shape assert output.shape == input.shape
# Re-assembled output should match input due to identity expert. # Re-assembled output should match input due to identity expert.
assert torch.allclose(input, output) assert torch.allclose(input, output)
@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_backward(device):
loss = torch.nn.MSELoss()
model_dim = 8
num_experts = dist.get_world_size(dist.group.WORLD)
input = torch.randn(3, 4, 16, model_dim).to(device)
gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use identity matrix
expert.weight = torch.nn.Parameter(torch.eye(model_dim))
moe = MOELayer(gate, expert).to(device)
output = moe(input)
assert output.shape == input.shape
output = loss(output, input)
output.backward()
assert torch.allclose(expert.weight.grad, torch.zeros_like(expert.weight))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment