"vscode:/vscode.git/clone" did not exist on "ebdbfde0cee9d6adca2c0508f1a664c13d3cd65a"
Commit 0cd65242 authored by Mandeep Singh Baines's avatar Mandeep Singh Baines
Browse files

Initial commit

parents
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A helper to roughly balance a sequential module.
Usage::
import torch
from fairscale.nn import Pipe
from fairscale.nn.balance import balance_by_time
sample = torch.empty(128, 3, 224, 224)
balance = balance_by_time(torch.cuda.device_count(), model, sample)
pipe = Pipe(model, balance, chunks=8)
"""
from typing import List, Tuple, Union
import torch
from torch import Tensor
import torch.nn as nn
from . import blockpartition
from .profile import profile_sizes, profile_times
__all__ = ["balance_by_time", "balance_by_size"]
Device = Union[torch.device, int, str]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
def balance_cost(cost: List[int], partitions: int) -> List[int]:
partitioned = blockpartition.solve(cost, partitions)
return [len(p) for p in partitioned]
def balance_by_time(
partitions: int,
module: nn.Sequential,
sample: TensorOrTensors,
*,
timeout: float = 1.0,
device: Device = torch.device("cuda"),
) -> List[int]:
"""Naive automatic balancing by elapsed time per layer.
::
sample = torch.empty(128, 3, 224, 224)
balance = balance_by_time(torch.cuda.device_count(), model, sample)
pipe = Pipe(model, balance, chunks=8)
Args:
partitions (int):
intended number of partitions
module (torch.nn.Sequential):
sequential module to be partitioned
sample (torch.Tensor):
example input with arbitrary batch size
Keyword Args:
timeout (float):
profiling iterates again if the timeout (in second) is not exceeded
(default: ``1.0``)
device ('cpu' or 'cuda' device):
CPU or CUDA device where each layer is profiled (default: the
current CUDA device)
Returns:
A list of number of layers in each partition. Use it for the `balance`
parameter of :class:`~torchpipe.Pipe`.
.. note::
`module` and `sample` must be placed on the same device.
"""
times = profile_times(module, sample, timeout, torch.device(device))
return balance_cost(times, partitions)
def balance_by_size(
partitions: int,
module: nn.Sequential,
input: TensorOrTensors,
*,
chunks: int = 1,
param_scale: float = 2.0,
device: Device = torch.device("cuda"),
) -> List[int]:
"""Naive automatic balancing by CUDA memory usage per layer.
During training, required memory for parameters depends on which optimizer
is used. Optimizers may use buffers for each parameter to track
optimization statistics internally, such as momentum buffer in SGD.
To get more reliable size based balance, you should specify `param_scale`
with regard to your optimizer. The default `param_scale` is 2 instead of 1
due to gradient accumulation which is necessary for every optimizer.
Follow this guide to choose correct `param_scale` for typical optimizers:
========= ============= =========================================
Optimizer `param_scale` Internal State
========= ============= =========================================
SGD 2--3 (momentum_buffer)
Adam 4--5 exp_avg, exp_avg_sq, (max_exp_avg_sq)
Adadelta 4 square_avg, acc_delta
Adagrad 3 sum
RMSprop 3--5 square_avg, (momentum_buffer), (grad_avg)
========= ============= =========================================
Here's a simple example with the Adam optimizer::
balance = balance_by_size(
torch.cuda.device_count(),
model,
# Same size with mini-batch to train
torch.empty(1024, 3, 224, 224),
# Number of micro-batches to train with Pipe
chunks=8,
# 4 for Adam
param_scale=4.0,
)
pipe = Pipe(model, balance, chunks=8)
adam = Adam(pipe.parameters())
Args:
partitions (int):
intended number of partitions
module (torch.nn.Sequential):
sequential module to be partitioned
input (torch.Tensor):
example mini-batch with the same size to train
Keyword Args:
chunks (int):
number of micro-batches will be used to train (default: ``1``)
param_scale (float):
how many copies of parameters would be allocated for training. It
depends on optimizer. See the above guide. (default: ``2.0``)
device ('cuda' device):
CUDA device where each layer is profiled (default: the current CUDA
device)
Returns:
A list of number of layers in each partition. Use it for the `balance`
parameter of :class:`~torchpipe.Pipe`.
.. note::
`module` and `input` must be placed on the same CUDA device.
"""
sizes = profile_sizes(module, input, chunks, param_scale, torch.device(device))
return balance_cost(sizes, partitions)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements "Block Partitions of Sequences" by Imre Bárány et al.
Paper: https://arxiv.org/pdf/1308.2452.pdf
"""
from typing import Iterator, List, Tuple
__all__ = ["solve"]
def solve(sequence: List[int], partitions: int = 1) -> List[List[int]]:
"""Splits a sequence into several partitions to minimize variance for each
partition.
The result might not be optimal. However, it can be done only in O(kn³),
where k is the number of partitions and n is the length of the sequence.
"""
if partitions < 1:
raise ValueError(f"partitions must be a positive integer ({partitions} < 1)")
n = len(sequence)
if n < partitions:
raise ValueError(f"sequence is shorter than intended partitions ({n} < {partitions})")
# Normalize the sequence in [0, 1].
minimum = min(sequence)
maximum = max(sequence) - minimum
normal_sequence: List[float]
if maximum == 0:
normal_sequence = [0 for _ in sequence]
else:
normal_sequence = [(x - minimum) / maximum for x in sequence]
splits = [n // partitions * (x + 1) for x in range(partitions - 1)] + [n]
def block_size(i: int) -> float:
start = splits[i - 1] if i > 0 else 0
stop = splits[i]
return sum(normal_sequence[start:stop])
def leaderboard() -> Iterator[Tuple[float, int]]:
return ((block_size(i), i) for i in range(partitions))
while True:
"""
(1) Fix p ∈ [k] with M(P) = bp. So Bp is a maximal block of P.
"""
# max_size: M(P)
max_size, p = max(leaderboard())
while True:
"""
(2) If M(P) ≤ m(P) + 1, then stop.
"""
# min_size: m(P)
min_size, q = min(leaderboard())
if max_size <= min_size + 1:
return [sequence[i:j] for i, j in zip([0] + splits[:-1], splits)]
"""
(3) If M(P) > m(P) + 1, then let m(P) = bq for the q ∈ [k] which is
closest to p (ties broken arbitrarily). Thus Bq is a minimal block
of P. Let Bh be the block next to Bq between Bp and Bq. (Note that
Bh is a non-empty block: if it were, then m(P) = 0 and we should
have chosen Bh instead of Bq.)
"""
if p < q:
"""
So either p < q and then h = q−1 and we define P ∗ by moving
the last element from Bh = Bq−1 to Bq,
"""
h = q - 1
splits[h] -= 1
else:
"""
or q < p, and then h = q + 1 and P ∗ is obtained by moving the
first element of Bh = Bq+1 to Bq.
"""
h = q + 1
splits[q] += 1
"""
Set P = P ∗ . If p = h, then go to (1), else go to (2).
"""
if p == h:
break
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Per-layer profilers."""
import copy
import time
from typing import Generator, List, Tuple, Union
import torch
from torch import Tensor
import torch.nn as nn
from ..microbatch import Batch
__all__: List[str] = []
Device = Union[torch.device, int, str]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
def layerwise_sandbox(module: nn.Sequential, device: torch.device,) -> Generator[nn.Module, None, None]:
"""Copies layers for ease to profile. It doesn't modify the given
module.
"""
for layer in module:
layer_copy = copy.deepcopy(layer)
layer_copy.to(device)
layer_copy.train()
yield layer_copy
def detach(batch: Batch) -> None:
"""Detaches from autograd graph."""
for i, x in enumerate(batch):
batch[i] = x.detach().requires_grad_(x.requires_grad)
def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float, device: torch.device,) -> List[int]:
"""Profiles elapsed times per layer."""
if any(p.grad is not None for p in module.parameters()):
raise ValueError("some parameter already has gradient")
_batch = Batch(sample)
for i, x in enumerate(_batch):
_batch[i] = x.detach().to(device).requires_grad_(x.requires_grad)
time_bufs: List[List[float]] = [[] for _ in module]
begun_at = time.time()
while time.time() - begun_at < timeout:
batch = _batch
for i, layer in enumerate(layerwise_sandbox(module, device)):
detach(batch)
if device.type == "cuda":
torch.cuda.synchronize(device)
tick = time.time()
# Forward
batch = batch.call(layer)
# Backward
backward_tensors = tuple(y for y in batch if y.requires_grad)
if backward_tensors:
torch.autograd.backward(backward_tensors, backward_tensors)
if device.type == "cuda":
torch.cuda.synchronize(device)
tock = time.time()
time_bufs[i].append(tock - tick)
us = 1_000_000
return [sum(int(t * us) for t in buf) for buf in time_bufs]
def profile_sizes(
module: nn.Sequential, input: TensorOrTensors, chunks: int, param_scale: float, device: torch.device,
) -> List[int]:
"""Profiles CUDA memory usage per layer."""
if device.type != "cuda":
raise ValueError("size profiler supports only CUDA device")
batch = Batch(input)
sizes: List[int] = []
latent_scale = batch[0].size(0) / chunks
for i, x in enumerate(batch):
batch[i] = x[:1].detach().to(device).requires_grad_(x.requires_grad)
for layer in layerwise_sandbox(module, device):
detach(batch)
# Detect memory usage at forward.
memory_before = torch.cuda.memory_allocated(device)
batch = batch.call(layer)
memory_after = torch.cuda.memory_allocated(device)
latent_size = memory_after - memory_before
# Analyze size of parameters.
param_size = sum(p.storage().size() * p.storage().element_size() for p in layer.parameters())
# Combine size of parameters and activations with normalize scales.
size = latent_size * latent_scale + param_size * param_scale
sizes.append(int(size))
return sizes
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tracks the running statistics per mini-batch instead of micro-batch."""
from typing import Optional, TypeVar, cast
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm
from .checkpoint import is_recomputing
__all__ = ["DeferredBatchNorm"]
TModule = TypeVar("TModule", bound=nn.Module)
class DeferredBatchNorm(_BatchNorm):
"""A BatchNorm layer tracks multiple micro-batches to update running
statistics per mini-batch.
"""
sum: Tensor
sum_squares: Tensor
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: Optional[float] = 0.1,
affine: bool = True,
chunks: int = 1,
) -> None:
super().__init__(num_features, eps, momentum, affine, track_running_stats=True)
self.register_buffer("sum", torch.zeros_like(self.running_mean))
self.register_buffer("sum_squares", torch.zeros_like(self.running_var))
self.counter = 0
self.tracked = 0
self.chunks = chunks
def _check_input_dim(self, input: Tensor) -> None:
# It's the typical _check_input_dim() implementation in PyTorch.
if input.dim() <= 2:
raise ValueError("expected at least 3D input (got %dD input)" % input.dim())
def _track(self, input: Tensor) -> bool:
"""Tracks statistics of a micro-batch."""
# Dimensions except channel. For example, (0, 2, 3) is for BatchNorm2d.
dim = [0]
dim.extend(range(2, input.dim()))
with torch.no_grad():
self.sum += input.sum(dim)
self.sum_squares += (input ** 2).sum(dim)
size = input.size().numel() // input.size(1)
self.counter += size
self.tracked += 1
return self.tracked == self.chunks
def _commit(self) -> None:
"""Updates the running statistics of a mini-batch."""
exponential_average_factor = 0.0
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
mean = self.sum / self.counter
var = self.sum_squares / self.counter - mean ** 2
# Calculate the exponential moving average here.
m = exponential_average_factor
self.running_mean *= 1 - m
self.running_mean += mean * m
self.running_var *= 1 - m
self.running_var += var * m
self.sum.zero_()
self.sum_squares.zero_()
self.counter = 0
self.tracked = 0
def forward(self, input: Tensor) -> Tensor: # type: ignore
if not self.training:
# Don't train parameters on the evaluation mode.
return F.batch_norm(
input,
running_mean=self.running_mean,
running_var=self.running_var,
weight=self.weight,
bias=self.bias,
training=False,
momentum=0.0,
eps=self.eps,
)
if not is_recomputing():
# Track a micro-batch on the training mode
# but not under a recomputation.
tracked_enough = self._track(input)
# Update the running statistics for a mini-batch
# if it has tracked enough micro-batches.
if tracked_enough:
self._commit()
# Normalize a micro-batch and train the parameters.
return F.batch_norm(
input,
running_mean=None,
running_var=None,
weight=self.weight,
bias=self.bias,
training=True,
momentum=0.0,
eps=self.eps,
)
@classmethod
def convert_deferred_batch_norm(cls, module: TModule, chunks: int = 1) -> TModule:
"""Converts a :class:`nn.BatchNorm` or underlying
:class:`nn.BatchNorm`s into :class:`DeferredBatchNorm`::
from torchvision.models.resnet import resnet101
from torchpipe.batchnorm import DeferredBatchNorm
model = resnet101()
model = DeferredBatchNorm.convert_deferred_batch_norm(model)
"""
if isinstance(module, DeferredBatchNorm) and module.chunks is chunks:
return cast(TModule, module)
module_output: nn.Module = module
if isinstance(module, _BatchNorm) and module.track_running_stats:
module_output = DeferredBatchNorm(module.num_features, module.eps, module.momentum, module.affine, chunks)
if module.affine:
module_output.register_parameter("weight", module.weight)
module_output.register_parameter("bias", module.bias)
module_output.register_buffer("running_mean", module.running_mean)
module_output.register_buffer("running_var", module.running_var)
module_output.register_buffer("num_batches_tracked", module.num_batches_tracked)
for name, child in module.named_children():
module_output.add_module(name, cls.convert_deferred_batch_norm(child, chunks))
return cast(TModule, module_output)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpointing with preceding recomputation.
PyTorch already provides the official checkpointing utilities in
:mod:`torch.utils.checkpoint`. The official checkpointing combines
recomputation and recursive backpropagation into one autograd function named
``CheckpointFunction``. Hence, the recomputation can be started only when the
gradients arrive to the function. In Pipe, the recomputation needs to precede
the gradient arrival to minimize the GPU idle time.
We solve this problem by introducing separate autograd functions named
:class:`Recompute` and :class:`Checkpoint`. Each function represents
recomputation and recursive backpropagation, respectively. We can manipulate
the control flow in aspect of both the autograd engine and CUDA with a pair of
the functions.
Specifically, we place CUDA stream synchronization between :class:`Recompute`
and :class:`Checkpoint` to delay only :class:`Checkpoint` until the gradient is
copied entirely.
"""
from collections import deque
from contextlib import contextmanager
import threading
from typing import TYPE_CHECKING, Deque, Generator, List, Optional, Tuple, Union
import torch
from torch import ByteTensor, Tensor
import torch.autograd
from .dependency import fork, join
from .microbatch import Batch
from .phony import get_phony
__all__ = ["is_checkpointing", "is_recomputing"]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
# Types for shared memory between Checkpoint and Recompute.
Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf)
RNGStates = Tuple[ByteTensor, Optional[ByteTensor]] # (cpu_rng_state, gpu_rng_state)
if TYPE_CHECKING:
from typing_extensions import Protocol
else:
Protocol = object
# Protocol with __call__ instead of Callable can be used as an attribute type.
# See: https://github.com/python/mypy/issues/708#issuecomment-561735949
class Function(Protocol):
def __call__(self, input: TensorOrTensors) -> TensorOrTensors:
...
def checkpoint(function: Function, input: TensorOrTensors) -> TensorOrTensors:
"""Makes a checkpoint with a simple interface like
:func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug
:class:`Checkpoint` and :class:`Recompute` without boilerplate.
"""
batch = Batch(input)
chk = Checkpointing(function, batch)
batch = chk.checkpoint()
chk.recompute(batch)
return batch.tensor_or_tensors
class Checkpointing:
"""Generates a pair of :class:`Checkpoint` and :class:`Recompute`."""
def __init__(self, function: Function, batch: Batch) -> None:
self.function = function
self.batch = batch
# Shared memory between Checkpoint and Recompute. 1-length deque is
# used for mutability and length limitation.
self.recomputed: Deque[Recomputed] = deque(maxlen=1)
self.rng_states: Deque[RNGStates] = deque(maxlen=1)
def checkpoint(self) -> Batch:
"""Returns a batch applied by :class:`Checkpoint`."""
input_atomic = self.batch.atomic
input = tuple(self.batch)
# Use a phony which requires grad to ensure that Checkpoint can be
# tracked by the autograd engine even when none of the input tensors
# require grad.
phony = get_phony(self.batch[0].device, requires_grad=True)
output = Checkpoint.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *input)
# Gradients are only supported for float Tensors.
if isinstance(output, tuple):
output = tuple([x if x.is_floating_point() else x.detach() for x in output])
return Batch(output)
def recompute(self, batch: Batch) -> None:
"""Applies :class:`Recompute` to the batch in place."""
input_atomic = self.batch.atomic
input = tuple(self.batch)
# batch[0] is always requiring grad, because it has been passed
# checkpoint with a phony requiring grad.
batch[0], phony = fork(batch[0])
phony = Recompute.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *input)
batch[0] = join(batch[0], phony)
class ThreadLocal(threading.local):
def __init__(self) -> None:
self.is_checkpointing = False
self.is_recomputing = False
thread_local = ThreadLocal()
@contextmanager
def enable_checkpointing() -> Generator[None, None, None]:
"""Makes :func:`is_checkpointing` return :data:`True` within a context."""
orig = thread_local.is_checkpointing
thread_local.is_checkpointing = True
try:
yield
finally:
thread_local.is_checkpointing = orig
@contextmanager
def enable_recomputing() -> Generator[None, None, None]:
"""Makes :func:`is_recomputing` return :data:`True` within a context."""
orig = thread_local.is_recomputing
thread_local.is_recomputing = True
try:
yield
finally:
thread_local.is_recomputing = orig
def is_checkpointing() -> bool:
"""Whether the current forward propagation is under checkpointing.
Returns:
bool: :data:`True` if it's under checkpointing.
"""
return thread_local.is_checkpointing
def is_recomputing() -> bool:
"""Whether the current forward propagation is under checkpoint
recomputation. Use this to prevent duplicated side-effects at forward
propagation::
class Counter(nn.Module):
def __init__(self):
super().__init__()
self.counter = 0
def forward(self, input):
if not is_recomputing():
self.counter += 1
return input
Returns:
bool: :data:`True` if it's under checkpoint recomputation.
.. seealso:: :ref:`Detecting Recomputation`
"""
return thread_local.is_recomputing
class Context:
"""The common interface between the :class:`Checkpoint` and
:class:`Recompute` context.
"""
recomputed: Deque[Recomputed]
rng_states: Deque[RNGStates]
function: Function
input_atomic: bool
saved_tensors: Tuple[Tensor, ...]
def save_for_backward(self, *tensors: Tensor) -> None: # pragma: no cover
pass
def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None:
""":meth:`Checkpoint.forward` captures the current PyTorch's random number
generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.
.. seealso:: :ref:`Referential Transparency`
"""
cpu_rng_state = torch.get_rng_state()
gpu_rng_state: Optional[ByteTensor]
if device.type == "cuda":
gpu_rng_state = torch.cuda.get_rng_state(device)
else:
gpu_rng_state = None
rng_states.append((cpu_rng_state, gpu_rng_state))
@contextmanager
def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]:
""":meth:`Recompute.backward` restores the random number generator states
captured by :func:`save_rng_states` within its context.
.. seealso:: :ref:`Referential Transparency`
"""
cpu_rng_state, gpu_rng_state = rng_states.pop()
gpu_devices: List[torch.device] = []
if device.type == "cuda":
gpu_devices.append(device)
with torch.random.fork_rng(gpu_devices):
torch.set_rng_state(cpu_rng_state)
if gpu_rng_state is not None:
torch.cuda.set_rng_state(gpu_rng_state, device)
yield
class Checkpoint(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(
ctx: Context,
phony: Tensor,
recomputed: Deque[Recomputed],
rng_states: Deque[RNGStates],
function: Function,
input_atomic: bool,
*input: Tensor,
) -> TensorOrTensors:
ctx.recomputed = recomputed
ctx.rng_states = rng_states
save_rng_states(input[0].device, ctx.rng_states)
ctx.function = function
ctx.input_atomic = input_atomic
ctx.save_for_backward(*input)
with torch.no_grad(), enable_checkpointing():
output = function(input[0] if input_atomic else input)
return output
@staticmethod
def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: # pragma: no cover
output, input_leaf = ctx.recomputed.pop()
if isinstance(output, tuple):
tensors = output
else:
tensors = (output,)
if any(y.requires_grad for y in tensors):
tensors = tuple([x for x in tensors if x.requires_grad])
torch.autograd.backward(tensors, grad_output)
grad_input: List[Optional[Tensor]] = [None, None, None, None, None]
grad_input.extend(x.grad for x in input_leaf)
return tuple(grad_input)
class Recompute(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(
ctx: Context,
phony: Tensor,
recomputed: Deque[Recomputed],
rng_states: Deque[RNGStates],
function: Function,
input_atomic: bool,
*input: Tensor,
) -> Tensor:
ctx.recomputed = recomputed
ctx.rng_states = rng_states
ctx.function = function
ctx.input_atomic = input_atomic
ctx.save_for_backward(*input)
return phony
@staticmethod
def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]: # pragma: no cover
input = ctx.saved_tensors
input_leaf = tuple(x.detach().requires_grad_(x.requires_grad) for x in input)
with restore_rng_states(input[0].device, ctx.rng_states):
with torch.enable_grad(), enable_recomputing():
output = ctx.function(input_leaf[0] if ctx.input_atomic else input_leaf)
ctx.recomputed.append((output, input_leaf))
grad_input: List[None] = [None, None, None, None, None]
grad_input.extend(None for _ in ctx.saved_tensors)
return tuple(grad_input)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Autograd functions for stream-aware CUDA copy. It is used to overlap copy
and computation on the same GPU.
"""
from collections import deque
from typing import Deque, List, Optional, Tuple
import torch
from torch import Tensor
from .stream import AbstractStream, current_stream, get_device, record_stream, use_stream, wait_stream
__all__: List[str] = []
Tensors = Tuple[Tensor, ...]
# Common interface between :class:`Copy` and :class:`Wait`.
class Context:
prev_stream: AbstractStream
next_stream: AbstractStream
class Copy(torch.autograd.Function):
"""Copies tensors on specific streams."""
@staticmethod
# type: ignore
def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input: Tensor,) -> Tensors:
ctx.prev_stream = prev_stream
ctx.next_stream = next_stream
output = []
output_stream = current_stream(get_device(next_stream))
with use_stream(prev_stream), use_stream(next_stream):
for x in input:
y = x.to(get_device(next_stream), non_blocking=True)
output.append(y)
# 'prev_stream' is not where 'x' has been allocated.
record_stream(x, prev_stream)
# 'y' has been allocated on 'next_stream'.
# It might be used on the current stream captured as 'output_stream'.
record_stream(y, output_stream)
return tuple(output)
@staticmethod
def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]:
prev_stream = ctx.prev_stream
next_stream = ctx.next_stream
grad_input: Deque[Tensor] = deque(maxlen=len(grad_output))
input_stream = current_stream(get_device(prev_stream))
with use_stream(prev_stream), use_stream(next_stream):
for x in reversed(grad_output):
y = x.to(get_device(prev_stream), non_blocking=True)
grad_input.appendleft(y)
# 'next_stream' is not where 'x' has been allocated.
record_stream(x, next_stream)
# 'y' has been allocated on 'prev_stream'.
# It might be used on the current stream captured as 'input_stream'.
record_stream(y, input_stream)
grad_streams: Tuple[Optional[Tensor], ...] = (None, None)
return grad_streams + tuple(grad_input)
class Wait(torch.autograd.Function):
"""Synchronizes a stream to another stream.
Place it just before you want to start an operation on the next stream,
provided that all operations on the previous stream are done.
"""
@staticmethod
# type: ignore
def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input: Tensor,) -> Tensors:
ctx.prev_stream = prev_stream
ctx.next_stream = next_stream
wait_stream(next_stream, prev_stream)
return tuple(x.detach() for x in input)
@staticmethod
def backward(ctx: Context, *grad_input: Tensor,) -> Tuple[Optional[Tensor], ...]:
prev_stream = ctx.prev_stream
next_stream = ctx.next_stream
wait_stream(prev_stream, next_stream)
grad_streams: Tuple[Optional[Tensor], ...] = (None, None)
return grad_streams + grad_input
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Arbitrary dependency between two autograd lanes."""
from typing import List, Tuple
import torch
from torch import Tensor
from .phony import get_phony
__all__: List[str] = []
def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
"""Branches out from an autograd lane of the given tensor."""
if torch.is_grad_enabled() and input.requires_grad:
input, phony = Fork.apply(input)
else:
phony = get_phony(input.device, requires_grad=False)
return input, phony
class Fork(torch.autograd.Function):
@staticmethod
def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
phony = get_phony(input.device, requires_grad=False)
return input.detach(), phony.detach()
@staticmethod
def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore
return grad_input
def join(input: Tensor, phony: Tensor) -> Tensor:
"""Merges two autograd lanes."""
if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
input = Join.apply(input, phony)
return input
class Join(torch.autograd.Function):
@staticmethod
def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore
return input.detach()
@staticmethod
def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore
return grad_input, None
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Manipulation of micro-batches."""
import typing
from typing import Callable, Iterable, Iterator, List, Tuple, Union, cast
import torch
from torch import Tensor
import torch.cuda.comm
__all__: List[str] = []
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
Function = Callable[[TensorOrTensors], TensorOrTensors]
class Batch:
"""An abstraction of an atomic tensor or a tuple of tensors. This
eliminates every boilerplate code to classify an atomic tensor or a tuple
of tensors.
::
x = generate_tensor_or_tensors()
x = Batch(x)
# in-place update
x[0] = F.apply(x[0])
x[:] = F.apply(*x)
# f(x) if x is a tensor.
# f(*x) if x is a tuple of tensors.
# y is also a batch.
y = x.call(f)
"""
def __init__(self, value: TensorOrTensors) -> None:
self.value = value
self.atomic = torch.is_tensor(value)
@property
def tensor(self) -> Tensor:
"""Retrieves the underlying tensor."""
if not self.atomic:
raise AttributeError("not atomic batch")
return cast(Tensor, self.value)
@property
def tensors(self) -> Tensors:
"""Retrieves the underlying tensors."""
if self.atomic:
raise AttributeError("batch is atomic")
return cast(Tensors, self.value)
@property
def tensor_or_tensors(self) -> TensorOrTensors:
"""Retrieves the underlying tensor or tensors regardless of type."""
return self.value
def call(self, function: Function) -> "Batch":
"""Calls a function by the underlying tensor or tensors. It also wraps
the output with :class:`Batch`.
"""
return Batch(function(self.value))
def __repr__(self) -> str:
return f"Batch[atomic={self.atomic!r}]({self.value!r})"
def __iter__(self) -> Iterator[Tensor]:
if self.atomic:
yield self.tensor
else:
yield from self.tensors
def __len__(self) -> int:
return 1 if self.atomic else len(self.tensors)
def __getitem__(self, index: int) -> Tensor:
if not self.atomic:
return self.tensors[index]
if index != 0:
raise IndexError("atomic batch allows index 0 only")
return self.tensor
# NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload".
@typing.overload
def __setitem__(self, index: int, value: Tensor) -> None:
...
@typing.overload
def __setitem__(self, index: slice, value: Tensors) -> None:
...
def __setitem__(self, index: Union[int, slice], value: TensorOrTensors) -> None:
if isinstance(index, int):
value = cast(Tensor, value)
self._setitem_by_index(index, value)
else:
value = cast(Tensors, value)
self._setitem_by_slice(index, value)
def _setitem_by_index(self, index: int, value: Tensor) -> None:
if not self.atomic:
i = index
self.value = self.value[:i] + (value,) + self.value[i + 1 :]
return
if index != 0:
raise IndexError("atomic batch allows index 0 only")
self.value = value
def _setitem_by_slice(self, index: slice, value: Tensors) -> None:
if not (index.start is index.stop is index.step is None):
raise NotImplementedError("only slice [:] supported")
if not self.atomic:
self.value = value
return
if len(value) != 1:
raise IndexError("atomic batch cannot be replaced with multiple tensors")
self.value = value[0]
def check(input: TensorOrTensors) -> None:
"""Checks whether the input is a tensor or tensors.
Raises:
TypeError: input is not a tensor or tensors.
"""
if isinstance(input, tuple):
for x in input:
check(x)
return
if not isinstance(input, Tensor):
raise TypeError(f"expected Tensor, but got {input.__class__.__name__}")
def scatter(input: TensorOrTensors, chunks: int) -> List[Batch]:
"""Splits an input mini-batch into multiple micro-batches."""
inputs: Iterable[TensorOrTensors]
if isinstance(input, Tensor):
inputs = input.chunk(chunks)
else:
rotated: List[Tensors] = []
for tensor in input:
tensors = tensor.chunk(chunks)
rotated.append(cast(Tensors, tensors))
inputs = zip(*rotated)
return [Batch(x) for x in inputs]
def gather(outputs: List[Batch]) -> TensorOrTensors:
"""Concatenates output micro-batches into a mini-batch."""
output: TensorOrTensors
if outputs[0].atomic:
tensors = tuple(b.tensor for b in outputs)
output = torch.cat(tensors)
else:
rotated = [b.tensors for b in outputs]
output_buf = []
for tensors in zip(*rotated):
output_buf.append(torch.cat(tensors))
output = tuple(output_buf)
return output
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides phony for arbitrary dependency in a autograd graph."""
from typing import Dict, List, Tuple
import torch
from torch import Tensor
from .stream import default_stream, use_stream
__all__: List[str] = []
_phonies: Dict[Tuple[torch.device, bool], Tensor] = {}
def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor:
"""Gets a phony. Phony is tensor without space. It is useful to make
arbitrary dependency in a autograd graph because it doesn't require any
gradient accumulation.
.. note::
Phonies for each device are cached. If an autograd function gets a phony
internally, the phony must be detached to be returned. Otherwise, the
autograd engine will mutate the cached phony in-place::
class Phonify(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
phony = get_phony(input.device, requires_grad=False)
return phony.detach() # detach() is necessary.
"""
key = (device, requires_grad)
try:
phony = _phonies[key]
except KeyError:
with use_stream(default_stream(device)):
phony = torch.empty(0, device=device, requires_grad=requires_grad)
_phonies[key] = phony
return phony
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The Pipe interface."""
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast
import torch
from torch import Tensor, nn
import torch.autograd
import torch.cuda
from . import microbatch
from .batchnorm import DeferredBatchNorm
from .pipeline import Pipeline
from .skip.layout import inspect_skip_layout
from .skip.skippable import verify_skippables
from .stream import AbstractStream, new_stream
__all__ = ["Pipe"]
Device = Union[torch.device, int, str]
Devices = Union[Iterable[Device], List[Device]]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
if TYPE_CHECKING:
Module = nn.Module[TensorOrTensors]
NamedModules = OrderedDict[str, Module]
else:
Module = nn.Module
NamedModules = OrderedDict
def recommend_auto_balance(message: str) -> str:
"""Expands a message with recommendation to :mod:`torchpipe.balance`."""
return f"""{message}
If your model is still under development, its optimal balance would change
frequently. In this case, we highly recommend 'fairscale.nn.pipe.balance' for
naive automatic balancing:
from fairscale.nn import Pipe
from fairscale.nn.pipe.balance import balance_by_time
partitions = torch.cuda.device_count()
sample = torch.empty(...)
balance = balance_by_time(partitions, model, sample)
model = Pipe(model, balance, ...)
"""
def verify_module(module: nn.Sequential) -> None:
if not isinstance(module, nn.Sequential):
raise TypeError("module must be nn.Sequential to be partitioned")
named_children = list(module.named_children())
if len(named_children) != len(module):
raise ValueError("module with duplicate children is not supported")
def verify_splitting(
module: nn.Sequential, partitions: List[nn.Sequential], balance: Iterable[int], devices: List[torch.device]
) -> None:
num_parameters = len(list(module.parameters()))
num_child_parameters = sum(len(list(child.parameters())) for child in module.children())
if num_parameters == num_child_parameters:
return
for i in range(len(partitions)):
for j in range(i + 1, len(partitions)):
parti = partitions[i]
partj = partitions[j]
if devices[i] == devices[j]:
continue
for p in parti.parameters():
for q in partj.parameters():
if p is q:
raise ValueError("module with duplicate parameters on distinct devices is not supported")
class BalanceError(ValueError):
pass
def split_module(
module: nn.Sequential, balance: Iterable[int], devices: List[torch.device],
) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]:
"""Splits a module into multiple partitions.
Returns:
A tuple of (partitions, balance, devices).
Partitions are represented as a :class:`~torch.nn.ModuleList` whose
item is a partition. All layers in a partition are placed in the
same device.
Raises:
BalanceError:
wrong balance
IndexError:
the number of devices is fewer than the number of partitions.
"""
balance = list(balance)
if len(module) != sum(balance):
raise BalanceError(
"module and sum of balance have different length "
f"(module: {len(module)}, sum of balance: {sum(balance)})"
)
if any(x <= 0 for x in balance):
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
if len(balance) > len(devices):
raise IndexError(
"too few devices to hold given partitions " f"(devices: {len(devices)}, partitions: {len(balance)})"
)
j = 0
partitions = []
layers: NamedModules = OrderedDict()
for name, layer in module.named_children():
layers[name] = layer
if len(layers) == balance[j]:
# Group buffered layers as a partition.
partition = nn.Sequential(layers)
device = devices[j]
partition.to(device)
partitions.append(partition)
# Prepare for the next partition.
layers.clear()
j += 1
partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))
del devices[j:]
return partitions, balance, devices
MOVING_DENIED = TypeError("denied to move parameters and buffers, " "because Pipe should manage device placement")
class Pipe(Module):
"""Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module
to train on Pipe_. If the module requires lots of memory, Pipe will be
very efficient.
::
model = nn.Sequential(a, b, c, d)
model = Pipe(model, balance=[1, 1, 1, 1], chunks=8)
output = model(input)
.. _Pipe: https://arxiv.org/abs/1811.06965
Pipe combines pipeline parallelism with checkpointing to reduce peak
memory required to train while minimizing device under-utilization.
You should determine the balance when defining a :class:`Pipe` module, as
balancing will not be done automatically. The module will be partitioned
into multiple devices according to the given balance. You may rely on
heuristics to find your own optimal configuration.
Args:
module (torch.nn.Sequential):
sequential module to be parallelized
balance (ints):
list of number of layers in each partition
Keyword Args:
devices (iterable of devices):
devices to use (default: all CUDA devices)
chunks (int):
number of micro-batches (default: ``1``)
checkpoint (str):
when to enable checkpointing, one of ``'always'``,
``'except_last'``, or ``'never'`` (default: ``'except_last'``)
deferred_batch_norm (bool):
whether to use deferred BatchNorm moving statistics (default:
:data:`False`, see :ref:`Deferred Batch Normalization` for more
details)
Raises:
TypeError:
the module is not a :class:`nn.Sequential <torch.nn.Sequential>`.
ValueError:
invalid arguments, or wrong balance
IndexError:
the number of devices is fewer than the number of partitions.
"""
#: The number of layers in each partition.
balance: List[int] = []
# ^^
# The default value [] required for Sphinx's autoattribute.
#: The devices mapped to each partition.
#:
#: ``devices[-1]`` refers to the device of the last partition, which means
#: it is the output device. Probably, you need to use it to transfer the
#: target to calculate the loss without a device mismatch
#: :exc:`RuntimeError`. For example::
#:
#: out_device = pipe.devices[-1]
#:
#: for input, target in loader:
#: target = target.to(out_device, non_blocking=True)
#: output = pipe(input)
#: loss = F.cross_entropy(output, target)
#:
devices: List[torch.device] = []
#: The number of micro-batches.
chunks: int = 1
#: The checkpoint mode to determine when to enable checkpointing. It is one
#: of ``'always'``, ``'except_last'``, or ``'never'``.
checkpoint: str = "except_last"
def __init__(
self,
module: nn.Sequential,
balance: Optional[Iterable[int]] = None,
*,
devices: Optional[Devices] = None,
chunks: int = chunks,
checkpoint: str = checkpoint,
deferred_batch_norm: bool = False,
) -> None:
super().__init__()
chunks = int(chunks)
checkpoint = str(checkpoint)
if balance is None:
raise ValueError(recommend_auto_balance("balance is required"))
if chunks <= 0:
raise ValueError("number of chunks must be positive integer")
if checkpoint not in ["always", "except_last", "never"]:
raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'")
verify_module(module)
# Verify if the underlying skippable modules satisfy integrity. The
# integrity can be verified before forward() because it is static.
verify_skippables(module)
self.chunks = chunks
self.checkpoint = checkpoint
if deferred_batch_norm:
module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks)
if devices is None:
devices = range(torch.cuda.device_count())
devices = [torch.device(d) for d in devices]
devices = cast(List[torch.device], devices)
try:
self.partitions, self.balance, self.devices = split_module(module, balance, devices)
except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc)))
verify_splitting(module, self.partitions, self.balance, self.devices)
self._copy_streams: List[List[AbstractStream]] = []
self._skip_layout = inspect_skip_layout(self.partitions)
# Separate CUDA streams for copy.
copy_streams = self._ensure_copy_streams()
# The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop)
def __len__(self) -> int:
"""Counts the length of the underlying sequential module."""
return sum(len(p) for p in self.partitions)
def __getitem__(self, index: int) -> nn.Module:
"""Gets a layer in the underlying sequential module."""
partitions = self.partitions
if index < 0:
partitions = partitions[::-1]
for partition in partitions:
try:
return partition[index]
except IndexError:
pass
shift = len(partition)
if index < 0:
index += shift
else:
index -= shift
raise IndexError
def __iter__(self) -> Iterable[nn.Module]:
"""Iterates over children of the underlying sequential module."""
for partition in self.partitions:
yield from partition
# Pipe should manage the device of each partition.
# Deny cuda(), cpu(), and to() with device, by TypeError.
def cuda(self, device: Optional[Device] = None) -> "Pipe":
raise MOVING_DENIED
def cpu(self) -> "Pipe":
raise MOVING_DENIED
def to(self, *args: Any, **kwargs: Any) -> "Pipe":
# Deny these usages:
#
# - to(device[, dtype, non_blocking])
# - to(tensor[, non_blocking])
#
# But allow this:
#
# - to(dtype[, non_blocking])
#
if "device" in kwargs or "tensor" in kwargs:
raise MOVING_DENIED
if args:
if isinstance(args[0], (torch.device, int, str)):
raise MOVING_DENIED
if torch.is_tensor(args[0]):
raise MOVING_DENIED
return super().to(*args, **kwargs)
def _ensure_copy_streams(self) -> List[List[AbstractStream]]:
"""Ensures that :class:`Pipe` caches CUDA streams for copy.
It's worth to cache CUDA streams although PyTorch already manages a
pool of pre-allocated CUDA streams, because it may reduce GPU memory
fragementation when the number of micro-batches is small.
"""
if not self._copy_streams:
for device in self.devices:
self._copy_streams.append([new_stream(device) for _ in range(self.chunks)])
return self._copy_streams
def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore
""":class:`Pipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But
there's type restriction. Input and output have to be a
:class:`~torch.Tensor` or a tuple of tensors. This restriction is
applied at partition boundaries too.
Args:
input (torch.Tensor or tensors): input mini-batch
Returns:
tensor or tensors: output mini-batch
Raises:
TypeError: input is not a tensor or tensors.
"""
microbatch.check(input)
if not self.devices:
# Empty sequential module is not illegal.
return input
# Divide a mini-batch into micro-batches.
batches = microbatch.scatter(input, self.chunks)
# Run pipeline parallelism.
self.pipeline.run(batches)
# Merge the micro-batches into one mini-batch.
output = microbatch.gather(batches)
return output
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The pipeline parallelism of Pipe."""
from queue import Queue
from types import TracebackType
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast
import torch
from torch import Tensor, nn
from torch.autograd.profiler import record_function
from .checkpoint import Checkpointing
from .copy import Copy, Wait
from .dependency import fork, join
from .microbatch import Batch
from .skip.layout import SkipLayout
from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from .stream import AbstractStream, current_stream, use_device
from .worker import Task, create_workers, join_workers
__all__: List[str] = []
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if TYPE_CHECKING:
InQueue = Queue[Optional["Task"]]
OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]]
else:
InQueue = Queue
OutQueue = Queue
def depend(fork_from: Batch, join_to: Batch) -> None:
fork_from[0], phony = fork(fork_from[0])
join_to[0] = join(join_to[0], phony)
def copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None:
batch[:] = Copy.apply(prev_stream, next_stream, *batch)
# Gradients are only supported for float Tensors.
batch[:] = tuple([x if x.is_floating_point() else x.detach() for x in batch])
def wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None:
batch[:] = Wait.apply(prev_stream, next_stream, *batch)
# Gradients are only supported for float Tensors.
batch[:] = tuple([x if x.is_floating_point() else x.detach() for x in batch])
def clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]:
"""Generates schedules for each clock cycle."""
# m: number of micro-batches
# n: number of partitions
# i: index of micro-batch
# j: index of partition
# k: clock number
#
# k (i,j) (i,j) (i,j)
# - ----- ----- -----
# 0 (0,0)
# 1 (1,0) (0,1)
# 2 (2,0) (1,1) (0,2)
# 3 (2,1) (1,2)
# 4 (2,2)
for k in range(m + n - 1):
yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))]
class Pipeline:
"""The pipeline parallelism for Pipe."""
def __init__(
self,
partitions: List[nn.Sequential],
devices: List[torch.device],
copy_streams: List[List[AbstractStream]],
skip_layout: SkipLayout,
checkpoint_stop: int,
) -> None:
self.partitions = partitions
self.devices = devices
self.copy_streams = copy_streams
self.skip_layout = skip_layout
self.checkpoint_stop = checkpoint_stop
(self.in_queues, self.out_queues) = create_workers(devices)
def __del__(self) -> None:
join_workers(self.in_queues, self.out_queues)
def run(self, batches: List[Batch]) -> None:
"""Runs pipeline parallelism.
It modifies the given batches in place.
"""
partitions = self.partitions
devices = self.devices
skip_layout = self.skip_layout
m = len(batches)
n = len(partitions)
skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches]
for schedule in clock_cycles(m, n):
self.fence(batches, schedule, skip_trackers)
self.compute(batches, schedule, skip_trackers)
def fence(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
) -> None:
"""Copies micro-batches after computation for the previous
micro-batches.
"""
copy_streams = self.copy_streams
skip_layout = self.skip_layout
for i, j in schedule:
# Ensure that batches[i-1] is executed after batches[i] in
# backpropagation by an explicit dependency.
if i != 0 and j != 0:
depend(batches[i - 1], batches[i])
next_stream = copy_streams[j][i]
for prev_j, ns, name in skip_layout.copy_policy(j):
prev_stream = copy_streams[prev_j][i]
skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name)
if j != 0:
prev_stream = copy_streams[j - 1][i]
copy(batches[i], prev_stream, next_stream)
def compute(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
) -> None:
"""Runs tasks with synchronization to copy streams."""
partitions = self.partitions
devices = self.devices
copy_streams = self.copy_streams
checkpoint_stop = self.checkpoint_stop
# Disable checkpointing if in eval mode.
if not self.partitions[0].training:
checkpoint_stop = 0
n = len(partitions)
streams = [current_stream(d) for d in devices]
exc_info: Optional[ExcInfo] = None
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for i, j in schedule:
batch = batches[i]
partition = partitions[j]
# Synchronize with the copied input. ([1] in the diagram)
if j != 0:
wait(batch, copy_streams[j][i], streams[j])
# Determine whether checkpointing or not.
checkpoint = i < checkpoint_stop
if checkpoint:
def function(
input: TensorOrTensors,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return partition(input)
chk = Checkpointing(function, batch)
task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
del function, chk
else:
def compute(
batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> Batch:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition)
task = Task(streams[j], compute=compute, finalize=None)
del compute
# Compute tasks in parallel. ([2] in the diagram)
self.in_queues[j].put(task)
for i, j in schedule:
ok, payload = self.out_queues[j].get()
# Hold the first exception.
if exc_info is not None:
continue
elif not ok:
exc_info = cast(ExcInfo, payload)
continue
task, batch = cast(Tuple[Task, Batch], payload)
# The copy stream synchronizes to copy the output. ([3] in the
# diagram)
if j != n - 1:
wait(batch, streams[j], copy_streams[j][i])
# Finalize tasks. If checkpointing is enabled, here the
# recomputation is scheduled at backpropagation. ([4] in the
# diagram)
with use_device(devices[j]):
task.finalize(batch)
batches[i] = batch
# Fail at the first exception.
if exc_info is not None:
raise exc_info[0].with_traceback(exc_info[1], exc_info[2])
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Supports efficiency with skip connections."""
from .namespace import Namespace
from .skippable import pop, skippable, stash, verify_skippables
__all__ = ["skippable", "stash", "pop", "verify_skippables", "Namespace"]
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Static skip connection layout of ``@skippable`` modules."""
from typing import Dict, Iterable, List, Tuple
from torch import nn
from .namespace import Namespace
__all__: List[str] = []
class SkipLayout:
"""Represents a skip connection layout across partitions."""
# Skip routes indexed by 'ns, name': {(ns, name): (prev_j, next_j), ...}
by_ns_name: Dict[Tuple[Namespace, str], Tuple[int, int]]
# Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...]
by_partition: List[List[Tuple[int, Namespace, str]]]
def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None:
# The skip routes are already indexed by 'ns, name'.
self.by_ns_name = skip_routes
# Index skip routes by partition number 'j'.
self.by_partition = [[] for _ in range(num_partitions)]
for (ns, name), (prev_j, next_j) in skip_routes.items():
self.by_partition[next_j].append((prev_j, ns, name))
for p in self.by_partition:
p.sort()
def copy_policy(self, next_j: int) -> Iterable[Tuple[int, Namespace, str]]:
"""Generates skip routes for the given destination partition number.
The skip routes are sorted by source partition number in ascending
order.
Yields:
Each tuple of (source partition number, namespace, name).
"""
for prev_j, ns, name in self.by_partition[next_j]:
if prev_j == next_j:
# This skip tensor will be popped at the same partition where
# it is stashed. In this case, copy is not required.
continue
yield (prev_j, ns, name)
def requires_copy(self, ns: Namespace, name: str) -> bool:
"""Whether the given namespace and name requires partition-to-partition
copy or not.
"""
prev_j, next_j = self.by_ns_name.get((ns, name), (-1, -1))
return prev_j != next_j
def inspect_skip_layout(partitions: List[nn.Sequential]) -> SkipLayout:
"""Inspects the skip connection layout in the given partitions."""
# NOTE(sublee): Hide circular import inside this subroutine. Circular
# import is not ideal but placing this logic near to SkipLayout may
# increase cohesion of code.
from .skippable import Skippable
skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]] = {}
stashed_at: Dict[Tuple[Namespace, str], int] = {}
for j, partition in enumerate(partitions):
for layer in partition:
if not isinstance(layer, Skippable):
continue
for ns, name in layer.stashable():
stashed_at[(ns, name)] = j
for ns, name in layer.poppable():
prev_j = stashed_at.pop((ns, name))
skip_routes[(ns, name)] = (prev_j, j)
return SkipLayout(len(partitions), skip_routes)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides isolated namespace of skip tensors."""
import abc
from functools import total_ordering
from typing import Any
import uuid
__all__ = ["Namespace"]
@total_ordering
class Namespace(metaclass=abc.ABCMeta):
"""Namespace for isolating skip tensors used by :meth:`isolate()
<torchpipe.skip.skippable.Skippable.isolate>`.
"""
__slots__ = ("id",)
def __init__(self) -> None:
self.id = uuid.uuid4()
def __repr__(self) -> str:
return f"<Namespace '{self.id}'>"
def __hash__(self) -> int:
return hash(self.id)
# Namespaces should support ordering, since SkipLayout will sort tuples
# including a namespace. But actual order between namespaces is not
# important. That's why they are ordered by version 4 UUID which generates
# random numbers.
def __lt__(self, other: Any) -> bool:
if isinstance(other, Namespace):
return self.id < other.id
return False
def __eq__(self, other: Any) -> bool:
if isinstance(other, Namespace):
return self.id == other.id
return False
# 'None' is the default namespace,
# which means that 'isinstance(None, Namespace)' is 'True'.
Namespace.register(type(None))
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Portal keeps a tensor in the pocket plane. The tensor becomes hidden to the
autograd engine. The shared context of three functions (:class:`PortalBlue`,
:class:`PortalOrange`, and :class:`PortalCopy`) out of the computation graph is
one of the most important feature of :mod:`torchpipe.skip`.
The metaphor is inspired by Portal™ from Valve.
"""
from typing import List, Optional, Tuple
import torch
from torch import Tensor
from ..copy import Context as CopyContext
from ..copy import Copy
from ..phony import get_phony
from ..stream import AbstractStream, get_device
__all__: List[str] = []
class Portal:
"""A portal for a tensor."""
def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None:
self.put_tensor(tensor, tensor_life)
self.grad: Optional[Tensor] = None
def blue(self) -> Tensor:
"""Creates a :class:`PortalBlue` which hides the underlying tensor from
the autograd engine.
Join the returning phony to the main lane of the autograd graph to
assure the correct backpropagation::
PortalBlue --+
|
---------- Join --
"""
tensor = self.use_tensor()
if tensor is None:
return get_phony(torch.device("cpu"), requires_grad=False)
return PortalBlue.apply(self, tensor)
def orange(self, phony: Tensor) -> Optional[Tensor]:
"""Creates a :class:`PortalOrange` which retrieves the hidden tensor
without losing ability of backpropagation.
Give a phony forked from the main lane of an autograd graph::
+-- PortalOrange --+
| |
-- Fork --------- f(a, b) --
"""
self.check_tensor_life()
if self.tensor is None:
return self.use_tensor()
return PortalOrange.apply(self, phony)
def copy(self, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,) -> Tensor:
"""Copies the hidden tensor by a :class:`PortalCopy`.
Give a phony and use the returning phony to keep backpropagation::
+-- PortalCopy --+
| |
-- Fork ---------- Join --
"""
if self.tensor is None:
return get_phony(torch.device("cpu"), requires_grad=False)
return PortalCopy.apply(self, prev_stream, next_stream, phony)
def check_tensor_life(self) -> None:
if self.tensor_life <= 0:
raise RuntimeError("tensor in portal has been removed")
def put_tensor(self, tensor: Optional[Tensor], tensor_life: int) -> None:
"""Stores a tensor into this portal."""
# [Life of Tensor through Portal]
#
# The tensor can be retrieved by use_tensor() up to 'tensor_life'
# times. When the life becomes 0, the tensor will be deleted for
# deallocation in CUDA memory.
#
# The below events participate in a tensor through a portal.
# Note that [x] denotes the events which call use_tensor():
#
# 1. [x] blue()
# 2. [ ] PortalBlue.forward
# 3. [ ] copy()
# 4. [ ] PortalCopy.forward
# 5. [ ] orange()
# 6. [x] PortalOrange.forward
# - - - - - - - - - - - - - - - - - - - - - - - - - - -
# 7. [ ] orange() (recomputed)
# 8. [x] PortalOrange.forward (recomputed)
# 9. [ ] PortalOrange.backward
# 10. [ ] PortalCopy.backward
# 11. [x] blue() (recomputed)
# 12. [ ] PortalBlue.forward (recomputed)
# 13. [ ] PortalBlue.backward
#
self.tensor_life = tensor_life
if tensor_life > 0:
self.tensor = tensor
else:
self.tensor = None
def use_tensor(self) -> Optional[Tensor]:
"""Retrieves the underlying tensor and decreases the tensor life. When
the life becomes 0, it the tensor will be removed.
"""
self.check_tensor_life()
tensor = self.tensor
self.tensor_life -= 1
if self.tensor_life <= 0:
self.tensor = None
return tensor
def put_grad(self, grad: Tensor) -> None:
"""Stores a gradient into this portal."""
self.grad = grad
def use_grad(self) -> Tensor:
"""Retrieves and removes the underlying gradient. The gradient is
always ephemeral.
"""
if self.grad is None:
raise RuntimeError("grad in portal has been removed or never set")
grad = self.grad
self.grad = None
return grad
# Common interface between :class:`PortalBlue`, :class:`PortalOrange`, and
# :class:`PortalCopy`.
class Context(CopyContext):
portal: Portal
class PortalBlue(torch.autograd.Function):
"""Hides a tensor from the autograd engine by a :class:`Portal`."""
@staticmethod
# type: ignore
def forward(
ctx: Context,
portal: Portal,
# This tensor must be retrieved by portal.use_tensor().
tensor: Tensor,
) -> Tensor:
ctx.portal = portal
phony = get_phony(tensor.device, requires_grad=False)
return phony.detach()
@staticmethod
# type: ignore
def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, Tensor]:
# The paired PortalOrange should keep the gradient.
grad = ctx.portal.use_grad()
return None, grad
class PortalOrange(torch.autograd.Function):
"""Retrieves the hidden tensor from a :class:`Portal`."""
@staticmethod
# type: ignore
def forward(ctx: Context, portal: Portal, phony: Tensor) -> Tensor:
ctx.portal = portal
tensor = portal.use_tensor()
assert tensor is not None
return tensor.detach()
@staticmethod
def backward(ctx: Context, grad: Tensor) -> Tuple[None, None]: # type: ignore
# The paired PortalBlue will use the gradient.
ctx.portal.put_grad(grad)
return None, None
class PortalCopy(torch.autograd.Function):
"""Copies the hidden tensor in a :class:`Portal`. It replaces the hidden
tensor with copied one.
"""
@staticmethod
# type: ignore
def forward(
ctx: Context, portal: Portal, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,
) -> Tensor:
ctx.portal = portal
assert portal.tensor is not None
(portal.tensor,) = Copy.forward(ctx, prev_stream, next_stream, portal.tensor)
phony = get_phony(get_device(next_stream), requires_grad=False)
return phony.detach()
@staticmethod
# type: ignore
def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, None, None, None]:
portal = ctx.portal
assert portal.grad is not None
_, _, portal.grad = Copy.backward(ctx, portal.grad)
return None, None, None, None
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The user interface to define skip connections."""
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
FrozenSet,
Generator,
Iterable,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from torch import Tensor, nn
from ..microbatch import Batch
from .namespace import Namespace
from .tracker import current_skip_tracker
__all__ = ["skippable", "stash", "pop", "verify_skippables"]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
StashPop = Union["stash", "pop"]
StashPopGenerator = Generator[StashPop, Optional[Tensor], TensorOrTensors]
if TYPE_CHECKING:
SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]]
else:
SkippableModule = nn.Module
T = TypeVar("T", bound="Skippable")
class Skippable(nn.Module):
"""The base class for skippable modules.
Do not use this class directly. Define a subclass by :func:`skippable`
instead.
"""
module_cls: ClassVar[Type[SkippableModule]]
stashable_names: ClassVar[FrozenSet[str]]
poppable_names: ClassVar[FrozenSet[str]]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
self.module = self.module_cls(*args, **kwargs) # type: ignore
self.namespaces: Dict[str, Namespace] = {}
def __repr__(self) -> str:
return f"@skippable({self.module})"
def namespaced(self, name: str) -> Tuple[Namespace, str]:
"""Prepends namespace for the given skip name."""
ns = self.namespaces.get(name)
ns = cast(Namespace, ns)
return (ns, name)
def stashable(self) -> Iterable[Tuple[Namespace, str]]:
"""Iterates over namespaced skip names to be stashed."""
for name in self.stashable_names:
yield self.namespaced(name)
def poppable(self) -> Iterable[Tuple[Namespace, str]]:
"""Iterates over namespaced skip names to be popped."""
for name in self.poppable_names:
yield self.namespaced(name)
def isolate(self: T, ns: Namespace, *, only: Optional[Iterable[str]] = None) -> T:
r"""Isolates a specified subset or the whole set of skip tensors into a
namespace. In a single sequential module, skip tensors with the same
name are not allowed unless they are isolated by different namespaces.
Here's an example using the same name for skip tensors twice. Each pair
of ``Layer1`` and ``Layer2`` is isolated with its own namespace ``ns1``
and ``ns2``. There is no conflict anymore::
ns1 = Namespace()
ns2 = Namespace()
model = nn.Sequential(
Layer1().isolate(ns1),
Layer1().isolate(ns2),
Layer2(),
Layer3().isolate(ns2),
Layer3().isolate(ns1),
)
When `only` parameter is omitted, all skip tensors are isolated. You
can isolate a subset of skip tensors by passing `only` parameter::
ns_alice = Namespace()
ns_bob = Namespace()
model = nn.Sequential(
...
StashStashPop().isolate(ns_alice, only=['alice']) \
.isolate(ns_bob, only=['bob']),
...
)
Args:
ns (Namespace):
namespace for isolation
Keyword Args:
only (iterable of strs):
names of specific skip tensors to be isolated (omit this option
to isolate all skip tensors declared in this module)
Returns:
this module itself
"""
names: Iterable[str]
if only is None:
names = self.stashable_names | self.poppable_names
else:
names = set(only)
for name in names:
self.namespaces[name] = ns
return self
def dispatch(
self,
input: TensorOrTensors,
handle_stash: Callable[[str, Optional[Tensor]], None],
handle_pop: Callable[[str], Optional[Tensor]],
) -> TensorOrTensors:
"""Dispatches :class:`stash` or :class:`pop` commands generated by the
module's ``forward()``.
"""
generator = self.module(input)
if not isinstance(generator, Generator):
# The underlying module returned output without any yield.
output = generator
return output
try:
op = next(generator)
while True:
if isinstance(op, stash):
handle_stash(op.name, op.tensor)
op = next(generator)
continue
if isinstance(op, pop):
tensor = handle_pop(op.name)
op = generator.send(tensor)
continue
raise TypeError("%r is not a command from @skippable" % op)
except StopIteration as stop:
output = stop.args[0]
return output
def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore
"""Performs the forward propagation. :class:`stash` or :class:`pop`
commands will be handled by portals silently. The portals won't be
exposed to users.
Raises:
RuntimeError:
illegal 'stash' or 'pop' is found.
"""
skip_tracker = current_skip_tracker()
stashed_tensors: Dict[str, Optional[Tensor]] = {}
# Load skip tensors that might be popped.
poppable_tensors = {}
batch = Batch(input)
for ns, name in self.poppable():
try:
poppable_tensors[name] = skip_tracker.load(batch, ns, name)
except KeyError:
raise RuntimeError(f"'{name}' has not been stashed")
input = batch.tensor_or_tensors
# Handle skip commands.
def handle_stash(name: str, tensor: Optional[Tensor]) -> None:
if name not in self.stashable_names:
raise RuntimeError(f"'{name}' has not been declared as stashable")
stashed_tensors[name] = tensor
def handle_pop(name: str) -> Optional[Tensor]:
if name not in self.poppable_names:
raise RuntimeError(f"'{name}' has not been declared as poppable")
return poppable_tensors.pop(name)
output = self.dispatch(input, handle_stash, handle_pop)
# All declared skips must be stashed or popped.
not_stashed = self.stashable_names - stashed_tensors.keys()
if not_stashed:
comma_names = ", ".join("'%s'" % n for n in not_stashed)
raise RuntimeError(f"{comma_names} must be stashed but have not")
not_popped = poppable_tensors.keys()
if not_popped:
comma_names = ", ".join("'%s'" % n for n in not_popped)
raise RuntimeError(f"{comma_names} must be popped but have not")
# Save stashed skip tensors.
batch = Batch(output)
for ns, name in self.stashable():
tensor = stashed_tensors[name]
skip_tracker.save(batch, ns, name, tensor)
output = batch.tensor_or_tensors
return output
# TODO(sublee): Move to above of Skippable class for better read flow.
def skippable(
stash: Iterable[str] = (), pop: Iterable[str] = (),
) -> Callable[[Type[SkippableModule]], Type[Skippable]]:
"""The decorator to define a :class:`nn.Module <torch.nn.Module>` with skip
connections. Decorated modules are called "skippable". This functionality
works perfectly fine even when the module is not wrapped by
:class:`~torchpipe.Pipe`.
Each skip tensor is managed by its name. Before manipulating skip tensors,
a skippable module must statically declare the names for skip tensors by
`stash` and/or `pop` parameters. Skip tensors with pre-declared name can be
stashed by ``yield stash(name, tensor)`` or popped by ``tensor = yield
pop(name)``.
Here is an example with three layers. A skip tensor named "1to3" is stashed
and popped at the first and last layer, respectively::
@skippable(stash=['1to3'])
class Layer1(nn.Module):
def forward(self, input):
yield stash('1to3', input)
return f1(input)
class Layer2(nn.Module):
def forward(self, input):
return f2(input)
@skippable(pop=['1to3'])
class Layer3(nn.Module):
def forward(self, input):
skip_1to3 = yield pop('1to3')
return f3(input) + skip_1to3
model = nn.Sequential(Layer1(), Layer2(), Layer3())
One skippable module can stash or pop multiple skip tensors::
@skippable(stash=['alice', 'bob'], pop=['carol'])
class StashStashPop(nn.Module):
def forward(self, input):
yield stash('alice', f_alice(input))
yield stash('bob', f_bob(input))
carol = yield pop('carol')
return input + carol
Every skip tensor must be associated with exactly one pair of `stash` and
`pop`. :class:`~torchpipe.Pipe` checks this restriction automatically
when wrapping a module. You can also check the restriction by
:func:`~torchpipe.skip.verify_skippables` without
:class:`~torchpipe.Pipe`.
.. note::
:func:`@skippable <skippable>` changes the type of the wrapped class.
But currently (mypy v0.740), mypy could not understand class decorators
yet (`#3135 <https://github.com/python/mypy/issues/3135>`_).
There are two workarounds:
1. Naively ignore type errors by ``# type: ignore``.
2. Use ``skippable()()`` as a function instead of a decorator.
.. seealso:: :ref:`Long Skip Connections`
"""
stashable_names = frozenset(stash)
poppable_names = frozenset(pop)
def extend_skippable(module_cls: Type[SkippableModule]) -> Type[Skippable]:
name = module_cls.__name__
bases = (Skippable,)
attrs = {"module_cls": module_cls, "stashable_names": stashable_names, "poppable_names": poppable_names}
return type(name, bases, attrs)
return extend_skippable
class stash:
"""The command to stash a skip tensor.
::
def forward(self, input):
yield stash('name', input)
return f(input)
Args:
name (str): name of skip tensor
input (torch.Tensor or None): tensor to pass to the skip connection
"""
__slots__ = ("name", "tensor")
def __init__(self, name: str, tensor: Optional[Tensor]) -> None:
self.name = name
self.tensor = tensor
class pop:
"""The command to pop a skip tensor.
::
def forward(self, input):
skip = yield pop('name')
return f(input) + skip
Args:
name (str): name of skip tensor
Returns:
the skip tensor previously stashed by another layer under the same name
"""
__slots__ = ("name",)
def __init__(self, name: str) -> None:
self.name = name
def verify_skippables(module: nn.Sequential) -> None:
"""Verifies if the underlying skippable modules satisfy integrity.
Every skip tensor must have only one pair of `stash` and `pop`. If there
are one or more unmatched pairs, it will raise :exc:`TypeError` with the
detailed messages.
Here are a few failure cases. :func:`verify_skippables` will report failure
for these cases::
# Layer1 stashes "1to3".
# Layer3 pops "1to3".
nn.Sequential(Layer1(), Layer2())
# └──── ?
nn.Sequential(Layer2(), Layer3())
# ? ────┘
nn.Sequential(Layer1(), Layer2(), Layer3(), Layer3())
# └───────────────────┘ ^^^^^^
nn.Sequential(Layer1(), Layer1(), Layer2(), Layer3())
# ^^^^^^ └───────────────────┘
To use the same name for multiple skip tensors, they must be isolated by
different namespaces. See :meth:`isolate()
<torchpipe.skip.skippable.Skippable.isolate>`.
Raises:
TypeError:
one or more pairs of `stash` and `pop` are not matched.
"""
stashed: Set[Tuple[Namespace, str]] = set()
popped: Set[Tuple[Namespace, str]] = set()
msgs: List[str] = []
for layer_name, layer in module.named_children():
if not isinstance(layer, Skippable):
continue
for name in layer.stashable_names & layer.poppable_names:
msg = f"'{layer_name}' declared '{name}' both as stashable and as poppable"
msgs.append(msg)
for ns, name in layer.stashable():
if name in layer.poppable_names:
continue
if (ns, name) in stashed:
msg = f"'{layer_name}' redeclared '{name}' as stashable " "but not isolated by namespace"
msgs.append(msg)
continue
stashed.add((ns, name))
for ns, name in layer.poppable():
if name in layer.stashable_names:
continue
if (ns, name) in popped:
msg = f"'{layer_name}' redeclared '{name}' as poppable " "but not isolated by namespace"
msgs.append(msg)
continue
if (ns, name) not in stashed:
msg = f"'{layer_name}' declared '{name}' as poppable but it was not stashed"
msgs.append(msg)
continue
popped.add((ns, name))
for (_, name) in stashed - popped:
msg = f"no module declared '{name}' as poppable but stashed"
msgs.append(msg)
if msgs:
raise TypeError(
"one or more pairs of stash and pop do not match:\n\n%s" "" % "\n".join("* %s" % x for x in msgs)
)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tracks skip tensors on a thread."""
from contextlib import contextmanager
import threading
from typing import Dict, Generator, List, Optional, Tuple
from torch import Tensor
from ..checkpoint import is_checkpointing
from ..dependency import fork, join
from ..microbatch import Batch
from ..stream import AbstractStream
from .layout import SkipLayout
from .namespace import Namespace
from .portal import Portal
__all__: List[str] = []
class SkipTracker:
"""Tracks saved skip tensors.
It will update the given micro-batch in place. This is because when it
manipulates the underlying skip tensors, the current micro-batch also has
to be connected with the skip tensors.
One thread has one skip tracker. Call :func:`current_skip_tracker` to get
the skip tracker on the current thread.
"""
def __init__(self) -> None:
self.tensors: Dict[Tuple[Namespace, str], Optional[Tensor]] = {}
def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None:
self.tensors[(ns, name)] = tensor
def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]:
return self.tensors.pop((ns, name))
def copy(
self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str,
) -> None:
raise TypeError("copy is not supported for non-portal skip tensors")
class SkipTrackerThroughPotals(SkipTracker):
"""Tracks saved skip tensors through portals. The skip tensors will be
hidden in portals so that the autograd engine does not need to track them.
This tracker is only used when the training or evaluating module is wrapped
with :class:`torchpipe.Pipe`.
"""
def __init__(self, skip_layout: SkipLayout) -> None:
super().__init__()
self.skip_layout = skip_layout
self.portals: Dict[Tuple[Namespace, str], Portal] = {}
def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None:
"""Saves the stashed skip tensor in a portal. The portal is then
connected to the given micro-batch with :class:`Join`.
"""
if not self.skip_layout.requires_copy(ns, name):
super().save(batch, ns, name, tensor)
return
# See [Tensor Life of Portal] at Portal.put_tensor() to understand the
# below tensor_life values. Here are the selected events which retrieve
# the tensor in portal:
#
# 1. [x] blue()
# ...
# 6. [x] PortalOrange.forward
# ...
# 8. [x] PortalOrange.forward (recomputed)
# ...
# 11. [x] blue() (recomputed)
#
if (ns, name) not in self.portals:
if is_checkpointing():
# Under checkpointing, the tensor used by the first
# PortalOrange should be alive in the portal. This tensor will
# be used again by the second PortalOrange during the
# recomputation.
tensor_life = 3 # Delete at [8. PortalOrange.forward (recomputed)]
else:
tensor_life = 2 # Delete at [6. PortalOrange.forward]
portal = Portal(tensor, tensor_life)
self.portals[(ns, name)] = portal
else:
# Under recomputation, the portal already exists.
portal = self.portals[(ns, name)]
# The existing tensor life already became 0. It should be reset as
# 1 to delete the tensor after the second PortalBlue immediately.
tensor_life = 1 # Delete at [11. blue() (recomputed)]
portal.put_tensor(tensor, tensor_life)
phony = portal.blue()
batch[0] = join(batch[0], phony)
def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]:
"""Loads a skip tensor from the corresponding portal to pop. The given
micro-batch is connected to the portal with :class:`Fork`.
"""
if not self.skip_layout.requires_copy(ns, name):
tensor = super().load(batch, ns, name)
return tensor
portal = self.portals[(ns, name)]
batch[0], phony = fork(batch[0])
tensor = portal.orange(phony)
return tensor
def copy(
self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str,
) -> None:
"""Copies the skip tensor in the corresponding portal. The given
micro-batch and the portal will be tied with :class:`Fork` and
:class:`Join`.
"""
assert self.skip_layout.requires_copy(ns, name)
batch[0], phony = fork(batch[0])
portal = self.portals[(ns, name)]
phony = portal.copy(prev_stream, next_stream, phony)
batch[0] = join(batch[0], phony)
class ThreadLocal(threading.local):
def __init__(self) -> None:
self.skip_tracker: Optional[SkipTracker] = None
thread_local = ThreadLocal()
@contextmanager
def use_skip_tracker(skip_tracker: SkipTracker) -> Generator[None, None, None]:
"""Registers the given skip tracker on the current thread within a
context::
with use_skip_tracker(my_skip_tracker):
...
"""
orig = thread_local.skip_tracker
thread_local.skip_tracker = skip_tracker
try:
yield
finally:
thread_local.skip_tracker = orig
def current_skip_tracker() -> SkipTracker:
"""Gets the skip tracker on the current thread."""
skip_tracker = thread_local.skip_tracker
if skip_tracker is None:
skip_tracker = SkipTracker()
thread_local.skip_tracker = skip_tracker
return skip_tracker
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for eliminating boilerplate code to handle abstract streams with
CPU device.
"""
from contextlib import contextmanager
from typing import Generator, List, Union, cast
import torch
__all__: List[str] = []
class CPUStreamType:
pass
# The placeholder on place of streams for the CPU device instead of CUDA.
CPUStream = CPUStreamType()
# It represents both CUDA streams and the CPU stream.
AbstractStream = Union[torch.cuda.Stream, CPUStreamType]
def new_stream(device: torch.device) -> AbstractStream:
"""Creates a new stream for either CPU or CUDA device."""
if device.type != "cuda":
return CPUStream
return torch.cuda.Stream(device)
def current_stream(device: torch.device) -> AbstractStream:
""":func:`torch.cuda.current_stream` for either CPU or CUDA device."""
if device.type != "cuda":
return CPUStream
return torch.cuda.current_stream(device)
def default_stream(device: torch.device) -> AbstractStream:
""":func:`torch.cuda.default_stream` for either CPU or CUDA device."""
if device.type != "cuda":
return CPUStream
return torch.cuda.default_stream(device)
@contextmanager
def use_device(device: torch.device) -> Generator[None, None, None]:
""":func:`torch.cuda.device` for either CPU or CUDA device."""
if device.type != "cuda":
yield
return
with torch.cuda.device(device):
yield
@contextmanager
def use_stream(stream: AbstractStream) -> Generator[None, None, None]:
""":func:`torch.cuda.stream` for either CPU or CUDA stream."""
if not is_cuda(stream):
yield
return
with torch.cuda.stream(as_cuda(stream)):
yield
def get_device(stream: AbstractStream) -> torch.device:
"""Gets the device from CPU or CUDA stream."""
if is_cuda(stream):
return as_cuda(stream).device
return torch.device("cpu")
def wait_stream(source: AbstractStream, target: AbstractStream) -> None:
""":meth:`torch.cuda.Stream.wait_stream` for either CPU or CUDA stream. It
makes the source stream wait until the target stream completes work queued.
"""
if is_cuda(target):
if is_cuda(source):
# A CUDA stream waits another CUDA stream.
as_cuda(source).wait_stream(as_cuda(target))
else:
# CPU waits a CUDA stream.
as_cuda(target).synchronize()
# If the target is CPU, synchronization is not required.
def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
""":meth:`torch.Tensor.record_stream` for either CPU or CUDA stream."""
if is_cuda(stream):
# NOTE(sublee): record_stream() on a shifted view tensor throws
# RuntimeError in PyTorch 1.1.0, and does nothing in 1.2.0. To safely
# protect the tensor against unexpected reallocation, here we use a
# temporal tensor associated with the same storage without shifting as
# a workaround.
#
# Issue: https://github.com/pytorch/pytorch/issues/27366
#
tensor = tensor.new_empty([0]).set_(tensor.storage())
tensor.record_stream(as_cuda(stream))
def is_cuda(stream: AbstractStream) -> bool:
"""Returns ``True`` if the given stream is a valid CUDA stream."""
return stream is not CPUStream
def as_cuda(stream: AbstractStream) -> torch.cuda.Stream:
"""Casts the given stream as :class:`torch.cuda.Stream`."""
return cast(torch.cuda.Stream, stream)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment