Unverified Commit cae9b638 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] pipe: separate out Single and MultiProcess pipe (#326)

parent eab1551a
...@@ -19,10 +19,9 @@ import torchtext ...@@ -19,10 +19,9 @@ import torchtext
from torchtext.data.utils import get_tokenizer from torchtext.data.utils import get_tokenizer
from experimental.nn.ampnet_pipe import pipe from experimental.nn.ampnet_pipe import pipe
from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule from fairscale.nn.pipe import LazyModule, MultiProcessPipe
from fairscale.optim import GradScaler from fairscale.optim import GradScaler
from fairscale.utils.testing import dist_init, get_worker_map from fairscale.utils.testing import dist_init, get_worker_map
...@@ -421,7 +420,7 @@ def run_mp_worker(args, available_workers): ...@@ -421,7 +420,7 @@ def run_mp_worker(args, available_workers):
p = pipe.AMPnetPipe( p = pipe.AMPnetPipe(
module=model, module=model,
balance=balance, balance=balance,
style=Pipe.AsyncSchedule, style=MultiProcessPipe.AsyncSchedule,
chunks=args.chunks, chunks=args.chunks,
worker_map=get_worker_map(), worker_map=get_worker_map(),
input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
......
...@@ -25,7 +25,7 @@ from torch.optim import Adam ...@@ -25,7 +25,7 @@ from torch.optim import Adam
from fairscale.nn import Pipe from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule, pipe from fairscale.nn.pipe import LazyModule, MultiProcessPipe
from fairscale.optim.oss import OSS from fairscale.optim.oss import OSS
from fairscale.utils.testing import dist_init, get_worker_map from fairscale.utils.testing import dist_init, get_worker_map
...@@ -157,7 +157,7 @@ def dump_cuda_tensors(): ...@@ -157,7 +157,7 @@ def dump_cuda_tensors():
def log_number_of_parameters(model): def log_number_of_parameters(model):
num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters())) num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
if model.group: if hasattr(model, "group"):
total = torch.Tensor([num_params]) total = torch.Tensor([num_params])
if torch.cuda.is_available(): if torch.cuda.is_available():
total = total.cuda() total = total.cuda()
...@@ -212,7 +212,7 @@ def train(model_config, model, benchmark_config, args): ...@@ -212,7 +212,7 @@ def train(model_config, model, benchmark_config, args):
optimizer = optimizer(model.parameters()) optimizer = optimizer(model.parameters())
pipe_group = model.group pipe_group = model.group if hasattr(model, "group") else None
if args.ddp_zero: if args.ddp_zero:
model = DDP( model = DDP(
...@@ -479,9 +479,7 @@ def benchmark_single_process(args): ...@@ -479,9 +479,7 @@ def benchmark_single_process(args):
model = model_config["model"] model = model_config["model"]
balance = generate_balance(min(num_devices, 4), len(model)) balance = generate_balance(min(num_devices, 4), len(model))
pipe_model = pipe.Pipe( pipe_model = Pipe(model, balance, chunks=args.chunks, checkpoint=args.checkpoint)
model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint
)
del model del model
del model_config["model"] del model_config["model"]
...@@ -498,10 +496,10 @@ def run_mp_worker(args, available_workers): ...@@ -498,10 +496,10 @@ def run_mp_worker(args, available_workers):
model = model_config["model"] model = model_config["model"]
balance = generate_balance_weighted(get_pipeline_parallel_group().size(), len(model), 0.8) balance = generate_balance_weighted(get_pipeline_parallel_group().size(), len(model), 0.8)
pipe_model = pipe.Pipe( pipe_model = MultiProcessPipe(
model, model,
balance, balance,
style=Pipe.AsyncSchedule, style=MultiProcessPipe.AsyncSchedule,
chunks=args.chunks, chunks=args.chunks,
worker_map=get_worker_map(), worker_map=get_worker_map(),
input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
......
...@@ -6,8 +6,8 @@ import torch.distributed as dist ...@@ -6,8 +6,8 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.optim as optim import torch.optim as optim
import fairscale
from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.pipe import MultiProcessPipe
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RANK = 0 # example RANK = 0 # example
...@@ -27,10 +27,10 @@ def run(rank, world_size): ...@@ -27,10 +27,10 @@ def run(rank, world_size):
device = torch.device("cuda", RANK) if DEVICE == "cuda" else torch.device("cpu") device = torch.device("cuda", RANK) if DEVICE == "cuda" else torch.device("cpu")
model = fairscale.nn.Pipe( model = MultiProcessPipe(
model, model,
balance=[2, 1], balance=[2, 1],
style=fairscale.nn.Pipe.MultiProcess, style=MultiProcessPipe.MultiProcess,
worker_map={0: "worker0", 1: "worker1"}, # Needed to convert ranks to RPC worker names worker_map={0: "worker0", 1: "worker1"}, # Needed to convert ranks to RPC worker names
input_device=device, input_device=device,
).to(device) ).to(device)
......
...@@ -11,7 +11,7 @@ from torch import nn ...@@ -11,7 +11,7 @@ from torch import nn
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from fairscale.nn.pipe import Pipe from fairscale.nn.pipe import MultiProcessPipe
from fairscale.nn.pipe.types import PipelineStyle from fairscale.nn.pipe.types import PipelineStyle
from .ampnet import AsyncAMPnetEventLoop from .ampnet import AsyncAMPnetEventLoop
...@@ -19,9 +19,9 @@ from .ampnet import AsyncAMPnetEventLoop ...@@ -19,9 +19,9 @@ from .ampnet import AsyncAMPnetEventLoop
__all__ = ["AMPnetPipe"] __all__ = ["AMPnetPipe"]
class AMPnetPipe(Pipe): class AMPnetPipe(MultiProcessPipe):
""" """
AMPnetPipe is the asynchronous version of the Pipe implementation AMPnetPipe is the asynchronous version of the MultiProcessPipe implementation
which avoids the bubble issue, by using stale weights and gradients. which avoids the bubble issue, by using stale weights and gradients.
The implementation closely follows the paper: https://arxiv.org/abs/1705.09786 The implementation closely follows the paper: https://arxiv.org/abs/1705.09786
""" """
...@@ -39,7 +39,7 @@ class AMPnetPipe(Pipe): ...@@ -39,7 +39,7 @@ class AMPnetPipe(Pipe):
weight_prediction: bool = False, weight_prediction: bool = False,
) -> None: ) -> None:
partitions = self.mp_partitions partitions = self.partitions
n = len(partitions) n = len(partitions)
# AMPnet implementation doesn't handle skip_trackers! # AMPnet implementation doesn't handle skip_trackers!
......
...@@ -23,7 +23,7 @@ from torch.optim.optimizer import Optimizer ...@@ -23,7 +23,7 @@ from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from experimental.nn.ampnet_pipe.pipe import AMPnetPipe from experimental.nn.ampnet_pipe.pipe import AMPnetPipe
from fairscale.nn.pipe import Pipe from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
...@@ -87,7 +87,7 @@ def async_event_loop_interleave_simple(): ...@@ -87,7 +87,7 @@ def async_event_loop_interleave_simple():
pipe = AMPnetPipe( pipe = AMPnetPipe(
module=model, module=model,
balance=[2, 2], balance=[2, 2],
style=Pipe.AsyncSchedule, style=MultiProcessPipe.AsyncSchedule,
worker_map=get_worker_map(), worker_map=get_worker_map(),
chunks=10, chunks=10,
checkpoint="never", checkpoint="never",
...@@ -105,7 +105,7 @@ def async_event_loop_interleave_hard(): ...@@ -105,7 +105,7 @@ def async_event_loop_interleave_hard():
pipe = AMPnetPipe( pipe = AMPnetPipe(
module=model, module=model,
balance=[1, 1, 1, 1], balance=[1, 1, 1, 1],
style=Pipe.AsyncSchedule, style=MultiProcessPipe.AsyncSchedule,
worker_map=get_worker_map(), worker_map=get_worker_map(),
chunks=10, chunks=10,
checkpoint="never", checkpoint="never",
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
from .data_parallel import ShardedDataParallel from .data_parallel import ShardedDataParallel
from .misc import FlattenParamsWrapper from .misc import FlattenParamsWrapper
from .moe import MOELayer, Top2Gate from .moe import MOELayer, Top2Gate
from .pipe import LazyModule, Pipe, PipeRPCWrapper from .pipe import Pipe, PipeRPCWrapper
__all__ = [ __all__ = [
"FlattenParamsWrapper", "FlattenParamsWrapper",
......
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
"""A Pipe implementation in PyTorch.""" """A Pipe implementation in PyTorch."""
from .checkpoint import is_checkpointing, is_recomputing from .checkpoint import is_checkpointing, is_recomputing
from .pipe import LazyModule, Pipe from .multiprocess_pipe import LazyModule, MultiProcessPipe
from .pipe import Pipe
from .rpc import PipeRPCWrapper from .rpc import PipeRPCWrapper
__all__ = ["Pipe", "is_checkpointing", "is_recomputing", "LazyModule"] __all__ = ["Pipe", "is_checkpointing", "is_recomputing", "LazyModule"]
...@@ -191,7 +191,7 @@ class AsyncEventLoop: ...@@ -191,7 +191,7 @@ class AsyncEventLoop:
"""Actually run the forward pass for a given module, and send the result """Actually run the forward pass for a given module, and send the result
to the next stage in the pipeline if needed.""" to the next stage in the pipeline if needed."""
assert self.group assert self.group
from .pipeline import create_task from .multiprocess_pipeline import create_task
task = create_task( task = create_task(
PipelineStyle.AsyncSchedule, PipelineStyle.AsyncSchedule,
...@@ -201,7 +201,6 @@ class AsyncEventLoop: ...@@ -201,7 +201,6 @@ class AsyncEventLoop:
batch, batch,
partition.module, partition.module,
skip_trackers, skip_trackers,
[],
) )
result = task.compute() result = task.compute()
task.finalize(result) task.finalize(result)
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The MultiProcessPipe interface."""
from collections import OrderedDict
from dataclasses import dataclass, field
import itertools
import threading
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union, cast
import warnings
import torch
from torch import Tensor, nn
import torch.autograd
import torch.cuda
from fairscale.nn.model_parallel import get_model_parallel_world_size, get_pipeline_parallel_group
from . import microbatch
from .async_schedule import Invocation, Location, ModuleWrapper
from .batchnorm import DeferredBatchNorm
from .multiprocess_pipeline import MultiProcessPipeline
from .skip.layout import SkipLayout, inspect_skip_layout
from .skip.skippable import Skippable, verify_skippables
from .types import LazyModule, PipelineStyle
__all__ = ["MultiProcessPipe", "LazyModule"]
Device = Union[torch.device, int, str]
Devices = Union[Iterable[Device], List[Device]]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
ListOfLazyModules = List[LazyModule]
if TYPE_CHECKING:
Module = nn.Module[TensorOrTensors]
NamedModules = OrderedDict[str, Module]
else:
Module = nn.Module
NamedModules = OrderedDict
def recommend_auto_balance(message: str) -> str:
"""Expands a message with recommendation to :mod:`torchpipe.balance`."""
return f"""{message}
If your model is still under development, its optimal balance would change
frequently. In this case, we highly recommend 'fairscale.nn.pipe.balance' for
naive automatic balancing:
from fairscale.nn import Pipe
from fairscale.nn.pipe.balance import balance_by_time
partitions = torch.cuda.device_count()
sample = torch.empty(...)
balance = balance_by_time(partitions, model, sample)
model = MultiProcessPipe(model, balance, ...)
"""
# FIXME(tom) make this a valid way to call
def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None:
for layer in module:
if isinstance(layer, nn.Module):
pass
elif isinstance(layer, LazyModule):
pass
else:
raise TypeError(f"layer {type(layer)} must be nn.Module or LazyModule to be partitioned")
def verify_module(module: Union[nn.Sequential, ListOfLazyModules]) -> None:
if isinstance(module, Iterable) and not isinstance(module, nn.Sequential):
verify_list_of_callable(module)
else:
if not isinstance(module, nn.Sequential):
raise TypeError("module must be nn.Sequential to be partitioned")
named_children = list(module.named_children())
if len(named_children) != len(module):
raise ValueError("module with duplicate children is not supported")
def verify_splitting(module: nn.Sequential, partitions: List[nn.Sequential], balance: Iterable[int],) -> None:
num_parameters = len(list(module.parameters()))
num_child_parameters = sum(len(list(child.parameters())) for child in module.children())
if num_parameters == num_child_parameters:
return
for i in range(len(partitions)):
for j in range(i + 1, len(partitions)):
parti = partitions[i]
partj = partitions[j]
for p in parti.parameters():
for q in partj.parameters():
if p is q:
raise ValueError("module with duplicate parameters on distinct devices is not supported")
class BalanceError(ValueError):
pass
def check_balance(module: Any, balance: Iterable[int], filter_unique: bool = False) -> None:
if filter_unique:
module_len = len(set(map(id, module)))
else:
module_len = len(module)
if module_len != sum(balance):
raise BalanceError(
f"module and sum of balance have different length (module: {len(module)}, sum of balance: {sum(balance)})"
)
if any(x <= 0 for x in balance):
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
@dataclass
class PartitionInfo:
location: Location
modules: "OrderedDict[str, nn.Module]"
invocations: List[Invocation] = field(default_factory=list)
def __len__(self) -> int:
return len(self.modules)
def instantiate_partition(
module: Union[nn.Sequential, ListOfLazyModules],
balance: Iterable[int],
group: torch.distributed.ProcessGroup,
style: PipelineStyle,
) -> List[ModuleWrapper]:
balance = list(balance)
check_balance(module, balance, True)
layers: NamedModules = OrderedDict()
def maybe_realize(layer: Any) -> nn.Module:
if isinstance(layer, nn.Module):
return layer
elif callable(layer):
return layer()
else:
raise TypeError(f"layer must be nn.Module or callable, is {type(layer)}")
def iterate_module(module: Union[nn.Sequential, list]) -> Iterable[Tuple[Any, nn.Module]]:
if isinstance(module, nn.Sequential):
yield from module.named_children()
else:
yield from ((str(k), v) for k, v in enumerate(module))
if style == PipelineStyle.AsyncSchedule:
module_ids = list(map(id, module))
index_of_first_use = [module_ids.index(x) for x in module_ids]
locations: List[Location] = []
module_iter = enumerate(iterate_module(module))
partitions: List[List[PartitionInfo]] = []
for bi, b in enumerate(balance):
modules_for_rank: List[PartitionInfo] = []
current_module: OrderedDict[str, nn.Module] = OrderedDict()
def current_location() -> Location:
return Location(bi, len(modules_for_rank))
def append_module(mod: "OrderedDict[str, nn.Module]") -> None:
modules_for_rank.append(PartitionInfo(current_location(), mod))
while sum(map(len, modules_for_rank)) + len(current_module) < b:
module_index, (name, layer) = next(module_iter)
if index_of_first_use[module_index] != module_index:
# Subsequent reuse of a module
locations.append(locations[index_of_first_use[module_index]])
continue
is_reused = index_of_first_use.count(index_of_first_use[module_index]) > 1
if is_reused and len(current_module) > 0:
append_module(current_module)
current_module = OrderedDict()
current_module[str(name)] = layer
locations.append(current_location())
if is_reused:
append_module(current_module)
current_module = OrderedDict()
if len(current_module) > 0:
append_module(current_module)
partitions.append(modules_for_rank)
filtered_locations: List[Optional[Location]] = [loc for loc, _ in itertools.groupby(locations)]
filtered_locations.append(None)
for i in range(len(filtered_locations) - 1):
loc = filtered_locations[i]
assert loc
if i == 0:
inv = Invocation(i, loc, None, filtered_locations[i + 1])
else:
inv = Invocation(i, loc, filtered_locations[i - 1], filtered_locations[i + 1])
partitions[loc.stage][loc.index].invocations.append(inv)
invocations = enumerate(iterate_module(module))
partition = partitions[group.rank()]
result: List[ModuleWrapper] = []
for partition_info in partition:
wrapper = ModuleWrapper(
nn.Sequential(OrderedDict((k, maybe_realize(m)) for k, m in partition_info.modules.items())),
partition_info.location,
partition_info.invocations,
)
if not isinstance(module, nn.Sequential):
for layer in wrapper.module:
if isinstance(layer, Skippable):
raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction")
result.append(wrapper)
return result
j = 0
for name, layer in iterate_module(module):
layers[name] = layer
if len(layers) == balance[j]:
if j == group.rank():
for key in layers:
layers[key] = maybe_realize(layers[key])
if not isinstance(module, nn.Sequential):
for layer in layers.values():
if isinstance(layer, Skippable):
raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction")
return [ModuleWrapper(nn.Sequential(layers), Location(j, 0))]
# Prepare for the next partition.
layers.clear()
j += 1
raise ValueError("Souldn't get here, more ranks than partitions")
def split_module(module: nn.Sequential, balance: Iterable[int],) -> Tuple[List[nn.Sequential], List[int]]:
"""Splits a module into multiple partitions.
Returns:
A tuple of (partitions, balance).
Partitions are represented as a :class:`~torch.nn.ModuleList` whose
item is a partition. All layers in a partition are placed in the
same device.
Raises:
BalanceError:
wrong balance
IndexError:
the number of devices is fewer than the number of partitions.
"""
balance = list(balance)
check_balance(module, balance)
j = 0
partitions = []
layers: NamedModules = OrderedDict()
for name, layer in module.named_children():
layers[name] = layer
if len(layers) == balance[j]:
# Group buffered layers as a partition.
partition = nn.Sequential(layers)
partitions.append(partition)
# Prepare for the next partition.
layers.clear()
j += 1
partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))
return partitions, balance
MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement")
class MultiProcessPipe(Module):
"""Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module
to train on Pipe_. If the module requires lots of memory, Pipe will be
very efficient.
::
model = nn.Sequential(a, b, c, d)
model = Pipe(model, balance=[1, 1, 1, 1], chunks=8)
output = model(input)
.. _Pipe: https://arxiv.org/abs/1811.06965
Pipe combines pipeline parallelism with checkpointing to reduce peak
memory required to train while minimizing device under-utilization.
You should determine the balance when defining a :class:`Pipe` module, as
balancing will not be done automatically. The module will be partitioned
into multiple devices according to the given balance. You may rely on
heuristics to find your own optimal configuration.
Args:
module (torch.nn.Sequential):
sequential module to be parallelized
balance (ints):
list of number of layers in each partition
Keyword Args:
style (PipelineStyle):
whether to use a single process for all pipeline stages or to assign
one stage per process
group (ProcessGroup):
specific to `style=MultiProcess`, the process group that all
pipeline stages are a member of. Defaults to
`get_pipeline_parallel_group()`
worker_map (Dict[int, str]):
a map from worker name (the first argument to
`torch.distributed.rpc.init_rpc`) to global rank (i.e.
`torch.distributed.get_rank()`) needed in order for pipeline stages
to communicate with each other
input_device (device):
the device on which tensors should be located before being passed to
the first module in a given pipeline stage
chunks (int):
number of micro-batches (default: ``1``)
checkpoint (str):
when to enable checkpointing, one of ``'always'``,
``'except_last'``, or ``'never'`` (default: ``'except_last'``)
deferred_batch_norm (bool):
whether to use deferred BatchNorm moving statistics (default:
:data:`False`, see :class:`DeferredBatchNorm` for more
details)
pipelined_backward (bool, optional):
if True, call torch.autograd.backward once per microbatch on the
backward pass (instead of once for the whole batch). This works
around a potential deadlock in pytorch when using tensor parallelism
at the same time. Defaults to `True` if
`get_model_parallel_world_size() > 1`
(default: `None`)
retain_graph (bool):
The value passed to `torch.autograd.backwards(..., retain_graph=<value>)
(default: = `True`)
Raises:
TypeError:
the module is not a :class:`nn.Sequential <torch.nn.Sequential>`.
ValueError:
invalid arguments, or wrong balance
IndexError:
the number of devices is fewer than the number of partitions.
"""
MultiProcess: PipelineStyle = PipelineStyle.MultiProcess
AsyncSchedule: PipelineStyle = PipelineStyle.AsyncSchedule
#: The number of layers in each partition.
balance: List[int] = []
# ^^
# The default value [] required for Sphinx's autoattribute.
#: The devices mapped to each partition.
#:
#: ``devices[-1]`` refers to the device of the last partition, which means
#: it is the output device. Probably, you need to use it to transfer the
#: target to calculate the loss without a device mismatch
#: :exc:`RuntimeError`. For example::
#:
#: out_device = pipe.devices[-1]
#:
#: for input, target in loader:
#: target = target.to(out_device, non_blocking=True)
#: output = pipe(input)
#: loss = F.cross_entropy(output, target)
#:
#: The number of micro-batches.
chunks: int = 1
#: The checkpoint mode to determine when to enable checkpointing. It is one
#: of ``'always'``, ``'except_last'``, or ``'never'``.
checkpoint: str = "except_last"
def __init__(
self,
module: Union[nn.Sequential, ListOfLazyModules],
balance: Optional[Iterable[int]] = None,
*,
style: PipelineStyle = PipelineStyle.MultiProcess,
group: Optional[torch.distributed.ProcessGroup] = None,
worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None,
chunks: int = chunks,
checkpoint: str = checkpoint,
deferred_batch_norm: bool = False,
pipelined_backward: bool = None,
retain_graph: bool = False,
loss_fn: Optional[nn.Module] = None,
) -> None:
super().__init__()
chunks = int(chunks)
checkpoint = str(checkpoint)
if balance is None:
raise ValueError(recommend_auto_balance("balance is required"))
if chunks <= 0:
raise ValueError("number of chunks must be positive integer")
if checkpoint not in ["always", "except_last", "never"]:
raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'")
verify_module(module)
# Verify if the underlying skippable modules satisfy integrity. The
# integrity can be verified before forward() because it is static.
if isinstance(module, nn.Sequential):
verify_skippables(module)
self.chunks = chunks
self.checkpoint = checkpoint
self.pipelined_backward = pipelined_backward
self.retain_graph = retain_graph
self.pipeline: Optional[MultiProcessPipeline]
self.loss_fn = loss_fn
self.lock = threading.Lock()
self.group = group
self.worker_map = worker_map
self.input_device = input_device
# The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
if self.group is None:
self.group = get_pipeline_parallel_group()
assert self.group
self.balance = list(balance)
if self.group.size() < len(self.balance):
raise IndexError(
f"too few ranks to hold given partitions (ranks: {self.group.size()}, partitions:"
f" {len(self.balance)})"
)
try:
rank = self.group.rank()
if rank >= len(self.balance):
warnings.warn("More ranks than partitions, some ranks unused")
self.partitions: List[ModuleWrapper] = []
else:
self.partitions = instantiate_partition(module, balance, self.group, style)
if deferred_batch_norm:
for part in self.partitions:
part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks)
for name, part in enumerate(self.partitions):
self.add_module(str(name), part.module)
if isinstance(module, nn.Sequential):
local_partitions, _ = split_module(module, balance)
self._skip_layout = inspect_skip_layout(local_partitions)
else:
self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom)
except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc)))
rank = self.group.rank()
if rank >= len(self.balance):
self.pipeline = None
self.final_stage = False
else:
self.final_stage = rank == len(self.balance) - 1
assert loss_fn is None or self.final_stage
self.pipeline = MultiProcessPipeline(
cast(List[nn.Sequential], self.partitions),
self._skip_layout,
checkpoint_stop,
style=style,
group=self.group,
worker_map=self.worker_map,
input_device=self.input_device,
final_stage=self.final_stage,
)
del module
if self.pipelined_backward is None:
if get_model_parallel_world_size() > 1:
self.pipelined_backward = True
else:
self.pipelined_backward = False
def __len__(self) -> int:
"""Counts the length of the underlying sequential module."""
return sum(len(p) for p in self.partitions)
def __getitem__(self, index: int) -> nn.Module:
"""Gets a layer in the underlying sequential module."""
partitions: List[Any]
partitions = self.partitions
if index < 0:
partitions = partitions[::-1]
for partition in partitions:
try:
if isinstance(partition, ModuleWrapper):
return partition.module[index]
else:
return partition[index]
except IndexError:
pass
shift = len(partition)
if index < 0:
index += shift
else:
index -= shift
raise IndexError
def __iter__(self) -> Iterable[nn.Module]:
"""Iterates over children of the underlying sequential module."""
for partition in self.partitions:
yield from partition.module
def forward(self, input: TensorOrTensors, *, event=None) -> TensorOrTensors: # type: ignore
""":class:`MultiProcessPipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But
there's type restriction. Input and output have to be a
:class:`~torch.Tensor` or a tuple of tensors. This restriction is
applied at partition boundaries too.
Args:
input (torch.Tensor or tensors): input mini-batch
Returns:
tensor or tensors: output mini-batch
Raises:
TypeError: input is not a tensor or tensors.
"""
microbatch.check(input)
if not self.group:
# Empty sequential module is not illegal.
return input
if not self.pipeline:
# No pipeline is not illegal, more ranks than partitions
return input
# Divide a mini-batch into micro-batches.
batches = microbatch.scatter(input, self.chunks)
# Run pipeline parallelism.
with self.lock:
self.pipeline.run(self.training, batches, event)
if not self.final_stage:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
return batches # type: ignore
else:
# Merge the micro-batches into one mini-batch.
if self.pipelined_backward:
with torch.no_grad():
output = microbatch.gather(batches)
from .phony import get_phony
phony = get_phony(
torch.device(torch.cuda.current_device() if torch.cuda.is_available() else "cpu"),
requires_grad=True,
)
output = PipelinedBackwardPass.apply(output, batches, phony, True) # self.retain_graph)
else:
output = microbatch.gather(batches)
return output
def back_helper(self, output: List[microbatch.Batch]) -> None:
if self.final_stage:
raise ValueError("back_helper should only be called on non-final stages")
if self.pipeline:
self.pipeline.back_helper(list(reversed(output)))
class PipelinedBackwardPass(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx, input: TensorOrTensors, batches, phony, retain_graph) -> TensorOrTensors:
ctx.batches = batches
ctx.retain_graph = retain_graph
return input
@staticmethod
# type: ignore
def backward(ctx, *grads) -> Tuple:
with torch.no_grad():
grad_batches = microbatch.scatter(grads, len(ctx.batches))
for grad, batch in reversed(list(zip(grad_batches, ctx.batches))):
for t in batch:
t.retain_grad()
torch.autograd.backward(batch.tensor_or_tensors, grad_tensors=(*grad,), retain_graph=ctx.retain_graph)
with torch.no_grad():
if ctx.batches[0].atomic:
tensors = tuple(b.tensor.grad for b in ctx.batches)
output: TensorOrTensors = torch.cat(tensors)
else:
rotated = [[t.grad for t in b.tensors] for b in ctx.batches]
output_buf = []
for tensors in zip(*rotated):
output_buf.append(torch.cat(tensors))
output = tuple(output_buf)
del ctx.batches
return (output, None, None, None)
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The multiprocess pipeline parallelism of Pipe."""
import logging
import os
from queue import Empty as QueueEmpty
from queue import Queue
from threading import Event
from types import TracebackType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union, cast
import torch
from torch import Tensor, nn
from torch.autograd.profiler import record_function
from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from .async_schedule import AsyncEventLoop, ModuleWrapper
from .checkpoint import Checkpointing
from .messages import MakeTransport, Transport
from .microbatch import Batch
from .skip import Namespace
from .skip.layout import SkipLayout
from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from .types import (
ACTIVATIONS_GRADS_QUEUE,
PORTAL_QUEUE,
SKIP_TENSOR_QUEUE,
PipelineStyle,
PipeMessage,
TensorOrTensors,
Tensors,
)
from .worker import Task
__all__: List[str] = []
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
class SendOperator(torch.autograd.Function):
"""Send activations to the next pipeline stage"""
@staticmethod
# type: ignore
def forward(ctx, src_rank, dst_rank, transport: Transport, input: List[Tensor], index: int) -> Tensors:
assert src_rank == torch.distributed.get_rank()
transport.send_message(
PipeMessage(src_rank, dst_rank, queue_name=ACTIVATIONS_GRADS_QUEUE, args=index, tensors=tuple(input)),
)
return ()
@staticmethod
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tensors:
return tuple(grad)
class RecvOperator(torch.autograd.Function):
"""Receive activations to the previous pipeline stage"""
@staticmethod
# type: ignore
def forward(ctx, dst_rank: int, tensor: Tensor, input_device, transport: Transport, index: int) -> Tensors:
assert dst_rank == torch.distributed.get_rank()
ctx.transport = transport
ctx.index = index
result = transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, index)
def maybe_requires_grad(t: Tensor) -> Tensor:
if t.dtype.is_floating_point:
return t.requires_grad_()
return t
return tuple(maybe_requires_grad(r) for r in result)
@staticmethod
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
ctx.transport.send_message(
PipeMessage(
this_rank,
ranks[ranks.index(this_rank) - 1],
queue_name=ACTIVATIONS_GRADS_QUEUE,
args=ctx.index,
tensors=tuple(grad),
),
)
return (None, None, None, None, None)
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if TYPE_CHECKING:
InQueue = Queue[Optional["Task"]]
OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]]
else:
InQueue = Queue
OutQueue = Queue
def create_task(
style: PipelineStyle,
checkpoint_stop: int,
i: int,
j: int,
batch: Batch,
partition: nn.Sequential,
skip_trackers: List[SkipTrackerThroughPotals],
) -> Task:
# Determine whether checkpointing or not.
if i < checkpoint_stop:
def function(
input: TensorOrTensors,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
ret = partition(input)
# We do a check here because the backtrace from the checkpoint backward code path
# is very hard to make sense. It would be much easier to check earlier at this point.
assert type(ret) is not list, "Only Tensor or Tuple of Tensor output is supported"
return ret
chk = Checkpointing(function, batch)
task = Task(None, compute=chk.checkpoint, finalize=chk.recompute)
del function, chk # TODO(tom) maybe remove
else:
def compute(
batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> Batch:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition)
task = Task(None, compute=compute, finalize=None)
del compute # TODO(tom) maybe remove
return task
class MultiProcessPipeline:
"""The multiprocess pipeline parallelism for Pipe."""
def __init__(
self,
partitions: List[nn.Sequential],
skip_layout: SkipLayout,
checkpoint_stop: int,
style: PipelineStyle,
group: Optional[torch.distributed.ProcessGroup] = None,
worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None,
final_stage: bool = False,
) -> None:
self.partitions: List[ModuleWrapper] = cast(List[ModuleWrapper], partitions)
self.skip_layout = skip_layout
self.__checkpoint_stop = checkpoint_stop
self.style = style
self.group = group
self.training: bool
self.transport = MakeTransport(
use_rpc=("OMPI_COMM_WORLD_RANK" not in os.environ) or ("FORCE_RPC" in os.environ),
worker_map=worker_map,
input_device=input_device,
)
self.input_device = input_device
self.all_at_once = False
self.callcount = 0
self.final_stage = final_stage
@property
def checkpoint_stop(self) -> int:
# Disable checkpointing if in eval mode.
training = self.partitions[0].module.training
if not training:
return 0
return self.__checkpoint_stop
def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> None:
"""Runs pipeline parallelism.
It modifies the given batches in place.
"""
self.training = training
m = len(batches)
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))]
if self.style is PipelineStyle.MultiProcess:
assert self.group
schedule = [(i, self.group.rank()) for i in range(m)]
self.compute(batches, schedule, skip_trackers)
elif self.style is PipelineStyle.AsyncSchedule:
assert self.group
rank = self.group.rank()
event_loop = AsyncEventLoop(
self.partitions, self.group, self.transport, self.training, self.checkpoint_stop,
)
if rank == 0 and not self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event head")
event_loop.event_loop_head(batches, skip_trackers, event)
logging.debug(f"{torch.distributed.get_rank()}: exited event head")
elif self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event tail")
event_loop.event_loop_tail(batches, skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event tail")
else:
logging.debug(f"{torch.distributed.get_rank()}: entered event loop")
event_loop.event_loop(len(batches), skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event loop")
self.callcount += 1
def get_batch_from_previous_stage(
self, i: int, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]
) -> Batch:
phony = torch.empty(0, device=self.input_device, requires_grad=True)
result = RecvOperator.apply(torch.distributed.get_rank(), phony, self.input_device, self.transport, i)
if len(result) == 1:
batch = Batch(result[0], i)
else:
batch = Batch(result, i)
self.recv_skip_tensors(skip_trackers, batches)
return batch
def send_skip_tensors(
self, this_rank: int, ranks: List[int], batch: Batch, i: int, skip_trackers: List[SkipTrackerThroughPotals]
) -> None:
assert self.group
for next_j, ns, name in self.skip_layout.copy_policy_by_src(self.group.rank()):
life = skip_trackers[i].portals[(ns, name)].tensor_life
loaded = skip_trackers[i].load(batch, ns, name)
if loaded is not None:
tensors = tuple([loaded])
else:
tensors = tuple()
self.transport.send_message(
PipeMessage(
this_rank, ranks[next_j], queue_name=SKIP_TENSOR_QUEUE, args=(i, ns, name, life), tensors=tensors,
),
sync=True,
)
def recv_skip_tensors(self, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]) -> None:
while True:
try:
message = self.transport.recv_message(SKIP_TENSOR_QUEUE, nowait=True)
(si, ns, name, life) = message.args
value: Optional[TensorOrTensors] = message.tensors
assert isinstance(value, tuple)
if len(value) == 0:
value = None
else:
assert len(value) == 1
value = value[0]
skip_trackers[si].save(batches[si], ns, name, value)
old_life = skip_trackers[si].portals[(ns, name)].tensor_life
if life != 0:
skip_trackers[si].portals[(ns, name)].tensor_life = life
except QueueEmpty:
break
def execute_task(self, task: Task, i: int, skip_trackers: List[SkipTrackerThroughPotals]) -> Batch:
batch = task.compute()
assert self.group
rank = self.group.rank()
if self.style is PipelineStyle.MultiProcess and not self.final_stage:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
self.send_skip_tensors(this_rank, ranks, batch, i, skip_trackers)
SendOperator.apply(this_rank, ranks[ranks.index(this_rank) + 1], self.transport, [*batch], i)
for portal in skip_trackers[i].portals.values():
portal.pipeline = self
task.finalize(batch)
return batch
def compute(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals]
) -> None:
"""Runs tasks with synchronization to copy streams."""
if self.style is PipelineStyle.MultiProcess:
assert self.group
n = self.group.size()
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for i, j in schedule:
batch = batches[i]
if self.style is PipelineStyle.MultiProcess:
assert len(self.partitions) == 1
partition = self.partitions[0]
assert self.group
if self.group.rank() != 0:
batch = self.get_batch_from_previous_stage(i, skip_trackers, batches)
task = create_task(self.style, self.checkpoint_stop, i, j, batch, partition.module, skip_trackers)
batches[i] = self.execute_task(task, i, skip_trackers)
def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: TensorOrTensors) -> None:
dest, src = self.skip_layout.by_ns_name.get(ns_name, (-1, -1))
if dest == src:
return
ranks = get_pipeline_parallel_ranks()
dst_rank = ranks[dest]
if dst_rank == torch.distributed.get_rank():
return
if isinstance(grad, Tensor):
grad = tuple([grad])
self.transport.send_message(
PipeMessage(ranks[src], dst_rank, queue_name=PORTAL_QUEUE, args=(ns_name, index), tensors=grad), sync=True,
)
def recv_portal_grad(self, expected_ns_name: Tuple[Namespace, str], expected_index: int) -> Tensor:
message = self.transport.recv_message(PORTAL_QUEUE)
(ns_name, index) = message.args
grad = message.tensors
assert len(grad) == 1
result = grad[0]
assert index == expected_index and ns_name == expected_ns_name
return result
def back_helper(self, output: List[Batch]) -> None:
if self.style == PipelineStyle.AsyncSchedule:
return
o = list(output)
tensors: Tensors
if self.all_at_once:
# FIXME(tom) allow specifying this branch when constructing Pipe(), add a test
grads = []
for i, batch in enumerate(o):
rank = torch.distributed.get_rank()
found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, i)
assert len(found) == 1
grads.append(found[0])
tensors = tuple(x.tensor_or_tensors for x in o) # type: ignore
try:
torch.autograd.backward(tensors, grad_tensors=grads, retain_graph=True)
except Exception as e:
raise RuntimeError("Autograd failed") from e
else:
rank = torch.distributed.get_rank()
for batch in o:
found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, batch.index)
if batch.atomic:
tensors = tuple([batch.tensor])
else:
tensors = batch.tensors
if len(found) != len(tensors):
raise RuntimeError("different number of tensors and gradients")
grads = []
final_tensors = []
for i, tensor in enumerate(tensors):
if tensor.requires_grad or getattr(tensor, "grad_fn", None) is not None:
grads.append(found[i])
final_tensors.append(tensor)
try:
torch.autograd.backward(final_tensors, grad_tensors=grads, retain_graph=True)
except Exception as e:
raise RuntimeError(f"Autograd failed on {torch.distributed.get_rank()}") from e
...@@ -19,29 +19,21 @@ ...@@ -19,29 +19,21 @@
"""The Pipe interface.""" """The Pipe interface."""
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast
import itertools
import threading
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union, cast
import warnings
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
import torch.autograd import torch.autograd
import torch.cuda import torch.cuda
from fairscale.nn.model_parallel import get_model_parallel_world_size, get_pipeline_parallel_group
from . import microbatch from . import microbatch
from .async_schedule import Invocation, Location, ModuleWrapper
from .batchnorm import DeferredBatchNorm from .batchnorm import DeferredBatchNorm
from .pipeline import Pipeline from .pipeline import Pipeline
from .skip.layout import SkipLayout, inspect_skip_layout from .skip.layout import inspect_skip_layout
from .skip.skippable import Skippable, verify_skippables from .skip.skippable import verify_skippables
from .stream import AbstractStream, new_stream from .stream import AbstractStream, new_stream
from .types import LazyModule, PipelineStyle
__all__ = ["Pipe", "LazyModule"] __all__ = ["Pipe"]
Device = Union[torch.device, int, str] Device = Union[torch.device, int, str]
...@@ -50,8 +42,6 @@ Devices = Union[Iterable[Device], List[Device]] ...@@ -50,8 +42,6 @@ Devices = Union[Iterable[Device], List[Device]]
Tensors = Tuple[Tensor, ...] Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors] TensorOrTensors = Union[Tensor, Tensors]
ListOfLazyModules = List[LazyModule]
if TYPE_CHECKING: if TYPE_CHECKING:
Module = nn.Module[TensorOrTensors] Module = nn.Module[TensorOrTensors]
NamedModules = OrderedDict[str, Module] NamedModules = OrderedDict[str, Module]
...@@ -79,34 +69,17 @@ naive automatic balancing: ...@@ -79,34 +69,17 @@ naive automatic balancing:
""" """
# FIXME(tom) make this a valid way to call def verify_module(module: nn.Sequential) -> None:
def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None: if not isinstance(module, nn.Sequential):
for layer in module: raise TypeError("module must be nn.Sequential to be partitioned")
if isinstance(layer, nn.Module):
pass
elif isinstance(layer, LazyModule):
pass
else:
raise TypeError(f"layer {type(layer)} must be nn.Module or LazyModule to be partitioned")
def verify_module(module: Union[nn.Sequential, ListOfLazyModules]) -> None: named_children = list(module.named_children())
if isinstance(module, Iterable) and not isinstance(module, nn.Sequential): if len(named_children) != len(module):
verify_list_of_callable(module) raise ValueError("module with duplicate children is not supported")
else:
if not isinstance(module, nn.Sequential):
raise TypeError("module must be nn.Sequential to be partitioned")
named_children = list(module.named_children())
if len(named_children) != len(module):
raise ValueError("module with duplicate children is not supported")
def verify_splitting( def verify_splitting(
module: nn.Sequential, module: nn.Sequential, partitions: List[nn.Sequential], balance: Iterable[int], devices: List[torch.device]
partitions: List[nn.Sequential],
balance: Iterable[int],
devices: Optional[List[torch.device]],
) -> None: ) -> None:
num_parameters = len(list(module.parameters())) num_parameters = len(list(module.parameters()))
num_child_parameters = sum(len(list(child.parameters())) for child in module.children()) num_child_parameters = sum(len(list(child.parameters())) for child in module.children())
...@@ -117,7 +90,7 @@ def verify_splitting( ...@@ -117,7 +90,7 @@ def verify_splitting(
for j in range(i + 1, len(partitions)): for j in range(i + 1, len(partitions)):
parti = partitions[i] parti = partitions[i]
partj = partitions[j] partj = partitions[j]
if devices and devices[i] == devices[j]: if devices[i] == devices[j]:
continue continue
for p in parti.parameters(): for p in parti.parameters():
for q in partj.parameters(): for q in partj.parameters():
...@@ -129,159 +102,9 @@ class BalanceError(ValueError): ...@@ -129,159 +102,9 @@ class BalanceError(ValueError):
pass pass
def check_balance(module: Any, balance: Iterable[int], filter_unique: bool = False) -> None:
if filter_unique:
module_len = len(set(map(id, module)))
else:
module_len = len(module)
if module_len != sum(balance):
raise BalanceError(
f"module and sum of balance have different length (module: {len(module)}, sum of balance: {sum(balance)})"
)
if any(x <= 0 for x in balance):
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
@dataclass
class PartitionInfo:
location: Location
modules: "OrderedDict[str, nn.Module]"
invocations: List[Invocation] = field(default_factory=list)
def __len__(self) -> int:
return len(self.modules)
def instantiate_partition(
module: Union[nn.Sequential, ListOfLazyModules],
balance: Iterable[int],
group: torch.distributed.ProcessGroup,
style: PipelineStyle,
) -> List[ModuleWrapper]:
balance = list(balance)
check_balance(module, balance, True)
layers: NamedModules = OrderedDict()
def maybe_realize(layer: Any) -> nn.Module:
if isinstance(layer, nn.Module):
return layer
elif callable(layer):
return layer()
else:
raise TypeError(f"layer must be nn.Module or callable, is {type(layer)}")
def iterate_module(module: Union[nn.Sequential, list]) -> Iterable[Tuple[Any, nn.Module]]:
if isinstance(module, nn.Sequential):
yield from module.named_children()
else:
yield from ((str(k), v) for k, v in enumerate(module))
if style == PipelineStyle.AsyncSchedule:
module_ids = list(map(id, module))
index_of_first_use = [module_ids.index(x) for x in module_ids]
locations: List[Location] = []
module_iter = enumerate(iterate_module(module))
partitions: List[List[PartitionInfo]] = []
for bi, b in enumerate(balance):
modules_for_rank: List[PartitionInfo] = []
current_module: OrderedDict[str, nn.Module] = OrderedDict()
def current_location() -> Location:
return Location(bi, len(modules_for_rank))
def append_module(mod: "OrderedDict[str, nn.Module]") -> None:
modules_for_rank.append(PartitionInfo(current_location(), mod))
while sum(map(len, modules_for_rank)) + len(current_module) < b:
module_index, (name, layer) = next(module_iter)
if index_of_first_use[module_index] != module_index:
# Subsequent reuse of a module
locations.append(locations[index_of_first_use[module_index]])
continue
is_reused = index_of_first_use.count(index_of_first_use[module_index]) > 1
if is_reused and len(current_module) > 0:
append_module(current_module)
current_module = OrderedDict()
current_module[str(name)] = layer
locations.append(current_location())
if is_reused:
append_module(current_module)
current_module = OrderedDict()
if len(current_module) > 0:
append_module(current_module)
partitions.append(modules_for_rank)
filtered_locations: List[Optional[Location]] = [loc for loc, _ in itertools.groupby(locations)]
filtered_locations.append(None)
for i in range(len(filtered_locations) - 1):
loc = filtered_locations[i]
assert loc
if i == 0:
inv = Invocation(i, loc, None, filtered_locations[i + 1])
else:
inv = Invocation(i, loc, filtered_locations[i - 1], filtered_locations[i + 1])
partitions[loc.stage][loc.index].invocations.append(inv)
invocations = enumerate(iterate_module(module))
partition = partitions[group.rank()]
result: List[ModuleWrapper] = []
for partition_info in partition:
wrapper = ModuleWrapper(
nn.Sequential(OrderedDict((k, maybe_realize(m)) for k, m in partition_info.modules.items())),
partition_info.location,
partition_info.invocations,
)
if not isinstance(module, nn.Sequential):
for layer in wrapper.module:
if isinstance(layer, Skippable):
raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction")
result.append(wrapper)
return result
j = 0
for name, layer in iterate_module(module):
layers[name] = layer
if len(layers) == balance[j]:
if j == group.rank():
for key in layers:
layers[key] = maybe_realize(layers[key])
if not isinstance(module, nn.Sequential):
for layer in layers.values():
if isinstance(layer, Skippable):
raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction")
return [ModuleWrapper(nn.Sequential(layers), Location(j, 0))]
# Prepare for the next partition.
layers.clear()
j += 1
raise ValueError("Souldn't get here, more ranks than partitions")
def split_module( def split_module(
module: nn.Sequential, balance: Iterable[int], devices: Optional[List[torch.device]], module: nn.Sequential, balance: Iterable[int], devices: List[torch.device],
) -> Tuple[List[nn.Sequential], List[int], Optional[List[torch.device]]]: ) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]:
"""Splits a module into multiple partitions. """Splits a module into multiple partitions.
Returns: Returns:
...@@ -300,11 +123,18 @@ def split_module( ...@@ -300,11 +123,18 @@ def split_module(
""" """
balance = list(balance) balance = list(balance)
check_balance(module, balance) if len(module) != sum(balance):
raise BalanceError(
"module and sum of balance have different length "
f"(module: {len(module)}, sum of balance: {sum(balance)})"
)
if devices and len(balance) > len(devices): if any(x <= 0 for x in balance):
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
if len(balance) > len(devices):
raise IndexError( raise IndexError(
f"too few devices to hold given partitions (devices: {len(devices)}, partitions: {len(balance)})" "too few devices to hold given partitions " f"(devices: {len(devices)}, partitions: {len(balance)})"
) )
j = 0 j = 0
...@@ -318,9 +148,8 @@ def split_module( ...@@ -318,9 +148,8 @@ def split_module(
# Group buffered layers as a partition. # Group buffered layers as a partition.
partition = nn.Sequential(layers) partition = nn.Sequential(layers)
if devices: device = devices[j]
device = devices[j] partition.to(device)
partition.to(device)
partitions.append(partition) partitions.append(partition)
...@@ -329,13 +158,12 @@ def split_module( ...@@ -329,13 +158,12 @@ def split_module(
j += 1 j += 1
partitions = cast(List[nn.Sequential], nn.ModuleList(partitions)) partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))
if devices: del devices[j:]
del devices[j:]
return partitions, balance, devices return partitions, balance, devices
MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement") MOVING_DENIED = TypeError("denied to move parameters and buffers, " "because Pipe should manage device placement")
class Pipe(Module): class Pipe(Module):
...@@ -365,23 +193,8 @@ class Pipe(Module): ...@@ -365,23 +193,8 @@ class Pipe(Module):
list of number of layers in each partition list of number of layers in each partition
Keyword Args: Keyword Args:
style (PipelineStyle):
whether to use a single process for all pipeline stages or to assign
one stage per process
devices (iterable of devices): devices (iterable of devices):
devices to use (default: all CUDA devices) devices to use (default: all CUDA devices)
group (ProcessGroup):
specific to `style=MultiProcess`, the process group that all
pipeline stages are a member of. Defaults to
`get_pipeline_parallel_group()`
worker_map (Dict[int, str]):
a map from worker name (the first argument to
`torch.distributed.rpc.init_rpc`) to global rank (i.e.
`torch.distributed.get_rank()`) needed in order for pipeline stages
to communicate with each other
input_device (device):
the device on which tensors should be located before being passed to
the first module in a given pipeline stage
chunks (int): chunks (int):
number of micro-batches (default: ``1``) number of micro-batches (default: ``1``)
checkpoint (str): checkpoint (str):
...@@ -389,18 +202,8 @@ class Pipe(Module): ...@@ -389,18 +202,8 @@ class Pipe(Module):
``'except_last'``, or ``'never'`` (default: ``'except_last'``) ``'except_last'``, or ``'never'`` (default: ``'except_last'``)
deferred_batch_norm (bool): deferred_batch_norm (bool):
whether to use deferred BatchNorm moving statistics (default: whether to use deferred BatchNorm moving statistics (default:
:data:`False`, see :class:`DeferredBatchNorm` for more :data:`False`, see :ref:`Deferred Batch Normalization` for more
details) details)
pipelined_backward (bool, optional):
if True, call torch.autograd.backward once per microbatch on the
backward pass (instead of once for the whole batch). This works
around a potential deadlock in pytorch when using tensor parallelism
at the same time. Defaults to `True` if
`get_model_parallel_world_size() > 1`
(default: `None`)
retain_graph (bool):
The value passed to `torch.autograd.backwards(..., retain_graph=<value>)
(default: = `True`)
Raises: Raises:
TypeError: TypeError:
...@@ -412,10 +215,6 @@ class Pipe(Module): ...@@ -412,10 +215,6 @@ class Pipe(Module):
""" """
SingleProcess: PipelineStyle = PipelineStyle.SingleProcess
MultiProcess: PipelineStyle = PipelineStyle.MultiProcess
AsyncSchedule: PipelineStyle = PipelineStyle.AsyncSchedule
#: The number of layers in each partition. #: The number of layers in each partition.
balance: List[int] = [] balance: List[int] = []
# ^^ # ^^
...@@ -435,7 +234,7 @@ class Pipe(Module): ...@@ -435,7 +234,7 @@ class Pipe(Module):
#: output = pipe(input) #: output = pipe(input)
#: loss = F.cross_entropy(output, target) #: loss = F.cross_entropy(output, target)
#: #:
devices: Optional[List[torch.device]] = None devices: List[torch.device] = []
#: The number of micro-batches. #: The number of micro-batches.
chunks: int = 1 chunks: int = 1
...@@ -446,20 +245,13 @@ class Pipe(Module): ...@@ -446,20 +245,13 @@ class Pipe(Module):
def __init__( def __init__(
self, self,
module: Union[nn.Sequential, ListOfLazyModules], module: nn.Sequential,
balance: Optional[Iterable[int]] = None, balance: Optional[Iterable[int]] = None,
*, *,
style: PipelineStyle = PipelineStyle.SingleProcess,
devices: Optional[Devices] = None, devices: Optional[Devices] = None,
group: Optional[torch.distributed.ProcessGroup] = None,
worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None,
chunks: int = chunks, chunks: int = chunks,
checkpoint: str = checkpoint, checkpoint: str = checkpoint,
deferred_batch_norm: bool = False, deferred_batch_norm: bool = False,
pipelined_backward: bool = None,
retain_graph: bool = False,
loss_fn: Optional[nn.Module] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -477,145 +269,50 @@ class Pipe(Module): ...@@ -477,145 +269,50 @@ class Pipe(Module):
# Verify if the underlying skippable modules satisfy integrity. The # Verify if the underlying skippable modules satisfy integrity. The
# integrity can be verified before forward() because it is static. # integrity can be verified before forward() because it is static.
if isinstance(module, nn.Sequential): verify_skippables(module)
verify_skippables(module)
self.chunks = chunks self.chunks = chunks
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.pipelined_backward = pipelined_backward
self.retain_graph = retain_graph
self.pipeline: Optional[Pipeline]
self.loss_fn = loss_fn
self.lock = threading.Lock()
self.group = group
self.worker_map = worker_map
self.input_device = input_device
self._copy_streams: List[List[AbstractStream]] = []
# The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
if style is PipelineStyle.SingleProcess:
module = cast(nn.Sequential, module)
if deferred_batch_norm:
module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks)
if input_device is not None: if deferred_batch_norm:
raise ValueError("'input_device' argument only applies to 'PipelineStyle.MultiProcess'") module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks)
if devices is None: if devices is None:
devices = range(torch.cuda.device_count()) devices = range(torch.cuda.device_count())
devices = [torch.device(d) for d in devices]
devices = cast(List[torch.device], devices)
devices = [torch.device(d) for d in devices] try:
devices = cast(List[torch.device], devices) self.partitions, self.balance, self.devices = split_module(module, balance, devices)
except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc)))
try: verify_splitting(module, self.partitions, self.balance, self.devices)
self.partitions, self.balance, self.devices = split_module(module, balance, devices)
except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc)))
verify_splitting(module, self.partitions, self.balance, self.devices)
self._skip_layout = inspect_skip_layout(self.partitions)
# Separate CUDA streams for copy.
copy_streams = self._ensure_copy_streams()
if self.pipelined_backward is None:
self.pipelined_backward = False
self.pipeline = Pipeline(
self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop, style=style,
)
elif style in [PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule]: self._copy_streams: List[List[AbstractStream]] = []
self._skip_layout = inspect_skip_layout(self.partitions)
if self.group is None:
self.group = get_pipeline_parallel_group()
assert self.group
if devices is not None: # Separate CUDA streams for copy.
raise ValueError("'devices' argument only applies to 'PipelineStyle.SingleProcess'") copy_streams = self._ensure_copy_streams()
self.balance = list(balance) # The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
if self.group.size() < len(self.balance): self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop)
raise IndexError(
f"too few ranks to hold given partitions (ranks: {self.group.size()}, partitions:"
f" {len(self.balance)})"
)
try:
rank = self.group.rank()
if rank >= len(self.balance):
warnings.warn("More ranks than partitions, some ranks unused")
self.mp_partitions: List[ModuleWrapper] = []
else:
self.mp_partitions = instantiate_partition(module, balance, self.group, style)
if deferred_batch_norm:
for part in self.mp_partitions:
part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks)
for name, part in enumerate(self.mp_partitions):
self.add_module(str(name), part.module)
self.devices = None
if isinstance(module, nn.Sequential):
local_partitions, _, _ = split_module(module, balance, None)
self._skip_layout = inspect_skip_layout(local_partitions)
else:
self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom)
except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc)))
rank = self.group.rank()
if rank >= len(self.balance):
self.pipeline = None
self.final_stage = False
else:
self.final_stage = rank == len(self.balance) - 1
assert loss_fn is None or self.final_stage
self.pipeline = Pipeline(
cast(List[nn.Sequential], self.mp_partitions),
None,
None,
self._skip_layout,
checkpoint_stop,
style=style,
group=self.group,
worker_map=self.worker_map,
input_device=self.input_device,
final_stage=self.final_stage,
)
del module
if self.pipelined_backward is None:
if get_model_parallel_world_size() > 1:
self.pipelined_backward = True
else:
self.pipelined_backward = False
def __len__(self) -> int: def __len__(self) -> int:
"""Counts the length of the underlying sequential module.""" """Counts the length of the underlying sequential module."""
if hasattr(self, "partitions"): return sum(len(p) for p in self.partitions)
return sum(len(p) for p in self.partitions)
else:
return sum(len(p) for p in self.mp_partitions)
def __getitem__(self, index: int) -> nn.Module: def __getitem__(self, index: int) -> nn.Module:
"""Gets a layer in the underlying sequential module.""" """Gets a layer in the underlying sequential module."""
partitions: List[Any] partitions = self.partitions
if hasattr(self, "partitions"):
partitions = self.partitions
else:
partitions = self.mp_partitions
if index < 0: if index < 0:
partitions = partitions[::-1] partitions = partitions[::-1]
for partition in partitions: for partition in partitions:
try: try:
if isinstance(partition, ModuleWrapper): return partition[index]
return partition.module[index]
else:
return partition[index]
except IndexError: except IndexError:
pass pass
...@@ -630,47 +327,35 @@ class Pipe(Module): ...@@ -630,47 +327,35 @@ class Pipe(Module):
def __iter__(self) -> Iterable[nn.Module]: def __iter__(self) -> Iterable[nn.Module]:
"""Iterates over children of the underlying sequential module.""" """Iterates over children of the underlying sequential module."""
if hasattr(self, "partitions"): for partition in self.partitions:
for partition in self.partitions: yield from partition
yield from partition
else:
for mp_partition in self.mp_partitions:
yield from mp_partition.module
# Pipe should manage the device of each partition. # Pipe should manage the device of each partition.
# Deny cuda(), cpu(), and to() with device, by TypeError. # Deny cuda(), cpu(), and to() with device, by TypeError.
def cuda(self, device: Optional[Device] = None) -> "Pipe": def cuda(self, device: Optional[Device] = None) -> "Pipe":
if self.devices: raise MOVING_DENIED
raise MOVING_DENIED
if device:
return super().cuda(device=device)
else:
return super().cuda()
def cpu(self) -> "Pipe": def cpu(self) -> "Pipe":
if self.devices: raise MOVING_DENIED
raise MOVING_DENIED
return super().cpu()
def to(self, *args: Any, **kwargs: Any) -> "Pipe": def to(self, *args: Any, **kwargs: Any) -> "Pipe":
"""Restrict .to() options. # Deny these usages:
#
Deny these usages: # - to(device[, dtype, non_blocking])
- to(device[, dtype, non_blocking]) # - to(tensor[, non_blocking])
- to(tensor[, non_blocking]) #
# But allow this:
#
# - to(dtype[, non_blocking])
#
if "device" in kwargs or "tensor" in kwargs:
raise MOVING_DENIED
But allow this: if args:
- to(dtype[, non_blocking]) if isinstance(args[0], (torch.device, int, str)):
""" raise MOVING_DENIED
if self.devices: if torch.is_tensor(args[0]):
if "device" in kwargs or "tensor" in kwargs:
raise MOVING_DENIED raise MOVING_DENIED
if args:
if isinstance(args[0], (torch.device, int, str)):
raise MOVING_DENIED
if torch.is_tensor(args[0]):
raise MOVING_DENIED
return super().to(*args, **kwargs) return super().to(*args, **kwargs)
...@@ -683,13 +368,12 @@ class Pipe(Module): ...@@ -683,13 +368,12 @@ class Pipe(Module):
""" """
if not self._copy_streams: if not self._copy_streams:
assert self.devices is not None
for device in self.devices: for device in self.devices:
self._copy_streams.append([new_stream(device) for _ in range(self.chunks)]) self._copy_streams.append([new_stream(device) for _ in range(self.chunks)])
return self._copy_streams return self._copy_streams
def forward(self, input: TensorOrTensors, *, event=None) -> TensorOrTensors: # type: ignore def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore
""":class:`Pipe` is a fairly transparent module wrapper. It doesn't """:class:`Pipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But modify the input and output signature of the underlying module. But
there's type restriction. Input and output have to be a there's type restriction. Input and output have to be a
...@@ -708,82 +392,16 @@ class Pipe(Module): ...@@ -708,82 +392,16 @@ class Pipe(Module):
""" """
microbatch.check(input) microbatch.check(input)
if not self.group and not self.devices: if not self.devices:
# Empty sequential module is not illegal. # Empty sequential module is not illegal.
return input return input
if not self.pipeline:
# No pipeline is not illegal, more ranks than partitions
return input
# Divide a mini-batch into micro-batches. # Divide a mini-batch into micro-batches.
batches = microbatch.scatter(input, self.chunks) batches = microbatch.scatter(input, self.chunks)
# Run pipeline parallelism. # Run pipeline parallelism.
with self.lock: self.pipeline.run(batches)
self.pipeline.run(self.training, batches, event)
if self.group and not self.final_stage:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
return batches # type: ignore
else:
# Merge the micro-batches into one mini-batch.
if self.pipelined_backward:
with torch.no_grad():
output = microbatch.gather(batches)
from .phony import get_phony
phony = get_phony(
torch.device(torch.cuda.current_device() if torch.cuda.is_available() else "cpu"),
requires_grad=True,
)
output = PipelinedBackwardPass.apply(output, batches, phony, True) # self.retain_graph)
else:
output = microbatch.gather(batches)
return output
def back_helper(self, output: List[microbatch.Batch]) -> None:
if self.final_stage:
raise ValueError("back_helper should only be called on non-final stages")
if self.pipeline:
self.pipeline.back_helper(list(reversed(output)))
class PipelinedBackwardPass(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx, input: TensorOrTensors, batches, phony, retain_graph) -> TensorOrTensors:
ctx.batches = batches
ctx.retain_graph = retain_graph
return input
@staticmethod
# type: ignore
def backward(ctx, *grads) -> Tuple:
with torch.no_grad():
grad_batches = microbatch.scatter(grads, len(ctx.batches))
for grad, batch in reversed(list(zip(grad_batches, ctx.batches))):
for t in batch:
t.retain_grad()
torch.autograd.backward(batch.tensor_or_tensors, grad_tensors=(*grad,), retain_graph=ctx.retain_graph)
with torch.no_grad():
if ctx.batches[0].atomic:
tensors = tuple(b.tensor.grad for b in ctx.batches)
output: TensorOrTensors = torch.cat(tensors)
else:
rotated = [[t.grad for t in b.tensors] for b in ctx.batches]
output_buf = []
for tensors in zip(*rotated):
output_buf.append(torch.cat(tensors))
output = tuple(output_buf)
del ctx.batches
return (output, None, None, None) # Merge the micro-batches into one mini-batch.
output = microbatch.gather(batches)
return output
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
# #
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
...@@ -17,101 +18,30 @@ ...@@ -17,101 +18,30 @@
# limitations under the License. # limitations under the License.
"""The pipeline parallelism of Pipe.""" """The pipeline parallelism of Pipe."""
import logging
import os
from queue import Empty as QueueEmpty
from queue import Queue from queue import Queue
from threading import Event
from types import TracebackType from types import TracebackType
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Type, Union, cast from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from torch.autograd.profiler import record_function from torch.autograd.profiler import record_function
from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from .async_schedule import AsyncEventLoop, ModuleWrapper
from .checkpoint import Checkpointing from .checkpoint import Checkpointing
from .copy import Copy, Wait from .copy import Copy, Wait
from .dependency import fork, join from .dependency import fork, join
from .messages import MakeTransport, Transport
from .microbatch import Batch from .microbatch import Batch
from .skip import Namespace
from .skip.layout import SkipLayout from .skip.layout import SkipLayout
from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from .stream import AbstractStream, current_stream, use_device from .stream import AbstractStream, current_stream, use_device
from .types import (
ACTIVATIONS_GRADS_QUEUE,
PORTAL_QUEUE,
SKIP_TENSOR_QUEUE,
PipelineStyle,
PipeMessage,
Schedule,
TensorOrTensors,
Tensors,
)
from .worker import Task, create_workers, join_workers from .worker import Task, create_workers, join_workers
__all__: List[str] = [] __all__: List[str] = []
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
class SendOperator(torch.autograd.Function):
"""Send activations to the next pipeline stage"""
@staticmethod
# type: ignore
def forward(ctx, src_rank, dst_rank, transport: Transport, input: List[Tensor], index: int) -> Tensors:
assert src_rank == torch.distributed.get_rank()
transport.send_message(
PipeMessage(src_rank, dst_rank, queue_name=ACTIVATIONS_GRADS_QUEUE, args=index, tensors=tuple(input)),
)
return ()
@staticmethod
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tensors:
return tuple(grad)
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
class RecvOperator(torch.autograd.Function): ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
"""Receive activations to the previous pipeline stage"""
@staticmethod
# type: ignore
def forward(ctx, dst_rank: int, tensor: Tensor, input_device, transport: Transport, index: int) -> Tensors:
assert dst_rank == torch.distributed.get_rank()
ctx.transport = transport
ctx.index = index
result = transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, index)
def maybe_requires_grad(t: Tensor) -> Tensor:
if t.dtype.is_floating_point:
return t.requires_grad_()
return t
return tuple(maybe_requires_grad(r) for r in result)
@staticmethod
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
ctx.transport.send_message(
PipeMessage(
this_rank,
ranks[ranks.index(this_rank) - 1],
queue_name=ACTIVATIONS_GRADS_QUEUE,
args=ctx.index,
tensors=tuple(grad),
),
)
return (None, None, None, None, None)
# Queue is generic only in stubs. # Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime # https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
...@@ -140,7 +70,7 @@ def wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) ...@@ -140,7 +70,7 @@ def wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream)
batch[:] = tuple([x if x.is_floating_point() else x.detach() for x in batch]) batch[:] = tuple([x if x.is_floating_point() else x.detach() for x in batch])
def clock_cycles(m: int, n: int) -> Iterable[Schedule]: def clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]:
"""Generates schedules for each clock cycle.""" """Generates schedules for each clock cycle."""
# m: number of micro-batches # m: number of micro-batches
# n: number of partitions # n: number of partitions
...@@ -159,159 +89,45 @@ def clock_cycles(m: int, n: int) -> Iterable[Schedule]: ...@@ -159,159 +89,45 @@ def clock_cycles(m: int, n: int) -> Iterable[Schedule]:
yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))] yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))]
def create_task(
style: PipelineStyle,
checkpoint_stop: int,
i: int,
j: int,
batch: Batch,
partition: nn.Sequential,
skip_trackers: List[SkipTrackerThroughPotals],
streams: List[AbstractStream],
) -> Task:
# Determine whether checkpointing or not.
if i < checkpoint_stop:
def function(
input: TensorOrTensors,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
ret = partition(input)
# We do a check here because the backtrace from the checkpoint backward code path
# is very hard to make sense. It would be much easier to check earlier at this point.
assert type(ret) is not list, "Only Tensor or Tuple of Tensor output is supported"
return ret
chk = Checkpointing(function, batch)
if style is PipelineStyle.SingleProcess:
task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
elif style in [PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule]:
task = Task(None, compute=chk.checkpoint, finalize=chk.recompute)
del function, chk # TODO(tom) maybe remove
else:
def compute(
batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> Batch:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition)
if style is PipelineStyle.SingleProcess:
task = Task(streams[j], compute=compute, finalize=None)
elif style in [PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule]:
task = Task(None, compute=compute, finalize=None)
del compute # TODO(tom) maybe remove
return task
class Pipeline: class Pipeline:
"""The pipeline parallelism for Pipe.""" """The pipeline parallelism for Pipe."""
def __init__( def __init__(
self, self,
partitions: List[nn.Sequential], partitions: List[nn.Sequential],
devices: Optional[List[torch.device]], devices: List[torch.device],
copy_streams: Optional[List[List[AbstractStream]]], copy_streams: List[List[AbstractStream]],
skip_layout: SkipLayout, skip_layout: SkipLayout,
checkpoint_stop: int, checkpoint_stop: int,
style: PipelineStyle,
group: Optional[torch.distributed.ProcessGroup] = None,
worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None,
final_stage: bool = False,
) -> None: ) -> None:
if style == PipelineStyle.SingleProcess: self.partitions = partitions
self.partitions = partitions
else:
self.mp_partitions: List[ModuleWrapper] = cast(List[ModuleWrapper], partitions)
self.devices = devices self.devices = devices
self.copy_streams = copy_streams self.copy_streams = copy_streams
self.skip_layout = skip_layout self.skip_layout = skip_layout
self.__checkpoint_stop = checkpoint_stop self.checkpoint_stop = checkpoint_stop
self.style = style (self.in_queues, self.out_queues) = create_workers(devices)
self.group = group
self.training: bool
if style in [PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule]:
self.transport = MakeTransport(
use_rpc=("OMPI_COMM_WORLD_RANK" not in os.environ) or ("FORCE_RPC" in os.environ),
worker_map=worker_map,
input_device=input_device,
)
self.input_device = input_device
self.all_at_once = False
self.callcount = 0
self.final_stage = final_stage
if self.style is PipelineStyle.SingleProcess:
assert self.devices is not None
(self.in_queues, self.out_queues) = create_workers(self.devices)
@property
def checkpoint_stop(self) -> int:
# Disable checkpointing if in eval mode.
if self.style == PipelineStyle.SingleProcess:
training = self.partitions[0].training
else:
training = self.mp_partitions[0].module.training
if not training:
return 0
return self.__checkpoint_stop
def __del__(self) -> None: def __del__(self) -> None:
if self.style is PipelineStyle.SingleProcess: join_workers(self.in_queues, self.out_queues)
join_workers(self.in_queues, self.out_queues)
def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> None:
def run(self, batches: List[Batch]) -> None:
"""Runs pipeline parallelism. """Runs pipeline parallelism.
It modifies the given batches in place. It modifies the given batches in place.
""" """
self.training = training partitions = self.partitions
devices = self.devices
skip_layout = self.skip_layout
m = len(batches) m = len(batches)
n = len(partitions)
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))] skip_trackers = [SkipTrackerThroughPotals(skip_layout, i) for i in range(m)]
if self.style is PipelineStyle.SingleProcess: for schedule in clock_cycles(m, n):
n = len(self.partitions) self.fence(batches, schedule, skip_trackers)
for schedule in clock_cycles(m, n):
self.fence(batches, schedule, skip_trackers)
self.compute(batches, schedule, skip_trackers)
elif self.style is PipelineStyle.MultiProcess:
assert self.group
schedule = [(i, self.group.rank()) for i in range(m)]
self.compute(batches, schedule, skip_trackers) self.compute(batches, schedule, skip_trackers)
elif self.style is PipelineStyle.AsyncSchedule:
assert self.group
rank = self.group.rank()
event_loop = AsyncEventLoop(
self.mp_partitions, self.group, self.transport, self.training, self.checkpoint_stop,
)
if rank == 0 and not self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event head")
event_loop.event_loop_head(batches, skip_trackers, event)
logging.debug(f"{torch.distributed.get_rank()}: exited event head")
elif self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event tail")
event_loop.event_loop_tail(batches, skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event tail")
else:
logging.debug(f"{torch.distributed.get_rank()}: entered event loop")
event_loop.event_loop(len(batches), skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event loop")
self.callcount += 1
def fence( def fence(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
...@@ -322,9 +138,6 @@ class Pipeline: ...@@ -322,9 +138,6 @@ class Pipeline:
copy_streams = self.copy_streams copy_streams = self.copy_streams
skip_layout = self.skip_layout skip_layout = self.skip_layout
assert copy_streams
assert skip_layout
for i, j in schedule: for i, j in schedule:
# Ensure that batches[i-1] is executed after batches[i] in # Ensure that batches[i-1] is executed after batches[i] in
# backpropagation by an explicit dependency. # backpropagation by an explicit dependency.
...@@ -341,91 +154,92 @@ class Pipeline: ...@@ -341,91 +154,92 @@ class Pipeline:
prev_stream = copy_streams[j - 1][i] prev_stream = copy_streams[j - 1][i]
copy(batches[i], prev_stream, next_stream) copy(batches[i], prev_stream, next_stream)
def get_batch_from_previous_stage( def compute(
self, i: int, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch] self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
) -> Batch:
phony = torch.empty(0, device=self.input_device, requires_grad=True)
result = RecvOperator.apply(torch.distributed.get_rank(), phony, self.input_device, self.transport, i)
if len(result) == 1:
batch = Batch(result[0], i)
else:
batch = Batch(result, i)
self.recv_skip_tensors(skip_trackers, batches)
return batch
def send_skip_tensors(
self, this_rank: int, ranks: List[int], batch: Batch, i: int, skip_trackers: List[SkipTrackerThroughPotals]
) -> None: ) -> None:
assert self.group """Runs tasks with synchronization to copy streams."""
for next_j, ns, name in self.skip_layout.copy_policy_by_src(self.group.rank()): partitions = self.partitions
life = skip_trackers[i].portals[(ns, name)].tensor_life devices = self.devices
loaded = skip_trackers[i].load(batch, ns, name) copy_streams = self.copy_streams
if loaded is not None: checkpoint_stop = self.checkpoint_stop
tensors = tuple([loaded])
else:
tensors = tuple()
self.transport.send_message(
PipeMessage(
this_rank, ranks[next_j], queue_name=SKIP_TENSOR_QUEUE, args=(i, ns, name, life), tensors=tensors,
),
sync=True,
)
def recv_skip_tensors(self, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]) -> None:
while True:
try:
message = self.transport.recv_message(SKIP_TENSOR_QUEUE, nowait=True)
(si, ns, name, life) = message.args
value: Optional[TensorOrTensors] = message.tensors
assert isinstance(value, tuple)
if len(value) == 0:
value = None
else:
assert len(value) == 1
value = value[0]
skip_trackers[si].save(batches[si], ns, name, value) # Disable checkpointing if in eval mode.
old_life = skip_trackers[si].portals[(ns, name)].tensor_life if not self.partitions[0].training:
if life != 0: checkpoint_stop = 0
skip_trackers[si].portals[(ns, name)].tensor_life = life
except QueueEmpty:
break
def execute_task(self, task: Task, i: int, skip_trackers: List[SkipTrackerThroughPotals]) -> Batch: n = len(partitions)
batch = task.compute() streams = [current_stream(d) for d in devices]
exc_info: Optional[ExcInfo] = None
assert self.group # With checkpointing, the autograd graph looks like this diagram:
rank = self.group.rank() # ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for i, j in schedule:
batch = batches[i]
partition = partitions[j]
if self.style is PipelineStyle.MultiProcess and not self.final_stage: # Synchronize with the copied input. ([1] in the diagram)
ranks = get_pipeline_parallel_ranks() if j != 0:
this_rank = torch.distributed.get_rank() wait(batch, copy_streams[j][i], streams[j])
# Determine whether checkpointing or not.
checkpoint = i < checkpoint_stop
if checkpoint:
def function(
input: TensorOrTensors,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return partition(input)
chk = Checkpointing(function, batch)
task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
del function, chk
self.send_skip_tensors(this_rank, ranks, batch, i, skip_trackers) else:
SendOperator.apply(this_rank, ranks[ranks.index(this_rank) + 1], self.transport, [*batch], i)
for portal in skip_trackers[i].portals.values(): def compute(
portal.pipeline = self batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> Batch:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition)
task.finalize(batch) task = Task(streams[j], compute=compute, finalize=None)
del compute
return batch # Compute tasks in parallel. ([2] in the diagram)
self.in_queues[j].put(task)
def finalize_tasks(
self,
n: int,
schedule: Schedule,
streams: List[AbstractStream],
copy_streams: List[List[AbstractStream]],
batches: List[Batch],
) -> None:
exc_info: Optional[ExcInfo] = None
for i, j in schedule: for i, j in schedule:
ok, payload = self.out_queues[j].get() ok, payload = self.out_queues[j].get()
...@@ -446,8 +260,7 @@ class Pipeline: ...@@ -446,8 +260,7 @@ class Pipeline:
# Finalize tasks. If checkpointing is enabled, here the # Finalize tasks. If checkpointing is enabled, here the
# recomputation is scheduled at backpropagation. ([4] in the # recomputation is scheduled at backpropagation. ([4] in the
# diagram) # diagram)
assert self.devices with use_device(devices[j]):
with use_device(self.devices[j]):
task.finalize(batch) task.finalize(batch)
batches[i] = batch batches[i] = batch
...@@ -455,147 +268,3 @@ class Pipeline: ...@@ -455,147 +268,3 @@ class Pipeline:
# Fail at the first exception. # Fail at the first exception.
if exc_info is not None: if exc_info is not None:
raise exc_info[0].with_traceback(exc_info[1], exc_info[2]) raise exc_info[0].with_traceback(exc_info[1], exc_info[2])
def compute(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals]
) -> None:
"""Runs tasks with synchronization to copy streams."""
devices = self.devices
copy_streams = self.copy_streams
if self.style is PipelineStyle.SingleProcess:
assert devices is not None
n = len(self.partitions)
streams = [current_stream(d) for d in devices]
elif self.style is PipelineStyle.MultiProcess:
assert self.group
n = self.group.size()
streams = []
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for i, j in schedule:
batch = batches[i]
if self.style is PipelineStyle.SingleProcess:
partition = self.partitions[j]
# Synchronize with the copied input. ([1] in the diagram)
assert copy_streams
if j != 0:
wait(batch, copy_streams[j][i], streams[j])
task = create_task(self.style, self.checkpoint_stop, i, j, batch, partition, skip_trackers, streams)
# Compute tasks in parallel. ([2] in the diagram)
self.in_queues[j].put(task)
elif self.style is PipelineStyle.MultiProcess:
assert len(self.mp_partitions) == 1
mp_partition = self.mp_partitions[0]
assert self.group
if self.group.rank() != 0:
batch = self.get_batch_from_previous_stage(i, skip_trackers, batches)
task = create_task(
self.style, self.checkpoint_stop, i, j, batch, mp_partition.module, skip_trackers, streams
)
batches[i] = self.execute_task(task, i, skip_trackers)
if self.style is PipelineStyle.SingleProcess:
assert copy_streams
self.finalize_tasks(n, schedule, streams, copy_streams, batches)
def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: TensorOrTensors) -> None:
dest, src = self.skip_layout.by_ns_name.get(ns_name, (-1, -1))
if dest == src:
return
ranks = get_pipeline_parallel_ranks()
dst_rank = ranks[dest]
if dst_rank == torch.distributed.get_rank():
return
if isinstance(grad, Tensor):
grad = tuple([grad])
self.transport.send_message(
PipeMessage(ranks[src], dst_rank, queue_name=PORTAL_QUEUE, args=(ns_name, index), tensors=grad), sync=True,
)
def recv_portal_grad(self, expected_ns_name: Tuple[Namespace, str], expected_index: int) -> Tensor:
message = self.transport.recv_message(PORTAL_QUEUE)
(ns_name, index) = message.args
grad = message.tensors
assert len(grad) == 1
result = grad[0]
assert index == expected_index and ns_name == expected_ns_name
return result
def back_helper(self, output: List[Batch]) -> None:
if self.style == PipelineStyle.AsyncSchedule:
return
o = list(output)
tensors: Tensors
if self.all_at_once:
# FIXME(tom) allow specifying this branch when constructing Pipe(), add a test
grads = []
for i, batch in enumerate(o):
rank = torch.distributed.get_rank()
found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, i)
assert len(found) == 1
grads.append(found[0])
tensors = tuple(x.tensor_or_tensors for x in o) # type: ignore
try:
torch.autograd.backward(tensors, grad_tensors=grads, retain_graph=True)
except Exception as e:
raise RuntimeError("Autograd failed") from e
else:
rank = torch.distributed.get_rank()
for batch in o:
found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, batch.index)
if batch.atomic:
tensors = tuple([batch.tensor])
else:
tensors = batch.tensors
if len(found) != len(tensors):
raise RuntimeError("different number of tensors and gradients")
grads = []
final_tensors = []
for i, tensor in enumerate(tensors):
if tensor.requires_grad or getattr(tensor, "grad_fn", None) is not None:
grads.append(found[i])
final_tensors.append(tensor)
try:
torch.autograd.backward(final_tensors, grad_tensors=grads, retain_graph=True)
except Exception as e:
raise RuntimeError(f"Autograd failed on {torch.distributed.get_rank()}") from e
...@@ -13,13 +13,13 @@ from torch.distributed.distributed_c10d import _get_global_rank ...@@ -13,13 +13,13 @@ from torch.distributed.distributed_c10d import _get_global_rank
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from . import Pipe from .multiprocess_pipe import MultiProcessPipe
from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors
DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024
PipeModel: Pipe PipeModel: MultiProcessPipe
PipeResult: TensorOrTensors PipeResult: TensorOrTensors
...@@ -71,7 +71,7 @@ class PipeBackRedirect(torch.autograd.Function): ...@@ -71,7 +71,7 @@ class PipeBackRedirect(torch.autograd.Function):
return (None, None, None, None, None, None) return (None, None, None, None, None, None)
def callback_with_model(callback: Callable[[Any, Pipe], None], ctx: Any) -> None: def callback_with_model(callback: Callable[[Any, MultiProcessPipe], None], ctx: Any) -> None:
try: try:
group = get_pipeline_parallel_group() # FIXME(tom) handle dynamic group group = get_pipeline_parallel_group() # FIXME(tom) handle dynamic group
set_device_based_on_group(group) set_device_based_on_group(group)
...@@ -105,10 +105,10 @@ class PipeRPCWrapper(nn.Module): ...@@ -105,10 +105,10 @@ class PipeRPCWrapper(nn.Module):
else: else:
kwargs["group"] = self.group kwargs["group"] = self.group
kwargs["style"] = Pipe.AsyncSchedule kwargs["style"] = MultiProcessPipe.AsyncSchedule
kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device()) kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device())
self.model = Pipe(*args, **kwargs) self.model = MultiProcessPipe(*args, **kwargs)
self.worker_map = kwargs["worker_map"] self.worker_map = kwargs["worker_map"]
self._foreach_worker(self._register_remote_model, args=(args, kwargs)) self._foreach_worker(self._register_remote_model, args=(args, kwargs))
self.model.cuda() self.model.cuda()
...@@ -121,7 +121,7 @@ class PipeRPCWrapper(nn.Module): ...@@ -121,7 +121,7 @@ class PipeRPCWrapper(nn.Module):
futures = [f.wait() for f in futures] futures = [f.wait() for f in futures]
def foreach_worker( def foreach_worker(
self, callback: Callable[[Any, Pipe], None], ctx: Any = None, *, include_self: bool = False self, callback: Callable[[Any, MultiProcessPipe], None], ctx: Any = None, *, include_self: bool = False
) -> None: ) -> None:
"""Call `callback` on each worker with the `ctx` and model local to that """Call `callback` on each worker with the `ctx` and model local to that
worker. e.g. worker. e.g.
...@@ -196,7 +196,9 @@ class PipeRPCWrapper(nn.Module): ...@@ -196,7 +196,9 @@ class PipeRPCWrapper(nn.Module):
return self.model.final_stage return self.model.final_stage
@staticmethod @staticmethod
def _recv_result(model: Pipe, shapes: SizeOrSizes, dtypes: DtypeOrDtypes, message: PipeMessage) -> TensorOrTensors: def _recv_result(
model: MultiProcessPipe, shapes: SizeOrSizes, dtypes: DtypeOrDtypes, message: PipeMessage
) -> TensorOrTensors:
group = get_pipeline_parallel_group() group = get_pipeline_parallel_group()
set_device_based_on_group(group) set_device_based_on_group(group)
...@@ -243,7 +245,7 @@ class PipeRPCWrapper(nn.Module): ...@@ -243,7 +245,7 @@ class PipeRPCWrapper(nn.Module):
set_device_based_on_group(group) set_device_based_on_group(group)
kwargs["group"] = group kwargs["group"] = group
kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device()) kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device())
model = Pipe(*args, **kwargs) model = MultiProcessPipe(*args, **kwargs)
model.cuda() model.cuda()
global PipeModel global PipeModel
PipeModel = model PipeModel = model
......
...@@ -24,7 +24,6 @@ Tensors = Tuple[Tensor, ...] ...@@ -24,7 +24,6 @@ Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors] TensorOrTensors = Union[Tensor, Tensors]
InputDevice = Union[None, int, str, torch.device] InputDevice = Union[None, int, str, torch.device]
Schedule = List[Tuple[int, int]]
class LazyModule: class LazyModule:
......
...@@ -5,7 +5,7 @@ from typing import Any, Callable, Optional, Tuple ...@@ -5,7 +5,7 @@ from typing import Any, Callable, Optional, Tuple
from torch import Tensor from torch import Tensor
def spawn( def spawn(
fn: Callable[[Any], Any], fn: Callable[..., Any],
args: Tuple[Optional[Any], ...] = (), args: Tuple[Optional[Any], ...] = (),
nprocs: int = 1, nprocs: int = 1,
join: bool = True, join: bool = True,
......
...@@ -31,7 +31,7 @@ from torch.nn.parameter import Parameter ...@@ -31,7 +31,7 @@ from torch.nn.parameter import Parameter
from fairscale.nn.model_parallel import initialize as mpu from fairscale.nn.model_parallel import initialize as mpu
from fairscale.nn.model_parallel import layers from fairscale.nn.model_parallel import layers
from fairscale.nn.pipe import Pipe from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import dist_init, get_world_sizes, set_random_seed, spawn_for_all_world_sizes, torch_spawn from fairscale.utils.testing import dist_init, get_world_sizes, set_random_seed, spawn_for_all_world_sizes, torch_spawn
...@@ -319,7 +319,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False ...@@ -319,7 +319,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
model_parallel_size = mpu.get_model_parallel_world_size() model_parallel_size = mpu.get_model_parallel_world_size()
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print( print(
"> testing Sequential + Pipe with model parallel size: {}, pipe: {}".format( "> testing Sequential + MultiProcessPipe with model parallel size: {}, pipe: {}".format(
model_parallel_size, pipe_world_size model_parallel_size, pipe_world_size
) )
) )
...@@ -431,13 +431,13 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False ...@@ -431,13 +431,13 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
model[2].weight.data = saved_weight_2 model[2].weight.data = saved_weight_2
worker_map = {i: f"Test{i}" for i in range(torch.distributed.get_world_size())} worker_map = {i: f"Test{i}" for i in range(torch.distributed.get_world_size())}
style = Pipe.MultiProcess # Pipe.AsyncSchedule style = MultiProcessPipe.MultiProcess # MultiProcessPipe.AsyncSchedule
if pipe_world_size == 2: if pipe_world_size == 2:
print(f"actually doing pipe stuff now") print(f"actually doing pipe stuff now")
assert torch.equal(saved_weight_0, model[0].weight.data) assert torch.equal(saved_weight_0, model[0].weight.data)
assert torch.equal(saved_weight_2, model[2].weight.data) assert torch.equal(saved_weight_2, model[2].weight.data)
pipe_model = Pipe( pipe_model = MultiProcessPipe(
model, model,
[2, 1], [2, 1],
style=style, style=style,
...@@ -507,7 +507,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False ...@@ -507,7 +507,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
failed = False failed = False
with torch.autograd.profiler.profile() as prof: with torch.autograd.profiler.profile() as prof:
try: try:
if style == Pipe.MultiProcess: if style == MultiProcessPipe.MultiProcess:
pipe_model.back_helper(pipe_output) pipe_model.back_helper(pipe_output)
except Exception as e: except Exception as e:
failed = True failed = True
......
...@@ -23,7 +23,7 @@ import pytest ...@@ -23,7 +23,7 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from fairscale.nn.pipe import LazyModule, Pipe from fairscale.nn.pipe import LazyModule, MultiProcessPipe
from fairscale.nn.pipe.skip import pop, skippable, stash from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
...@@ -33,12 +33,12 @@ from fairscale.utils.testing import get_worker_map, torch_spawn ...@@ -33,12 +33,12 @@ from fairscale.utils.testing import get_worker_map, torch_spawn
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"]) @pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
def x1to3(balance, checkpoint, pipeline_style): def x1to3(balance, checkpoint, pipeline_style):
torch.manual_seed(0) torch.manual_seed(0)
if pipeline_style == Pipe.AsyncSchedule and len(balance) > 1: if pipeline_style == MultiProcessPipe.AsyncSchedule and len(balance) > 1:
print(f"skipping yarg") print(f"skipping yarg")
pytest.skip("Skip tensors NYI for AsyncSchedule") pytest.skip("Skip tensors NYI for AsyncSchedule")
...@@ -74,7 +74,7 @@ def x1to3(balance, checkpoint, pipeline_style): ...@@ -74,7 +74,7 @@ def x1to3(balance, checkpoint, pipeline_style):
return output return output
model = nn.Sequential(Layer1(), Layer2(), Layer3()) model = nn.Sequential(Layer1(), Layer2(), Layer3())
model = Pipe( model = MultiProcessPipe(
model, model,
balance, balance,
chunks=3, chunks=3,
...@@ -106,9 +106,10 @@ def x1to3(balance, checkpoint, pipeline_style): ...@@ -106,9 +106,10 @@ def x1to3(balance, checkpoint, pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
@pytest.mark.skip(reason="flaky test")
def none_skip(pipeline_style): def none_skip(pipeline_style):
if pipeline_style == Pipe.AsyncSchedule: if pipeline_style == MultiProcessPipe.AsyncSchedule:
pytest.skip("Skip tensors NYI for AsyncSchedule") pytest.skip("Skip tensors NYI for AsyncSchedule")
@skippable(stash=["none"]) @skippable(stash=["none"])
...@@ -125,7 +126,7 @@ def none_skip(pipeline_style): ...@@ -125,7 +126,7 @@ def none_skip(pipeline_style):
return input return input
model = nn.Sequential(Stash(), Pop()) model = nn.Sequential(Stash(), Pop())
model = Pipe( model = MultiProcessPipe(
model, model,
[1, 1], [1, 1],
style=pipeline_style, style=pipeline_style,
...@@ -160,7 +161,7 @@ def none_skip(pipeline_style): ...@@ -160,7 +161,7 @@ def none_skip(pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def lazy_skippable_error(pipeline_style): def lazy_skippable_error(pipeline_style):
"""Using skippable layers in combination with lazy construction is currently """Using skippable layers in combination with lazy construction is currently
not supported, check that it raises an Exception""" not supported, check that it raises an Exception"""
...@@ -180,6 +181,6 @@ def lazy_skippable_error(pipeline_style): ...@@ -180,6 +181,6 @@ def lazy_skippable_error(pipeline_style):
] ]
with pytest.raises(ValueError, match="Can't use Skippable layers with multi-process pipe and lazy construction"): with pytest.raises(ValueError, match="Can't use Skippable layers with multi-process pipe and lazy construction"):
Pipe( MultiProcessPipe(
model, [2, 1], style=pipeline_style, worker_map=get_worker_map(), model, [2, 1], style=pipeline_style, worker_map=get_worker_map(),
) )
...@@ -23,7 +23,7 @@ import pytest ...@@ -23,7 +23,7 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from fairscale.nn.pipe import Pipe, is_checkpointing, is_recomputing from fairscale.nn.pipe import MultiProcessPipe, is_checkpointing, is_recomputing
from fairscale.nn.pipe.skip import pop, skippable, stash from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.tracker import current_skip_tracker from fairscale.nn.pipe.skip.tracker import current_skip_tracker
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
...@@ -46,7 +46,7 @@ class Pop(nn.Module): ...@@ -46,7 +46,7 @@ class Pop(nn.Module):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) @pytest.mark.parametrize("train", [True, False], ids=["train", "eval"])
@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"]) @pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def delete_portal_tensor(train, checkpoint, pipeline_style): def delete_portal_tensor(train, checkpoint, pipeline_style):
...@@ -60,7 +60,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style): ...@@ -60,7 +60,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 | # | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
# +----------+ +------------+ +------------+ +----------+ # +----------+ +------------+ +------------+ +----------+
if pipeline_style == Pipe.AsyncSchedule: if pipeline_style == MultiProcessPipe.AsyncSchedule:
pytest.skip("Skip tensors NYI for AsyncSchedule") pytest.skip("Skip tensors NYI for AsyncSchedule")
def portal_tensor_life_is(tensor_life, skip_tracker=None): def portal_tensor_life_is(tensor_life, skip_tracker=None):
...@@ -114,7 +114,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style): ...@@ -114,7 +114,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
return self.F.apply(input) return self.F.apply(input)
model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_) model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_)
model = Pipe( model = MultiProcessPipe(
model, balance=[2, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint, model, balance=[2, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint,
) )
......
...@@ -22,15 +22,15 @@ import torch ...@@ -22,15 +22,15 @@ import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from fairscale.nn.pipe import Pipe from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def python_autograd_function(pipeline_style): def python_autograd_function(pipeline_style):
# FIXME deadlock with Pipe.AsyncSchedule? # FIXME deadlock with MultiProcessPipe.AsyncSchedule?
# A Python autograd function might fail with this error: # A Python autograd function might fail with this error:
# #
# RuntimeError: Returning Variables sharing storage with other Variables # RuntimeError: Returning Variables sharing storage with other Variables
...@@ -57,7 +57,9 @@ def python_autograd_function(pipeline_style): ...@@ -57,7 +57,9 @@ def python_autograd_function(pipeline_style):
return Identity.apply(input) return Identity.apply(input)
model = nn.Sequential(M(), M()) model = nn.Sequential(M(), M())
model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always").cuda() model = MultiProcessPipe(
model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always"
).cuda()
model.eval() model.eval()
x = torch.rand(42) x = torch.rand(42)
...@@ -71,7 +73,7 @@ def python_autograd_function(pipeline_style): ...@@ -71,7 +73,7 @@ def python_autograd_function(pipeline_style):
@torch_spawn([3]) @torch_spawn([3])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def exception_no_hang(pipeline_style): def exception_no_hang(pipeline_style):
# In v0.0.2, once a failed partition receives a normal message # In v0.0.2, once a failed partition receives a normal message
# (non-closing) for the next micro-batch, a hang occured. The reason was # (non-closing) for the next micro-batch, a hang occured. The reason was
...@@ -90,7 +92,7 @@ def exception_no_hang(pipeline_style): ...@@ -90,7 +92,7 @@ def exception_no_hang(pipeline_style):
raise ExpectedException() raise ExpectedException()
model = nn.Sequential(Pass(), Pass(), Raise()) model = nn.Sequential(Pass(), Pass(), Raise())
model = Pipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3) model = MultiProcessPipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3)
model.eval() model.eval()
if model.group.rank() == 2: if model.group.rank() == 2:
...@@ -104,7 +106,7 @@ def exception_no_hang(pipeline_style): ...@@ -104,7 +106,7 @@ def exception_no_hang(pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def tuple_wait(cuda_sleep, pipeline_style): def tuple_wait(cuda_sleep, pipeline_style):
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch. # In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# Under this behavior, if checkpointing was disabled, there's a possibility # Under this behavior, if checkpointing was disabled, there's a possibility
...@@ -133,7 +135,7 @@ def tuple_wait(cuda_sleep, pipeline_style): ...@@ -133,7 +135,7 @@ def tuple_wait(cuda_sleep, pipeline_style):
return a + b + c return a + b + c
model = nn.Sequential(Layer1(), Layer2()) model = nn.Sequential(Layer1(), Layer2())
model = Pipe( model = MultiProcessPipe(
model, model,
[1, 1], [1, 1],
style=pipeline_style, style=pipeline_style,
...@@ -158,7 +160,7 @@ def tuple_wait(cuda_sleep, pipeline_style): ...@@ -158,7 +160,7 @@ def tuple_wait(cuda_sleep, pipeline_style):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def parallel_randoms(pipeline_style): def parallel_randoms(pipeline_style):
class Dropouts(nn.Module): class Dropouts(nn.Module):
def forward(self, x): def forward(self, x):
...@@ -170,7 +172,7 @@ def parallel_randoms(pipeline_style): ...@@ -170,7 +172,7 @@ def parallel_randoms(pipeline_style):
x = torch.rand(10, 10, requires_grad=True).cuda() x = torch.rand(10, 10, requires_grad=True).cuda()
x.retain_grad() x.retain_grad()
model = Pipe( model = MultiProcessPipe(
model, model,
[1, 1], [1, 1],
style=pipeline_style, style=pipeline_style,
......
...@@ -21,20 +21,20 @@ import pytest ...@@ -21,20 +21,20 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from fairscale.nn.pipe import Pipe from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def inplace_on_requires_grad(pipeline_style): def inplace_on_requires_grad(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True))
model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always") model = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
x = torch.rand(1) x = torch.rand(1)
if pipeline_style == Pipe.AsyncSchedule and model.group.rank() == 0: if pipeline_style == MultiProcessPipe.AsyncSchedule and model.group.rank() == 0:
# With AsyncSchedule, model will wait forever for gradients if not eval # With AsyncSchedule, model will wait forever for gradients if not eval
model.eval() model.eval()
...@@ -50,12 +50,12 @@ def inplace_on_requires_grad(pipeline_style): ...@@ -50,12 +50,12 @@ def inplace_on_requires_grad(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.xfail(strict=True) @pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def inplace_on_not_requires_grad(pipeline_style): def inplace_on_not_requires_grad(pipeline_style):
# In-place operation on a tensor not requiring grad doesn't cause a # In-place operation on a tensor not requiring grad doesn't cause a
# RuntimeError. Currently, we cannot detect this case. # RuntimeError. Currently, we cannot detect this case.
model = nn.Sequential(nn.ReLU(inplace=True)) model = nn.Sequential(nn.ReLU(inplace=True))
model = Pipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always") model = MultiProcessPipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
x = torch.rand(1) x = torch.rand(1)
y = model(x) y = model(x)
...@@ -70,7 +70,7 @@ def inplace_on_not_requires_grad(pipeline_style): ...@@ -70,7 +70,7 @@ def inplace_on_not_requires_grad(pipeline_style):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.xfail(strict=True) @pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def inplace_incorrect_grad(pipeline_style): def inplace_incorrect_grad(pipeline_style):
class M(nn.Module): class M(nn.Module):
def forward(self, foo_bar): def forward(self, foo_bar):
...@@ -88,7 +88,7 @@ def inplace_incorrect_grad(pipeline_style): ...@@ -88,7 +88,7 @@ def inplace_incorrect_grad(pipeline_style):
return foo * bar return foo * bar
model = nn.Sequential(M()) model = nn.Sequential(M())
model = Pipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always") model = MultiProcessPipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
foo = torch.tensor([1.0], requires_grad=True) foo = torch.tensor([1.0], requires_grad=True)
bar = torch.tensor([1.0]) bar = torch.tensor([1.0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment