Unverified Commit 6d802f5a authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[feat] moe: add all_to_all support (#134)

parent 177151e0
...@@ -42,6 +42,7 @@ install_dep_15: &install_dep_15 ...@@ -42,6 +42,7 @@ install_dep_15: &install_dep_15
- run: - run:
name: Install Dependencies name: Install Dependencies
command: | command: |
sudo apt-get install -y mpi-default-dev
pip install --progress-bar off torch==1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html pip install --progress-bar off torch==1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt pip install --progress-bar off -r requirements-test.txt
python -c 'import torch; print("Torch version:", torch.__version__)' python -c 'import torch; print("Torch version:", torch.__version__)'
...@@ -51,6 +52,7 @@ install_dep_16: &install_dep_16 ...@@ -51,6 +52,7 @@ install_dep_16: &install_dep_16
- run: - run:
name: Install Dependencies name: Install Dependencies
command: | command: |
sudo apt-get install -y mpi-default-dev
pip install --progress-bar off torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html pip install --progress-bar off torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt pip install --progress-bar off -r requirements-test.txt
python -c 'import torch; print("Torch version:", torch.__version__)' python -c 'import torch; print("Torch version:", torch.__version__)'
...@@ -84,6 +86,12 @@ run_unittests: &run_unittests ...@@ -84,6 +86,12 @@ run_unittests: &run_unittests
command: | command: |
pytest --junitxml=test-results/junit.xml --verbose pytest --junitxml=test-results/junit.xml --verbose
run_mpi_unittests: &run_mpi_unittests
- run:
name: Run MPI Unit Tests
command: |
mpirun -n4 python -m pytest -only-mpi --junitxml=test-results/junit.xml --verbose
run_flake8: &run_flake8 run_flake8: &run_flake8
- run: - run:
name: Run Linter (flake8) name: Run Linter (flake8)
......
...@@ -3,10 +3,11 @@ ...@@ -3,10 +3,11 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, Optional
import torch import torch
from torch import Tensor from torch import Tensor
import torch.distributed as dist
from torch.nn import Module from torch.nn import Module
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -24,7 +25,8 @@ class MOELayer(Base): ...@@ -24,7 +25,8 @@ class MOELayer(Base):
gate = Top2Gate(model_dim, num_experts) gate = Top2Gate(model_dim, num_experts)
moe = MOELayer(gate, expert) moe = MOELayer(gate, expert)
l_aux, combine_weights, dispatch_mask = moe(input) output = moe(input)
l_aux = moe.l_aux
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
...@@ -35,24 +37,31 @@ class MOELayer(Base): ...@@ -35,24 +37,31 @@ class MOELayer(Base):
expert network expert network
""" """
def __init__(self, gate: Module, expert: Module) -> None: def __init__(self, gate: Module, expert: Module, group: Optional[Any] = None) -> None:
super().__init__() super().__init__()
self.gate = gate self.gate = gate
self.expert = expert self.expert = expert
self.group = group if group is not None else dist.group.WORLD
self.world_size = dist.get_world_size(self.group)
def all_to_all_dispatch(self, dispatch_mask: Tensor, input: Tensor) -> Tensor: def all_to_all_dispatch(self, dispatch_mask: Tensor, input: Tensor) -> Tensor:
dispatched_input = torch.einsum("gsec,gsm->egcm", dispatch_mask.float(), input) dispatched_input = torch.einsum("gsec,gsm->egcm", dispatch_mask.float(), input)
# TODO(msb) all-to-all dispatched_input = dispatched_input.contiguous()
dispatched_input = torch.squeeze(dispatched_input, 0) # drop E dimension chunks = list(dispatched_input.chunk(self.world_size))
dist.all_to_all(chunks, chunks, self.group)
return dispatched_input return dispatched_input
def all_to_all_combine(self, combine_weights: Tensor, input: Tensor) -> Tensor: def all_to_all_combine(self, combine_weights: Tensor, input: Tensor) -> Tensor:
# TODO(msb) all-to-all expert_output = input.contiguous()
expert_output = torch.unsqueeze(input, 1) # add E dimension chunks = list(expert_output.chunk(self.world_size))
output = torch.einsum("gsec,gecm->gsm", combine_weights, expert_output) dist.all_to_all(chunks, chunks, self.group)
output = torch.einsum("gsec,egcm->gsm", combine_weights, expert_output)
return output return output
def forward(self, *input: Any, **kwargs: Any) -> Tensor: def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
assert len(input) == 1, "only single input Tensor supported"
assert len(input[0].shape) == 4, "input Tensor must have dimensions: (g)roup, (s)equence, (t)oken, (m)odel"
# Implement Algorithm 2 from GShard paper. # Implement Algorithm 2 from GShard paper.
shape = input[0].shape shape = input[0].shape
# Reshape into S tokens per group. # Reshape into S tokens per group.
......
black == 19.10b0 black == 19.10b0
flake8 == 3.7.9 flake8 == 3.7.9
isort == 4.3.21 isort == 4.3.21
mpi4py == 3.0.3
mypy == 0.770 mypy == 0.770
pytest == 5.4.1 pytest == 5.4.1
pytest-cov == 2.10.0 pytest-cov == 2.10.0
pytest-mpi == 0.4
torchtext == 0.6.0 torchtext == 0.6.0
torch >= 1.5.1 torch >= 1.5.1
torchvision >= 0.6.0 torchvision >= 0.6.0
......
...@@ -6,7 +6,10 @@ import datetime ...@@ -6,7 +6,10 @@ import datetime
from . import rpc as rpc from . import rpc as rpc
class Backend: ... class Backend:
GLOO: str
MPI: str
NCCL: str
class ProcessGroup: class ProcessGroup:
def size(self) -> int: ... def size(self) -> int: ...
...@@ -29,8 +32,10 @@ def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ... ...@@ -29,8 +32,10 @@ def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ...
def is_initialized() -> bool: ... def is_initialized() -> bool: ...
def init_process_group(backend: Union[str, Backend], timeout: datetime.timedelta = datetime.timedelta(0, 1800), rank: Optional[int] = None, world_size: Optional[int] = None): ...
def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta(0, 1800), backend: Union[None, str, Backend] = None): ... def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta(0, 1800), backend: Union[None, str, Backend] = None): ...
def all_to_all(output: List[Tensor], intput: List[Tensor], group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ... def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ... def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
......
...@@ -3,34 +3,53 @@ ...@@ -3,34 +3,53 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import os
import pytest import pytest
import torch import torch
import torch.distributed as dist
from fairscale.nn import MOELayer, Top2Gate from fairscale.nn import MOELayer, Top2Gate
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
def test_create(): if torch.cuda.is_available():
model_dim = 8 devices = ["cpu", "cuda"]
num_experts = 4 else:
gate = Top2Gate(model_dim, num_experts) devices = ["cpu"]
expert = torch.nn.Linear(model_dim, model_dim)
moe = MOELayer(gate, expert) os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
if "OMPI_COMM_WORLD_SIZE" in os.environ:
dist.init_process_group(backend=dist.Backend.MPI)
def setup_module(module):
if "OMPI_COMM_WORLD_SIZE" not in os.environ:
dist.init_process_group(backend=BACKEND, rank=0, world_size=1)
@skip_if_no_cuda def teardown_module(module):
def test_create_cuda(): if "OMPI_COMM_WORLD_SIZE" not in os.environ:
torch.distributed.destroy_process_group()
@pytest.mark.parametrize("device", devices)
def test_create(device):
model_dim = 8 model_dim = 8
num_experts = 4 num_experts = 4
gate = Top2Gate(model_dim, num_experts) gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim) expert = torch.nn.Linear(model_dim, model_dim)
moe = MOELayer(gate, expert).cuda() moe = MOELayer(gate, expert).to(device)
def do_test_forward(device): @pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_forward(device):
model_dim = 8 model_dim = 8
num_experts = 1 num_experts = dist.get_world_size(dist.group.WORLD)
input = torch.randn(3, 4, 16, model_dim).to(device) input = torch.randn(3, 4, 16, model_dim).to(device)
gate = Top2Gate(model_dim, num_experts) gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False) expert = torch.nn.Linear(model_dim, model_dim, bias=False)
...@@ -38,16 +57,6 @@ def do_test_forward(device): ...@@ -38,16 +57,6 @@ def do_test_forward(device):
expert.weight = torch.nn.Parameter(torch.eye(model_dim)) expert.weight = torch.nn.Parameter(torch.eye(model_dim))
moe = MOELayer(gate, expert).to(device) moe = MOELayer(gate, expert).to(device)
output = moe(input) output = moe(input)
assert moe.l_aux.item() == 1.0
assert output.shape == input.shape assert output.shape == input.shape
# Re-assembled output should match input due to identity expert. # Re-assembled output should match input due to identity expert.
assert torch.equal(input, output) assert torch.allclose(input, output)
def test_forward_cpu():
do_test_forward("cpu")
@skip_if_no_cuda
def test_forward_cuda():
do_test_forward("cuda")
...@@ -29,6 +29,10 @@ def setup_module(module): ...@@ -29,6 +29,10 @@ def setup_module(module):
dist.init_process_group(backend=BACKEND, rank=0, world_size=1) dist.init_process_group(backend=BACKEND, rank=0, world_size=1)
def teardown_module(module):
torch.distributed.destroy_process_group()
def dist_init(rank, world_size): def dist_init(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501" os.environ["MASTER_PORT"] = "29501"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment