Unverified Commit cae9b638 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] pipe: separate out Single and MultiProcess pipe (#326)

parent eab1551a
......@@ -19,10 +19,9 @@ import torchtext
from torchtext.data.utils import get_tokenizer
from experimental.nn.ampnet_pipe import pipe
from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule
from fairscale.nn.pipe import LazyModule, MultiProcessPipe
from fairscale.optim import GradScaler
from fairscale.utils.testing import dist_init, get_worker_map
......@@ -421,7 +420,7 @@ def run_mp_worker(args, available_workers):
p = pipe.AMPnetPipe(
module=model,
balance=balance,
style=Pipe.AsyncSchedule,
style=MultiProcessPipe.AsyncSchedule,
chunks=args.chunks,
worker_map=get_worker_map(),
input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
......
......@@ -25,7 +25,7 @@ from torch.optim import Adam
from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule, pipe
from fairscale.nn.pipe import LazyModule, MultiProcessPipe
from fairscale.optim.oss import OSS
from fairscale.utils.testing import dist_init, get_worker_map
......@@ -157,7 +157,7 @@ def dump_cuda_tensors():
def log_number_of_parameters(model):
num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
if model.group:
if hasattr(model, "group"):
total = torch.Tensor([num_params])
if torch.cuda.is_available():
total = total.cuda()
......@@ -212,7 +212,7 @@ def train(model_config, model, benchmark_config, args):
optimizer = optimizer(model.parameters())
pipe_group = model.group
pipe_group = model.group if hasattr(model, "group") else None
if args.ddp_zero:
model = DDP(
......@@ -479,9 +479,7 @@ def benchmark_single_process(args):
model = model_config["model"]
balance = generate_balance(min(num_devices, 4), len(model))
pipe_model = pipe.Pipe(
model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint
)
pipe_model = Pipe(model, balance, chunks=args.chunks, checkpoint=args.checkpoint)
del model
del model_config["model"]
......@@ -498,10 +496,10 @@ def run_mp_worker(args, available_workers):
model = model_config["model"]
balance = generate_balance_weighted(get_pipeline_parallel_group().size(), len(model), 0.8)
pipe_model = pipe.Pipe(
pipe_model = MultiProcessPipe(
model,
balance,
style=Pipe.AsyncSchedule,
style=MultiProcessPipe.AsyncSchedule,
chunks=args.chunks,
worker_map=get_worker_map(),
input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
......
......@@ -6,8 +6,8 @@ import torch.distributed as dist
import torch.multiprocessing as mp
import torch.optim as optim
import fairscale
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.pipe import MultiProcessPipe
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RANK = 0 # example
......@@ -27,10 +27,10 @@ def run(rank, world_size):
device = torch.device("cuda", RANK) if DEVICE == "cuda" else torch.device("cpu")
model = fairscale.nn.Pipe(
model = MultiProcessPipe(
model,
balance=[2, 1],
style=fairscale.nn.Pipe.MultiProcess,
style=MultiProcessPipe.MultiProcess,
worker_map={0: "worker0", 1: "worker1"}, # Needed to convert ranks to RPC worker names
input_device=device,
).to(device)
......
......@@ -11,7 +11,7 @@ from torch import nn
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from fairscale.nn.pipe import Pipe
from fairscale.nn.pipe import MultiProcessPipe
from fairscale.nn.pipe.types import PipelineStyle
from .ampnet import AsyncAMPnetEventLoop
......@@ -19,9 +19,9 @@ from .ampnet import AsyncAMPnetEventLoop
__all__ = ["AMPnetPipe"]
class AMPnetPipe(Pipe):
class AMPnetPipe(MultiProcessPipe):
"""
AMPnetPipe is the asynchronous version of the Pipe implementation
AMPnetPipe is the asynchronous version of the MultiProcessPipe implementation
which avoids the bubble issue, by using stale weights and gradients.
The implementation closely follows the paper: https://arxiv.org/abs/1705.09786
"""
......@@ -39,7 +39,7 @@ class AMPnetPipe(Pipe):
weight_prediction: bool = False,
) -> None:
partitions = self.mp_partitions
partitions = self.partitions
n = len(partitions)
# AMPnet implementation doesn't handle skip_trackers!
......
......@@ -23,7 +23,7 @@ from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset
from experimental.nn.ampnet_pipe.pipe import AMPnetPipe
from fairscale.nn.pipe import Pipe
from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn
......@@ -87,7 +87,7 @@ def async_event_loop_interleave_simple():
pipe = AMPnetPipe(
module=model,
balance=[2, 2],
style=Pipe.AsyncSchedule,
style=MultiProcessPipe.AsyncSchedule,
worker_map=get_worker_map(),
chunks=10,
checkpoint="never",
......@@ -105,7 +105,7 @@ def async_event_loop_interleave_hard():
pipe = AMPnetPipe(
module=model,
balance=[1, 1, 1, 1],
style=Pipe.AsyncSchedule,
style=MultiProcessPipe.AsyncSchedule,
worker_map=get_worker_map(),
chunks=10,
checkpoint="never",
......
......@@ -6,7 +6,7 @@
from .data_parallel import ShardedDataParallel
from .misc import FlattenParamsWrapper
from .moe import MOELayer, Top2Gate
from .pipe import LazyModule, Pipe, PipeRPCWrapper
from .pipe import Pipe, PipeRPCWrapper
__all__ = [
"FlattenParamsWrapper",
......
......@@ -19,7 +19,8 @@
"""A Pipe implementation in PyTorch."""
from .checkpoint import is_checkpointing, is_recomputing
from .pipe import LazyModule, Pipe
from .multiprocess_pipe import LazyModule, MultiProcessPipe
from .pipe import Pipe
from .rpc import PipeRPCWrapper
__all__ = ["Pipe", "is_checkpointing", "is_recomputing", "LazyModule"]
......@@ -191,7 +191,7 @@ class AsyncEventLoop:
"""Actually run the forward pass for a given module, and send the result
to the next stage in the pipeline if needed."""
assert self.group
from .pipeline import create_task
from .multiprocess_pipeline import create_task
task = create_task(
PipelineStyle.AsyncSchedule,
......@@ -201,7 +201,6 @@ class AsyncEventLoop:
batch,
partition.module,
skip_trackers,
[],
)
result = task.compute()
task.finalize(result)
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The MultiProcessPipe interface."""
from collections import OrderedDict
from dataclasses import dataclass, field
import itertools
import threading
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union, cast
import warnings
import torch
from torch import Tensor, nn
import torch.autograd
import torch.cuda
from fairscale.nn.model_parallel import get_model_parallel_world_size, get_pipeline_parallel_group
from . import microbatch
from .async_schedule import Invocation, Location, ModuleWrapper
from .batchnorm import DeferredBatchNorm
from .multiprocess_pipeline import MultiProcessPipeline
from .skip.layout import SkipLayout, inspect_skip_layout
from .skip.skippable import Skippable, verify_skippables
from .types import LazyModule, PipelineStyle
__all__ = ["MultiProcessPipe", "LazyModule"]
Device = Union[torch.device, int, str]
Devices = Union[Iterable[Device], List[Device]]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
ListOfLazyModules = List[LazyModule]
if TYPE_CHECKING:
Module = nn.Module[TensorOrTensors]
NamedModules = OrderedDict[str, Module]
else:
Module = nn.Module
NamedModules = OrderedDict
def recommend_auto_balance(message: str) -> str:
"""Expands a message with recommendation to :mod:`torchpipe.balance`."""
return f"""{message}
If your model is still under development, its optimal balance would change
frequently. In this case, we highly recommend 'fairscale.nn.pipe.balance' for
naive automatic balancing:
from fairscale.nn import Pipe
from fairscale.nn.pipe.balance import balance_by_time
partitions = torch.cuda.device_count()
sample = torch.empty(...)
balance = balance_by_time(partitions, model, sample)
model = MultiProcessPipe(model, balance, ...)
"""
# FIXME(tom) make this a valid way to call
def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None:
for layer in module:
if isinstance(layer, nn.Module):
pass
elif isinstance(layer, LazyModule):
pass
else:
raise TypeError(f"layer {type(layer)} must be nn.Module or LazyModule to be partitioned")
def verify_module(module: Union[nn.Sequential, ListOfLazyModules]) -> None:
if isinstance(module, Iterable) and not isinstance(module, nn.Sequential):
verify_list_of_callable(module)
else:
if not isinstance(module, nn.Sequential):
raise TypeError("module must be nn.Sequential to be partitioned")
named_children = list(module.named_children())
if len(named_children) != len(module):
raise ValueError("module with duplicate children is not supported")
def verify_splitting(module: nn.Sequential, partitions: List[nn.Sequential], balance: Iterable[int],) -> None:
num_parameters = len(list(module.parameters()))
num_child_parameters = sum(len(list(child.parameters())) for child in module.children())
if num_parameters == num_child_parameters:
return
for i in range(len(partitions)):
for j in range(i + 1, len(partitions)):
parti = partitions[i]
partj = partitions[j]
for p in parti.parameters():
for q in partj.parameters():
if p is q:
raise ValueError("module with duplicate parameters on distinct devices is not supported")
class BalanceError(ValueError):
pass
def check_balance(module: Any, balance: Iterable[int], filter_unique: bool = False) -> None:
if filter_unique:
module_len = len(set(map(id, module)))
else:
module_len = len(module)
if module_len != sum(balance):
raise BalanceError(
f"module and sum of balance have different length (module: {len(module)}, sum of balance: {sum(balance)})"
)
if any(x <= 0 for x in balance):
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
@dataclass
class PartitionInfo:
location: Location
modules: "OrderedDict[str, nn.Module]"
invocations: List[Invocation] = field(default_factory=list)
def __len__(self) -> int:
return len(self.modules)
def instantiate_partition(
module: Union[nn.Sequential, ListOfLazyModules],
balance: Iterable[int],
group: torch.distributed.ProcessGroup,
style: PipelineStyle,
) -> List[ModuleWrapper]:
balance = list(balance)
check_balance(module, balance, True)
layers: NamedModules = OrderedDict()
def maybe_realize(layer: Any) -> nn.Module:
if isinstance(layer, nn.Module):
return layer
elif callable(layer):
return layer()
else:
raise TypeError(f"layer must be nn.Module or callable, is {type(layer)}")
def iterate_module(module: Union[nn.Sequential, list]) -> Iterable[Tuple[Any, nn.Module]]:
if isinstance(module, nn.Sequential):
yield from module.named_children()
else:
yield from ((str(k), v) for k, v in enumerate(module))
if style == PipelineStyle.AsyncSchedule:
module_ids = list(map(id, module))
index_of_first_use = [module_ids.index(x) for x in module_ids]
locations: List[Location] = []
module_iter = enumerate(iterate_module(module))
partitions: List[List[PartitionInfo]] = []
for bi, b in enumerate(balance):
modules_for_rank: List[PartitionInfo] = []
current_module: OrderedDict[str, nn.Module] = OrderedDict()
def current_location() -> Location:
return Location(bi, len(modules_for_rank))
def append_module(mod: "OrderedDict[str, nn.Module]") -> None:
modules_for_rank.append(PartitionInfo(current_location(), mod))
while sum(map(len, modules_for_rank)) + len(current_module) < b:
module_index, (name, layer) = next(module_iter)
if index_of_first_use[module_index] != module_index:
# Subsequent reuse of a module
locations.append(locations[index_of_first_use[module_index]])
continue
is_reused = index_of_first_use.count(index_of_first_use[module_index]) > 1
if is_reused and len(current_module) > 0:
append_module(current_module)
current_module = OrderedDict()
current_module[str(name)] = layer
locations.append(current_location())
if is_reused:
append_module(current_module)
current_module = OrderedDict()
if len(current_module) > 0:
append_module(current_module)
partitions.append(modules_for_rank)
filtered_locations: List[Optional[Location]] = [loc for loc, _ in itertools.groupby(locations)]
filtered_locations.append(None)
for i in range(len(filtered_locations) - 1):
loc = filtered_locations[i]
assert loc
if i == 0:
inv = Invocation(i, loc, None, filtered_locations[i + 1])
else:
inv = Invocation(i, loc, filtered_locations[i - 1], filtered_locations[i + 1])
partitions[loc.stage][loc.index].invocations.append(inv)
invocations = enumerate(iterate_module(module))
partition = partitions[group.rank()]
result: List[ModuleWrapper] = []
for partition_info in partition:
wrapper = ModuleWrapper(
nn.Sequential(OrderedDict((k, maybe_realize(m)) for k, m in partition_info.modules.items())),
partition_info.location,
partition_info.invocations,
)
if not isinstance(module, nn.Sequential):
for layer in wrapper.module:
if isinstance(layer, Skippable):
raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction")
result.append(wrapper)
return result
j = 0
for name, layer in iterate_module(module):
layers[name] = layer
if len(layers) == balance[j]:
if j == group.rank():
for key in layers:
layers[key] = maybe_realize(layers[key])
if not isinstance(module, nn.Sequential):
for layer in layers.values():
if isinstance(layer, Skippable):
raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction")
return [ModuleWrapper(nn.Sequential(layers), Location(j, 0))]
# Prepare for the next partition.
layers.clear()
j += 1
raise ValueError("Souldn't get here, more ranks than partitions")
def split_module(module: nn.Sequential, balance: Iterable[int],) -> Tuple[List[nn.Sequential], List[int]]:
"""Splits a module into multiple partitions.
Returns:
A tuple of (partitions, balance).
Partitions are represented as a :class:`~torch.nn.ModuleList` whose
item is a partition. All layers in a partition are placed in the
same device.
Raises:
BalanceError:
wrong balance
IndexError:
the number of devices is fewer than the number of partitions.
"""
balance = list(balance)
check_balance(module, balance)
j = 0
partitions = []
layers: NamedModules = OrderedDict()
for name, layer in module.named_children():
layers[name] = layer
if len(layers) == balance[j]:
# Group buffered layers as a partition.
partition = nn.Sequential(layers)
partitions.append(partition)
# Prepare for the next partition.
layers.clear()
j += 1
partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))
return partitions, balance
MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement")
class MultiProcessPipe(Module):
"""Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module
to train on Pipe_. If the module requires lots of memory, Pipe will be
very efficient.
::
model = nn.Sequential(a, b, c, d)
model = Pipe(model, balance=[1, 1, 1, 1], chunks=8)
output = model(input)
.. _Pipe: https://arxiv.org/abs/1811.06965
Pipe combines pipeline parallelism with checkpointing to reduce peak
memory required to train while minimizing device under-utilization.
You should determine the balance when defining a :class:`Pipe` module, as
balancing will not be done automatically. The module will be partitioned
into multiple devices according to the given balance. You may rely on
heuristics to find your own optimal configuration.
Args:
module (torch.nn.Sequential):
sequential module to be parallelized
balance (ints):
list of number of layers in each partition
Keyword Args:
style (PipelineStyle):
whether to use a single process for all pipeline stages or to assign
one stage per process
group (ProcessGroup):
specific to `style=MultiProcess`, the process group that all
pipeline stages are a member of. Defaults to
`get_pipeline_parallel_group()`
worker_map (Dict[int, str]):
a map from worker name (the first argument to
`torch.distributed.rpc.init_rpc`) to global rank (i.e.
`torch.distributed.get_rank()`) needed in order for pipeline stages
to communicate with each other
input_device (device):
the device on which tensors should be located before being passed to
the first module in a given pipeline stage
chunks (int):
number of micro-batches (default: ``1``)
checkpoint (str):
when to enable checkpointing, one of ``'always'``,
``'except_last'``, or ``'never'`` (default: ``'except_last'``)
deferred_batch_norm (bool):
whether to use deferred BatchNorm moving statistics (default:
:data:`False`, see :class:`DeferredBatchNorm` for more
details)
pipelined_backward (bool, optional):
if True, call torch.autograd.backward once per microbatch on the
backward pass (instead of once for the whole batch). This works
around a potential deadlock in pytorch when using tensor parallelism
at the same time. Defaults to `True` if
`get_model_parallel_world_size() > 1`
(default: `None`)
retain_graph (bool):
The value passed to `torch.autograd.backwards(..., retain_graph=<value>)
(default: = `True`)
Raises:
TypeError:
the module is not a :class:`nn.Sequential <torch.nn.Sequential>`.
ValueError:
invalid arguments, or wrong balance
IndexError:
the number of devices is fewer than the number of partitions.
"""
MultiProcess: PipelineStyle = PipelineStyle.MultiProcess
AsyncSchedule: PipelineStyle = PipelineStyle.AsyncSchedule
#: The number of layers in each partition.
balance: List[int] = []
# ^^
# The default value [] required for Sphinx's autoattribute.
#: The devices mapped to each partition.
#:
#: ``devices[-1]`` refers to the device of the last partition, which means
#: it is the output device. Probably, you need to use it to transfer the
#: target to calculate the loss without a device mismatch
#: :exc:`RuntimeError`. For example::
#:
#: out_device = pipe.devices[-1]
#:
#: for input, target in loader:
#: target = target.to(out_device, non_blocking=True)
#: output = pipe(input)
#: loss = F.cross_entropy(output, target)
#:
#: The number of micro-batches.
chunks: int = 1
#: The checkpoint mode to determine when to enable checkpointing. It is one
#: of ``'always'``, ``'except_last'``, or ``'never'``.
checkpoint: str = "except_last"
def __init__(
self,
module: Union[nn.Sequential, ListOfLazyModules],
balance: Optional[Iterable[int]] = None,
*,
style: PipelineStyle = PipelineStyle.MultiProcess,
group: Optional[torch.distributed.ProcessGroup] = None,
worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None,
chunks: int = chunks,
checkpoint: str = checkpoint,
deferred_batch_norm: bool = False,
pipelined_backward: bool = None,
retain_graph: bool = False,
loss_fn: Optional[nn.Module] = None,
) -> None:
super().__init__()
chunks = int(chunks)
checkpoint = str(checkpoint)
if balance is None:
raise ValueError(recommend_auto_balance("balance is required"))
if chunks <= 0:
raise ValueError("number of chunks must be positive integer")
if checkpoint not in ["always", "except_last", "never"]:
raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'")
verify_module(module)
# Verify if the underlying skippable modules satisfy integrity. The
# integrity can be verified before forward() because it is static.
if isinstance(module, nn.Sequential):
verify_skippables(module)
self.chunks = chunks
self.checkpoint = checkpoint
self.pipelined_backward = pipelined_backward
self.retain_graph = retain_graph
self.pipeline: Optional[MultiProcessPipeline]
self.loss_fn = loss_fn
self.lock = threading.Lock()
self.group = group
self.worker_map = worker_map
self.input_device = input_device
# The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
if self.group is None:
self.group = get_pipeline_parallel_group()
assert self.group
self.balance = list(balance)
if self.group.size() < len(self.balance):
raise IndexError(
f"too few ranks to hold given partitions (ranks: {self.group.size()}, partitions:"
f" {len(self.balance)})"
)
try:
rank = self.group.rank()
if rank >= len(self.balance):
warnings.warn("More ranks than partitions, some ranks unused")
self.partitions: List[ModuleWrapper] = []
else:
self.partitions = instantiate_partition(module, balance, self.group, style)
if deferred_batch_norm:
for part in self.partitions:
part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks)
for name, part in enumerate(self.partitions):
self.add_module(str(name), part.module)
if isinstance(module, nn.Sequential):
local_partitions, _ = split_module(module, balance)
self._skip_layout = inspect_skip_layout(local_partitions)
else:
self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom)
except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc)))
rank = self.group.rank()
if rank >= len(self.balance):
self.pipeline = None
self.final_stage = False
else:
self.final_stage = rank == len(self.balance) - 1
assert loss_fn is None or self.final_stage
self.pipeline = MultiProcessPipeline(
cast(List[nn.Sequential], self.partitions),
self._skip_layout,
checkpoint_stop,
style=style,
group=self.group,
worker_map=self.worker_map,
input_device=self.input_device,
final_stage=self.final_stage,
)
del module
if self.pipelined_backward is None:
if get_model_parallel_world_size() > 1:
self.pipelined_backward = True
else:
self.pipelined_backward = False
def __len__(self) -> int:
"""Counts the length of the underlying sequential module."""
return sum(len(p) for p in self.partitions)
def __getitem__(self, index: int) -> nn.Module:
"""Gets a layer in the underlying sequential module."""
partitions: List[Any]
partitions = self.partitions
if index < 0:
partitions = partitions[::-1]
for partition in partitions:
try:
if isinstance(partition, ModuleWrapper):
return partition.module[index]
else:
return partition[index]
except IndexError:
pass
shift = len(partition)
if index < 0:
index += shift
else:
index -= shift
raise IndexError
def __iter__(self) -> Iterable[nn.Module]:
"""Iterates over children of the underlying sequential module."""
for partition in self.partitions:
yield from partition.module
def forward(self, input: TensorOrTensors, *, event=None) -> TensorOrTensors: # type: ignore
""":class:`MultiProcessPipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But
there's type restriction. Input and output have to be a
:class:`~torch.Tensor` or a tuple of tensors. This restriction is
applied at partition boundaries too.
Args:
input (torch.Tensor or tensors): input mini-batch
Returns:
tensor or tensors: output mini-batch
Raises:
TypeError: input is not a tensor or tensors.
"""
microbatch.check(input)
if not self.group:
# Empty sequential module is not illegal.
return input
if not self.pipeline:
# No pipeline is not illegal, more ranks than partitions
return input
# Divide a mini-batch into micro-batches.
batches = microbatch.scatter(input, self.chunks)
# Run pipeline parallelism.
with self.lock:
self.pipeline.run(self.training, batches, event)
if not self.final_stage:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
return batches # type: ignore
else:
# Merge the micro-batches into one mini-batch.
if self.pipelined_backward:
with torch.no_grad():
output = microbatch.gather(batches)
from .phony import get_phony
phony = get_phony(
torch.device(torch.cuda.current_device() if torch.cuda.is_available() else "cpu"),
requires_grad=True,
)
output = PipelinedBackwardPass.apply(output, batches, phony, True) # self.retain_graph)
else:
output = microbatch.gather(batches)
return output
def back_helper(self, output: List[microbatch.Batch]) -> None:
if self.final_stage:
raise ValueError("back_helper should only be called on non-final stages")
if self.pipeline:
self.pipeline.back_helper(list(reversed(output)))
class PipelinedBackwardPass(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx, input: TensorOrTensors, batches, phony, retain_graph) -> TensorOrTensors:
ctx.batches = batches
ctx.retain_graph = retain_graph
return input
@staticmethod
# type: ignore
def backward(ctx, *grads) -> Tuple:
with torch.no_grad():
grad_batches = microbatch.scatter(grads, len(ctx.batches))
for grad, batch in reversed(list(zip(grad_batches, ctx.batches))):
for t in batch:
t.retain_grad()
torch.autograd.backward(batch.tensor_or_tensors, grad_tensors=(*grad,), retain_graph=ctx.retain_graph)
with torch.no_grad():
if ctx.batches[0].atomic:
tensors = tuple(b.tensor.grad for b in ctx.batches)
output: TensorOrTensors = torch.cat(tensors)
else:
rotated = [[t.grad for t in b.tensors] for b in ctx.batches]
output_buf = []
for tensors in zip(*rotated):
output_buf.append(torch.cat(tensors))
output = tuple(output_buf)
del ctx.batches
return (output, None, None, None)
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The multiprocess pipeline parallelism of Pipe."""
import logging
import os
from queue import Empty as QueueEmpty
from queue import Queue
from threading import Event
from types import TracebackType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union, cast
import torch
from torch import Tensor, nn
from torch.autograd.profiler import record_function
from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from .async_schedule import AsyncEventLoop, ModuleWrapper
from .checkpoint import Checkpointing
from .messages import MakeTransport, Transport
from .microbatch import Batch
from .skip import Namespace
from .skip.layout import SkipLayout
from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from .types import (
ACTIVATIONS_GRADS_QUEUE,
PORTAL_QUEUE,
SKIP_TENSOR_QUEUE,
PipelineStyle,
PipeMessage,
TensorOrTensors,
Tensors,
)
from .worker import Task
__all__: List[str] = []
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
class SendOperator(torch.autograd.Function):
"""Send activations to the next pipeline stage"""
@staticmethod
# type: ignore
def forward(ctx, src_rank, dst_rank, transport: Transport, input: List[Tensor], index: int) -> Tensors:
assert src_rank == torch.distributed.get_rank()
transport.send_message(
PipeMessage(src_rank, dst_rank, queue_name=ACTIVATIONS_GRADS_QUEUE, args=index, tensors=tuple(input)),
)
return ()
@staticmethod
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tensors:
return tuple(grad)
class RecvOperator(torch.autograd.Function):
"""Receive activations to the previous pipeline stage"""
@staticmethod
# type: ignore
def forward(ctx, dst_rank: int, tensor: Tensor, input_device, transport: Transport, index: int) -> Tensors:
assert dst_rank == torch.distributed.get_rank()
ctx.transport = transport
ctx.index = index
result = transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, index)
def maybe_requires_grad(t: Tensor) -> Tensor:
if t.dtype.is_floating_point:
return t.requires_grad_()
return t
return tuple(maybe_requires_grad(r) for r in result)
@staticmethod
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
ctx.transport.send_message(
PipeMessage(
this_rank,
ranks[ranks.index(this_rank) - 1],
queue_name=ACTIVATIONS_GRADS_QUEUE,
args=ctx.index,
tensors=tuple(grad),
),
)
return (None, None, None, None, None)
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if TYPE_CHECKING:
InQueue = Queue[Optional["Task"]]
OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]]
else:
InQueue = Queue
OutQueue = Queue
def create_task(
style: PipelineStyle,
checkpoint_stop: int,
i: int,
j: int,
batch: Batch,
partition: nn.Sequential,
skip_trackers: List[SkipTrackerThroughPotals],
) -> Task:
# Determine whether checkpointing or not.
if i < checkpoint_stop:
def function(
input: TensorOrTensors,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
ret = partition(input)
# We do a check here because the backtrace from the checkpoint backward code path
# is very hard to make sense. It would be much easier to check earlier at this point.
assert type(ret) is not list, "Only Tensor or Tuple of Tensor output is supported"
return ret
chk = Checkpointing(function, batch)
task = Task(None, compute=chk.checkpoint, finalize=chk.recompute)
del function, chk # TODO(tom) maybe remove
else:
def compute(
batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> Batch:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition)
task = Task(None, compute=compute, finalize=None)
del compute # TODO(tom) maybe remove
return task
class MultiProcessPipeline:
"""The multiprocess pipeline parallelism for Pipe."""
def __init__(
self,
partitions: List[nn.Sequential],
skip_layout: SkipLayout,
checkpoint_stop: int,
style: PipelineStyle,
group: Optional[torch.distributed.ProcessGroup] = None,
worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None,
final_stage: bool = False,
) -> None:
self.partitions: List[ModuleWrapper] = cast(List[ModuleWrapper], partitions)
self.skip_layout = skip_layout
self.__checkpoint_stop = checkpoint_stop
self.style = style
self.group = group
self.training: bool
self.transport = MakeTransport(
use_rpc=("OMPI_COMM_WORLD_RANK" not in os.environ) or ("FORCE_RPC" in os.environ),
worker_map=worker_map,
input_device=input_device,
)
self.input_device = input_device
self.all_at_once = False
self.callcount = 0
self.final_stage = final_stage
@property
def checkpoint_stop(self) -> int:
# Disable checkpointing if in eval mode.
training = self.partitions[0].module.training
if not training:
return 0
return self.__checkpoint_stop
def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> None:
"""Runs pipeline parallelism.
It modifies the given batches in place.
"""
self.training = training
m = len(batches)
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))]
if self.style is PipelineStyle.MultiProcess:
assert self.group
schedule = [(i, self.group.rank()) for i in range(m)]
self.compute(batches, schedule, skip_trackers)
elif self.style is PipelineStyle.AsyncSchedule:
assert self.group
rank = self.group.rank()
event_loop = AsyncEventLoop(
self.partitions, self.group, self.transport, self.training, self.checkpoint_stop,
)
if rank == 0 and not self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event head")
event_loop.event_loop_head(batches, skip_trackers, event)
logging.debug(f"{torch.distributed.get_rank()}: exited event head")
elif self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event tail")
event_loop.event_loop_tail(batches, skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event tail")
else:
logging.debug(f"{torch.distributed.get_rank()}: entered event loop")
event_loop.event_loop(len(batches), skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event loop")
self.callcount += 1
def get_batch_from_previous_stage(
self, i: int, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]
) -> Batch:
phony = torch.empty(0, device=self.input_device, requires_grad=True)
result = RecvOperator.apply(torch.distributed.get_rank(), phony, self.input_device, self.transport, i)
if len(result) == 1:
batch = Batch(result[0], i)
else:
batch = Batch(result, i)
self.recv_skip_tensors(skip_trackers, batches)
return batch
def send_skip_tensors(
self, this_rank: int, ranks: List[int], batch: Batch, i: int, skip_trackers: List[SkipTrackerThroughPotals]
) -> None:
assert self.group
for next_j, ns, name in self.skip_layout.copy_policy_by_src(self.group.rank()):
life = skip_trackers[i].portals[(ns, name)].tensor_life
loaded = skip_trackers[i].load(batch, ns, name)
if loaded is not None:
tensors = tuple([loaded])
else:
tensors = tuple()
self.transport.send_message(
PipeMessage(
this_rank, ranks[next_j], queue_name=SKIP_TENSOR_QUEUE, args=(i, ns, name, life), tensors=tensors,
),
sync=True,
)
def recv_skip_tensors(self, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]) -> None:
while True:
try:
message = self.transport.recv_message(SKIP_TENSOR_QUEUE, nowait=True)
(si, ns, name, life) = message.args
value: Optional[TensorOrTensors] = message.tensors
assert isinstance(value, tuple)
if len(value) == 0:
value = None
else:
assert len(value) == 1
value = value[0]
skip_trackers[si].save(batches[si], ns, name, value)
old_life = skip_trackers[si].portals[(ns, name)].tensor_life
if life != 0:
skip_trackers[si].portals[(ns, name)].tensor_life = life
except QueueEmpty:
break
def execute_task(self, task: Task, i: int, skip_trackers: List[SkipTrackerThroughPotals]) -> Batch:
batch = task.compute()
assert self.group
rank = self.group.rank()
if self.style is PipelineStyle.MultiProcess and not self.final_stage:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
self.send_skip_tensors(this_rank, ranks, batch, i, skip_trackers)
SendOperator.apply(this_rank, ranks[ranks.index(this_rank) + 1], self.transport, [*batch], i)
for portal in skip_trackers[i].portals.values():
portal.pipeline = self
task.finalize(batch)
return batch
def compute(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals]
) -> None:
"""Runs tasks with synchronization to copy streams."""
if self.style is PipelineStyle.MultiProcess:
assert self.group
n = self.group.size()
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for i, j in schedule:
batch = batches[i]
if self.style is PipelineStyle.MultiProcess:
assert len(self.partitions) == 1
partition = self.partitions[0]
assert self.group
if self.group.rank() != 0:
batch = self.get_batch_from_previous_stage(i, skip_trackers, batches)
task = create_task(self.style, self.checkpoint_stop, i, j, batch, partition.module, skip_trackers)
batches[i] = self.execute_task(task, i, skip_trackers)
def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: TensorOrTensors) -> None:
dest, src = self.skip_layout.by_ns_name.get(ns_name, (-1, -1))
if dest == src:
return
ranks = get_pipeline_parallel_ranks()
dst_rank = ranks[dest]
if dst_rank == torch.distributed.get_rank():
return
if isinstance(grad, Tensor):
grad = tuple([grad])
self.transport.send_message(
PipeMessage(ranks[src], dst_rank, queue_name=PORTAL_QUEUE, args=(ns_name, index), tensors=grad), sync=True,
)
def recv_portal_grad(self, expected_ns_name: Tuple[Namespace, str], expected_index: int) -> Tensor:
message = self.transport.recv_message(PORTAL_QUEUE)
(ns_name, index) = message.args
grad = message.tensors
assert len(grad) == 1
result = grad[0]
assert index == expected_index and ns_name == expected_ns_name
return result
def back_helper(self, output: List[Batch]) -> None:
if self.style == PipelineStyle.AsyncSchedule:
return
o = list(output)
tensors: Tensors
if self.all_at_once:
# FIXME(tom) allow specifying this branch when constructing Pipe(), add a test
grads = []
for i, batch in enumerate(o):
rank = torch.distributed.get_rank()
found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, i)
assert len(found) == 1
grads.append(found[0])
tensors = tuple(x.tensor_or_tensors for x in o) # type: ignore
try:
torch.autograd.backward(tensors, grad_tensors=grads, retain_graph=True)
except Exception as e:
raise RuntimeError("Autograd failed") from e
else:
rank = torch.distributed.get_rank()
for batch in o:
found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, batch.index)
if batch.atomic:
tensors = tuple([batch.tensor])
else:
tensors = batch.tensors
if len(found) != len(tensors):
raise RuntimeError("different number of tensors and gradients")
grads = []
final_tensors = []
for i, tensor in enumerate(tensors):
if tensor.requires_grad or getattr(tensor, "grad_fn", None) is not None:
grads.append(found[i])
final_tensors.append(tensor)
try:
torch.autograd.backward(final_tensors, grad_tensors=grads, retain_graph=True)
except Exception as e:
raise RuntimeError(f"Autograd failed on {torch.distributed.get_rank()}") from e
......@@ -19,29 +19,21 @@
"""The Pipe interface."""
from collections import OrderedDict
from dataclasses import dataclass, field
import itertools
import threading
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union, cast
import warnings
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast
import torch
from torch import Tensor, nn
import torch.autograd
import torch.cuda
from fairscale.nn.model_parallel import get_model_parallel_world_size, get_pipeline_parallel_group
from . import microbatch
from .async_schedule import Invocation, Location, ModuleWrapper
from .batchnorm import DeferredBatchNorm
from .pipeline import Pipeline
from .skip.layout import SkipLayout, inspect_skip_layout
from .skip.skippable import Skippable, verify_skippables
from .skip.layout import inspect_skip_layout
from .skip.skippable import verify_skippables
from .stream import AbstractStream, new_stream
from .types import LazyModule, PipelineStyle
__all__ = ["Pipe", "LazyModule"]
__all__ = ["Pipe"]
Device = Union[torch.device, int, str]
......@@ -50,8 +42,6 @@ Devices = Union[Iterable[Device], List[Device]]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
ListOfLazyModules = List[LazyModule]
if TYPE_CHECKING:
Module = nn.Module[TensorOrTensors]
NamedModules = OrderedDict[str, Module]
......@@ -79,34 +69,17 @@ naive automatic balancing:
"""
# FIXME(tom) make this a valid way to call
def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None:
for layer in module:
if isinstance(layer, nn.Module):
pass
elif isinstance(layer, LazyModule):
pass
else:
raise TypeError(f"layer {type(layer)} must be nn.Module or LazyModule to be partitioned")
def verify_module(module: nn.Sequential) -> None:
if not isinstance(module, nn.Sequential):
raise TypeError("module must be nn.Sequential to be partitioned")
def verify_module(module: Union[nn.Sequential, ListOfLazyModules]) -> None:
if isinstance(module, Iterable) and not isinstance(module, nn.Sequential):
verify_list_of_callable(module)
else:
if not isinstance(module, nn.Sequential):
raise TypeError("module must be nn.Sequential to be partitioned")
named_children = list(module.named_children())
if len(named_children) != len(module):
raise ValueError("module with duplicate children is not supported")
named_children = list(module.named_children())
if len(named_children) != len(module):
raise ValueError("module with duplicate children is not supported")
def verify_splitting(
module: nn.Sequential,
partitions: List[nn.Sequential],
balance: Iterable[int],
devices: Optional[List[torch.device]],
module: nn.Sequential, partitions: List[nn.Sequential], balance: Iterable[int], devices: List[torch.device]
) -> None:
num_parameters = len(list(module.parameters()))
num_child_parameters = sum(len(list(child.parameters())) for child in module.children())
......@@ -117,7 +90,7 @@ def verify_splitting(
for j in range(i + 1, len(partitions)):
parti = partitions[i]
partj = partitions[j]
if devices and devices[i] == devices[j]:
if devices[i] == devices[j]:
continue
for p in parti.parameters():
for q in partj.parameters():
......@@ -129,159 +102,9 @@ class BalanceError(ValueError):
pass
def check_balance(module: Any, balance: Iterable[int], filter_unique: bool = False) -> None:
if filter_unique:
module_len = len(set(map(id, module)))
else:
module_len = len(module)
if module_len != sum(balance):
raise BalanceError(
f"module and sum of balance have different length (module: {len(module)}, sum of balance: {sum(balance)})"
)
if any(x <= 0 for x in balance):
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
@dataclass
class PartitionInfo:
location: Location
modules: "OrderedDict[str, nn.Module]"
invocations: List[Invocation] = field(default_factory=list)
def __len__(self) -> int:
return len(self.modules)
def instantiate_partition(
module: Union[nn.Sequential, ListOfLazyModules],
balance: Iterable[int],
group: torch.distributed.ProcessGroup,
style: PipelineStyle,
) -> List[ModuleWrapper]:
balance = list(balance)
check_balance(module, balance, True)
layers: NamedModules = OrderedDict()
def maybe_realize(layer: Any) -> nn.Module:
if isinstance(layer, nn.Module):
return layer
elif callable(layer):
return layer()
else:
raise TypeError(f"layer must be nn.Module or callable, is {type(layer)}")
def iterate_module(module: Union[nn.Sequential, list]) -> Iterable[Tuple[Any, nn.Module]]:
if isinstance(module, nn.Sequential):
yield from module.named_children()
else:
yield from ((str(k), v) for k, v in enumerate(module))
if style == PipelineStyle.AsyncSchedule:
module_ids = list(map(id, module))
index_of_first_use = [module_ids.index(x) for x in module_ids]
locations: List[Location] = []
module_iter = enumerate(iterate_module(module))
partitions: List[List[PartitionInfo]] = []
for bi, b in enumerate(balance):
modules_for_rank: List[PartitionInfo] = []
current_module: OrderedDict[str, nn.Module] = OrderedDict()
def current_location() -> Location:
return Location(bi, len(modules_for_rank))
def append_module(mod: "OrderedDict[str, nn.Module]") -> None:
modules_for_rank.append(PartitionInfo(current_location(), mod))
while sum(map(len, modules_for_rank)) + len(current_module) < b:
module_index, (name, layer) = next(module_iter)
if index_of_first_use[module_index] != module_index:
# Subsequent reuse of a module
locations.append(locations[index_of_first_use[module_index]])
continue
is_reused = index_of_first_use.count(index_of_first_use[module_index]) > 1
if is_reused and len(current_module) > 0:
append_module(current_module)
current_module = OrderedDict()
current_module[str(name)] = layer
locations.append(current_location())
if is_reused:
append_module(current_module)
current_module = OrderedDict()
if len(current_module) > 0:
append_module(current_module)
partitions.append(modules_for_rank)
filtered_locations: List[Optional[Location]] = [loc for loc, _ in itertools.groupby(locations)]
filtered_locations.append(None)
for i in range(len(filtered_locations) - 1):
loc = filtered_locations[i]
assert loc
if i == 0:
inv = Invocation(i, loc, None, filtered_locations[i + 1])
else:
inv = Invocation(i, loc, filtered_locations[i - 1], filtered_locations[i + 1])
partitions[loc.stage][loc.index].invocations.append(inv)
invocations = enumerate(iterate_module(module))
partition = partitions[group.rank()]
result: List[ModuleWrapper] = []
for partition_info in partition:
wrapper = ModuleWrapper(
nn.Sequential(OrderedDict((k, maybe_realize(m)) for k, m in partition_info.modules.items())),
partition_info.location,
partition_info.invocations,
)
if not isinstance(module, nn.Sequential):
for layer in wrapper.module:
if isinstance(layer, Skippable):
raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction")
result.append(wrapper)
return result
j = 0
for name, layer in iterate_module(module):
layers[name] = layer
if len(layers) == balance[j]:
if j == group.rank():
for key in layers:
layers[key] = maybe_realize(layers[key])
if not isinstance(module, nn.Sequential):
for layer in layers.values():
if isinstance(layer, Skippable):
raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction")
return [ModuleWrapper(nn.Sequential(layers), Location(j, 0))]
# Prepare for the next partition.
layers.clear()
j += 1
raise ValueError("Souldn't get here, more ranks than partitions")
def split_module(
module: nn.Sequential, balance: Iterable[int], devices: Optional[List[torch.device]],
) -> Tuple[List[nn.Sequential], List[int], Optional[List[torch.device]]]:
module: nn.Sequential, balance: Iterable[int], devices: List[torch.device],
) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]:
"""Splits a module into multiple partitions.
Returns:
......@@ -300,11 +123,18 @@ def split_module(
"""
balance = list(balance)
check_balance(module, balance)
if len(module) != sum(balance):
raise BalanceError(
"module and sum of balance have different length "
f"(module: {len(module)}, sum of balance: {sum(balance)})"
)
if devices and len(balance) > len(devices):
if any(x <= 0 for x in balance):
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
if len(balance) > len(devices):
raise IndexError(
f"too few devices to hold given partitions (devices: {len(devices)}, partitions: {len(balance)})"
"too few devices to hold given partitions " f"(devices: {len(devices)}, partitions: {len(balance)})"
)
j = 0
......@@ -318,9 +148,8 @@ def split_module(
# Group buffered layers as a partition.
partition = nn.Sequential(layers)
if devices:
device = devices[j]
partition.to(device)
device = devices[j]
partition.to(device)
partitions.append(partition)
......@@ -329,13 +158,12 @@ def split_module(
j += 1
partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))
if devices:
del devices[j:]
del devices[j:]
return partitions, balance, devices
MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement")
MOVING_DENIED = TypeError("denied to move parameters and buffers, " "because Pipe should manage device placement")
class Pipe(Module):
......@@ -365,23 +193,8 @@ class Pipe(Module):
list of number of layers in each partition
Keyword Args:
style (PipelineStyle):
whether to use a single process for all pipeline stages or to assign
one stage per process
devices (iterable of devices):
devices to use (default: all CUDA devices)
group (ProcessGroup):
specific to `style=MultiProcess`, the process group that all
pipeline stages are a member of. Defaults to
`get_pipeline_parallel_group()`
worker_map (Dict[int, str]):
a map from worker name (the first argument to
`torch.distributed.rpc.init_rpc`) to global rank (i.e.
`torch.distributed.get_rank()`) needed in order for pipeline stages
to communicate with each other
input_device (device):
the device on which tensors should be located before being passed to
the first module in a given pipeline stage
chunks (int):
number of micro-batches (default: ``1``)
checkpoint (str):
......@@ -389,18 +202,8 @@ class Pipe(Module):
``'except_last'``, or ``'never'`` (default: ``'except_last'``)
deferred_batch_norm (bool):
whether to use deferred BatchNorm moving statistics (default:
:data:`False`, see :class:`DeferredBatchNorm` for more
:data:`False`, see :ref:`Deferred Batch Normalization` for more
details)
pipelined_backward (bool, optional):
if True, call torch.autograd.backward once per microbatch on the
backward pass (instead of once for the whole batch). This works
around a potential deadlock in pytorch when using tensor parallelism
at the same time. Defaults to `True` if
`get_model_parallel_world_size() > 1`
(default: `None`)
retain_graph (bool):
The value passed to `torch.autograd.backwards(..., retain_graph=<value>)
(default: = `True`)
Raises:
TypeError:
......@@ -412,10 +215,6 @@ class Pipe(Module):
"""
SingleProcess: PipelineStyle = PipelineStyle.SingleProcess
MultiProcess: PipelineStyle = PipelineStyle.MultiProcess
AsyncSchedule: PipelineStyle = PipelineStyle.AsyncSchedule
#: The number of layers in each partition.
balance: List[int] = []
# ^^
......@@ -435,7 +234,7 @@ class Pipe(Module):
#: output = pipe(input)
#: loss = F.cross_entropy(output, target)
#:
devices: Optional[List[torch.device]] = None
devices: List[torch.device] = []
#: The number of micro-batches.
chunks: int = 1
......@@ -446,20 +245,13 @@ class Pipe(Module):
def __init__(
self,
module: Union[nn.Sequential, ListOfLazyModules],
module: nn.Sequential,
balance: Optional[Iterable[int]] = None,
*,
style: PipelineStyle = PipelineStyle.SingleProcess,
devices: Optional[Devices] = None,
group: Optional[torch.distributed.ProcessGroup] = None,
worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None,
chunks: int = chunks,
checkpoint: str = checkpoint,
deferred_batch_norm: bool = False,
pipelined_backward: bool = None,
retain_graph: bool = False,
loss_fn: Optional[nn.Module] = None,
) -> None:
super().__init__()
......@@ -477,145 +269,50 @@ class Pipe(Module):
# Verify if the underlying skippable modules satisfy integrity. The
# integrity can be verified before forward() because it is static.
if isinstance(module, nn.Sequential):
verify_skippables(module)
verify_skippables(module)
self.chunks = chunks
self.checkpoint = checkpoint
self.pipelined_backward = pipelined_backward
self.retain_graph = retain_graph
self.pipeline: Optional[Pipeline]
self.loss_fn = loss_fn
self.lock = threading.Lock()
self.group = group
self.worker_map = worker_map
self.input_device = input_device
self._copy_streams: List[List[AbstractStream]] = []
# The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
if style is PipelineStyle.SingleProcess:
module = cast(nn.Sequential, module)
if deferred_batch_norm:
module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks)
if input_device is not None:
raise ValueError("'input_device' argument only applies to 'PipelineStyle.MultiProcess'")
if deferred_batch_norm:
module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks)
if devices is None:
devices = range(torch.cuda.device_count())
if devices is None:
devices = range(torch.cuda.device_count())
devices = [torch.device(d) for d in devices]
devices = cast(List[torch.device], devices)
devices = [torch.device(d) for d in devices]
devices = cast(List[torch.device], devices)
try:
self.partitions, self.balance, self.devices = split_module(module, balance, devices)
except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc)))
try:
self.partitions, self.balance, self.devices = split_module(module, balance, devices)
except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc)))
verify_splitting(module, self.partitions, self.balance, self.devices)
self._skip_layout = inspect_skip_layout(self.partitions)
# Separate CUDA streams for copy.
copy_streams = self._ensure_copy_streams()
if self.pipelined_backward is None:
self.pipelined_backward = False
self.pipeline = Pipeline(
self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop, style=style,
)
verify_splitting(module, self.partitions, self.balance, self.devices)
elif style in [PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule]:
if self.group is None:
self.group = get_pipeline_parallel_group()
assert self.group
self._copy_streams: List[List[AbstractStream]] = []
self._skip_layout = inspect_skip_layout(self.partitions)
if devices is not None:
raise ValueError("'devices' argument only applies to 'PipelineStyle.SingleProcess'")
# Separate CUDA streams for copy.
copy_streams = self._ensure_copy_streams()
self.balance = list(balance)
# The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
if self.group.size() < len(self.balance):
raise IndexError(
f"too few ranks to hold given partitions (ranks: {self.group.size()}, partitions:"
f" {len(self.balance)})"
)
try:
rank = self.group.rank()
if rank >= len(self.balance):
warnings.warn("More ranks than partitions, some ranks unused")
self.mp_partitions: List[ModuleWrapper] = []
else:
self.mp_partitions = instantiate_partition(module, balance, self.group, style)
if deferred_batch_norm:
for part in self.mp_partitions:
part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks)
for name, part in enumerate(self.mp_partitions):
self.add_module(str(name), part.module)
self.devices = None
if isinstance(module, nn.Sequential):
local_partitions, _, _ = split_module(module, balance, None)
self._skip_layout = inspect_skip_layout(local_partitions)
else:
self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom)
except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc)))
rank = self.group.rank()
if rank >= len(self.balance):
self.pipeline = None
self.final_stage = False
else:
self.final_stage = rank == len(self.balance) - 1
assert loss_fn is None or self.final_stage
self.pipeline = Pipeline(
cast(List[nn.Sequential], self.mp_partitions),
None,
None,
self._skip_layout,
checkpoint_stop,
style=style,
group=self.group,
worker_map=self.worker_map,
input_device=self.input_device,
final_stage=self.final_stage,
)
del module
if self.pipelined_backward is None:
if get_model_parallel_world_size() > 1:
self.pipelined_backward = True
else:
self.pipelined_backward = False
self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop)
def __len__(self) -> int:
"""Counts the length of the underlying sequential module."""
if hasattr(self, "partitions"):
return sum(len(p) for p in self.partitions)
else:
return sum(len(p) for p in self.mp_partitions)
return sum(len(p) for p in self.partitions)
def __getitem__(self, index: int) -> nn.Module:
"""Gets a layer in the underlying sequential module."""
partitions: List[Any]
if hasattr(self, "partitions"):
partitions = self.partitions
else:
partitions = self.mp_partitions
partitions = self.partitions
if index < 0:
partitions = partitions[::-1]
for partition in partitions:
try:
if isinstance(partition, ModuleWrapper):
return partition.module[index]
else:
return partition[index]
return partition[index]
except IndexError:
pass
......@@ -630,47 +327,35 @@ class Pipe(Module):
def __iter__(self) -> Iterable[nn.Module]:
"""Iterates over children of the underlying sequential module."""
if hasattr(self, "partitions"):
for partition in self.partitions:
yield from partition
else:
for mp_partition in self.mp_partitions:
yield from mp_partition.module
for partition in self.partitions:
yield from partition
# Pipe should manage the device of each partition.
# Deny cuda(), cpu(), and to() with device, by TypeError.
def cuda(self, device: Optional[Device] = None) -> "Pipe":
if self.devices:
raise MOVING_DENIED
if device:
return super().cuda(device=device)
else:
return super().cuda()
raise MOVING_DENIED
def cpu(self) -> "Pipe":
if self.devices:
raise MOVING_DENIED
return super().cpu()
raise MOVING_DENIED
def to(self, *args: Any, **kwargs: Any) -> "Pipe":
"""Restrict .to() options.
Deny these usages:
- to(device[, dtype, non_blocking])
- to(tensor[, non_blocking])
# Deny these usages:
#
# - to(device[, dtype, non_blocking])
# - to(tensor[, non_blocking])
#
# But allow this:
#
# - to(dtype[, non_blocking])
#
if "device" in kwargs or "tensor" in kwargs:
raise MOVING_DENIED
But allow this:
- to(dtype[, non_blocking])
"""
if self.devices:
if "device" in kwargs or "tensor" in kwargs:
if args:
if isinstance(args[0], (torch.device, int, str)):
raise MOVING_DENIED
if torch.is_tensor(args[0]):
raise MOVING_DENIED
if args:
if isinstance(args[0], (torch.device, int, str)):
raise MOVING_DENIED
if torch.is_tensor(args[0]):
raise MOVING_DENIED
return super().to(*args, **kwargs)
......@@ -683,13 +368,12 @@ class Pipe(Module):
"""
if not self._copy_streams:
assert self.devices is not None
for device in self.devices:
self._copy_streams.append([new_stream(device) for _ in range(self.chunks)])
return self._copy_streams
def forward(self, input: TensorOrTensors, *, event=None) -> TensorOrTensors: # type: ignore
def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore
""":class:`Pipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But
there's type restriction. Input and output have to be a
......@@ -708,82 +392,16 @@ class Pipe(Module):
"""
microbatch.check(input)
if not self.group and not self.devices:
if not self.devices:
# Empty sequential module is not illegal.
return input
if not self.pipeline:
# No pipeline is not illegal, more ranks than partitions
return input
# Divide a mini-batch into micro-batches.
batches = microbatch.scatter(input, self.chunks)
# Run pipeline parallelism.
with self.lock:
self.pipeline.run(self.training, batches, event)
if self.group and not self.final_stage:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
return batches # type: ignore
else:
# Merge the micro-batches into one mini-batch.
if self.pipelined_backward:
with torch.no_grad():
output = microbatch.gather(batches)
from .phony import get_phony
phony = get_phony(
torch.device(torch.cuda.current_device() if torch.cuda.is_available() else "cpu"),
requires_grad=True,
)
output = PipelinedBackwardPass.apply(output, batches, phony, True) # self.retain_graph)
else:
output = microbatch.gather(batches)
return output
def back_helper(self, output: List[microbatch.Batch]) -> None:
if self.final_stage:
raise ValueError("back_helper should only be called on non-final stages")
if self.pipeline:
self.pipeline.back_helper(list(reversed(output)))
class PipelinedBackwardPass(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx, input: TensorOrTensors, batches, phony, retain_graph) -> TensorOrTensors:
ctx.batches = batches
ctx.retain_graph = retain_graph
return input
@staticmethod
# type: ignore
def backward(ctx, *grads) -> Tuple:
with torch.no_grad():
grad_batches = microbatch.scatter(grads, len(ctx.batches))
for grad, batch in reversed(list(zip(grad_batches, ctx.batches))):
for t in batch:
t.retain_grad()
torch.autograd.backward(batch.tensor_or_tensors, grad_tensors=(*grad,), retain_graph=ctx.retain_graph)
with torch.no_grad():
if ctx.batches[0].atomic:
tensors = tuple(b.tensor.grad for b in ctx.batches)
output: TensorOrTensors = torch.cat(tensors)
else:
rotated = [[t.grad for t in b.tensors] for b in ctx.batches]
output_buf = []
for tensors in zip(*rotated):
output_buf.append(torch.cat(tensors))
output = tuple(output_buf)
del ctx.batches
self.pipeline.run(batches)
return (output, None, None, None)
# Merge the micro-batches into one mini-batch.
output = microbatch.gather(batches)
return output
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
......@@ -17,101 +18,30 @@
# limitations under the License.
"""The pipeline parallelism of Pipe."""
import logging
import os
from queue import Empty as QueueEmpty
from queue import Queue
from threading import Event
from types import TracebackType
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Type, Union, cast
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast
import torch
from torch import Tensor, nn
from torch.autograd.profiler import record_function
from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from .async_schedule import AsyncEventLoop, ModuleWrapper
from .checkpoint import Checkpointing
from .copy import Copy, Wait
from .dependency import fork, join
from .messages import MakeTransport, Transport
from .microbatch import Batch
from .skip import Namespace
from .skip.layout import SkipLayout
from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from .stream import AbstractStream, current_stream, use_device
from .types import (
ACTIVATIONS_GRADS_QUEUE,
PORTAL_QUEUE,
SKIP_TENSOR_QUEUE,
PipelineStyle,
PipeMessage,
Schedule,
TensorOrTensors,
Tensors,
)
from .worker import Task, create_workers, join_workers
__all__: List[str] = []
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
class SendOperator(torch.autograd.Function):
"""Send activations to the next pipeline stage"""
@staticmethod
# type: ignore
def forward(ctx, src_rank, dst_rank, transport: Transport, input: List[Tensor], index: int) -> Tensors:
assert src_rank == torch.distributed.get_rank()
transport.send_message(
PipeMessage(src_rank, dst_rank, queue_name=ACTIVATIONS_GRADS_QUEUE, args=index, tensors=tuple(input)),
)
return ()
@staticmethod
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tensors:
return tuple(grad)
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
class RecvOperator(torch.autograd.Function):
"""Receive activations to the previous pipeline stage"""
@staticmethod
# type: ignore
def forward(ctx, dst_rank: int, tensor: Tensor, input_device, transport: Transport, index: int) -> Tensors:
assert dst_rank == torch.distributed.get_rank()
ctx.transport = transport
ctx.index = index
result = transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, index)
def maybe_requires_grad(t: Tensor) -> Tensor:
if t.dtype.is_floating_point:
return t.requires_grad_()
return t
return tuple(maybe_requires_grad(r) for r in result)
@staticmethod
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
ctx.transport.send_message(
PipeMessage(
this_rank,
ranks[ranks.index(this_rank) - 1],
queue_name=ACTIVATIONS_GRADS_QUEUE,
args=ctx.index,
tensors=tuple(grad),
),
)
return (None, None, None, None, None)
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
......@@ -140,7 +70,7 @@ def wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream)
batch[:] = tuple([x if x.is_floating_point() else x.detach() for x in batch])
def clock_cycles(m: int, n: int) -> Iterable[Schedule]:
def clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]:
"""Generates schedules for each clock cycle."""
# m: number of micro-batches
# n: number of partitions
......@@ -159,159 +89,45 @@ def clock_cycles(m: int, n: int) -> Iterable[Schedule]:
yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))]
def create_task(
style: PipelineStyle,
checkpoint_stop: int,
i: int,
j: int,
batch: Batch,
partition: nn.Sequential,
skip_trackers: List[SkipTrackerThroughPotals],
streams: List[AbstractStream],
) -> Task:
# Determine whether checkpointing or not.
if i < checkpoint_stop:
def function(
input: TensorOrTensors,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
ret = partition(input)
# We do a check here because the backtrace from the checkpoint backward code path
# is very hard to make sense. It would be much easier to check earlier at this point.
assert type(ret) is not list, "Only Tensor or Tuple of Tensor output is supported"
return ret
chk = Checkpointing(function, batch)
if style is PipelineStyle.SingleProcess:
task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
elif style in [PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule]:
task = Task(None, compute=chk.checkpoint, finalize=chk.recompute)
del function, chk # TODO(tom) maybe remove
else:
def compute(
batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> Batch:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition)
if style is PipelineStyle.SingleProcess:
task = Task(streams[j], compute=compute, finalize=None)
elif style in [PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule]:
task = Task(None, compute=compute, finalize=None)
del compute # TODO(tom) maybe remove
return task
class Pipeline:
"""The pipeline parallelism for Pipe."""
def __init__(
self,
partitions: List[nn.Sequential],
devices: Optional[List[torch.device]],
copy_streams: Optional[List[List[AbstractStream]]],
devices: List[torch.device],
copy_streams: List[List[AbstractStream]],
skip_layout: SkipLayout,
checkpoint_stop: int,
style: PipelineStyle,
group: Optional[torch.distributed.ProcessGroup] = None,
worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None,
final_stage: bool = False,
) -> None:
if style == PipelineStyle.SingleProcess:
self.partitions = partitions
else:
self.mp_partitions: List[ModuleWrapper] = cast(List[ModuleWrapper], partitions)
self.partitions = partitions
self.devices = devices
self.copy_streams = copy_streams
self.skip_layout = skip_layout
self.__checkpoint_stop = checkpoint_stop
self.style = style
self.group = group
self.training: bool
if style in [PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule]:
self.transport = MakeTransport(
use_rpc=("OMPI_COMM_WORLD_RANK" not in os.environ) or ("FORCE_RPC" in os.environ),
worker_map=worker_map,
input_device=input_device,
)
self.input_device = input_device
self.all_at_once = False
self.callcount = 0
self.final_stage = final_stage
if self.style is PipelineStyle.SingleProcess:
assert self.devices is not None
(self.in_queues, self.out_queues) = create_workers(self.devices)
@property
def checkpoint_stop(self) -> int:
# Disable checkpointing if in eval mode.
if self.style == PipelineStyle.SingleProcess:
training = self.partitions[0].training
else:
training = self.mp_partitions[0].module.training
if not training:
return 0
return self.__checkpoint_stop
self.checkpoint_stop = checkpoint_stop
(self.in_queues, self.out_queues) = create_workers(devices)
def __del__(self) -> None:
if self.style is PipelineStyle.SingleProcess:
join_workers(self.in_queues, self.out_queues)
def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> None:
join_workers(self.in_queues, self.out_queues)
def run(self, batches: List[Batch]) -> None:
"""Runs pipeline parallelism.
It modifies the given batches in place.
"""
self.training = training
partitions = self.partitions
devices = self.devices
skip_layout = self.skip_layout
m = len(batches)
n = len(partitions)
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))]
skip_trackers = [SkipTrackerThroughPotals(skip_layout, i) for i in range(m)]
if self.style is PipelineStyle.SingleProcess:
n = len(self.partitions)
for schedule in clock_cycles(m, n):
self.fence(batches, schedule, skip_trackers)
self.compute(batches, schedule, skip_trackers)
elif self.style is PipelineStyle.MultiProcess:
assert self.group
schedule = [(i, self.group.rank()) for i in range(m)]
for schedule in clock_cycles(m, n):
self.fence(batches, schedule, skip_trackers)
self.compute(batches, schedule, skip_trackers)
elif self.style is PipelineStyle.AsyncSchedule:
assert self.group
rank = self.group.rank()
event_loop = AsyncEventLoop(
self.mp_partitions, self.group, self.transport, self.training, self.checkpoint_stop,
)
if rank == 0 and not self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event head")
event_loop.event_loop_head(batches, skip_trackers, event)
logging.debug(f"{torch.distributed.get_rank()}: exited event head")
elif self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event tail")
event_loop.event_loop_tail(batches, skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event tail")
else:
logging.debug(f"{torch.distributed.get_rank()}: entered event loop")
event_loop.event_loop(len(batches), skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event loop")
self.callcount += 1
def fence(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
......@@ -322,9 +138,6 @@ class Pipeline:
copy_streams = self.copy_streams
skip_layout = self.skip_layout
assert copy_streams
assert skip_layout
for i, j in schedule:
# Ensure that batches[i-1] is executed after batches[i] in
# backpropagation by an explicit dependency.
......@@ -341,91 +154,92 @@ class Pipeline:
prev_stream = copy_streams[j - 1][i]
copy(batches[i], prev_stream, next_stream)
def get_batch_from_previous_stage(
self, i: int, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]
) -> Batch:
phony = torch.empty(0, device=self.input_device, requires_grad=True)
result = RecvOperator.apply(torch.distributed.get_rank(), phony, self.input_device, self.transport, i)
if len(result) == 1:
batch = Batch(result[0], i)
else:
batch = Batch(result, i)
self.recv_skip_tensors(skip_trackers, batches)
return batch
def send_skip_tensors(
self, this_rank: int, ranks: List[int], batch: Batch, i: int, skip_trackers: List[SkipTrackerThroughPotals]
def compute(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
) -> None:
assert self.group
for next_j, ns, name in self.skip_layout.copy_policy_by_src(self.group.rank()):
life = skip_trackers[i].portals[(ns, name)].tensor_life
loaded = skip_trackers[i].load(batch, ns, name)
if loaded is not None:
tensors = tuple([loaded])
else:
tensors = tuple()
self.transport.send_message(
PipeMessage(
this_rank, ranks[next_j], queue_name=SKIP_TENSOR_QUEUE, args=(i, ns, name, life), tensors=tensors,
),
sync=True,
)
def recv_skip_tensors(self, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]) -> None:
while True:
try:
message = self.transport.recv_message(SKIP_TENSOR_QUEUE, nowait=True)
(si, ns, name, life) = message.args
value: Optional[TensorOrTensors] = message.tensors
assert isinstance(value, tuple)
if len(value) == 0:
value = None
else:
assert len(value) == 1
value = value[0]
"""Runs tasks with synchronization to copy streams."""
partitions = self.partitions
devices = self.devices
copy_streams = self.copy_streams
checkpoint_stop = self.checkpoint_stop
skip_trackers[si].save(batches[si], ns, name, value)
old_life = skip_trackers[si].portals[(ns, name)].tensor_life
if life != 0:
skip_trackers[si].portals[(ns, name)].tensor_life = life
except QueueEmpty:
break
# Disable checkpointing if in eval mode.
if not self.partitions[0].training:
checkpoint_stop = 0
def execute_task(self, task: Task, i: int, skip_trackers: List[SkipTrackerThroughPotals]) -> Batch:
batch = task.compute()
n = len(partitions)
streams = [current_stream(d) for d in devices]
exc_info: Optional[ExcInfo] = None
assert self.group
rank = self.group.rank()
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for i, j in schedule:
batch = batches[i]
partition = partitions[j]
if self.style is PipelineStyle.MultiProcess and not self.final_stage:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
# Synchronize with the copied input. ([1] in the diagram)
if j != 0:
wait(batch, copy_streams[j][i], streams[j])
# Determine whether checkpointing or not.
checkpoint = i < checkpoint_stop
if checkpoint:
def function(
input: TensorOrTensors,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return partition(input)
chk = Checkpointing(function, batch)
task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
del function, chk
self.send_skip_tensors(this_rank, ranks, batch, i, skip_trackers)
SendOperator.apply(this_rank, ranks[ranks.index(this_rank) + 1], self.transport, [*batch], i)
else:
for portal in skip_trackers[i].portals.values():
portal.pipeline = self
def compute(
batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> Batch:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition)
task.finalize(batch)
task = Task(streams[j], compute=compute, finalize=None)
del compute
return batch
# Compute tasks in parallel. ([2] in the diagram)
self.in_queues[j].put(task)
def finalize_tasks(
self,
n: int,
schedule: Schedule,
streams: List[AbstractStream],
copy_streams: List[List[AbstractStream]],
batches: List[Batch],
) -> None:
exc_info: Optional[ExcInfo] = None
for i, j in schedule:
ok, payload = self.out_queues[j].get()
......@@ -446,8 +260,7 @@ class Pipeline:
# Finalize tasks. If checkpointing is enabled, here the
# recomputation is scheduled at backpropagation. ([4] in the
# diagram)
assert self.devices
with use_device(self.devices[j]):
with use_device(devices[j]):
task.finalize(batch)
batches[i] = batch
......@@ -455,147 +268,3 @@ class Pipeline:
# Fail at the first exception.
if exc_info is not None:
raise exc_info[0].with_traceback(exc_info[1], exc_info[2])
def compute(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals]
) -> None:
"""Runs tasks with synchronization to copy streams."""
devices = self.devices
copy_streams = self.copy_streams
if self.style is PipelineStyle.SingleProcess:
assert devices is not None
n = len(self.partitions)
streams = [current_stream(d) for d in devices]
elif self.style is PipelineStyle.MultiProcess:
assert self.group
n = self.group.size()
streams = []
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for i, j in schedule:
batch = batches[i]
if self.style is PipelineStyle.SingleProcess:
partition = self.partitions[j]
# Synchronize with the copied input. ([1] in the diagram)
assert copy_streams
if j != 0:
wait(batch, copy_streams[j][i], streams[j])
task = create_task(self.style, self.checkpoint_stop, i, j, batch, partition, skip_trackers, streams)
# Compute tasks in parallel. ([2] in the diagram)
self.in_queues[j].put(task)
elif self.style is PipelineStyle.MultiProcess:
assert len(self.mp_partitions) == 1
mp_partition = self.mp_partitions[0]
assert self.group
if self.group.rank() != 0:
batch = self.get_batch_from_previous_stage(i, skip_trackers, batches)
task = create_task(
self.style, self.checkpoint_stop, i, j, batch, mp_partition.module, skip_trackers, streams
)
batches[i] = self.execute_task(task, i, skip_trackers)
if self.style is PipelineStyle.SingleProcess:
assert copy_streams
self.finalize_tasks(n, schedule, streams, copy_streams, batches)
def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: TensorOrTensors) -> None:
dest, src = self.skip_layout.by_ns_name.get(ns_name, (-1, -1))
if dest == src:
return
ranks = get_pipeline_parallel_ranks()
dst_rank = ranks[dest]
if dst_rank == torch.distributed.get_rank():
return
if isinstance(grad, Tensor):
grad = tuple([grad])
self.transport.send_message(
PipeMessage(ranks[src], dst_rank, queue_name=PORTAL_QUEUE, args=(ns_name, index), tensors=grad), sync=True,
)
def recv_portal_grad(self, expected_ns_name: Tuple[Namespace, str], expected_index: int) -> Tensor:
message = self.transport.recv_message(PORTAL_QUEUE)
(ns_name, index) = message.args
grad = message.tensors
assert len(grad) == 1
result = grad[0]
assert index == expected_index and ns_name == expected_ns_name
return result
def back_helper(self, output: List[Batch]) -> None:
if self.style == PipelineStyle.AsyncSchedule:
return
o = list(output)
tensors: Tensors
if self.all_at_once:
# FIXME(tom) allow specifying this branch when constructing Pipe(), add a test
grads = []
for i, batch in enumerate(o):
rank = torch.distributed.get_rank()
found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, i)
assert len(found) == 1
grads.append(found[0])
tensors = tuple(x.tensor_or_tensors for x in o) # type: ignore
try:
torch.autograd.backward(tensors, grad_tensors=grads, retain_graph=True)
except Exception as e:
raise RuntimeError("Autograd failed") from e
else:
rank = torch.distributed.get_rank()
for batch in o:
found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, batch.index)
if batch.atomic:
tensors = tuple([batch.tensor])
else:
tensors = batch.tensors
if len(found) != len(tensors):
raise RuntimeError("different number of tensors and gradients")
grads = []
final_tensors = []
for i, tensor in enumerate(tensors):
if tensor.requires_grad or getattr(tensor, "grad_fn", None) is not None:
grads.append(found[i])
final_tensors.append(tensor)
try:
torch.autograd.backward(final_tensors, grad_tensors=grads, retain_graph=True)
except Exception as e:
raise RuntimeError(f"Autograd failed on {torch.distributed.get_rank()}") from e
......@@ -13,13 +13,13 @@ from torch.distributed.distributed_c10d import _get_global_rank
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from . import Pipe
from .multiprocess_pipe import MultiProcessPipe
from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
PipeModel: Pipe
PipeModel: MultiProcessPipe
PipeResult: TensorOrTensors
......@@ -71,7 +71,7 @@ class PipeBackRedirect(torch.autograd.Function):
return (None, None, None, None, None, None)
def callback_with_model(callback: Callable[[Any, Pipe], None], ctx: Any) -> None:
def callback_with_model(callback: Callable[[Any, MultiProcessPipe], None], ctx: Any) -> None:
try:
group = get_pipeline_parallel_group() # FIXME(tom) handle dynamic group
set_device_based_on_group(group)
......@@ -105,10 +105,10 @@ class PipeRPCWrapper(nn.Module):
else:
kwargs["group"] = self.group
kwargs["style"] = Pipe.AsyncSchedule
kwargs["style"] = MultiProcessPipe.AsyncSchedule
kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device())
self.model = Pipe(*args, **kwargs)
self.model = MultiProcessPipe(*args, **kwargs)
self.worker_map = kwargs["worker_map"]
self._foreach_worker(self._register_remote_model, args=(args, kwargs))
self.model.cuda()
......@@ -121,7 +121,7 @@ class PipeRPCWrapper(nn.Module):
futures = [f.wait() for f in futures]
def foreach_worker(
self, callback: Callable[[Any, Pipe], None], ctx: Any = None, *, include_self: bool = False
self, callback: Callable[[Any, MultiProcessPipe], None], ctx: Any = None, *, include_self: bool = False
) -> None:
"""Call `callback` on each worker with the `ctx` and model local to that
worker. e.g.
......@@ -196,7 +196,9 @@ class PipeRPCWrapper(nn.Module):
return self.model.final_stage
@staticmethod
def _recv_result(model: Pipe, shapes: SizeOrSizes, dtypes: DtypeOrDtypes, message: PipeMessage) -> TensorOrTensors:
def _recv_result(
model: MultiProcessPipe, shapes: SizeOrSizes, dtypes: DtypeOrDtypes, message: PipeMessage
) -> TensorOrTensors:
group = get_pipeline_parallel_group()
set_device_based_on_group(group)
......@@ -243,7 +245,7 @@ class PipeRPCWrapper(nn.Module):
set_device_based_on_group(group)
kwargs["group"] = group
kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device())
model = Pipe(*args, **kwargs)
model = MultiProcessPipe(*args, **kwargs)
model.cuda()
global PipeModel
PipeModel = model
......
......@@ -24,7 +24,6 @@ Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
InputDevice = Union[None, int, str, torch.device]
Schedule = List[Tuple[int, int]]
class LazyModule:
......
......@@ -5,7 +5,7 @@ from typing import Any, Callable, Optional, Tuple
from torch import Tensor
def spawn(
fn: Callable[[Any], Any],
fn: Callable[..., Any],
args: Tuple[Optional[Any], ...] = (),
nprocs: int = 1,
join: bool = True,
......
......@@ -31,7 +31,7 @@ from torch.nn.parameter import Parameter
from fairscale.nn.model_parallel import initialize as mpu
from fairscale.nn.model_parallel import layers
from fairscale.nn.pipe import Pipe
from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import dist_init, get_world_sizes, set_random_seed, spawn_for_all_world_sizes, torch_spawn
......@@ -319,7 +319,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
model_parallel_size = mpu.get_model_parallel_world_size()
if torch.distributed.get_rank() == 0:
print(
"> testing Sequential + Pipe with model parallel size: {}, pipe: {}".format(
"> testing Sequential + MultiProcessPipe with model parallel size: {}, pipe: {}".format(
model_parallel_size, pipe_world_size
)
)
......@@ -431,13 +431,13 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
model[2].weight.data = saved_weight_2
worker_map = {i: f"Test{i}" for i in range(torch.distributed.get_world_size())}
style = Pipe.MultiProcess # Pipe.AsyncSchedule
style = MultiProcessPipe.MultiProcess # MultiProcessPipe.AsyncSchedule
if pipe_world_size == 2:
print(f"actually doing pipe stuff now")
assert torch.equal(saved_weight_0, model[0].weight.data)
assert torch.equal(saved_weight_2, model[2].weight.data)
pipe_model = Pipe(
pipe_model = MultiProcessPipe(
model,
[2, 1],
style=style,
......@@ -507,7 +507,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
failed = False
with torch.autograd.profiler.profile() as prof:
try:
if style == Pipe.MultiProcess:
if style == MultiProcessPipe.MultiProcess:
pipe_model.back_helper(pipe_output)
except Exception as e:
failed = True
......
......@@ -23,7 +23,7 @@ import pytest
import torch
from torch import nn
from fairscale.nn.pipe import LazyModule, Pipe
from fairscale.nn.pipe import LazyModule, MultiProcessPipe
from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange
from fairscale.utils.testing import get_worker_map, torch_spawn
......@@ -33,12 +33,12 @@ from fairscale.utils.testing import get_worker_map, torch_spawn
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
def x1to3(balance, checkpoint, pipeline_style):
torch.manual_seed(0)
if pipeline_style == Pipe.AsyncSchedule and len(balance) > 1:
if pipeline_style == MultiProcessPipe.AsyncSchedule and len(balance) > 1:
print(f"skipping yarg")
pytest.skip("Skip tensors NYI for AsyncSchedule")
......@@ -74,7 +74,7 @@ def x1to3(balance, checkpoint, pipeline_style):
return output
model = nn.Sequential(Layer1(), Layer2(), Layer3())
model = Pipe(
model = MultiProcessPipe(
model,
balance,
chunks=3,
......@@ -106,9 +106,10 @@ def x1to3(balance, checkpoint, pipeline_style):
@torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
@pytest.mark.skip(reason="flaky test")
def none_skip(pipeline_style):
if pipeline_style == Pipe.AsyncSchedule:
if pipeline_style == MultiProcessPipe.AsyncSchedule:
pytest.skip("Skip tensors NYI for AsyncSchedule")
@skippable(stash=["none"])
......@@ -125,7 +126,7 @@ def none_skip(pipeline_style):
return input
model = nn.Sequential(Stash(), Pop())
model = Pipe(
model = MultiProcessPipe(
model,
[1, 1],
style=pipeline_style,
......@@ -160,7 +161,7 @@ def none_skip(pipeline_style):
@torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def lazy_skippable_error(pipeline_style):
"""Using skippable layers in combination with lazy construction is currently
not supported, check that it raises an Exception"""
......@@ -180,6 +181,6 @@ def lazy_skippable_error(pipeline_style):
]
with pytest.raises(ValueError, match="Can't use Skippable layers with multi-process pipe and lazy construction"):
Pipe(
MultiProcessPipe(
model, [2, 1], style=pipeline_style, worker_map=get_worker_map(),
)
......@@ -23,7 +23,7 @@ import pytest
import torch
from torch import nn
from fairscale.nn.pipe import Pipe, is_checkpointing, is_recomputing
from fairscale.nn.pipe import MultiProcessPipe, is_checkpointing, is_recomputing
from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.tracker import current_skip_tracker
from fairscale.utils.testing import get_worker_map, torch_spawn
......@@ -46,7 +46,7 @@ class Pop(nn.Module):
@torch_spawn([2])
@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"])
@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def delete_portal_tensor(train, checkpoint, pipeline_style):
......@@ -60,7 +60,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
# +----------+ +------------+ +------------+ +----------+
if pipeline_style == Pipe.AsyncSchedule:
if pipeline_style == MultiProcessPipe.AsyncSchedule:
pytest.skip("Skip tensors NYI for AsyncSchedule")
def portal_tensor_life_is(tensor_life, skip_tracker=None):
......@@ -114,7 +114,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
return self.F.apply(input)
model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_)
model = Pipe(
model = MultiProcessPipe(
model, balance=[2, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint,
)
......
......@@ -22,15 +22,15 @@ import torch
from torch import nn
import torch.nn.functional as F
from fairscale.nn.pipe import Pipe
from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn
@torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def python_autograd_function(pipeline_style):
# FIXME deadlock with Pipe.AsyncSchedule?
# FIXME deadlock with MultiProcessPipe.AsyncSchedule?
# A Python autograd function might fail with this error:
#
# RuntimeError: Returning Variables sharing storage with other Variables
......@@ -57,7 +57,9 @@ def python_autograd_function(pipeline_style):
return Identity.apply(input)
model = nn.Sequential(M(), M())
model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always").cuda()
model = MultiProcessPipe(
model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always"
).cuda()
model.eval()
x = torch.rand(42)
......@@ -71,7 +73,7 @@ def python_autograd_function(pipeline_style):
@torch_spawn([3])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def exception_no_hang(pipeline_style):
# In v0.0.2, once a failed partition receives a normal message
# (non-closing) for the next micro-batch, a hang occured. The reason was
......@@ -90,7 +92,7 @@ def exception_no_hang(pipeline_style):
raise ExpectedException()
model = nn.Sequential(Pass(), Pass(), Raise())
model = Pipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3)
model = MultiProcessPipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3)
model.eval()
if model.group.rank() == 2:
......@@ -104,7 +106,7 @@ def exception_no_hang(pipeline_style):
@torch_spawn([2])
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def tuple_wait(cuda_sleep, pipeline_style):
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# Under this behavior, if checkpointing was disabled, there's a possibility
......@@ -133,7 +135,7 @@ def tuple_wait(cuda_sleep, pipeline_style):
return a + b + c
model = nn.Sequential(Layer1(), Layer2())
model = Pipe(
model = MultiProcessPipe(
model,
[1, 1],
style=pipeline_style,
......@@ -158,7 +160,7 @@ def tuple_wait(cuda_sleep, pipeline_style):
@torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def parallel_randoms(pipeline_style):
class Dropouts(nn.Module):
def forward(self, x):
......@@ -170,7 +172,7 @@ def parallel_randoms(pipeline_style):
x = torch.rand(10, 10, requires_grad=True).cuda()
x.retain_grad()
model = Pipe(
model = MultiProcessPipe(
model,
[1, 1],
style=pipeline_style,
......
......@@ -21,20 +21,20 @@ import pytest
import torch
from torch import nn
from fairscale.nn.pipe import Pipe
from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn
@torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def inplace_on_requires_grad(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True))
model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
model = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
x = torch.rand(1)
if pipeline_style == Pipe.AsyncSchedule and model.group.rank() == 0:
if pipeline_style == MultiProcessPipe.AsyncSchedule and model.group.rank() == 0:
# With AsyncSchedule, model will wait forever for gradients if not eval
model.eval()
......@@ -50,12 +50,12 @@ def inplace_on_requires_grad(pipeline_style):
@torch_spawn([1])
@pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def inplace_on_not_requires_grad(pipeline_style):
# In-place operation on a tensor not requiring grad doesn't cause a
# RuntimeError. Currently, we cannot detect this case.
model = nn.Sequential(nn.ReLU(inplace=True))
model = Pipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
model = MultiProcessPipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
x = torch.rand(1)
y = model(x)
......@@ -70,7 +70,7 @@ def inplace_on_not_requires_grad(pipeline_style):
@torch_spawn([1])
@pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def inplace_incorrect_grad(pipeline_style):
class M(nn.Module):
def forward(self, foo_bar):
......@@ -88,7 +88,7 @@ def inplace_incorrect_grad(pipeline_style):
return foo * bar
model = nn.Sequential(M())
model = Pipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
model = MultiProcessPipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
foo = torch.tensor([1.0], requires_grad=True)
bar = torch.tensor([1.0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment