Unverified Commit cae9b638 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] pipe: separate out Single and MultiProcess pipe (#326)

parent eab1551a
...@@ -19,10 +19,9 @@ import torchtext ...@@ -19,10 +19,9 @@ import torchtext
from torchtext.data.utils import get_tokenizer from torchtext.data.utils import get_tokenizer
from experimental.nn.ampnet_pipe import pipe from experimental.nn.ampnet_pipe import pipe
from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule from fairscale.nn.pipe import LazyModule, MultiProcessPipe
from fairscale.optim import GradScaler from fairscale.optim import GradScaler
from fairscale.utils.testing import dist_init, get_worker_map from fairscale.utils.testing import dist_init, get_worker_map
...@@ -421,7 +420,7 @@ def run_mp_worker(args, available_workers): ...@@ -421,7 +420,7 @@ def run_mp_worker(args, available_workers):
p = pipe.AMPnetPipe( p = pipe.AMPnetPipe(
module=model, module=model,
balance=balance, balance=balance,
style=Pipe.AsyncSchedule, style=MultiProcessPipe.AsyncSchedule,
chunks=args.chunks, chunks=args.chunks,
worker_map=get_worker_map(), worker_map=get_worker_map(),
input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
......
...@@ -25,7 +25,7 @@ from torch.optim import Adam ...@@ -25,7 +25,7 @@ from torch.optim import Adam
from fairscale.nn import Pipe from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule, pipe from fairscale.nn.pipe import LazyModule, MultiProcessPipe
from fairscale.optim.oss import OSS from fairscale.optim.oss import OSS
from fairscale.utils.testing import dist_init, get_worker_map from fairscale.utils.testing import dist_init, get_worker_map
...@@ -157,7 +157,7 @@ def dump_cuda_tensors(): ...@@ -157,7 +157,7 @@ def dump_cuda_tensors():
def log_number_of_parameters(model): def log_number_of_parameters(model):
num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters())) num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
if model.group: if hasattr(model, "group"):
total = torch.Tensor([num_params]) total = torch.Tensor([num_params])
if torch.cuda.is_available(): if torch.cuda.is_available():
total = total.cuda() total = total.cuda()
...@@ -212,7 +212,7 @@ def train(model_config, model, benchmark_config, args): ...@@ -212,7 +212,7 @@ def train(model_config, model, benchmark_config, args):
optimizer = optimizer(model.parameters()) optimizer = optimizer(model.parameters())
pipe_group = model.group pipe_group = model.group if hasattr(model, "group") else None
if args.ddp_zero: if args.ddp_zero:
model = DDP( model = DDP(
...@@ -479,9 +479,7 @@ def benchmark_single_process(args): ...@@ -479,9 +479,7 @@ def benchmark_single_process(args):
model = model_config["model"] model = model_config["model"]
balance = generate_balance(min(num_devices, 4), len(model)) balance = generate_balance(min(num_devices, 4), len(model))
pipe_model = pipe.Pipe( pipe_model = Pipe(model, balance, chunks=args.chunks, checkpoint=args.checkpoint)
model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint
)
del model del model
del model_config["model"] del model_config["model"]
...@@ -498,10 +496,10 @@ def run_mp_worker(args, available_workers): ...@@ -498,10 +496,10 @@ def run_mp_worker(args, available_workers):
model = model_config["model"] model = model_config["model"]
balance = generate_balance_weighted(get_pipeline_parallel_group().size(), len(model), 0.8) balance = generate_balance_weighted(get_pipeline_parallel_group().size(), len(model), 0.8)
pipe_model = pipe.Pipe( pipe_model = MultiProcessPipe(
model, model,
balance, balance,
style=Pipe.AsyncSchedule, style=MultiProcessPipe.AsyncSchedule,
chunks=args.chunks, chunks=args.chunks,
worker_map=get_worker_map(), worker_map=get_worker_map(),
input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
......
...@@ -6,8 +6,8 @@ import torch.distributed as dist ...@@ -6,8 +6,8 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.optim as optim import torch.optim as optim
import fairscale
from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.pipe import MultiProcessPipe
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RANK = 0 # example RANK = 0 # example
...@@ -27,10 +27,10 @@ def run(rank, world_size): ...@@ -27,10 +27,10 @@ def run(rank, world_size):
device = torch.device("cuda", RANK) if DEVICE == "cuda" else torch.device("cpu") device = torch.device("cuda", RANK) if DEVICE == "cuda" else torch.device("cpu")
model = fairscale.nn.Pipe( model = MultiProcessPipe(
model, model,
balance=[2, 1], balance=[2, 1],
style=fairscale.nn.Pipe.MultiProcess, style=MultiProcessPipe.MultiProcess,
worker_map={0: "worker0", 1: "worker1"}, # Needed to convert ranks to RPC worker names worker_map={0: "worker0", 1: "worker1"}, # Needed to convert ranks to RPC worker names
input_device=device, input_device=device,
).to(device) ).to(device)
......
...@@ -11,7 +11,7 @@ from torch import nn ...@@ -11,7 +11,7 @@ from torch import nn
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from fairscale.nn.pipe import Pipe from fairscale.nn.pipe import MultiProcessPipe
from fairscale.nn.pipe.types import PipelineStyle from fairscale.nn.pipe.types import PipelineStyle
from .ampnet import AsyncAMPnetEventLoop from .ampnet import AsyncAMPnetEventLoop
...@@ -19,9 +19,9 @@ from .ampnet import AsyncAMPnetEventLoop ...@@ -19,9 +19,9 @@ from .ampnet import AsyncAMPnetEventLoop
__all__ = ["AMPnetPipe"] __all__ = ["AMPnetPipe"]
class AMPnetPipe(Pipe): class AMPnetPipe(MultiProcessPipe):
""" """
AMPnetPipe is the asynchronous version of the Pipe implementation AMPnetPipe is the asynchronous version of the MultiProcessPipe implementation
which avoids the bubble issue, by using stale weights and gradients. which avoids the bubble issue, by using stale weights and gradients.
The implementation closely follows the paper: https://arxiv.org/abs/1705.09786 The implementation closely follows the paper: https://arxiv.org/abs/1705.09786
""" """
...@@ -39,7 +39,7 @@ class AMPnetPipe(Pipe): ...@@ -39,7 +39,7 @@ class AMPnetPipe(Pipe):
weight_prediction: bool = False, weight_prediction: bool = False,
) -> None: ) -> None:
partitions = self.mp_partitions partitions = self.partitions
n = len(partitions) n = len(partitions)
# AMPnet implementation doesn't handle skip_trackers! # AMPnet implementation doesn't handle skip_trackers!
......
...@@ -23,7 +23,7 @@ from torch.optim.optimizer import Optimizer ...@@ -23,7 +23,7 @@ from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from experimental.nn.ampnet_pipe.pipe import AMPnetPipe from experimental.nn.ampnet_pipe.pipe import AMPnetPipe
from fairscale.nn.pipe import Pipe from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
...@@ -87,7 +87,7 @@ def async_event_loop_interleave_simple(): ...@@ -87,7 +87,7 @@ def async_event_loop_interleave_simple():
pipe = AMPnetPipe( pipe = AMPnetPipe(
module=model, module=model,
balance=[2, 2], balance=[2, 2],
style=Pipe.AsyncSchedule, style=MultiProcessPipe.AsyncSchedule,
worker_map=get_worker_map(), worker_map=get_worker_map(),
chunks=10, chunks=10,
checkpoint="never", checkpoint="never",
...@@ -105,7 +105,7 @@ def async_event_loop_interleave_hard(): ...@@ -105,7 +105,7 @@ def async_event_loop_interleave_hard():
pipe = AMPnetPipe( pipe = AMPnetPipe(
module=model, module=model,
balance=[1, 1, 1, 1], balance=[1, 1, 1, 1],
style=Pipe.AsyncSchedule, style=MultiProcessPipe.AsyncSchedule,
worker_map=get_worker_map(), worker_map=get_worker_map(),
chunks=10, chunks=10,
checkpoint="never", checkpoint="never",
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
from .data_parallel import ShardedDataParallel from .data_parallel import ShardedDataParallel
from .misc import FlattenParamsWrapper from .misc import FlattenParamsWrapper
from .moe import MOELayer, Top2Gate from .moe import MOELayer, Top2Gate
from .pipe import LazyModule, Pipe, PipeRPCWrapper from .pipe import Pipe, PipeRPCWrapper
__all__ = [ __all__ = [
"FlattenParamsWrapper", "FlattenParamsWrapper",
......
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
"""A Pipe implementation in PyTorch.""" """A Pipe implementation in PyTorch."""
from .checkpoint import is_checkpointing, is_recomputing from .checkpoint import is_checkpointing, is_recomputing
from .pipe import LazyModule, Pipe from .multiprocess_pipe import LazyModule, MultiProcessPipe
from .pipe import Pipe
from .rpc import PipeRPCWrapper from .rpc import PipeRPCWrapper
__all__ = ["Pipe", "is_checkpointing", "is_recomputing", "LazyModule"] __all__ = ["Pipe", "is_checkpointing", "is_recomputing", "LazyModule"]
...@@ -191,7 +191,7 @@ class AsyncEventLoop: ...@@ -191,7 +191,7 @@ class AsyncEventLoop:
"""Actually run the forward pass for a given module, and send the result """Actually run the forward pass for a given module, and send the result
to the next stage in the pipeline if needed.""" to the next stage in the pipeline if needed."""
assert self.group assert self.group
from .pipeline import create_task from .multiprocess_pipeline import create_task
task = create_task( task = create_task(
PipelineStyle.AsyncSchedule, PipelineStyle.AsyncSchedule,
...@@ -201,7 +201,6 @@ class AsyncEventLoop: ...@@ -201,7 +201,6 @@ class AsyncEventLoop:
batch, batch,
partition.module, partition.module,
skip_trackers, skip_trackers,
[],
) )
result = task.compute() result = task.compute()
task.finalize(result) task.finalize(result)
......
This diff is collapsed.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The multiprocess pipeline parallelism of Pipe."""
import logging
import os
from queue import Empty as QueueEmpty
from queue import Queue
from threading import Event
from types import TracebackType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union, cast
import torch
from torch import Tensor, nn
from torch.autograd.profiler import record_function
from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from .async_schedule import AsyncEventLoop, ModuleWrapper
from .checkpoint import Checkpointing
from .messages import MakeTransport, Transport
from .microbatch import Batch
from .skip import Namespace
from .skip.layout import SkipLayout
from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from .types import (
ACTIVATIONS_GRADS_QUEUE,
PORTAL_QUEUE,
SKIP_TENSOR_QUEUE,
PipelineStyle,
PipeMessage,
TensorOrTensors,
Tensors,
)
from .worker import Task
__all__: List[str] = []
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
class SendOperator(torch.autograd.Function):
"""Send activations to the next pipeline stage"""
@staticmethod
# type: ignore
def forward(ctx, src_rank, dst_rank, transport: Transport, input: List[Tensor], index: int) -> Tensors:
assert src_rank == torch.distributed.get_rank()
transport.send_message(
PipeMessage(src_rank, dst_rank, queue_name=ACTIVATIONS_GRADS_QUEUE, args=index, tensors=tuple(input)),
)
return ()
@staticmethod
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tensors:
return tuple(grad)
class RecvOperator(torch.autograd.Function):
"""Receive activations to the previous pipeline stage"""
@staticmethod
# type: ignore
def forward(ctx, dst_rank: int, tensor: Tensor, input_device, transport: Transport, index: int) -> Tensors:
assert dst_rank == torch.distributed.get_rank()
ctx.transport = transport
ctx.index = index
result = transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, index)
def maybe_requires_grad(t: Tensor) -> Tensor:
if t.dtype.is_floating_point:
return t.requires_grad_()
return t
return tuple(maybe_requires_grad(r) for r in result)
@staticmethod
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
ctx.transport.send_message(
PipeMessage(
this_rank,
ranks[ranks.index(this_rank) - 1],
queue_name=ACTIVATIONS_GRADS_QUEUE,
args=ctx.index,
tensors=tuple(grad),
),
)
return (None, None, None, None, None)
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if TYPE_CHECKING:
InQueue = Queue[Optional["Task"]]
OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]]
else:
InQueue = Queue
OutQueue = Queue
def create_task(
style: PipelineStyle,
checkpoint_stop: int,
i: int,
j: int,
batch: Batch,
partition: nn.Sequential,
skip_trackers: List[SkipTrackerThroughPotals],
) -> Task:
# Determine whether checkpointing or not.
if i < checkpoint_stop:
def function(
input: TensorOrTensors,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
ret = partition(input)
# We do a check here because the backtrace from the checkpoint backward code path
# is very hard to make sense. It would be much easier to check earlier at this point.
assert type(ret) is not list, "Only Tensor or Tuple of Tensor output is supported"
return ret
chk = Checkpointing(function, batch)
task = Task(None, compute=chk.checkpoint, finalize=chk.recompute)
del function, chk # TODO(tom) maybe remove
else:
def compute(
batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> Batch:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition)
task = Task(None, compute=compute, finalize=None)
del compute # TODO(tom) maybe remove
return task
class MultiProcessPipeline:
"""The multiprocess pipeline parallelism for Pipe."""
def __init__(
self,
partitions: List[nn.Sequential],
skip_layout: SkipLayout,
checkpoint_stop: int,
style: PipelineStyle,
group: Optional[torch.distributed.ProcessGroup] = None,
worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None,
final_stage: bool = False,
) -> None:
self.partitions: List[ModuleWrapper] = cast(List[ModuleWrapper], partitions)
self.skip_layout = skip_layout
self.__checkpoint_stop = checkpoint_stop
self.style = style
self.group = group
self.training: bool
self.transport = MakeTransport(
use_rpc=("OMPI_COMM_WORLD_RANK" not in os.environ) or ("FORCE_RPC" in os.environ),
worker_map=worker_map,
input_device=input_device,
)
self.input_device = input_device
self.all_at_once = False
self.callcount = 0
self.final_stage = final_stage
@property
def checkpoint_stop(self) -> int:
# Disable checkpointing if in eval mode.
training = self.partitions[0].module.training
if not training:
return 0
return self.__checkpoint_stop
def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> None:
"""Runs pipeline parallelism.
It modifies the given batches in place.
"""
self.training = training
m = len(batches)
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))]
if self.style is PipelineStyle.MultiProcess:
assert self.group
schedule = [(i, self.group.rank()) for i in range(m)]
self.compute(batches, schedule, skip_trackers)
elif self.style is PipelineStyle.AsyncSchedule:
assert self.group
rank = self.group.rank()
event_loop = AsyncEventLoop(
self.partitions, self.group, self.transport, self.training, self.checkpoint_stop,
)
if rank == 0 and not self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event head")
event_loop.event_loop_head(batches, skip_trackers, event)
logging.debug(f"{torch.distributed.get_rank()}: exited event head")
elif self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event tail")
event_loop.event_loop_tail(batches, skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event tail")
else:
logging.debug(f"{torch.distributed.get_rank()}: entered event loop")
event_loop.event_loop(len(batches), skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event loop")
self.callcount += 1
def get_batch_from_previous_stage(
self, i: int, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]
) -> Batch:
phony = torch.empty(0, device=self.input_device, requires_grad=True)
result = RecvOperator.apply(torch.distributed.get_rank(), phony, self.input_device, self.transport, i)
if len(result) == 1:
batch = Batch(result[0], i)
else:
batch = Batch(result, i)
self.recv_skip_tensors(skip_trackers, batches)
return batch
def send_skip_tensors(
self, this_rank: int, ranks: List[int], batch: Batch, i: int, skip_trackers: List[SkipTrackerThroughPotals]
) -> None:
assert self.group
for next_j, ns, name in self.skip_layout.copy_policy_by_src(self.group.rank()):
life = skip_trackers[i].portals[(ns, name)].tensor_life
loaded = skip_trackers[i].load(batch, ns, name)
if loaded is not None:
tensors = tuple([loaded])
else:
tensors = tuple()
self.transport.send_message(
PipeMessage(
this_rank, ranks[next_j], queue_name=SKIP_TENSOR_QUEUE, args=(i, ns, name, life), tensors=tensors,
),
sync=True,
)
def recv_skip_tensors(self, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]) -> None:
while True:
try:
message = self.transport.recv_message(SKIP_TENSOR_QUEUE, nowait=True)
(si, ns, name, life) = message.args
value: Optional[TensorOrTensors] = message.tensors
assert isinstance(value, tuple)
if len(value) == 0:
value = None
else:
assert len(value) == 1
value = value[0]
skip_trackers[si].save(batches[si], ns, name, value)
old_life = skip_trackers[si].portals[(ns, name)].tensor_life
if life != 0:
skip_trackers[si].portals[(ns, name)].tensor_life = life
except QueueEmpty:
break
def execute_task(self, task: Task, i: int, skip_trackers: List[SkipTrackerThroughPotals]) -> Batch:
batch = task.compute()
assert self.group
rank = self.group.rank()
if self.style is PipelineStyle.MultiProcess and not self.final_stage:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
self.send_skip_tensors(this_rank, ranks, batch, i, skip_trackers)
SendOperator.apply(this_rank, ranks[ranks.index(this_rank) + 1], self.transport, [*batch], i)
for portal in skip_trackers[i].portals.values():
portal.pipeline = self
task.finalize(batch)
return batch
def compute(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals]
) -> None:
"""Runs tasks with synchronization to copy streams."""
if self.style is PipelineStyle.MultiProcess:
assert self.group
n = self.group.size()
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for i, j in schedule:
batch = batches[i]
if self.style is PipelineStyle.MultiProcess:
assert len(self.partitions) == 1
partition = self.partitions[0]
assert self.group
if self.group.rank() != 0:
batch = self.get_batch_from_previous_stage(i, skip_trackers, batches)
task = create_task(self.style, self.checkpoint_stop, i, j, batch, partition.module, skip_trackers)
batches[i] = self.execute_task(task, i, skip_trackers)
def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: TensorOrTensors) -> None:
dest, src = self.skip_layout.by_ns_name.get(ns_name, (-1, -1))
if dest == src:
return
ranks = get_pipeline_parallel_ranks()
dst_rank = ranks[dest]
if dst_rank == torch.distributed.get_rank():
return
if isinstance(grad, Tensor):
grad = tuple([grad])
self.transport.send_message(
PipeMessage(ranks[src], dst_rank, queue_name=PORTAL_QUEUE, args=(ns_name, index), tensors=grad), sync=True,
)
def recv_portal_grad(self, expected_ns_name: Tuple[Namespace, str], expected_index: int) -> Tensor:
message = self.transport.recv_message(PORTAL_QUEUE)
(ns_name, index) = message.args
grad = message.tensors
assert len(grad) == 1
result = grad[0]
assert index == expected_index and ns_name == expected_ns_name
return result
def back_helper(self, output: List[Batch]) -> None:
if self.style == PipelineStyle.AsyncSchedule:
return
o = list(output)
tensors: Tensors
if self.all_at_once:
# FIXME(tom) allow specifying this branch when constructing Pipe(), add a test
grads = []
for i, batch in enumerate(o):
rank = torch.distributed.get_rank()
found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, i)
assert len(found) == 1
grads.append(found[0])
tensors = tuple(x.tensor_or_tensors for x in o) # type: ignore
try:
torch.autograd.backward(tensors, grad_tensors=grads, retain_graph=True)
except Exception as e:
raise RuntimeError("Autograd failed") from e
else:
rank = torch.distributed.get_rank()
for batch in o:
found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, batch.index)
if batch.atomic:
tensors = tuple([batch.tensor])
else:
tensors = batch.tensors
if len(found) != len(tensors):
raise RuntimeError("different number of tensors and gradients")
grads = []
final_tensors = []
for i, tensor in enumerate(tensors):
if tensor.requires_grad or getattr(tensor, "grad_fn", None) is not None:
grads.append(found[i])
final_tensors.append(tensor)
try:
torch.autograd.backward(final_tensors, grad_tensors=grads, retain_graph=True)
except Exception as e:
raise RuntimeError(f"Autograd failed on {torch.distributed.get_rank()}") from e
This diff is collapsed.
This diff is collapsed.
...@@ -13,13 +13,13 @@ from torch.distributed.distributed_c10d import _get_global_rank ...@@ -13,13 +13,13 @@ from torch.distributed.distributed_c10d import _get_global_rank
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from . import Pipe from .multiprocess_pipe import MultiProcessPipe
from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors
DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024
PipeModel: Pipe PipeModel: MultiProcessPipe
PipeResult: TensorOrTensors PipeResult: TensorOrTensors
...@@ -71,7 +71,7 @@ class PipeBackRedirect(torch.autograd.Function): ...@@ -71,7 +71,7 @@ class PipeBackRedirect(torch.autograd.Function):
return (None, None, None, None, None, None) return (None, None, None, None, None, None)
def callback_with_model(callback: Callable[[Any, Pipe], None], ctx: Any) -> None: def callback_with_model(callback: Callable[[Any, MultiProcessPipe], None], ctx: Any) -> None:
try: try:
group = get_pipeline_parallel_group() # FIXME(tom) handle dynamic group group = get_pipeline_parallel_group() # FIXME(tom) handle dynamic group
set_device_based_on_group(group) set_device_based_on_group(group)
...@@ -105,10 +105,10 @@ class PipeRPCWrapper(nn.Module): ...@@ -105,10 +105,10 @@ class PipeRPCWrapper(nn.Module):
else: else:
kwargs["group"] = self.group kwargs["group"] = self.group
kwargs["style"] = Pipe.AsyncSchedule kwargs["style"] = MultiProcessPipe.AsyncSchedule
kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device()) kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device())
self.model = Pipe(*args, **kwargs) self.model = MultiProcessPipe(*args, **kwargs)
self.worker_map = kwargs["worker_map"] self.worker_map = kwargs["worker_map"]
self._foreach_worker(self._register_remote_model, args=(args, kwargs)) self._foreach_worker(self._register_remote_model, args=(args, kwargs))
self.model.cuda() self.model.cuda()
...@@ -121,7 +121,7 @@ class PipeRPCWrapper(nn.Module): ...@@ -121,7 +121,7 @@ class PipeRPCWrapper(nn.Module):
futures = [f.wait() for f in futures] futures = [f.wait() for f in futures]
def foreach_worker( def foreach_worker(
self, callback: Callable[[Any, Pipe], None], ctx: Any = None, *, include_self: bool = False self, callback: Callable[[Any, MultiProcessPipe], None], ctx: Any = None, *, include_self: bool = False
) -> None: ) -> None:
"""Call `callback` on each worker with the `ctx` and model local to that """Call `callback` on each worker with the `ctx` and model local to that
worker. e.g. worker. e.g.
...@@ -196,7 +196,9 @@ class PipeRPCWrapper(nn.Module): ...@@ -196,7 +196,9 @@ class PipeRPCWrapper(nn.Module):
return self.model.final_stage return self.model.final_stage
@staticmethod @staticmethod
def _recv_result(model: Pipe, shapes: SizeOrSizes, dtypes: DtypeOrDtypes, message: PipeMessage) -> TensorOrTensors: def _recv_result(
model: MultiProcessPipe, shapes: SizeOrSizes, dtypes: DtypeOrDtypes, message: PipeMessage
) -> TensorOrTensors:
group = get_pipeline_parallel_group() group = get_pipeline_parallel_group()
set_device_based_on_group(group) set_device_based_on_group(group)
...@@ -243,7 +245,7 @@ class PipeRPCWrapper(nn.Module): ...@@ -243,7 +245,7 @@ class PipeRPCWrapper(nn.Module):
set_device_based_on_group(group) set_device_based_on_group(group)
kwargs["group"] = group kwargs["group"] = group
kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device()) kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device())
model = Pipe(*args, **kwargs) model = MultiProcessPipe(*args, **kwargs)
model.cuda() model.cuda()
global PipeModel global PipeModel
PipeModel = model PipeModel = model
......
...@@ -24,7 +24,6 @@ Tensors = Tuple[Tensor, ...] ...@@ -24,7 +24,6 @@ Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors] TensorOrTensors = Union[Tensor, Tensors]
InputDevice = Union[None, int, str, torch.device] InputDevice = Union[None, int, str, torch.device]
Schedule = List[Tuple[int, int]]
class LazyModule: class LazyModule:
......
...@@ -5,7 +5,7 @@ from typing import Any, Callable, Optional, Tuple ...@@ -5,7 +5,7 @@ from typing import Any, Callable, Optional, Tuple
from torch import Tensor from torch import Tensor
def spawn( def spawn(
fn: Callable[[Any], Any], fn: Callable[..., Any],
args: Tuple[Optional[Any], ...] = (), args: Tuple[Optional[Any], ...] = (),
nprocs: int = 1, nprocs: int = 1,
join: bool = True, join: bool = True,
......
...@@ -31,7 +31,7 @@ from torch.nn.parameter import Parameter ...@@ -31,7 +31,7 @@ from torch.nn.parameter import Parameter
from fairscale.nn.model_parallel import initialize as mpu from fairscale.nn.model_parallel import initialize as mpu
from fairscale.nn.model_parallel import layers from fairscale.nn.model_parallel import layers
from fairscale.nn.pipe import Pipe from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import dist_init, get_world_sizes, set_random_seed, spawn_for_all_world_sizes, torch_spawn from fairscale.utils.testing import dist_init, get_world_sizes, set_random_seed, spawn_for_all_world_sizes, torch_spawn
...@@ -319,7 +319,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False ...@@ -319,7 +319,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
model_parallel_size = mpu.get_model_parallel_world_size() model_parallel_size = mpu.get_model_parallel_world_size()
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print( print(
"> testing Sequential + Pipe with model parallel size: {}, pipe: {}".format( "> testing Sequential + MultiProcessPipe with model parallel size: {}, pipe: {}".format(
model_parallel_size, pipe_world_size model_parallel_size, pipe_world_size
) )
) )
...@@ -431,13 +431,13 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False ...@@ -431,13 +431,13 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
model[2].weight.data = saved_weight_2 model[2].weight.data = saved_weight_2
worker_map = {i: f"Test{i}" for i in range(torch.distributed.get_world_size())} worker_map = {i: f"Test{i}" for i in range(torch.distributed.get_world_size())}
style = Pipe.MultiProcess # Pipe.AsyncSchedule style = MultiProcessPipe.MultiProcess # MultiProcessPipe.AsyncSchedule
if pipe_world_size == 2: if pipe_world_size == 2:
print(f"actually doing pipe stuff now") print(f"actually doing pipe stuff now")
assert torch.equal(saved_weight_0, model[0].weight.data) assert torch.equal(saved_weight_0, model[0].weight.data)
assert torch.equal(saved_weight_2, model[2].weight.data) assert torch.equal(saved_weight_2, model[2].weight.data)
pipe_model = Pipe( pipe_model = MultiProcessPipe(
model, model,
[2, 1], [2, 1],
style=style, style=style,
...@@ -507,7 +507,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False ...@@ -507,7 +507,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
failed = False failed = False
with torch.autograd.profiler.profile() as prof: with torch.autograd.profiler.profile() as prof:
try: try:
if style == Pipe.MultiProcess: if style == MultiProcessPipe.MultiProcess:
pipe_model.back_helper(pipe_output) pipe_model.back_helper(pipe_output)
except Exception as e: except Exception as e:
failed = True failed = True
......
...@@ -23,7 +23,7 @@ import pytest ...@@ -23,7 +23,7 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from fairscale.nn.pipe import LazyModule, Pipe from fairscale.nn.pipe import LazyModule, MultiProcessPipe
from fairscale.nn.pipe.skip import pop, skippable, stash from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
...@@ -33,12 +33,12 @@ from fairscale.utils.testing import get_worker_map, torch_spawn ...@@ -33,12 +33,12 @@ from fairscale.utils.testing import get_worker_map, torch_spawn
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"]) @pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
def x1to3(balance, checkpoint, pipeline_style): def x1to3(balance, checkpoint, pipeline_style):
torch.manual_seed(0) torch.manual_seed(0)
if pipeline_style == Pipe.AsyncSchedule and len(balance) > 1: if pipeline_style == MultiProcessPipe.AsyncSchedule and len(balance) > 1:
print(f"skipping yarg") print(f"skipping yarg")
pytest.skip("Skip tensors NYI for AsyncSchedule") pytest.skip("Skip tensors NYI for AsyncSchedule")
...@@ -74,7 +74,7 @@ def x1to3(balance, checkpoint, pipeline_style): ...@@ -74,7 +74,7 @@ def x1to3(balance, checkpoint, pipeline_style):
return output return output
model = nn.Sequential(Layer1(), Layer2(), Layer3()) model = nn.Sequential(Layer1(), Layer2(), Layer3())
model = Pipe( model = MultiProcessPipe(
model, model,
balance, balance,
chunks=3, chunks=3,
...@@ -106,9 +106,10 @@ def x1to3(balance, checkpoint, pipeline_style): ...@@ -106,9 +106,10 @@ def x1to3(balance, checkpoint, pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
@pytest.mark.skip(reason="flaky test")
def none_skip(pipeline_style): def none_skip(pipeline_style):
if pipeline_style == Pipe.AsyncSchedule: if pipeline_style == MultiProcessPipe.AsyncSchedule:
pytest.skip("Skip tensors NYI for AsyncSchedule") pytest.skip("Skip tensors NYI for AsyncSchedule")
@skippable(stash=["none"]) @skippable(stash=["none"])
...@@ -125,7 +126,7 @@ def none_skip(pipeline_style): ...@@ -125,7 +126,7 @@ def none_skip(pipeline_style):
return input return input
model = nn.Sequential(Stash(), Pop()) model = nn.Sequential(Stash(), Pop())
model = Pipe( model = MultiProcessPipe(
model, model,
[1, 1], [1, 1],
style=pipeline_style, style=pipeline_style,
...@@ -160,7 +161,7 @@ def none_skip(pipeline_style): ...@@ -160,7 +161,7 @@ def none_skip(pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def lazy_skippable_error(pipeline_style): def lazy_skippable_error(pipeline_style):
"""Using skippable layers in combination with lazy construction is currently """Using skippable layers in combination with lazy construction is currently
not supported, check that it raises an Exception""" not supported, check that it raises an Exception"""
...@@ -180,6 +181,6 @@ def lazy_skippable_error(pipeline_style): ...@@ -180,6 +181,6 @@ def lazy_skippable_error(pipeline_style):
] ]
with pytest.raises(ValueError, match="Can't use Skippable layers with multi-process pipe and lazy construction"): with pytest.raises(ValueError, match="Can't use Skippable layers with multi-process pipe and lazy construction"):
Pipe( MultiProcessPipe(
model, [2, 1], style=pipeline_style, worker_map=get_worker_map(), model, [2, 1], style=pipeline_style, worker_map=get_worker_map(),
) )
...@@ -23,7 +23,7 @@ import pytest ...@@ -23,7 +23,7 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from fairscale.nn.pipe import Pipe, is_checkpointing, is_recomputing from fairscale.nn.pipe import MultiProcessPipe, is_checkpointing, is_recomputing
from fairscale.nn.pipe.skip import pop, skippable, stash from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.tracker import current_skip_tracker from fairscale.nn.pipe.skip.tracker import current_skip_tracker
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
...@@ -46,7 +46,7 @@ class Pop(nn.Module): ...@@ -46,7 +46,7 @@ class Pop(nn.Module):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) @pytest.mark.parametrize("train", [True, False], ids=["train", "eval"])
@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"]) @pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def delete_portal_tensor(train, checkpoint, pipeline_style): def delete_portal_tensor(train, checkpoint, pipeline_style):
...@@ -60,7 +60,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style): ...@@ -60,7 +60,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 | # | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
# +----------+ +------------+ +------------+ +----------+ # +----------+ +------------+ +------------+ +----------+
if pipeline_style == Pipe.AsyncSchedule: if pipeline_style == MultiProcessPipe.AsyncSchedule:
pytest.skip("Skip tensors NYI for AsyncSchedule") pytest.skip("Skip tensors NYI for AsyncSchedule")
def portal_tensor_life_is(tensor_life, skip_tracker=None): def portal_tensor_life_is(tensor_life, skip_tracker=None):
...@@ -114,7 +114,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style): ...@@ -114,7 +114,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
return self.F.apply(input) return self.F.apply(input)
model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_) model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_)
model = Pipe( model = MultiProcessPipe(
model, balance=[2, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint, model, balance=[2, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint,
) )
......
...@@ -22,15 +22,15 @@ import torch ...@@ -22,15 +22,15 @@ import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from fairscale.nn.pipe import Pipe from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def python_autograd_function(pipeline_style): def python_autograd_function(pipeline_style):
# FIXME deadlock with Pipe.AsyncSchedule? # FIXME deadlock with MultiProcessPipe.AsyncSchedule?
# A Python autograd function might fail with this error: # A Python autograd function might fail with this error:
# #
# RuntimeError: Returning Variables sharing storage with other Variables # RuntimeError: Returning Variables sharing storage with other Variables
...@@ -57,7 +57,9 @@ def python_autograd_function(pipeline_style): ...@@ -57,7 +57,9 @@ def python_autograd_function(pipeline_style):
return Identity.apply(input) return Identity.apply(input)
model = nn.Sequential(M(), M()) model = nn.Sequential(M(), M())
model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always").cuda() model = MultiProcessPipe(
model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always"
).cuda()
model.eval() model.eval()
x = torch.rand(42) x = torch.rand(42)
...@@ -71,7 +73,7 @@ def python_autograd_function(pipeline_style): ...@@ -71,7 +73,7 @@ def python_autograd_function(pipeline_style):
@torch_spawn([3]) @torch_spawn([3])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def exception_no_hang(pipeline_style): def exception_no_hang(pipeline_style):
# In v0.0.2, once a failed partition receives a normal message # In v0.0.2, once a failed partition receives a normal message
# (non-closing) for the next micro-batch, a hang occured. The reason was # (non-closing) for the next micro-batch, a hang occured. The reason was
...@@ -90,7 +92,7 @@ def exception_no_hang(pipeline_style): ...@@ -90,7 +92,7 @@ def exception_no_hang(pipeline_style):
raise ExpectedException() raise ExpectedException()
model = nn.Sequential(Pass(), Pass(), Raise()) model = nn.Sequential(Pass(), Pass(), Raise())
model = Pipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3) model = MultiProcessPipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3)
model.eval() model.eval()
if model.group.rank() == 2: if model.group.rank() == 2:
...@@ -104,7 +106,7 @@ def exception_no_hang(pipeline_style): ...@@ -104,7 +106,7 @@ def exception_no_hang(pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def tuple_wait(cuda_sleep, pipeline_style): def tuple_wait(cuda_sleep, pipeline_style):
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch. # In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# Under this behavior, if checkpointing was disabled, there's a possibility # Under this behavior, if checkpointing was disabled, there's a possibility
...@@ -133,7 +135,7 @@ def tuple_wait(cuda_sleep, pipeline_style): ...@@ -133,7 +135,7 @@ def tuple_wait(cuda_sleep, pipeline_style):
return a + b + c return a + b + c
model = nn.Sequential(Layer1(), Layer2()) model = nn.Sequential(Layer1(), Layer2())
model = Pipe( model = MultiProcessPipe(
model, model,
[1, 1], [1, 1],
style=pipeline_style, style=pipeline_style,
...@@ -158,7 +160,7 @@ def tuple_wait(cuda_sleep, pipeline_style): ...@@ -158,7 +160,7 @@ def tuple_wait(cuda_sleep, pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def parallel_randoms(pipeline_style): def parallel_randoms(pipeline_style):
class Dropouts(nn.Module): class Dropouts(nn.Module):
def forward(self, x): def forward(self, x):
...@@ -170,7 +172,7 @@ def parallel_randoms(pipeline_style): ...@@ -170,7 +172,7 @@ def parallel_randoms(pipeline_style):
x = torch.rand(10, 10, requires_grad=True).cuda() x = torch.rand(10, 10, requires_grad=True).cuda()
x.retain_grad() x.retain_grad()
model = Pipe( model = MultiProcessPipe(
model, model,
[1, 1], [1, 1],
style=pipeline_style, style=pipeline_style,
......
...@@ -21,20 +21,20 @@ import pytest ...@@ -21,20 +21,20 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from fairscale.nn.pipe import Pipe from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def inplace_on_requires_grad(pipeline_style): def inplace_on_requires_grad(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True))
model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always") model = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
x = torch.rand(1) x = torch.rand(1)
if pipeline_style == Pipe.AsyncSchedule and model.group.rank() == 0: if pipeline_style == MultiProcessPipe.AsyncSchedule and model.group.rank() == 0:
# With AsyncSchedule, model will wait forever for gradients if not eval # With AsyncSchedule, model will wait forever for gradients if not eval
model.eval() model.eval()
...@@ -50,12 +50,12 @@ def inplace_on_requires_grad(pipeline_style): ...@@ -50,12 +50,12 @@ def inplace_on_requires_grad(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.xfail(strict=True) @pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def inplace_on_not_requires_grad(pipeline_style): def inplace_on_not_requires_grad(pipeline_style):
# In-place operation on a tensor not requiring grad doesn't cause a # In-place operation on a tensor not requiring grad doesn't cause a
# RuntimeError. Currently, we cannot detect this case. # RuntimeError. Currently, we cannot detect this case.
model = nn.Sequential(nn.ReLU(inplace=True)) model = nn.Sequential(nn.ReLU(inplace=True))
model = Pipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always") model = MultiProcessPipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
x = torch.rand(1) x = torch.rand(1)
y = model(x) y = model(x)
...@@ -70,7 +70,7 @@ def inplace_on_not_requires_grad(pipeline_style): ...@@ -70,7 +70,7 @@ def inplace_on_not_requires_grad(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.xfail(strict=True) @pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def inplace_incorrect_grad(pipeline_style): def inplace_incorrect_grad(pipeline_style):
class M(nn.Module): class M(nn.Module):
def forward(self, foo_bar): def forward(self, foo_bar):
...@@ -88,7 +88,7 @@ def inplace_incorrect_grad(pipeline_style): ...@@ -88,7 +88,7 @@ def inplace_incorrect_grad(pipeline_style):
return foo * bar return foo * bar
model = nn.Sequential(M()) model = nn.Sequential(M())
model = Pipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always") model = MultiProcessPipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
foo = torch.tensor([1.0], requires_grad=True) foo = torch.tensor([1.0], requires_grad=True)
bar = torch.tensor([1.0]) bar = torch.tensor([1.0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment