Unverified Commit eaee5976 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] make AsyncPipe its own class (#341)

parent 51625eda
......@@ -21,7 +21,7 @@ from torchtext.data.utils import get_tokenizer
from experimental.nn.ampnet_pipe import pipe
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule, MultiProcessPipe
from fairscale.nn.pipe import LazyModule
from fairscale.optim import GradScaler
from fairscale.utils.testing import dist_init, get_worker_map
......@@ -420,7 +420,6 @@ def run_mp_worker(args, available_workers):
p = pipe.AMPnetPipe(
module=model,
balance=balance,
style=MultiProcessPipe.AsyncSchedule,
chunks=args.chunks,
worker_map=get_worker_map(),
input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
......
......@@ -499,7 +499,6 @@ def run_mp_worker(args, available_workers):
pipe_model = MultiProcessPipe(
model,
balance,
style=MultiProcessPipe.AsyncSchedule,
chunks=args.chunks,
worker_map=get_worker_map(),
input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
......
......@@ -37,7 +37,6 @@ def create_task_without_skip_trackers(
checkpoint_stop: int, i: int, j: int, batch: Batch, partition: nn.Sequential,
) -> Task:
# Determine whether checkpointing or not.
# style is guaranteed to be PipelineStyle.AsyncSchedule
if i < checkpoint_stop:
def function(
......
......@@ -11,15 +11,14 @@ from torch import nn
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from fairscale.nn.pipe import MultiProcessPipe
from fairscale.nn.pipe.types import PipelineStyle
from fairscale.nn.pipe import AsyncPipe
from .ampnet import AsyncAMPnetEventLoop
__all__ = ["AMPnetPipe"]
class AMPnetPipe(MultiProcessPipe):
class AMPnetPipe(AsyncPipe):
"""
AMPnetPipe is the asynchronous version of the MultiProcessPipe implementation
which avoids the bubble issue, by using stale weights and gradients.
......@@ -44,7 +43,6 @@ class AMPnetPipe(MultiProcessPipe):
# AMPnet implementation doesn't handle skip_trackers!
assert self.pipeline.style is PipelineStyle.AsyncSchedule # type: ignore
assert self.group
rank = self.group.rank()
......
......@@ -23,7 +23,6 @@ from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset
from experimental.nn.ampnet_pipe.pipe import AMPnetPipe
from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn
......@@ -84,14 +83,7 @@ class FakeDataset(Dataset):
@torch_spawn([2])
def async_event_loop_interleave_simple():
model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(inplace=False), nn.Linear(10, 10), nn.ReLU(inplace=False))
pipe = AMPnetPipe(
module=model,
balance=[2, 2],
style=MultiProcessPipe.AsyncSchedule,
worker_map=get_worker_map(),
chunks=10,
checkpoint="never",
)
pipe = AMPnetPipe(module=model, balance=[2, 2], worker_map=get_worker_map(), chunks=10, checkpoint="never",)
fake_dataset = FakeDataset()
fake_dataloader = DataLoader(fake_dataset, batch_size=4, shuffle=True, num_workers=0)
loss = nn.MSELoss()
......@@ -102,14 +94,7 @@ def async_event_loop_interleave_simple():
@torch_spawn([4])
def async_event_loop_interleave_hard():
model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10))
pipe = AMPnetPipe(
module=model,
balance=[1, 1, 1, 1],
style=MultiProcessPipe.AsyncSchedule,
worker_map=get_worker_map(),
chunks=10,
checkpoint="never",
)
pipe = AMPnetPipe(module=model, balance=[1, 1, 1, 1], worker_map=get_worker_map(), chunks=10, checkpoint="never",)
fake_dataset = FakeDataset()
fake_dataloader = DataLoader(fake_dataset, batch_size=4, shuffle=True, num_workers=0)
loss = nn.MSELoss()
......
......@@ -18,6 +18,7 @@
# limitations under the License.
"""A Pipe implementation in PyTorch."""
from .async_pipe import AsyncPipe
from .checkpoint import is_checkpointing, is_recomputing
from .multiprocess_pipe import LazyModule, MultiProcessPipe
from .pipe import Pipe
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from .multiprocess_pipe import MultiProcessPipe
from .types import PipelineStyle
class AsyncPipe(MultiProcessPipe):
def __init__(self, *args, **kwargs) -> None: # type: ignore
super().__init__(*args, style=PipelineStyle.AsyncSchedule, **kwargs)
......@@ -386,9 +386,6 @@ class MultiProcessPipe(Module):
"""
MultiProcess: PipelineStyle = PipelineStyle.MultiProcess
AsyncSchedule: PipelineStyle = PipelineStyle.AsyncSchedule
#: The number of layers in each partition.
balance: List[int] = []
# ^^
......
......@@ -13,6 +13,7 @@ from torch.distributed.distributed_c10d import _get_global_rank
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from .async_pipe import AsyncPipe
from .multiprocess_pipe import MultiProcessPipe
from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors
......@@ -105,10 +106,9 @@ class PipeRPCWrapper(nn.Module):
else:
kwargs["group"] = self.group
kwargs["style"] = MultiProcessPipe.AsyncSchedule
kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device())
self.model = MultiProcessPipe(*args, **kwargs)
self.model = AsyncPipe(*args, **kwargs)
self.worker_map = kwargs["worker_map"]
self._foreach_worker(self._register_remote_model, args=(args, kwargs))
self.model.cuda()
......
......@@ -35,7 +35,6 @@ class LazyModule:
class PipelineStyle(Enum):
SingleProcess = auto()
MultiProcess = auto()
AsyncSchedule = auto()
......
......@@ -431,7 +431,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
model[2].weight.data = saved_weight_2
worker_map = {i: f"Test{i}" for i in range(torch.distributed.get_world_size())}
style = MultiProcessPipe.MultiProcess # MultiProcessPipe.AsyncSchedule
if pipe_world_size == 2:
print(f"actually doing pipe stuff now")
......@@ -440,7 +439,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
pipe_model = MultiProcessPipe(
model,
[2, 1],
style=style,
group=pipeline_devices,
worker_map=worker_map,
input_device=torch.cuda.current_device(),
......@@ -507,7 +505,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
failed = False
with torch.autograd.profiler.profile() as prof:
try:
if style == MultiProcessPipe.MultiProcess:
pipe_model.back_helper(pipe_output)
except Exception as e:
failed = True
......
......@@ -23,7 +23,7 @@ import pytest
import torch
from torch import nn
from fairscale.nn.pipe import LazyModule, MultiProcessPipe
from fairscale.nn.pipe import AsyncPipe, LazyModule, MultiProcessPipe
from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange
from fairscale.utils.testing import get_worker_map, torch_spawn
......@@ -33,14 +33,14 @@ from fairscale.utils.testing import get_worker_map, torch_spawn
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
def x1to3(balance, checkpoint, pipeline_style):
def x1to3(balance, checkpoint, pipe_class):
torch.manual_seed(0)
if pipeline_style == MultiProcessPipe.AsyncSchedule and len(balance) > 1:
if pipe_class == AsyncPipe and len(balance) > 1:
print(f"skipping yarg")
pytest.skip("Skip tensors NYI for AsyncSchedule")
pytest.skip("Skip tensors NYI for AsyncPipe")
@skippable(stash=["1to3"])
class Layer1(nn.Module):
......@@ -74,13 +74,12 @@ def x1to3(balance, checkpoint, pipeline_style):
return output
model = nn.Sequential(Layer1(), Layer2(), Layer3())
model = MultiProcessPipe(
model = pipe_class(
model,
balance,
chunks=3,
checkpoint=checkpoint,
input_device=torch.cuda.current_device(),
style=pipeline_style,
worker_map=get_worker_map(),
pipelined_backward=False,
).cuda()
......@@ -106,11 +105,11 @@ def x1to3(balance, checkpoint, pipeline_style):
@torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
@pytest.mark.skip(reason="flaky test")
def none_skip(pipeline_style):
if pipeline_style == MultiProcessPipe.AsyncSchedule:
pytest.skip("Skip tensors NYI for AsyncSchedule")
def none_skip(pipe_class):
if pipe_class == AsyncPipe:
pytest.skip("Skip tensors NYI for AsyncPipe")
@skippable(stash=["none"])
class Stash(nn.Module):
......@@ -126,13 +125,8 @@ def none_skip(pipeline_style):
return input
model = nn.Sequential(Stash(), Pop())
model = MultiProcessPipe(
model,
[1, 1],
style=pipeline_style,
worker_map=get_worker_map(),
input_device=torch.cuda.current_device(),
chunks=5,
model = pipe_class(
model, [1, 1], worker_map=get_worker_map(), input_device=torch.cuda.current_device(), chunks=5,
).cuda()
input = torch.rand(10, requires_grad=True).cuda()
......@@ -161,8 +155,8 @@ def none_skip(pipeline_style):
@torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def lazy_skippable_error(pipeline_style):
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def lazy_skippable_error(pipe_class):
"""Using skippable layers in combination with lazy construction is currently
not supported, check that it raises an Exception"""
......@@ -181,6 +175,6 @@ def lazy_skippable_error(pipeline_style):
]
with pytest.raises(ValueError, match="Can't use Skippable layers with multi-process pipe and lazy construction"):
MultiProcessPipe(
model, [2, 1], style=pipeline_style, worker_map=get_worker_map(),
pipe_class(
model, [2, 1], worker_map=get_worker_map(),
)
......@@ -23,7 +23,7 @@ import pytest
import torch
from torch import nn
from fairscale.nn.pipe import MultiProcessPipe, is_checkpointing, is_recomputing
from fairscale.nn.pipe import AsyncPipe, MultiProcessPipe, is_checkpointing, is_recomputing
from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.tracker import current_skip_tracker
from fairscale.utils.testing import get_worker_map, torch_spawn
......@@ -46,10 +46,10 @@ class Pop(nn.Module):
@torch_spawn([2])
@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"])
@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def delete_portal_tensor(train, checkpoint, pipeline_style):
def delete_portal_tensor(train, checkpoint, pipe_class):
# Without checkpointing:
# +- Stash --+ +--- Pop ----+ - - - layers
# | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function
......@@ -60,8 +60,8 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
# +----------+ +------------+ +------------+ +----------+
if pipeline_style == MultiProcessPipe.AsyncSchedule:
pytest.skip("Skip tensors NYI for AsyncSchedule")
if pipe_class == AsyncPipe:
pytest.skip("Skip tensors NYI for AsyncPipe")
def portal_tensor_life_is(tensor_life, skip_tracker=None):
if skip_tracker is None:
......@@ -114,9 +114,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
return self.F.apply(input)
model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_)
model = MultiProcessPipe(
model, balance=[2, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint,
)
model = pipe_class(model, balance=[2, 1], worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint,)
input = torch.rand(10, requires_grad=True)
......
......@@ -22,15 +22,15 @@ import torch
from torch import nn
import torch.nn.functional as F
from fairscale.nn.pipe import MultiProcessPipe
from fairscale.nn.pipe import AsyncPipe, MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn
@torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def python_autograd_function(pipeline_style):
# FIXME deadlock with MultiProcessPipe.AsyncSchedule?
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def python_autograd_function(pipe_class):
# FIXME deadlock with AsyncPipe?
# A Python autograd function might fail with this error:
#
# RuntimeError: Returning Variables sharing storage with other Variables
......@@ -57,9 +57,7 @@ def python_autograd_function(pipeline_style):
return Identity.apply(input)
model = nn.Sequential(M(), M())
model = MultiProcessPipe(
model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always"
).cuda()
model = pipe_class(model, [1, 1], worker_map=get_worker_map(), checkpoint="always").cuda()
model.eval()
x = torch.rand(42)
......@@ -73,8 +71,8 @@ def python_autograd_function(pipeline_style):
@torch_spawn([3])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def exception_no_hang(pipeline_style):
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def exception_no_hang(pipe_class):
# In v0.0.2, once a failed partition receives a normal message
# (non-closing) for the next micro-batch, a hang occured. The reason was
# that a failed partition didn't call in_queue.task_done() on a normal
......@@ -92,7 +90,7 @@ def exception_no_hang(pipeline_style):
raise ExpectedException()
model = nn.Sequential(Pass(), Pass(), Raise())
model = MultiProcessPipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3)
model = pipe_class(model, [1, 1, 1], worker_map=get_worker_map(), chunks=3)
model.eval()
if model.group.rank() == 2:
......@@ -106,8 +104,8 @@ def exception_no_hang(pipeline_style):
@torch_spawn([2])
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required")
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def tuple_wait(cuda_sleep, pipeline_style):
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def tuple_wait(cuda_sleep, pipe_class):
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# Under this behavior, if checkpointing was disabled, there's a possibility
# that gradient accumulations on other tensors are not synchronized
......@@ -135,10 +133,9 @@ def tuple_wait(cuda_sleep, pipeline_style):
return a + b + c
model = nn.Sequential(Layer1(), Layer2())
model = MultiProcessPipe(
model = pipe_class(
model,
[1, 1],
style=pipeline_style,
worker_map=get_worker_map(),
input_device=torch.cuda.current_device(),
chunks=32,
......@@ -160,8 +157,8 @@ def tuple_wait(cuda_sleep, pipeline_style):
@torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def parallel_randoms(pipeline_style):
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def parallel_randoms(pipe_class):
class Dropouts(nn.Module):
def forward(self, x):
for _ in range(100):
......@@ -172,10 +169,9 @@ def parallel_randoms(pipeline_style):
x = torch.rand(10, 10, requires_grad=True).cuda()
x.retain_grad()
model = MultiProcessPipe(
model = pipe_class(
model,
[1, 1],
style=pipeline_style,
input_device=torch.cuda.current_device(),
worker_map=get_worker_map(),
chunks=10,
......
......@@ -21,21 +21,21 @@ import pytest
import torch
from torch import nn
from fairscale.nn.pipe import MultiProcessPipe
from fairscale.nn.pipe import AsyncPipe, MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn
@torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def inplace_on_requires_grad(pipeline_style):
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def inplace_on_requires_grad(pipe_class):
model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True))
model = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
model = pipe_class(model, [1, 1], worker_map=get_worker_map(), checkpoint="always")
x = torch.rand(1)
if pipeline_style == MultiProcessPipe.AsyncSchedule and model.group.rank() == 0:
# With AsyncSchedule, model will wait forever for gradients if not eval
if pipe_class == AsyncPipe and model.group.rank() == 0:
# With AsyncPipe, model will wait forever for gradients if not eval
model.eval()
y = model(x)
......@@ -50,12 +50,12 @@ def inplace_on_requires_grad(pipeline_style):
@torch_spawn([1])
@pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def inplace_on_not_requires_grad(pipeline_style):
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def inplace_on_not_requires_grad(pipe_class):
# In-place operation on a tensor not requiring grad doesn't cause a
# RuntimeError. Currently, we cannot detect this case.
model = nn.Sequential(nn.ReLU(inplace=True))
model = MultiProcessPipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
model = pipe_class(model, [1], worker_map=get_worker_map(), checkpoint="always")
x = torch.rand(1)
y = model(x)
......@@ -70,8 +70,8 @@ def inplace_on_not_requires_grad(pipeline_style):
@torch_spawn([1])
@pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def inplace_incorrect_grad(pipeline_style):
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def inplace_incorrect_grad(pipe_class):
class M(nn.Module):
def forward(self, foo_bar):
# 'foo' requires grad but 'bar' does not. In-place operation on
......@@ -88,7 +88,7 @@ def inplace_incorrect_grad(pipeline_style):
return foo * bar
model = nn.Sequential(M())
model = MultiProcessPipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
model = pipe_class(model, [1], worker_map=get_worker_map(), checkpoint="always")
foo = torch.tensor([1.0], requires_grad=True)
bar = torch.tensor([1.0])
......
This diff is collapsed.
......@@ -21,14 +21,14 @@ import pytest
import torch
from torch import nn
from fairscale.nn.pipe import MultiProcessPipe
from fairscale.nn.pipe import AsyncPipe, MultiProcessPipe
from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn
@torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def simple_linears(pipeline_style):
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def simple_linears(pipe_class):
def sum_grad(parameters):
return sum([p.grad.sum() for p in parameters if p.grad is not None])
......@@ -54,8 +54,7 @@ def simple_linears(pipeline_style):
zero_grad(model.parameters())
# With MultiProcessPipe
model = MultiProcessPipe(model, [2, 2], style=pipeline_style, worker_map=get_worker_map(), chunks=4)
model = pipe_class(model, [2, 2], worker_map=get_worker_map(), chunks=4)
outputs = model(inputs)
if model.group.rank() == 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment