Unverified Commit 6d802f5a authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[feat] moe: add all_to_all support (#134)

parent 177151e0
......@@ -42,6 +42,7 @@ install_dep_15: &install_dep_15
- run:
name: Install Dependencies
command: |
sudo apt-get install -y mpi-default-dev
pip install --progress-bar off torch==1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
......@@ -51,6 +52,7 @@ install_dep_16: &install_dep_16
- run:
name: Install Dependencies
command: |
sudo apt-get install -y mpi-default-dev
pip install --progress-bar off torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
......@@ -84,6 +86,12 @@ run_unittests: &run_unittests
command: |
pytest --junitxml=test-results/junit.xml --verbose
run_mpi_unittests: &run_mpi_unittests
- run:
name: Run MPI Unit Tests
command: |
mpirun -n4 python -m pytest -only-mpi --junitxml=test-results/junit.xml --verbose
run_flake8: &run_flake8
- run:
name: Run Linter (flake8)
......
......@@ -3,10 +3,11 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional
import torch
from torch import Tensor
import torch.distributed as dist
from torch.nn import Module
if TYPE_CHECKING:
......@@ -24,7 +25,8 @@ class MOELayer(Base):
gate = Top2Gate(model_dim, num_experts)
moe = MOELayer(gate, expert)
l_aux, combine_weights, dispatch_mask = moe(input)
output = moe(input)
l_aux = moe.l_aux
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
......@@ -35,24 +37,31 @@ class MOELayer(Base):
expert network
"""
def __init__(self, gate: Module, expert: Module) -> None:
def __init__(self, gate: Module, expert: Module, group: Optional[Any] = None) -> None:
super().__init__()
self.gate = gate
self.expert = expert
self.group = group if group is not None else dist.group.WORLD
self.world_size = dist.get_world_size(self.group)
def all_to_all_dispatch(self, dispatch_mask: Tensor, input: Tensor) -> Tensor:
dispatched_input = torch.einsum("gsec,gsm->egcm", dispatch_mask.float(), input)
# TODO(msb) all-to-all
dispatched_input = torch.squeeze(dispatched_input, 0) # drop E dimension
dispatched_input = dispatched_input.contiguous()
chunks = list(dispatched_input.chunk(self.world_size))
dist.all_to_all(chunks, chunks, self.group)
return dispatched_input
def all_to_all_combine(self, combine_weights: Tensor, input: Tensor) -> Tensor:
# TODO(msb) all-to-all
expert_output = torch.unsqueeze(input, 1) # add E dimension
output = torch.einsum("gsec,gecm->gsm", combine_weights, expert_output)
expert_output = input.contiguous()
chunks = list(expert_output.chunk(self.world_size))
dist.all_to_all(chunks, chunks, self.group)
output = torch.einsum("gsec,egcm->gsm", combine_weights, expert_output)
return output
def forward(self, *input: Any, **kwargs: Any) -> Tensor:
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
assert len(input) == 1, "only single input Tensor supported"
assert len(input[0].shape) == 4, "input Tensor must have dimensions: (g)roup, (s)equence, (t)oken, (m)odel"
# Implement Algorithm 2 from GShard paper.
shape = input[0].shape
# Reshape into S tokens per group.
......
black == 19.10b0
flake8 == 3.7.9
isort == 4.3.21
mpi4py == 3.0.3
mypy == 0.770
pytest == 5.4.1
pytest-cov == 2.10.0
pytest-mpi == 0.4
torchtext == 0.6.0
torch >= 1.5.1
torchvision >= 0.6.0
......
......@@ -6,7 +6,10 @@ import datetime
from . import rpc as rpc
class Backend: ...
class Backend:
GLOO: str
MPI: str
NCCL: str
class ProcessGroup:
def size(self) -> int: ...
......@@ -29,8 +32,10 @@ def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ...
def is_initialized() -> bool: ...
def init_process_group(backend: Union[str, Backend], timeout: datetime.timedelta = datetime.timedelta(0, 1800), rank: Optional[int] = None, world_size: Optional[int] = None): ...
def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta(0, 1800), backend: Union[None, str, Backend] = None): ...
def all_to_all(output: List[Tensor], intput: List[Tensor], group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
......
......@@ -3,34 +3,53 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import os
import pytest
import torch
import torch.distributed as dist
from fairscale.nn import MOELayer, Top2Gate
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
def test_create():
model_dim = 8
num_experts = 4
gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim)
moe = MOELayer(gate, expert)
if torch.cuda.is_available():
devices = ["cpu", "cuda"]
else:
devices = ["cpu"]
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
if "OMPI_COMM_WORLD_SIZE" in os.environ:
dist.init_process_group(backend=dist.Backend.MPI)
def setup_module(module):
if "OMPI_COMM_WORLD_SIZE" not in os.environ:
dist.init_process_group(backend=BACKEND, rank=0, world_size=1)
@skip_if_no_cuda
def test_create_cuda():
def teardown_module(module):
if "OMPI_COMM_WORLD_SIZE" not in os.environ:
torch.distributed.destroy_process_group()
@pytest.mark.parametrize("device", devices)
def test_create(device):
model_dim = 8
num_experts = 4
gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim)
moe = MOELayer(gate, expert).cuda()
moe = MOELayer(gate, expert).to(device)
def do_test_forward(device):
@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_forward(device):
model_dim = 8
num_experts = 1
num_experts = dist.get_world_size(dist.group.WORLD)
input = torch.randn(3, 4, 16, model_dim).to(device)
gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False)
......@@ -38,16 +57,6 @@ def do_test_forward(device):
expert.weight = torch.nn.Parameter(torch.eye(model_dim))
moe = MOELayer(gate, expert).to(device)
output = moe(input)
assert moe.l_aux.item() == 1.0
assert output.shape == input.shape
# Re-assembled output should match input due to identity expert.
assert torch.equal(input, output)
def test_forward_cpu():
do_test_forward("cpu")
@skip_if_no_cuda
def test_forward_cuda():
do_test_forward("cuda")
assert torch.allclose(input, output)
......@@ -29,6 +29,10 @@ def setup_module(module):
dist.init_process_group(backend=BACKEND, rank=0, world_size=1)
def teardown_module(module):
torch.distributed.destroy_process_group()
def dist_init(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment