"vscode:/vscode.git/clone" did not exist on "d9effbd1d0da393b46ee4524e8ce8f52245e9bba"
Unverified Commit eaee5976 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] make AsyncPipe its own class (#341)

parent 51625eda
...@@ -21,7 +21,7 @@ from torchtext.data.utils import get_tokenizer ...@@ -21,7 +21,7 @@ from torchtext.data.utils import get_tokenizer
from experimental.nn.ampnet_pipe import pipe from experimental.nn.ampnet_pipe import pipe
from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule, MultiProcessPipe from fairscale.nn.pipe import LazyModule
from fairscale.optim import GradScaler from fairscale.optim import GradScaler
from fairscale.utils.testing import dist_init, get_worker_map from fairscale.utils.testing import dist_init, get_worker_map
...@@ -420,7 +420,6 @@ def run_mp_worker(args, available_workers): ...@@ -420,7 +420,6 @@ def run_mp_worker(args, available_workers):
p = pipe.AMPnetPipe( p = pipe.AMPnetPipe(
module=model, module=model,
balance=balance, balance=balance,
style=MultiProcessPipe.AsyncSchedule,
chunks=args.chunks, chunks=args.chunks,
worker_map=get_worker_map(), worker_map=get_worker_map(),
input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
......
...@@ -499,7 +499,6 @@ def run_mp_worker(args, available_workers): ...@@ -499,7 +499,6 @@ def run_mp_worker(args, available_workers):
pipe_model = MultiProcessPipe( pipe_model = MultiProcessPipe(
model, model,
balance, balance,
style=MultiProcessPipe.AsyncSchedule,
chunks=args.chunks, chunks=args.chunks,
worker_map=get_worker_map(), worker_map=get_worker_map(),
input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
......
...@@ -37,7 +37,6 @@ def create_task_without_skip_trackers( ...@@ -37,7 +37,6 @@ def create_task_without_skip_trackers(
checkpoint_stop: int, i: int, j: int, batch: Batch, partition: nn.Sequential, checkpoint_stop: int, i: int, j: int, batch: Batch, partition: nn.Sequential,
) -> Task: ) -> Task:
# Determine whether checkpointing or not. # Determine whether checkpointing or not.
# style is guaranteed to be PipelineStyle.AsyncSchedule
if i < checkpoint_stop: if i < checkpoint_stop:
def function( def function(
......
...@@ -11,15 +11,14 @@ from torch import nn ...@@ -11,15 +11,14 @@ from torch import nn
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from fairscale.nn.pipe import MultiProcessPipe from fairscale.nn.pipe import AsyncPipe
from fairscale.nn.pipe.types import PipelineStyle
from .ampnet import AsyncAMPnetEventLoop from .ampnet import AsyncAMPnetEventLoop
__all__ = ["AMPnetPipe"] __all__ = ["AMPnetPipe"]
class AMPnetPipe(MultiProcessPipe): class AMPnetPipe(AsyncPipe):
""" """
AMPnetPipe is the asynchronous version of the MultiProcessPipe implementation AMPnetPipe is the asynchronous version of the MultiProcessPipe implementation
which avoids the bubble issue, by using stale weights and gradients. which avoids the bubble issue, by using stale weights and gradients.
...@@ -44,7 +43,6 @@ class AMPnetPipe(MultiProcessPipe): ...@@ -44,7 +43,6 @@ class AMPnetPipe(MultiProcessPipe):
# AMPnet implementation doesn't handle skip_trackers! # AMPnet implementation doesn't handle skip_trackers!
assert self.pipeline.style is PipelineStyle.AsyncSchedule # type: ignore
assert self.group assert self.group
rank = self.group.rank() rank = self.group.rank()
......
...@@ -23,7 +23,6 @@ from torch.optim.optimizer import Optimizer ...@@ -23,7 +23,6 @@ from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from experimental.nn.ampnet_pipe.pipe import AMPnetPipe from experimental.nn.ampnet_pipe.pipe import AMPnetPipe
from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
...@@ -84,14 +83,7 @@ class FakeDataset(Dataset): ...@@ -84,14 +83,7 @@ class FakeDataset(Dataset):
@torch_spawn([2]) @torch_spawn([2])
def async_event_loop_interleave_simple(): def async_event_loop_interleave_simple():
model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(inplace=False), nn.Linear(10, 10), nn.ReLU(inplace=False)) model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(inplace=False), nn.Linear(10, 10), nn.ReLU(inplace=False))
pipe = AMPnetPipe( pipe = AMPnetPipe(module=model, balance=[2, 2], worker_map=get_worker_map(), chunks=10, checkpoint="never",)
module=model,
balance=[2, 2],
style=MultiProcessPipe.AsyncSchedule,
worker_map=get_worker_map(),
chunks=10,
checkpoint="never",
)
fake_dataset = FakeDataset() fake_dataset = FakeDataset()
fake_dataloader = DataLoader(fake_dataset, batch_size=4, shuffle=True, num_workers=0) fake_dataloader = DataLoader(fake_dataset, batch_size=4, shuffle=True, num_workers=0)
loss = nn.MSELoss() loss = nn.MSELoss()
...@@ -102,14 +94,7 @@ def async_event_loop_interleave_simple(): ...@@ -102,14 +94,7 @@ def async_event_loop_interleave_simple():
@torch_spawn([4]) @torch_spawn([4])
def async_event_loop_interleave_hard(): def async_event_loop_interleave_hard():
model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10)) model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10))
pipe = AMPnetPipe( pipe = AMPnetPipe(module=model, balance=[1, 1, 1, 1], worker_map=get_worker_map(), chunks=10, checkpoint="never",)
module=model,
balance=[1, 1, 1, 1],
style=MultiProcessPipe.AsyncSchedule,
worker_map=get_worker_map(),
chunks=10,
checkpoint="never",
)
fake_dataset = FakeDataset() fake_dataset = FakeDataset()
fake_dataloader = DataLoader(fake_dataset, batch_size=4, shuffle=True, num_workers=0) fake_dataloader = DataLoader(fake_dataset, batch_size=4, shuffle=True, num_workers=0)
loss = nn.MSELoss() loss = nn.MSELoss()
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
# limitations under the License. # limitations under the License.
"""A Pipe implementation in PyTorch.""" """A Pipe implementation in PyTorch."""
from .async_pipe import AsyncPipe
from .checkpoint import is_checkpointing, is_recomputing from .checkpoint import is_checkpointing, is_recomputing
from .multiprocess_pipe import LazyModule, MultiProcessPipe from .multiprocess_pipe import LazyModule, MultiProcessPipe
from .pipe import Pipe from .pipe import Pipe
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from .multiprocess_pipe import MultiProcessPipe
from .types import PipelineStyle
class AsyncPipe(MultiProcessPipe):
def __init__(self, *args, **kwargs) -> None: # type: ignore
super().__init__(*args, style=PipelineStyle.AsyncSchedule, **kwargs)
...@@ -386,9 +386,6 @@ class MultiProcessPipe(Module): ...@@ -386,9 +386,6 @@ class MultiProcessPipe(Module):
""" """
MultiProcess: PipelineStyle = PipelineStyle.MultiProcess
AsyncSchedule: PipelineStyle = PipelineStyle.AsyncSchedule
#: The number of layers in each partition. #: The number of layers in each partition.
balance: List[int] = [] balance: List[int] = []
# ^^ # ^^
......
...@@ -13,6 +13,7 @@ from torch.distributed.distributed_c10d import _get_global_rank ...@@ -13,6 +13,7 @@ from torch.distributed.distributed_c10d import _get_global_rank
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from .async_pipe import AsyncPipe
from .multiprocess_pipe import MultiProcessPipe from .multiprocess_pipe import MultiProcessPipe
from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors
...@@ -105,10 +106,9 @@ class PipeRPCWrapper(nn.Module): ...@@ -105,10 +106,9 @@ class PipeRPCWrapper(nn.Module):
else: else:
kwargs["group"] = self.group kwargs["group"] = self.group
kwargs["style"] = MultiProcessPipe.AsyncSchedule
kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device()) kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device())
self.model = MultiProcessPipe(*args, **kwargs) self.model = AsyncPipe(*args, **kwargs)
self.worker_map = kwargs["worker_map"] self.worker_map = kwargs["worker_map"]
self._foreach_worker(self._register_remote_model, args=(args, kwargs)) self._foreach_worker(self._register_remote_model, args=(args, kwargs))
self.model.cuda() self.model.cuda()
......
...@@ -35,7 +35,6 @@ class LazyModule: ...@@ -35,7 +35,6 @@ class LazyModule:
class PipelineStyle(Enum): class PipelineStyle(Enum):
SingleProcess = auto()
MultiProcess = auto() MultiProcess = auto()
AsyncSchedule = auto() AsyncSchedule = auto()
......
...@@ -431,7 +431,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False ...@@ -431,7 +431,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
model[2].weight.data = saved_weight_2 model[2].weight.data = saved_weight_2
worker_map = {i: f"Test{i}" for i in range(torch.distributed.get_world_size())} worker_map = {i: f"Test{i}" for i in range(torch.distributed.get_world_size())}
style = MultiProcessPipe.MultiProcess # MultiProcessPipe.AsyncSchedule
if pipe_world_size == 2: if pipe_world_size == 2:
print(f"actually doing pipe stuff now") print(f"actually doing pipe stuff now")
...@@ -440,7 +439,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False ...@@ -440,7 +439,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
pipe_model = MultiProcessPipe( pipe_model = MultiProcessPipe(
model, model,
[2, 1], [2, 1],
style=style,
group=pipeline_devices, group=pipeline_devices,
worker_map=worker_map, worker_map=worker_map,
input_device=torch.cuda.current_device(), input_device=torch.cuda.current_device(),
...@@ -507,7 +505,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False ...@@ -507,7 +505,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
failed = False failed = False
with torch.autograd.profiler.profile() as prof: with torch.autograd.profiler.profile() as prof:
try: try:
if style == MultiProcessPipe.MultiProcess:
pipe_model.back_helper(pipe_output) pipe_model.back_helper(pipe_output)
except Exception as e: except Exception as e:
failed = True failed = True
......
...@@ -23,7 +23,7 @@ import pytest ...@@ -23,7 +23,7 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from fairscale.nn.pipe import LazyModule, MultiProcessPipe from fairscale.nn.pipe import AsyncPipe, LazyModule, MultiProcessPipe
from fairscale.nn.pipe.skip import pop, skippable, stash from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
...@@ -33,14 +33,14 @@ from fairscale.utils.testing import get_worker_map, torch_spawn ...@@ -33,14 +33,14 @@ from fairscale.utils.testing import get_worker_map, torch_spawn
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"]) @pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
def x1to3(balance, checkpoint, pipeline_style): def x1to3(balance, checkpoint, pipe_class):
torch.manual_seed(0) torch.manual_seed(0)
if pipeline_style == MultiProcessPipe.AsyncSchedule and len(balance) > 1: if pipe_class == AsyncPipe and len(balance) > 1:
print(f"skipping yarg") print(f"skipping yarg")
pytest.skip("Skip tensors NYI for AsyncSchedule") pytest.skip("Skip tensors NYI for AsyncPipe")
@skippable(stash=["1to3"]) @skippable(stash=["1to3"])
class Layer1(nn.Module): class Layer1(nn.Module):
...@@ -74,13 +74,12 @@ def x1to3(balance, checkpoint, pipeline_style): ...@@ -74,13 +74,12 @@ def x1to3(balance, checkpoint, pipeline_style):
return output return output
model = nn.Sequential(Layer1(), Layer2(), Layer3()) model = nn.Sequential(Layer1(), Layer2(), Layer3())
model = MultiProcessPipe( model = pipe_class(
model, model,
balance, balance,
chunks=3, chunks=3,
checkpoint=checkpoint, checkpoint=checkpoint,
input_device=torch.cuda.current_device(), input_device=torch.cuda.current_device(),
style=pipeline_style,
worker_map=get_worker_map(), worker_map=get_worker_map(),
pipelined_backward=False, pipelined_backward=False,
).cuda() ).cuda()
...@@ -106,11 +105,11 @@ def x1to3(balance, checkpoint, pipeline_style): ...@@ -106,11 +105,11 @@ def x1to3(balance, checkpoint, pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
@pytest.mark.skip(reason="flaky test") @pytest.mark.skip(reason="flaky test")
def none_skip(pipeline_style): def none_skip(pipe_class):
if pipeline_style == MultiProcessPipe.AsyncSchedule: if pipe_class == AsyncPipe:
pytest.skip("Skip tensors NYI for AsyncSchedule") pytest.skip("Skip tensors NYI for AsyncPipe")
@skippable(stash=["none"]) @skippable(stash=["none"])
class Stash(nn.Module): class Stash(nn.Module):
...@@ -126,13 +125,8 @@ def none_skip(pipeline_style): ...@@ -126,13 +125,8 @@ def none_skip(pipeline_style):
return input return input
model = nn.Sequential(Stash(), Pop()) model = nn.Sequential(Stash(), Pop())
model = MultiProcessPipe( model = pipe_class(
model, model, [1, 1], worker_map=get_worker_map(), input_device=torch.cuda.current_device(), chunks=5,
[1, 1],
style=pipeline_style,
worker_map=get_worker_map(),
input_device=torch.cuda.current_device(),
chunks=5,
).cuda() ).cuda()
input = torch.rand(10, requires_grad=True).cuda() input = torch.rand(10, requires_grad=True).cuda()
...@@ -161,8 +155,8 @@ def none_skip(pipeline_style): ...@@ -161,8 +155,8 @@ def none_skip(pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def lazy_skippable_error(pipeline_style): def lazy_skippable_error(pipe_class):
"""Using skippable layers in combination with lazy construction is currently """Using skippable layers in combination with lazy construction is currently
not supported, check that it raises an Exception""" not supported, check that it raises an Exception"""
...@@ -181,6 +175,6 @@ def lazy_skippable_error(pipeline_style): ...@@ -181,6 +175,6 @@ def lazy_skippable_error(pipeline_style):
] ]
with pytest.raises(ValueError, match="Can't use Skippable layers with multi-process pipe and lazy construction"): with pytest.raises(ValueError, match="Can't use Skippable layers with multi-process pipe and lazy construction"):
MultiProcessPipe( pipe_class(
model, [2, 1], style=pipeline_style, worker_map=get_worker_map(), model, [2, 1], worker_map=get_worker_map(),
) )
...@@ -23,7 +23,7 @@ import pytest ...@@ -23,7 +23,7 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from fairscale.nn.pipe import MultiProcessPipe, is_checkpointing, is_recomputing from fairscale.nn.pipe import AsyncPipe, MultiProcessPipe, is_checkpointing, is_recomputing
from fairscale.nn.pipe.skip import pop, skippable, stash from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.tracker import current_skip_tracker from fairscale.nn.pipe.skip.tracker import current_skip_tracker
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
...@@ -46,10 +46,10 @@ class Pop(nn.Module): ...@@ -46,10 +46,10 @@ class Pop(nn.Module):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) @pytest.mark.parametrize("train", [True, False], ids=["train", "eval"])
@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"]) @pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def delete_portal_tensor(train, checkpoint, pipeline_style): def delete_portal_tensor(train, checkpoint, pipe_class):
# Without checkpointing: # Without checkpointing:
# +- Stash --+ +--- Pop ----+ - - - layers # +- Stash --+ +--- Pop ----+ - - - layers
# | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function # | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function
...@@ -60,8 +60,8 @@ def delete_portal_tensor(train, checkpoint, pipeline_style): ...@@ -60,8 +60,8 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 | # | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
# +----------+ +------------+ +------------+ +----------+ # +----------+ +------------+ +------------+ +----------+
if pipeline_style == MultiProcessPipe.AsyncSchedule: if pipe_class == AsyncPipe:
pytest.skip("Skip tensors NYI for AsyncSchedule") pytest.skip("Skip tensors NYI for AsyncPipe")
def portal_tensor_life_is(tensor_life, skip_tracker=None): def portal_tensor_life_is(tensor_life, skip_tracker=None):
if skip_tracker is None: if skip_tracker is None:
...@@ -114,9 +114,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style): ...@@ -114,9 +114,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
return self.F.apply(input) return self.F.apply(input)
model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_) model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_)
model = MultiProcessPipe( model = pipe_class(model, balance=[2, 1], worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint,)
model, balance=[2, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint,
)
input = torch.rand(10, requires_grad=True) input = torch.rand(10, requires_grad=True)
......
...@@ -22,15 +22,15 @@ import torch ...@@ -22,15 +22,15 @@ import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from fairscale.nn.pipe import MultiProcessPipe from fairscale.nn.pipe import AsyncPipe, MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def python_autograd_function(pipeline_style): def python_autograd_function(pipe_class):
# FIXME deadlock with MultiProcessPipe.AsyncSchedule? # FIXME deadlock with AsyncPipe?
# A Python autograd function might fail with this error: # A Python autograd function might fail with this error:
# #
# RuntimeError: Returning Variables sharing storage with other Variables # RuntimeError: Returning Variables sharing storage with other Variables
...@@ -57,9 +57,7 @@ def python_autograd_function(pipeline_style): ...@@ -57,9 +57,7 @@ def python_autograd_function(pipeline_style):
return Identity.apply(input) return Identity.apply(input)
model = nn.Sequential(M(), M()) model = nn.Sequential(M(), M())
model = MultiProcessPipe( model = pipe_class(model, [1, 1], worker_map=get_worker_map(), checkpoint="always").cuda()
model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always"
).cuda()
model.eval() model.eval()
x = torch.rand(42) x = torch.rand(42)
...@@ -73,8 +71,8 @@ def python_autograd_function(pipeline_style): ...@@ -73,8 +71,8 @@ def python_autograd_function(pipeline_style):
@torch_spawn([3]) @torch_spawn([3])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def exception_no_hang(pipeline_style): def exception_no_hang(pipe_class):
# In v0.0.2, once a failed partition receives a normal message # In v0.0.2, once a failed partition receives a normal message
# (non-closing) for the next micro-batch, a hang occured. The reason was # (non-closing) for the next micro-batch, a hang occured. The reason was
# that a failed partition didn't call in_queue.task_done() on a normal # that a failed partition didn't call in_queue.task_done() on a normal
...@@ -92,7 +90,7 @@ def exception_no_hang(pipeline_style): ...@@ -92,7 +90,7 @@ def exception_no_hang(pipeline_style):
raise ExpectedException() raise ExpectedException()
model = nn.Sequential(Pass(), Pass(), Raise()) model = nn.Sequential(Pass(), Pass(), Raise())
model = MultiProcessPipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3) model = pipe_class(model, [1, 1, 1], worker_map=get_worker_map(), chunks=3)
model.eval() model.eval()
if model.group.rank() == 2: if model.group.rank() == 2:
...@@ -106,8 +104,8 @@ def exception_no_hang(pipeline_style): ...@@ -106,8 +104,8 @@ def exception_no_hang(pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required")
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def tuple_wait(cuda_sleep, pipeline_style): def tuple_wait(cuda_sleep, pipe_class):
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch. # In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# Under this behavior, if checkpointing was disabled, there's a possibility # Under this behavior, if checkpointing was disabled, there's a possibility
# that gradient accumulations on other tensors are not synchronized # that gradient accumulations on other tensors are not synchronized
...@@ -135,10 +133,9 @@ def tuple_wait(cuda_sleep, pipeline_style): ...@@ -135,10 +133,9 @@ def tuple_wait(cuda_sleep, pipeline_style):
return a + b + c return a + b + c
model = nn.Sequential(Layer1(), Layer2()) model = nn.Sequential(Layer1(), Layer2())
model = MultiProcessPipe( model = pipe_class(
model, model,
[1, 1], [1, 1],
style=pipeline_style,
worker_map=get_worker_map(), worker_map=get_worker_map(),
input_device=torch.cuda.current_device(), input_device=torch.cuda.current_device(),
chunks=32, chunks=32,
...@@ -160,8 +157,8 @@ def tuple_wait(cuda_sleep, pipeline_style): ...@@ -160,8 +157,8 @@ def tuple_wait(cuda_sleep, pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def parallel_randoms(pipeline_style): def parallel_randoms(pipe_class):
class Dropouts(nn.Module): class Dropouts(nn.Module):
def forward(self, x): def forward(self, x):
for _ in range(100): for _ in range(100):
...@@ -172,10 +169,9 @@ def parallel_randoms(pipeline_style): ...@@ -172,10 +169,9 @@ def parallel_randoms(pipeline_style):
x = torch.rand(10, 10, requires_grad=True).cuda() x = torch.rand(10, 10, requires_grad=True).cuda()
x.retain_grad() x.retain_grad()
model = MultiProcessPipe( model = pipe_class(
model, model,
[1, 1], [1, 1],
style=pipeline_style,
input_device=torch.cuda.current_device(), input_device=torch.cuda.current_device(),
worker_map=get_worker_map(), worker_map=get_worker_map(),
chunks=10, chunks=10,
......
...@@ -21,21 +21,21 @@ import pytest ...@@ -21,21 +21,21 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from fairscale.nn.pipe import MultiProcessPipe from fairscale.nn.pipe import AsyncPipe, MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def inplace_on_requires_grad(pipeline_style): def inplace_on_requires_grad(pipe_class):
model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True))
model = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always") model = pipe_class(model, [1, 1], worker_map=get_worker_map(), checkpoint="always")
x = torch.rand(1) x = torch.rand(1)
if pipeline_style == MultiProcessPipe.AsyncSchedule and model.group.rank() == 0: if pipe_class == AsyncPipe and model.group.rank() == 0:
# With AsyncSchedule, model will wait forever for gradients if not eval # With AsyncPipe, model will wait forever for gradients if not eval
model.eval() model.eval()
y = model(x) y = model(x)
...@@ -50,12 +50,12 @@ def inplace_on_requires_grad(pipeline_style): ...@@ -50,12 +50,12 @@ def inplace_on_requires_grad(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.xfail(strict=True) @pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def inplace_on_not_requires_grad(pipeline_style): def inplace_on_not_requires_grad(pipe_class):
# In-place operation on a tensor not requiring grad doesn't cause a # In-place operation on a tensor not requiring grad doesn't cause a
# RuntimeError. Currently, we cannot detect this case. # RuntimeError. Currently, we cannot detect this case.
model = nn.Sequential(nn.ReLU(inplace=True)) model = nn.Sequential(nn.ReLU(inplace=True))
model = MultiProcessPipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always") model = pipe_class(model, [1], worker_map=get_worker_map(), checkpoint="always")
x = torch.rand(1) x = torch.rand(1)
y = model(x) y = model(x)
...@@ -70,8 +70,8 @@ def inplace_on_not_requires_grad(pipeline_style): ...@@ -70,8 +70,8 @@ def inplace_on_not_requires_grad(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.xfail(strict=True) @pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def inplace_incorrect_grad(pipeline_style): def inplace_incorrect_grad(pipe_class):
class M(nn.Module): class M(nn.Module):
def forward(self, foo_bar): def forward(self, foo_bar):
# 'foo' requires grad but 'bar' does not. In-place operation on # 'foo' requires grad but 'bar' does not. In-place operation on
...@@ -88,7 +88,7 @@ def inplace_incorrect_grad(pipeline_style): ...@@ -88,7 +88,7 @@ def inplace_incorrect_grad(pipeline_style):
return foo * bar return foo * bar
model = nn.Sequential(M()) model = nn.Sequential(M())
model = MultiProcessPipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always") model = pipe_class(model, [1], worker_map=get_worker_map(), checkpoint="always")
foo = torch.tensor([1.0], requires_grad=True) foo = torch.tensor([1.0], requires_grad=True)
bar = torch.tensor([1.0]) bar = torch.tensor([1.0])
......
...@@ -31,15 +31,16 @@ from fairscale.nn.model_parallel.initialize import ( ...@@ -31,15 +31,16 @@ from fairscale.nn.model_parallel.initialize import (
get_pipeline_parallel_group, get_pipeline_parallel_group,
initialize_model_parallel, initialize_model_parallel,
) )
from fairscale.nn.pipe import LazyModule, MultiProcessPipe from fairscale.nn.pipe import AsyncPipe, LazyModule, MultiProcessPipe
from fairscale.nn.pipe.types import PipelineStyle
from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn, torch_version from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn, torch_version
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def parameters(pipeline_style): def parameters(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
pipe = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1) pipe = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=1)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
assert list(pipe.parameters()) != [] assert list(pipe.parameters()) != []
else: else:
...@@ -107,8 +108,8 @@ def mpi(): ...@@ -107,8 +108,8 @@ def mpi():
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def public_attrs(pipeline_style): def public_attrs(pipe_class):
class MyString: class MyString:
def __init__(self, value): def __init__(self, value):
self.value = value self.value = value
...@@ -118,14 +119,7 @@ def public_attrs(pipeline_style): ...@@ -118,14 +119,7 @@ def public_attrs(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
pipe = MultiProcessPipe( pipe = pipe_class(model, balance=(1,), worker_map=get_worker_map(), chunks=42.000, checkpoint=MyString("always"),)
model,
balance=(1,),
style=pipeline_style,
worker_map=get_worker_map(),
chunks=42.000,
checkpoint=MyString("always"),
)
assert pipe.balance == [1] assert pipe.balance == [1]
assert pipe.chunks == 42 assert pipe.chunks == 42
...@@ -136,13 +130,13 @@ def public_attrs(pipeline_style): ...@@ -136,13 +130,13 @@ def public_attrs(pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("balance", [[2], [1, 1]]) @pytest.mark.parametrize("balance", [[2], [1, 1]])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def sequential_like(balance, pipeline_style): def sequential_like(balance, pipe_class):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
model = nn.Sequential(a, b) model = nn.Sequential(a, b)
model = MultiProcessPipe(model, balance, style=pipeline_style, worker_map=get_worker_map()) model = pipe_class(model, balance, worker_map=get_worker_map())
if balance == [2]: if balance == [2]:
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
...@@ -175,62 +169,62 @@ def sequential_like(balance, pipeline_style): ...@@ -175,62 +169,62 @@ def sequential_like(balance, pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def balance_wrong_length(pipeline_style): def balance_wrong_length(pipe_class):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
model = nn.Sequential(a, b) model = nn.Sequential(a, b)
with pytest.raises(ValueError): with pytest.raises(ValueError):
MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map()) pipe_class(model, balance=[1], worker_map=get_worker_map())
with pytest.raises(ValueError): with pytest.raises(ValueError):
MultiProcessPipe(model, balance=[3], style=pipeline_style, worker_map=get_worker_map()) pipe_class(model, balance=[3], worker_map=get_worker_map())
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def balance_less_than_1(pipeline_style): def balance_less_than_1(pipe_class):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
model = nn.Sequential(a, b) model = nn.Sequential(a, b)
with pytest.raises(ValueError): with pytest.raises(ValueError):
MultiProcessPipe(model, balance=[0, 2], style=pipeline_style, worker_map=get_worker_map()) pipe_class(model, balance=[0, 2], worker_map=get_worker_map())
with pytest.raises(ValueError): with pytest.raises(ValueError):
MultiProcessPipe(model, balance=[-1, 3], style=pipeline_style, worker_map=get_worker_map()) pipe_class(model, balance=[-1, 3], worker_map=get_worker_map())
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def chunks_less_than_1(pipeline_style): def chunks_less_than_1(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
with pytest.raises(ValueError): with pytest.raises(ValueError):
MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=0) pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=0)
with pytest.raises(ValueError): with pytest.raises(ValueError):
MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=-1) pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=-1)
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def too_few_devices(pipeline_style): def too_few_devices(pipe_class):
model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1))
with pytest.raises(IndexError): with pytest.raises(IndexError):
# len(balance) > len(group.size()) # len(balance) > len(group.size())
model = MultiProcessPipe(model, balance=[1, 1, 1, 1], style=pipeline_style, worker_map=get_worker_map()) model = pipe_class(model, balance=[1, 1, 1, 1], worker_map=get_worker_map())
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def batch_size_indivisible(pipeline_style): def batch_size_indivisible(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=4) model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=4)
with pytest.warns(None) as record: with pytest.warns(None) as record:
model(torch.rand(7, 1)) model(torch.rand(7, 1))
...@@ -240,10 +234,10 @@ def batch_size_indivisible(pipeline_style): ...@@ -240,10 +234,10 @@ def batch_size_indivisible(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def batch_size_small(pipeline_style): def batch_size_small(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=4) model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=4)
with pytest.warns(None) as record: with pytest.warns(None) as record:
model(torch.rand(2, 1)) model(torch.rand(2, 1))
...@@ -253,8 +247,8 @@ def batch_size_small(pipeline_style): ...@@ -253,8 +247,8 @@ def batch_size_small(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def checkpoint_mode(pipeline_style): def checkpoint_mode(pipe_class):
def count_grad_fn(grad_fn, name, visited=set()): def count_grad_fn(grad_fn, name, visited=set()):
if grad_fn in visited: if grad_fn in visited:
return 0 return 0
...@@ -273,32 +267,14 @@ def checkpoint_mode(pipeline_style): ...@@ -273,32 +267,14 @@ def checkpoint_mode(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
input = torch.rand(2, 1) input = torch.rand(2, 1)
always = MultiProcessPipe( always = pipe_class(
model, model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="always", pipelined_backward=False,
balance=[1],
style=pipeline_style,
worker_map=get_worker_map(),
chunks=2,
checkpoint="always",
pipelined_backward=False,
) )
except_last = MultiProcessPipe( except_last = pipe_class(
model, model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="except_last", pipelined_backward=False,
balance=[1],
style=pipeline_style,
worker_map=get_worker_map(),
chunks=2,
checkpoint="except_last",
pipelined_backward=False,
) )
never = MultiProcessPipe( never = pipe_class(
model, model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="never", pipelined_backward=False,
balance=[1],
style=pipeline_style,
worker_map=get_worker_map(),
chunks=2,
checkpoint="never",
pipelined_backward=False,
) )
always_output = always(input) always_output = always(input)
...@@ -311,45 +287,34 @@ def checkpoint_mode(pipeline_style): ...@@ -311,45 +287,34 @@ def checkpoint_mode(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def checkpoint_mode_invalid(pipeline_style): def checkpoint_mode_invalid(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"): with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"):
MultiProcessPipe( pipe_class(
model, model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint="INVALID_CHECKPOINT",
balance=[1],
style=pipeline_style,
worker_map=get_worker_map(),
chunks=2,
checkpoint="INVALID_CHECKPOINT",
) )
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def checkpoint_mode_when_chunks_1(pipeline_style): def checkpoint_mode_when_chunks_1(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
# All checkpoint modes are fine. # All checkpoint modes are fine.
MultiProcessPipe( pipe_class(
model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="except_last", model, balance=[1], worker_map=get_worker_map(), chunks=1, checkpoint="except_last",
)
MultiProcessPipe(
model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="always"
)
MultiProcessPipe(
model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="never"
) )
pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=1, checkpoint="always")
pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=1, checkpoint="never")
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def checkpoint_eval(pipeline_style): def checkpoint_eval(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = MultiProcessPipe( model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2, pipelined_backward=False,)
model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False,
)
input = torch.rand(2, 1) input = torch.rand(2, 1)
def find_grad_fn(grad_fn, name): def find_grad_fn(grad_fn, name):
...@@ -375,8 +340,8 @@ def checkpoint_eval(pipeline_style): ...@@ -375,8 +340,8 @@ def checkpoint_eval(pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.xfail(torch_version() < (1, 6, 0), reason="Doesn't work on torch < 1.6.0", strict=True) @pytest.mark.xfail(torch_version() < (1, 6, 0), reason="Doesn't work on torch < 1.6.0", strict=True)
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def checkpoint_non_float_input(pipeline_style): def checkpoint_non_float_input(pipe_class):
class ForkNonFloat(nn.Module): class ForkNonFloat(nn.Module):
def forward(self, input): def forward(self, input):
return (input * 2, torch.tensor([False])) return (input * 2, torch.tensor([False]))
...@@ -386,14 +351,8 @@ def checkpoint_non_float_input(pipeline_style): ...@@ -386,14 +351,8 @@ def checkpoint_non_float_input(pipeline_style):
return input[0] * 2 return input[0] * 2
model = nn.Sequential(ForkNonFloat(), JoinNonFloat()) model = nn.Sequential(ForkNonFloat(), JoinNonFloat())
model = MultiProcessPipe( model = pipe_class(
model, model, balance=[1, 1], worker_map=get_worker_map(), chunks=1, checkpoint="always", pipelined_backward=False,
balance=[1, 1],
style=pipeline_style,
worker_map=get_worker_map(),
chunks=1,
checkpoint="always",
pipelined_backward=False,
) )
input = torch.rand(1, requires_grad=True) input = torch.rand(1, requires_grad=True)
...@@ -401,17 +360,17 @@ def checkpoint_non_float_input(pipeline_style): ...@@ -401,17 +360,17 @@ def checkpoint_non_float_input(pipeline_style):
if model.group.rank() == 1: if model.group.rank() == 1:
# with torch.autograd.detect_anomaly(): # with torch.autograd.detect_anomaly():
output.backward() output.backward()
elif pipeline_style == MultiProcessPipe.MultiProcess: elif pipe_class == MultiProcessPipe:
model.back_helper(output) model.back_helper(output)
torch.distributed.barrier() torch.distributed.barrier()
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def no_grad(pipeline_style): def no_grad(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2) model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2)
input = torch.rand(2, 1) input = torch.rand(2, 1)
latent = None latent = None
...@@ -433,8 +392,8 @@ def no_grad(pipeline_style): ...@@ -433,8 +392,8 @@ def no_grad(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def exception(pipeline_style): def exception(pipe_class):
class ExpectedException(Exception): class ExpectedException(Exception):
pass pass
...@@ -443,7 +402,7 @@ def exception(pipeline_style): ...@@ -443,7 +402,7 @@ def exception(pipeline_style):
raise ExpectedException() raise ExpectedException()
model = nn.Sequential(Raise()) model = nn.Sequential(Raise())
model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1) model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=1)
with pytest.raises(ExpectedException): with pytest.raises(ExpectedException):
model(torch.rand(1)) model(torch.rand(1))
...@@ -453,8 +412,8 @@ def exception(pipeline_style): ...@@ -453,8 +412,8 @@ def exception(pipeline_style):
@torch_spawn([4]) @torch_spawn([4])
@pytest.mark.skipif(torch.cuda.is_available() and torch.cuda.device_count() < 4, reason="Not enough GPUs") @pytest.mark.skipif(torch.cuda.is_available() and torch.cuda.device_count() < 4, reason="Not enough GPUs")
@pytest.mark.xfail(strict=True) @pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def exception_early_stop_asap(pipeline_style): def exception_early_stop_asap(pipe_class):
"""Even the first partitions have finished to process, the partition before """Even the first partitions have finished to process, the partition before
the failed partition hould be killed as soon as possible. the failed partition hould be killed as soon as possible.
""" """
...@@ -482,7 +441,7 @@ def exception_early_stop_asap(pipeline_style): ...@@ -482,7 +441,7 @@ def exception_early_stop_asap(pipeline_style):
raise ExpectedException() raise ExpectedException()
model = nn.Sequential(Pass(), Pass(), Counter(), Raise()) model = nn.Sequential(Pass(), Pass(), Counter(), Raise())
model = MultiProcessPipe(model, [1, 1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3) model = pipe_class(model, [1, 1, 1, 1], worker_map=get_worker_map(), chunks=3)
with pytest.raises(ExpectedException): with pytest.raises(ExpectedException):
model(torch.rand(3)) model(torch.rand(3))
...@@ -492,8 +451,8 @@ def exception_early_stop_asap(pipeline_style): ...@@ -492,8 +451,8 @@ def exception_early_stop_asap(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def input_pair(pipeline_style): def input_pair(pipe_class):
class Two(nn.Module): class Two(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -505,9 +464,7 @@ def input_pair(pipeline_style): ...@@ -505,9 +464,7 @@ def input_pair(pipeline_style):
return (self.fc_a(a), self.fc_b(b)) return (self.fc_a(a), self.fc_b(b))
model = nn.Sequential(Two()) model = nn.Sequential(Two())
model = MultiProcessPipe( model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2, pipelined_backward=False,)
model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False,
)
a = torch.rand(10, 1, requires_grad=True) a = torch.rand(10, 1, requires_grad=True)
b = torch.rand(10, 1, requires_grad=True) b = torch.rand(10, 1, requires_grad=True)
...@@ -521,8 +478,8 @@ def input_pair(pipeline_style): ...@@ -521,8 +478,8 @@ def input_pair(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def input_singleton(pipeline_style): def input_singleton(pipe_class):
class One(nn.Module): class One(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -533,9 +490,7 @@ def input_singleton(pipeline_style): ...@@ -533,9 +490,7 @@ def input_singleton(pipeline_style):
return (self.fc(a),) return (self.fc(a),)
model = nn.Sequential(One()) model = nn.Sequential(One())
model = MultiProcessPipe( model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2, pipelined_backward=False,)
model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False,
)
a = torch.rand(10, 1, requires_grad=True) a = torch.rand(10, 1, requires_grad=True)
...@@ -548,10 +503,10 @@ def input_singleton(pipeline_style): ...@@ -548,10 +503,10 @@ def input_singleton(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def input_varargs(pipeline_style): def input_varargs(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map()) model = pipe_class(model, balance=[1], worker_map=get_worker_map())
a = torch.rand(1) a = torch.rand(1)
b = torch.rand(1) b = torch.rand(1)
...@@ -562,14 +517,14 @@ def input_varargs(pipeline_style): ...@@ -562,14 +517,14 @@ def input_varargs(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def non_tensor(pipeline_style): def non_tensor(pipe_class):
class NonTensor(nn.Module): class NonTensor(nn.Module):
def forward(self, _): def forward(self, _):
return "hello" return "hello"
model = nn.Sequential(NonTensor()) model = nn.Sequential(NonTensor())
model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map()) model = pipe_class(model, balance=[1], worker_map=get_worker_map())
x = torch.rand(1) x = torch.rand(1)
# TypeError: expected Tensor as element 0 in argument 0, but got str # TypeError: expected Tensor as element 0 in argument 0, but got str
...@@ -582,14 +537,14 @@ def non_tensor(pipeline_style): ...@@ -582,14 +537,14 @@ def non_tensor(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def non_tensor_tuple(pipeline_style): def non_tensor_tuple(pipe_class):
class NonTensorTuple(nn.Module): class NonTensorTuple(nn.Module):
def forward(self, x): def forward(self, x):
return (x, "hello") return (x, "hello")
model = nn.Sequential(NonTensorTuple()) model = nn.Sequential(NonTensorTuple())
model = MultiProcessPipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map()) model = pipe_class(model, balance=[1], worker_map=get_worker_map())
x = torch.rand(1) x = torch.rand(1)
# TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1 # TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1
...@@ -604,8 +559,8 @@ def non_tensor_tuple(pipeline_style): ...@@ -604,8 +559,8 @@ def non_tensor_tuple(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
@pytest.mark.parametrize("lazy", [True, False]) @pytest.mark.parametrize("lazy", [True, False])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def deferred_batch_norm(checkpoint, lazy, pipeline_style): def deferred_batch_norm(checkpoint, lazy, pipe_class):
bn = nn.BatchNorm2d(3) bn = nn.BatchNorm2d(3)
pipe_bn = deepcopy(bn) pipe_bn = deepcopy(bn)
pipe_fn = lambda: pipe_bn # noqa: E731 pipe_fn = lambda: pipe_bn # noqa: E731
...@@ -613,14 +568,8 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style): ...@@ -613,14 +568,8 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style):
model = [LazyModule(pipe_fn)] model = [LazyModule(pipe_fn)]
else: else:
model = nn.Sequential(pipe_bn) model = nn.Sequential(pipe_bn)
pipe = MultiProcessPipe( pipe = pipe_class(
model, model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint, deferred_batch_norm=True,
balance=[1],
style=pipeline_style,
worker_map=get_worker_map(),
chunks=2,
checkpoint=checkpoint,
deferred_batch_norm=True,
) )
x = torch.rand(4, 3, 10, 10) x = torch.rand(4, 3, 10, 10)
...@@ -634,8 +583,8 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style): ...@@ -634,8 +583,8 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("checkpoint", ["never", "always"]) @pytest.mark.parametrize("checkpoint", ["never", "always"])
@pytest.mark.parametrize("lazy", [True, False]) @pytest.mark.parametrize("lazy", [True, False])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def deferred_batch_norm_params(checkpoint, lazy, pipeline_style): def deferred_batch_norm_params(checkpoint, lazy, pipe_class):
bn = nn.BatchNorm2d(3) bn = nn.BatchNorm2d(3)
pipe_bn = deepcopy(bn) pipe_bn = deepcopy(bn)
pipe_fn = lambda: pipe_bn # noqa: E731 pipe_fn = lambda: pipe_bn # noqa: E731
...@@ -643,14 +592,8 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style): ...@@ -643,14 +592,8 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
model = [LazyModule(pipe_fn)] model = [LazyModule(pipe_fn)]
else: else:
model = nn.Sequential(pipe_bn) model = nn.Sequential(pipe_bn)
pipe = MultiProcessPipe( pipe = pipe_class(
model, model, balance=[1], worker_map=get_worker_map(), chunks=1, checkpoint=checkpoint, deferred_batch_norm=True,
balance=[1],
style=pipeline_style,
worker_map=get_worker_map(),
chunks=1,
checkpoint=checkpoint,
deferred_batch_norm=True,
) )
x = torch.rand(4, 3, 10, 10) x = torch.rand(4, 3, 10, 10)
...@@ -665,15 +608,15 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style): ...@@ -665,15 +608,15 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
@torch_spawn([4]) @torch_spawn([4])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def devices(pipeline_style): def devices(pipe_class):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
c = nn.Linear(1, 1) c = nn.Linear(1, 1)
# There are extra two ranks. # There are extra two ranks.
model = nn.Sequential(a, b, c) model = nn.Sequential(a, b, c)
model = MultiProcessPipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map()) model = pipe_class(model, [1, 1, 1], worker_map=get_worker_map())
# Extra devices must be discarded. # Extra devices must be discarded.
if model.group.rank() == 3: if model.group.rank() == 3:
...@@ -681,13 +624,13 @@ def devices(pipeline_style): ...@@ -681,13 +624,13 @@ def devices(pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def partitions(pipeline_style): def partitions(pipe_class):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
model = nn.Sequential(a, b) model = nn.Sequential(a, b)
model = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) model = pipe_class(model, [1, 1], worker_map=get_worker_map())
assert isinstance(model.partitions, list) assert isinstance(model.partitions, list)
assert len(model) == 1 assert len(model) == 1
...@@ -701,13 +644,13 @@ def partitions(pipeline_style): ...@@ -701,13 +644,13 @@ def partitions(pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def deny_moving(pipeline_style): def deny_moving(pipe_class):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
model = nn.Sequential(a, b) model = nn.Sequential(a, b)
model = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) model = pipe_class(model, [1, 1], worker_map=get_worker_map())
model.cuda() model.cuda()
model.cpu() model.cpu()
...@@ -725,11 +668,11 @@ def deny_moving(pipeline_style): ...@@ -725,11 +668,11 @@ def deny_moving(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def empty_module(pipeline_style): def empty_module(pipe_class):
# Empty sequential module is not illegal. # Empty sequential module is not illegal.
model = nn.Sequential() model = nn.Sequential()
model = MultiProcessPipe(model, [], style=pipeline_style, worker_map=get_worker_map()) model = pipe_class(model, [], worker_map=get_worker_map())
assert model(torch.tensor([42])) == torch.tensor([42]) assert model(torch.tensor([42])) == torch.tensor([42])
assert model((torch.tensor([42]),)) == (torch.tensor([42]),) assert model((torch.tensor([42]),)) == (torch.tensor([42]),)
...@@ -741,13 +684,13 @@ def empty_module(pipeline_style): ...@@ -741,13 +684,13 @@ def empty_module(pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def named_children(pipeline_style): def named_children(pipe_class):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
model = nn.Sequential(OrderedDict([("a", a), ("b", b)])) model = nn.Sequential(OrderedDict([("a", a), ("b", b)]))
model = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) model = pipe_class(model, [1, 1], worker_map=get_worker_map())
names = set(n for n, _ in model.named_modules()) names = set(n for n, _ in model.named_modules())
if model.group.rank() == 0: if model.group.rank() == 0:
...@@ -762,24 +705,24 @@ def named_children(pipeline_style): ...@@ -762,24 +705,24 @@ def named_children(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def recommend_auto_balance(pipeline_style): def recommend_auto_balance(pipe_class):
with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"): with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"):
# balance is required # balance is required
MultiProcessPipe(nn.Sequential()) pipe_class(nn.Sequential())
with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"): with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"):
# module and sum of balance have differen length (module: 0, sum of balance: 1) # module and sum of balance have differen length (module: 0, sum of balance: 1)
MultiProcessPipe(nn.Sequential(), [1]) pipe_class(nn.Sequential(), [1])
with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"): with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"):
# module and sum of balance have different length (module: 2, sum of balance: 1) # module and sum of balance have different length (module: 2, sum of balance: 1)
MultiProcessPipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1]) pipe_class(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1])
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def lazy_construction(pipeline_style): def lazy_construction(pipe_class):
init_count = 0 init_count = 0
class Custom(nn.Module): class Custom(nn.Module):
...@@ -798,7 +741,7 @@ def lazy_construction(pipeline_style): ...@@ -798,7 +741,7 @@ def lazy_construction(pipeline_style):
LazyModule(lambda: Custom()), LazyModule(lambda: Custom()),
] ]
pipe = MultiProcessPipe(model, balance=[2, 2], style=pipeline_style, worker_map=get_worker_map()) pipe = pipe_class(model, balance=[2, 2], worker_map=get_worker_map())
assert isinstance(pipe[0], Custom) assert isinstance(pipe[0], Custom)
assert isinstance(pipe[1], Custom) assert isinstance(pipe[1], Custom)
...@@ -808,18 +751,18 @@ def lazy_construction(pipeline_style): ...@@ -808,18 +751,18 @@ def lazy_construction(pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="doesn't apply to mpi") @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="doesn't apply to mpi")
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def missing_worker_map(pipeline_style): def missing_worker_map(pipe_class):
model = nn.Sequential(nn.ReLU(), nn.ReLU()) model = nn.Sequential(nn.ReLU(), nn.ReLU())
with pytest.raises(ValueError, match="'RpcTransport' requires 'worker_map' to be set"): with pytest.raises(ValueError, match="'RpcTransport' requires 'worker_map' to be set"):
MultiProcessPipe(model, [1, 1], style=pipeline_style) pipe_class(model, [1, 1])
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skip(reason="currently broken") @pytest.mark.skip(reason="currently broken")
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def verify_module_duplicate_parameters_on_distinct_partitions(pipeline_style): def verify_module_duplicate_parameters_on_distinct_partitions(pipe_class):
class Surrogate(nn.Module): class Surrogate(nn.Module):
def __init__(self, module): def __init__(self, module):
super().__init__() super().__init__()
...@@ -830,23 +773,23 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipeline_style): ...@@ -830,23 +773,23 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipeline_style):
# FIXME(tom) can't have duplicate params with separate processes # FIXME(tom) can't have duplicate params with separate processes
with pytest.raises(ValueError, match="module with duplicate parameters on distinct devices is not supported"): with pytest.raises(ValueError, match="module with duplicate parameters on distinct devices is not supported"):
MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) pipe_class(model, [1, 1], worker_map=get_worker_map())
@torch_spawn([4]) @torch_spawn([4])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def pipelined_backward(pipeline_style): def pipelined_backward(pipe_class):
model = nn.Sequential(nn.ReLU(), nn.ReLU()) model = nn.Sequential(nn.ReLU(), nn.ReLU())
destroy_model_parallel() destroy_model_parallel()
initialize_model_parallel(1, 4) initialize_model_parallel(1, 4)
pipe = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) pipe = pipe_class(model, [1, 1], worker_map=get_worker_map())
assert pipe.pipelined_backward is False assert pipe.pipelined_backward is False
destroy_model_parallel() destroy_model_parallel()
initialize_model_parallel(2, 2) initialize_model_parallel(2, 2)
pipe = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) pipe = pipe_class(model, [1, 1], worker_map=get_worker_map())
assert pipe.pipelined_backward is True assert pipe.pipelined_backward is True
...@@ -855,9 +798,7 @@ def pipelined_backward(pipeline_style): ...@@ -855,9 +798,7 @@ def pipelined_backward(pipeline_style):
def async_event_loop(): def async_event_loop():
model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU()) model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU())
pipe = MultiProcessPipe( pipe = AsyncPipe(model, [1, 1, 1, 1], worker_map=get_worker_map(), chunks=10)
model, [1, 1, 1, 1], style=MultiProcessPipe.AsyncSchedule, worker_map=get_worker_map(), chunks=10
)
inputs = torch.rand(100, 10) inputs = torch.rand(100, 10)
...@@ -873,7 +814,7 @@ def reuse_lazy(): ...@@ -873,7 +814,7 @@ def reuse_lazy():
reused = LazyModule(lambda: nn.Linear(10, 10)) reused = LazyModule(lambda: nn.Linear(10, 10))
model = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] model = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()]
# model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()] # model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()]
pipe = MultiProcessPipe(model, [3, 1, 1], style=MultiProcessPipe.AsyncSchedule, worker_map=get_worker_map()) pipe = AsyncPipe(model, [3, 1, 1], worker_map=get_worker_map())
pipe.eval() pipe.eval()
output = pipe(torch.rand(10)) output = pipe(torch.rand(10))
...@@ -891,7 +832,7 @@ def reuse_lazy(): ...@@ -891,7 +832,7 @@ def reuse_lazy():
# ensure identical weights but no sharing between model and pipe # ensure identical weights but no sharing between model and pipe
reused = nn.Linear(10, 10) reused = nn.Linear(10, 10)
layers = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] layers = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()]
pipe = MultiProcessPipe(layers, [3, 1, 1], style=MultiProcessPipe.AsyncSchedule, worker_map=get_worker_map()) pipe = AsyncPipe(layers, [3, 1, 1], worker_map=get_worker_map())
pipe.eval() pipe.eval()
model_optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) model_optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
pipe_optimizer = torch.optim.SGD(pipe.parameters(), lr=0.01, momentum=0.9) if len(list(pipe.parameters())) else None pipe_optimizer = torch.optim.SGD(pipe.parameters(), lr=0.01, momentum=0.9) if len(list(pipe.parameters())) else None
...@@ -964,7 +905,7 @@ def test_instantiate_partition(): ...@@ -964,7 +905,7 @@ def test_instantiate_partition():
# instantiated model # instantiated model
for rank in range(len(balance)): for rank in range(len(balance)):
instantiated = instantiate_partition( instantiated = instantiate_partition(
model, balance, FakeGroup(rank, len(balance)), MultiProcessPipe.AsyncSchedule model, balance, FakeGroup(rank, len(balance)), PipelineStyle.AsyncSchedule
) )
for part in instantiated: for part in instantiated:
assert isinstance(part.module, nn.Sequential) assert isinstance(part.module, nn.Sequential)
......
...@@ -21,14 +21,14 @@ import pytest ...@@ -21,14 +21,14 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from fairscale.nn.pipe import MultiProcessPipe from fairscale.nn.pipe import AsyncPipe, MultiProcessPipe
from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def simple_linears(pipeline_style): def simple_linears(pipe_class):
def sum_grad(parameters): def sum_grad(parameters):
return sum([p.grad.sum() for p in parameters if p.grad is not None]) return sum([p.grad.sum() for p in parameters if p.grad is not None])
...@@ -54,8 +54,7 @@ def simple_linears(pipeline_style): ...@@ -54,8 +54,7 @@ def simple_linears(pipeline_style):
zero_grad(model.parameters()) zero_grad(model.parameters())
# With MultiProcessPipe model = pipe_class(model, [2, 2], worker_map=get_worker_map(), chunks=4)
model = MultiProcessPipe(model, [2, 2], style=pipeline_style, worker_map=get_worker_map(), chunks=4)
outputs = model(inputs) outputs = model(inputs)
if model.group.rank() == 1: if model.group.rank() == 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment