Unverified Commit 84e0de84 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[feat] experimental.nn.multiprocess_pipe: re-implemented using rpc (#519)

parent f7e6680b
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import itertools
from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Type, Union
import torch
from torch import Tensor
import torch.distributed.rpc as rpc
import torch.nn as nn
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
LayerSpec = Tuple[str, Type[nn.Module], Tuple, Dict]
if TYPE_CHECKING:
Module = nn.Module[TensorOrTensors]
else:
Module = nn.Module
BOUNCE_TENSORS = True
def _verify_module(module: List[LayerSpec]) -> None:
if not isinstance(module, List):
raise TypeError("module must be a list")
for elem in module:
if not isinstance(elem, tuple):
raise TypeError("module must be a list of tuple")
if len(elem) != 4:
raise TypeError("each module tuple must contain (name, nn.module, args, kwargs)")
name, layer, args, kwargs = elem
if not (
isinstance(name, str)
and issubclass(layer, nn.Module)
and isinstance(args, tuple)
and isinstance(kwargs, dict)
):
raise TypeError("each module tuple must contain (name, nn.module, args, kwargs)")
class _ToHere(Module):
def __init__(self, device: str):
super().__init__()
self.device = device
def forward(self, x_rref: rpc.RRef) -> Tensor: # type: ignore
if BOUNCE_TENSORS:
return x_rref.remote().cpu().to_here().to(self.device)
else:
return x_rref.to_here().to(self.device)
def _create_sequential(layer_spec: List[LayerSpec], device: str) -> Module:
layers = [(name, layer(*args, **kwargs)) for name, layer, args, kwargs in layer_spec] # type: ignore
layers.insert(0, ("to_here", _ToHere(device)))
return nn.Sequential(OrderedDict(layers)).to(device)
def _rcat(tensors: List) -> Tensor:
return torch.cat([t.local_value() for t in tensors])
def _parameter_rrefs(module: rpc.RRef) -> List[rpc.RRef]:
return [rpc.RRef(p) for p in module.to_here().parameters()]
def rloss(loss_func: Callable, input_rref: rpc.RRef, target_rref: rpc.RRef) -> rpc.RRef:
if BOUNCE_TENSORS:
return loss_func(input_rref.remote().cpu().to_here(), target_rref.remote().cpu().to_here())
else:
return loss_func(input_rref.to_here(), target_rref.to_here())
def DistributedLoss(loss: nn.Module, *args: Tuple, **kwargs: Dict) -> Callable:
loss_func = loss(*args, **kwargs)
def dloss(input_rref: rpc.RRef, target_rref: rpc.RRef) -> rpc.RRef:
return rpc.remote(input_rref.owner(), rloss, args=(loss_func, input_rref, target_rref))
return dloss
class MultiProcessPipe(Module):
"""Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module
to train on Pipe_. If the module requires lots of memory, Pipe will be
very efficient.
::
model = nn.Sequential(a, b, c, d)
model = Pipe(model, balance=[1, 1, 1, 1], chunks=8)
output = model(input)
.. _Pipe: https://arxiv.org/abs/1811.06965
Pipe combines pipeline parallelism with checkpointing to reduce peak
memory required to train while minimizing device under-utilization.
You should determine the balance when defining a :class:`Pipe` module, as
balancing will not be done automatically. The module will be partitioned
into multiple devices according to the given balance. You may rely on
heuristics to find your own optimal configuration.
Args:
module (torch.nn.Sequential):
sequential module to be parallelized
balance (ints):
list of number of layers in each partition
Keyword Args:
devices (iterable of devices):
devices to use (default: all CUDA devices)
chunks (int):
number of micro-batches (default: ``1``)
checkpoint (str):
when to enable checkpointing, one of ``'always'``,
``'except_last'``, or ``'never'`` (default: ``'except_last'``)
deferred_batch_norm (bool):
whether to use deferred BatchNorm moving statistics (default:
:data:`False`, see :class:`Deferred Batch Normalization <DeferredBatchNorm>` for more
details)
Raises:
TypeError:
the module is not a :class:`nn.Sequential <torch.nn.Sequential>`.
ValueError:
invalid arguments, or wrong balance
IndexError:
the number of devices is fewer than the number of partitions.
"""
#: The number of micro-batches.
chunks: int = 1
#: The checkpoint mode to determine when to enable checkpointing. It is one
#: of ``'always'``, ``'except_last'``, or ``'never'``.
checkpoint: str = "never"
def __init__(
self,
module: List[LayerSpec],
*,
balance: List[int],
devices: List[str],
chunks: int = chunks,
checkpoint: str = checkpoint,
deferred_batch_norm: bool = False,
) -> None:
super().__init__()
if type(chunks) is not int or chunks <= 0:
raise ValueError("number of chunks must be positive integer")
if checkpoint not in ["never"]:
raise ValueError("checkpoint is not yet implemented")
if deferred_batch_norm:
raise ValueError("deferred_batch_norm is not yet implemented")
if len(balance) != len(devices):
raise ValueError("balance and devices lists must be the same size")
if len(module) != sum(balance):
raise ValueError("number of layers must match aggregate balance")
_verify_module(module)
index = 0
rmodule = []
workers = []
for num_layers, device_spec in zip(balance, devices):
worker, device = device_spec.split("/")
next_index = index + num_layers
rlayer = rpc.remote(worker, _create_sequential, args=(module[index:next_index], device))
index = next_index
workers.append(worker)
rmodule.append(rlayer)
self.chunks = chunks
self.checkpoint = checkpoint
self.module = module
self.workers = workers
self.rmodule = rmodule
def forward(self, x: Tensor) -> rpc.RRef: # type: ignore
outputs = []
for chunk in x.chunk(self.chunks):
output = rpc.RRef(chunk)
for rlayer in self.rmodule:
output = rlayer.remote().forward(output)
outputs.append(output)
return rpc.remote(outputs[0].owner(), _rcat, args=(outputs,))
def parameter_rrefs(self) -> List[rpc.RRef]:
rrefs_list_of_lists = [rpc.rpc_sync(l.owner(), _parameter_rrefs, args=(l,)) for l in self.rmodule]
return list(itertools.chain(*rrefs_list_of_lists))
...@@ -3,7 +3,12 @@ ...@@ -3,7 +3,12 @@
from typing import Union, Callable, Optional, Any from typing import Union, Callable, Optional, Any
from torch.futures import Future from torch.futures import Future
class RRef: ... class RRef:
def __init__(self, t: Any) -> None: ...
def owner(self) -> WorkerInfo: ...
def remote(self) -> Any: ...
def rpc_sync(self) -> Any: ...
def to_here(self) -> Any: ...
class WorkerInfo: ... class WorkerInfo: ...
class BackendType: class BackendType:
...@@ -12,6 +17,13 @@ class BackendType: ...@@ -12,6 +17,13 @@ class BackendType:
def TensorPipeRpcBackendOptions(init_method: str) -> Any: ... def TensorPipeRpcBackendOptions(init_method: str) -> Any: ...
def ProcessGroupRpcBackendOptions(init_method: str) -> Any: ... def ProcessGroupRpcBackendOptions(init_method: str) -> Any: ...
def remote(
to: Union[str, WorkerInfo],
func: Callable,
args: Optional[tuple] = None,
kwargs: Optional[dict] = None,
timeout=-1.0,
) -> RRef: ...
def rpc_async( def rpc_async(
to: Union[str, WorkerInfo], to: Union[str, WorkerInfo],
func: Callable, func: Callable,
...@@ -25,7 +37,7 @@ def rpc_sync( ...@@ -25,7 +37,7 @@ def rpc_sync(
args: Optional[tuple] = None, args: Optional[tuple] = None,
kwargs: Optional[dict] = None, kwargs: Optional[dict] = None,
timeout=-1.0, timeout=-1.0,
) -> None: ... ) -> Any: ...
def init_rpc( def init_rpc(
name: str, name: str,
backend: Optional[Any] = None, backend: Optional[Any] = None,
......
...@@ -31,3 +31,4 @@ tests/nn/pipe/test_phony.py ...@@ -31,3 +31,4 @@ tests/nn/pipe/test_phony.py
tests/nn/pipe/test_deferred_batch_norm.py tests/nn/pipe/test_deferred_batch_norm.py
tests/nn/pipe/test_dependency.py tests/nn/pipe/test_dependency.py
tests/nn/pipe/test_stream.py tests/nn/pipe/test_stream.py
tests/experimental/nn/test_multiprocess_pipe.py
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""
Testing MultiProcessPipe Module
"""
import functools
import tempfile
import pytest
import torch
import torch.distributed.autograd as dist_autograd
from torch.distributed.optim import DistributedOptimizer
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.nn as nn
from fairscale.experimental.nn.multiprocess_pipe import DistributedLoss, MultiProcessPipe
from fairscale.utils.testing import torch_version
BOUNCE_TENSORS = True
CPU_DEVICES = ["worker0/cpu", "worker1/cpu"]
GPU_DEVICES = ["worker0/cuda:0", "worker1/cuda:1"]
if torch.cuda.is_available():
DEVICES = [CPU_DEVICES, GPU_DEVICES]
else:
DEVICES = [CPU_DEVICES]
# cuda test is because of https://github.com/pytorch/pytorch/issues/54266
pytestmark = pytest.mark.skipif(
not torch.cuda.is_available() or torch_version() < (1, 8, 0), reason="requires torch version >= 1.8.0 and cuda"
)
def rpc_worker(rank, world_size, init_file, func, *args):
# Workaround for https://github.com/pytorch/pytorch/issues/54266
if not torch.cuda.is_available():
options = rpc.ProcessGroupRpcBackendOptions(init_method="file://" + init_file)
rpc.init_rpc(
"worker" + str(rank),
rank=rank,
world_size=world_size,
backend=rpc.BackendType.PROCESS_GROUP,
rpc_backend_options=options,
)
else:
# Workaround for https://github.com/pytorch/pytorch/issues/53844
if torch_version() == (1, 8, 0):
options = rpc.TensorPipeRpcBackendOptions(init_method="file://" + init_file, _transports=["ibv", "uv"])
else:
options = rpc.TensorPipeRpcBackendOptions(init_method="file://" + init_file)
rpc.init_rpc(
"worker" + str(rank),
rank=rank,
world_size=world_size,
backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=options,
)
if rank == 0:
func(*args)
rpc.shutdown()
def rpc_test(world_size=1):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
mp.spawn(rpc_worker, args=(world_size, tempfile.mkstemp()[1], func, *kwargs.values()), nprocs=world_size)
globals()["test_" + func.__name__] = wrapper
return func
return decorator
@rpc_test()
@pytest.mark.parametrize("devices", DEVICES)
def create(devices):
model = [("linear", nn.Linear, (4, 4), {})]
pipe = MultiProcessPipe(model, balance=[1], chunks=1, devices=devices[:1])
@rpc_test()
def create_multiple_layers():
model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
pipe = MultiProcessPipe(model, balance=[1, 1], chunks=1, devices=["worker0/cpu", "worker0/cpu"])
@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def create_multiple_workers(devices):
model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
pipe = MultiProcessPipe(model, balance=[1, 1], chunks=1, devices=devices[:2])
@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def parameter_rrefs(devices):
model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
pipe = MultiProcessPipe(model, balance=[1, 1], chunks=1, devices=devices[:2])
parameter_rrefs = pipe.parameter_rrefs()
assert len(parameter_rrefs) == 2
@rpc_test(world_size=1)
@pytest.mark.parametrize("devices", DEVICES)
def forward(devices):
yh = torch.tensor([1.0, 0.0])
x = torch.tensor([1.0, -1.0])
model = [("relu", nn.ReLU, (), {})]
pipe = MultiProcessPipe(model, balance=[1], chunks=1, devices=devices[:1])
y = pipe(x).to_here().cpu()
assert torch.equal(y, yh), f"{y} != {yh}"
@rpc_test(world_size=1)
@pytest.mark.parametrize("devices", DEVICES)
def forward_chunks(devices):
yh = torch.tensor([1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0])
x = torch.tensor([1.0, -1.0, 2.0, -2.0, 3.0, -3.0, 4.0, -4.0])
model = [("relu", nn.ReLU, (), {})]
pipe = MultiProcessPipe(model, balance=[1], chunks=4, devices=devices[:1])
y = pipe(x).to_here().cpu()
assert torch.equal(y, yh), f"{y} != {yh}"
@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def forward_multi(devices):
torch.random.manual_seed(3)
torch.cuda.manual_seed_all(3)
x = torch.randn(8, 4)
model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2])
if BOUNCE_TENSORS:
y = pipe(x).remote().cpu().to_here()
else:
y = pipe(x).to_here()
expected_sum = torch.tensor(5.0615)
assert y.shape == torch.Size([8, 4])
assert y.requires_grad is True
assert torch.allclose(y.sum(), expected_sum), f"{y.sum()} != {expected_sum}"
@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def backward(devices):
torch.random.manual_seed(3)
criterion = DistributedLoss(torch.nn.MSELoss)
x = torch.randn(8, 4)
model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2])
with dist_autograd.context() as context_id:
y = pipe(x)
loss = criterion(y, rpc.RRef(x))
loss.backward(context_id)
grads = dist_autograd.get_gradients(context_id)
assert len(grads) == 2
@rpc_test(world_size=2)
@pytest.mark.parametrize("devices", DEVICES)
def update(devices):
torch.random.manual_seed(3)
criterion = DistributedLoss(torch.nn.MSELoss)
x = torch.randn(8, 4)
model = [("linear1", nn.Linear, (4, 4), {}), ("relu", nn.ReLU, (), {})]
pipe = MultiProcessPipe(model, balance=[1, 1], chunks=4, devices=devices[:2])
params = pipe.parameter_rrefs()
opt = DistributedOptimizer(torch.optim.SGD, pipe.parameter_rrefs(), lr=0.05,)
losses = []
for i in range(2):
with dist_autograd.context() as context_id:
y = pipe(x)
loss = criterion(y, rpc.RRef(x))
losses.append(loss)
loss.backward(context_id)
opt.step(context_id)
losses = [l.to_here() for l in losses]
assert losses[0] > losses[1], f"{losses[0]} !> {losses[1]}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment