Unverified Commit 63f7796a authored by Tom Birch's avatar Tom Birch Committed by GitHub
Browse files

Multi-process pipe (#90)

Adds support for distributing pipeline stages across multiple processes (and therefore multiple machines)
* Adds a style argument to the Pipe constructor, defaulting to PipelineStyle.SingleProcess, but also supporting PipelineStyle.MultiProcess
* Added support for lazy construction of modules (see lazy_construction for an example)
* Added two implementations of inter-process communication: one based on rpc with globally visible queues, one based on send/recv
* Copied all the relevant tests from tests/pipe to tests/pipe_process and modified them to exercise PipelineStyle.MultiProcess
parent 49a198c9
...@@ -149,7 +149,7 @@ jobs: ...@@ -149,7 +149,7 @@ jobs:
- run: - run:
name: Run type-checking (mypy) name: Run type-checking (mypy)
command: | command: |
mypy --pretty . mypy --ignore-missing-imports --scripts-are-modules --pretty .
- <<: *run_flake8 - <<: *run_flake8
......
[settings] [settings]
known_third_party =numpy,pytest,recommonmark,setuptools,torch,torchtext,torchvision known_third_party =benchmark_dataset,dataclasses,numpy,packaging,pytest,recommonmark,setuptools,torch,torchtext,torchvision
...@@ -37,6 +37,7 @@ repos: ...@@ -37,6 +37,7 @@ repos:
rev: 4.3.20 rev: 4.3.20
hooks: hooks:
- id: isort - id: isort
exclude: README.md
additional_dependencies: [toml] additional_dependencies: [toml]
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
......
import torch
from torch.utils.data import Dataset
def collate_sentences_lm(samples):
if len(samples) == 0:
return {}
id = torch.LongTensor([s["id"] for s in samples])
src_tokens = torch.stack([s["source"] for s in samples], 0)
tgt_tokens = torch.stack([s["target"] for s in samples], 0)
ntokens = len(samples) * len(samples[0]["target"])
src_lengths = torch.LongTensor([len(samples[0]["source"])] * len(samples))
batch = {
"id": id,
"nsentences": len(samples),
"ntokens": ntokens,
"input": src_tokens,
"target": tgt_tokens,
}
return batch
class BenchmarkLMDataset(Dataset):
"""
Dataset to benchmark a translation like seq2seq task.
Args:
vocab_size (int, optional): size of the vocabulary (default 10000).
max_source_positions (int, optional): max number of tokens in the
source sentence (default: 1024).
total_samples (int, optional): the total number of rows in the
dataset (default: 10000).
"""
def __init__(
self, vocab_size=10000, max_source_positions=1024, total_samples=10000,
):
self.vocab_size = vocab_size
self.max_source_positions = max_source_positions
self.total_samples = total_samples
self.sizes = [self.max_source_positions] * self.total_samples
def __getitem__(self, index):
length = self.sizes[index]
source = torch.randint(1, self.vocab_size, (length,))
target = source.clone()
return {
"id": index,
"source": source,
"target": target,
}
def __len__(self):
return self.total_samples
This diff is collapsed.
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# #
import os import os
import sys import sys
from typing import Any, List
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes. # a list of builtin themes.
...@@ -46,7 +47,7 @@ templates_path = ["_templates"] ...@@ -46,7 +47,7 @@ templates_path = ["_templates"]
# List of patterns, relative to source directory, that match files and # List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files. # directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path. # This pattern also affects html_static_path and html_extra_path.
exclude_patterns = [] exclude_patterns: List[Any] = []
# -- Options for HTML output ------------------------------------------------- # -- Options for HTML output -------------------------------------------------
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from .cross_entropy import vocab_parallel_cross_entropy from .cross_entropy import vocab_parallel_cross_entropy
from .initialize import ( from .initialize import (
destroy_model_parallel,
get_data_parallel_group, get_data_parallel_group,
get_data_parallel_rank, get_data_parallel_rank,
get_data_parallel_world_size, get_data_parallel_world_size,
...@@ -12,6 +13,8 @@ from .initialize import ( ...@@ -12,6 +13,8 @@ from .initialize import (
get_model_parallel_rank, get_model_parallel_rank,
get_model_parallel_src_rank, get_model_parallel_src_rank,
get_model_parallel_world_size, get_model_parallel_world_size,
get_pipeline_parallel_group,
get_pipeline_parallel_ranks,
initialize_model_parallel, initialize_model_parallel,
) )
from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
......
...@@ -35,6 +35,8 @@ _DATA_PARALLEL_GROUP = None ...@@ -35,6 +35,8 @@ _DATA_PARALLEL_GROUP = None
# Pipeline parallel group that the current rank belongs to. # Pipeline parallel group that the current rank belongs to.
_PIPELINE_PARALLEL_GROUP = None _PIPELINE_PARALLEL_GROUP = None
_PIPELINE_PARALLEL_RANKS = None
def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = 1) -> None: def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = 1) -> None:
""" """
...@@ -93,7 +95,15 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = ...@@ -93,7 +95,15 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int =
global _PIPELINE_PARALLEL_GROUP global _PIPELINE_PARALLEL_GROUP
assert _PIPELINE_PARALLEL_GROUP is None, "model parallel group is already initialized" assert _PIPELINE_PARALLEL_GROUP is None, "model parallel group is already initialized"
_PIPELINE_PARALLEL_GROUP = groups[found[0], :, found[2]].tolist() global _PIPELINE_PARALLEL_RANKS
assert _PIPELINE_PARALLEL_RANKS is None, "model parallel group is already initialized"
for i in range(data_parallel_size):
for k in range(model_parallel_size):
ranks = groups[i, :, k].tolist()
group = torch.distributed.new_group(ranks)
if i == found[0] and k == found[2]:
_PIPELINE_PARALLEL_GROUP = group
_PIPELINE_PARALLEL_RANKS = ranks
def model_parallel_is_initialized() -> bool: def model_parallel_is_initialized() -> bool:
...@@ -115,12 +125,18 @@ def get_data_parallel_group() -> torch.distributed.ProcessGroup: ...@@ -115,12 +125,18 @@ def get_data_parallel_group() -> torch.distributed.ProcessGroup:
return _DATA_PARALLEL_GROUP return _DATA_PARALLEL_GROUP
def get_pipeline_parallel_group() -> List[int]: def get_pipeline_parallel_group() -> torch.distributed.ProcessGroup:
"""Get the pipeline parallel group the caller rank belongs to.""" """Get the pipeline parallel group the caller rank belongs to."""
assert _PIPELINE_PARALLEL_GROUP is not None, "pipeline parallel group is not initialized" assert _PIPELINE_PARALLEL_GROUP is not None, "pipeline parallel group is not initialized"
return _PIPELINE_PARALLEL_GROUP return _PIPELINE_PARALLEL_GROUP
def get_pipeline_parallel_ranks() -> List[int]:
"""Get the pipeline parallel group the caller rank belongs to."""
assert _PIPELINE_PARALLEL_RANKS is not None, "pipeline parallel group is not initialized"
return _PIPELINE_PARALLEL_RANKS
def get_model_parallel_world_size() -> int: def get_model_parallel_world_size() -> int:
"""Return world size for the model parallel group.""" """Return world size for the model parallel group."""
return torch.distributed.get_world_size(group=get_model_parallel_group()) return torch.distributed.get_world_size(group=get_model_parallel_group())
...@@ -157,3 +173,6 @@ def destroy_model_parallel() -> None: ...@@ -157,3 +173,6 @@ def destroy_model_parallel() -> None:
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
global _PIPELINE_PARALLEL_GROUP global _PIPELINE_PARALLEL_GROUP
_PIPELINE_PARALLEL_GROUP = None _PIPELINE_PARALLEL_GROUP = None
global _PIPELINE_PARALLEL_RANKS
_PIPELINE_PARALLEL_RANKS = None
...@@ -280,6 +280,9 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -280,6 +280,9 @@ class ColumnParallelLinear(torch.nn.Module):
return_master_weight=keep_master_weight_for_test, return_master_weight=keep_master_weight_for_test,
) )
def get_master_weight(self) -> torch.Tensor:
return gather_from_model_parallel_region(self.weight.data.transpose(0, 1)).transpose_(0, 1)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
# Set up backprop all-reduce. # Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_) input_parallel = copy_to_model_parallel_region(input_)
...@@ -364,6 +367,9 @@ class RowParallelLinear(torch.nn.Module): ...@@ -364,6 +367,9 @@ class RowParallelLinear(torch.nn.Module):
return_master_weight=keep_master_weight_for_test, return_master_weight=keep_master_weight_for_test,
) )
def get_master_weight(self) -> torch.Tensor:
return gather_from_model_parallel_region(self.weight.data)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore
# Set up backprop all-reduce. # Set up backprop all-reduce.
if self.input_is_parallel: if self.input_is_parallel:
......
...@@ -19,21 +19,27 @@ ...@@ -19,21 +19,27 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any
import torch import torch
from .initialize import get_model_parallel_group from .initialize import get_model_parallel_group
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
def _reduce(input_: torch.Tensor) -> torch.Tensor: def _reduce(ctx: Any, input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the the input tensor across model parallel group.""" """All-reduce the the input tensor across model parallel group."""
group = get_model_parallel_group() group = get_model_parallel_group()
if ctx:
ctx.mark_dirty(input_)
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if torch.distributed.get_world_size(group=group) == 1: if torch.distributed.get_world_size(group=group) == 1:
return input_ return input_
# All-reduce. # All-reduce.
print(f"doing all_reduce on {torch.distributed.get_rank()}")
torch.distributed.all_reduce(input_, group=group) torch.distributed.all_reduce(input_, group=group)
return input_ return input_
...@@ -87,11 +93,13 @@ class _CopyToModelParallelRegion(torch.autograd.Function): ...@@ -87,11 +93,13 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input_): # type: ignore def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _CopyToModelParallelRegion Forward")
return input_ return input_
@staticmethod @staticmethod
def backward(ctx, grad_output): # type: ignore def backward(ctx, grad_output): # type: ignore
return _reduce(grad_output) print(f"{torch.distributed.get_rank()}: _CopyToModelParallelRegion Backward")
return _reduce(None, grad_output)
class _ReduceFromModelParallelRegion(torch.autograd.Function): class _ReduceFromModelParallelRegion(torch.autograd.Function):
...@@ -99,10 +107,12 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function): ...@@ -99,10 +107,12 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input_): # type: ignore def forward(ctx, input_): # type: ignore
return _reduce(input_) print(f"{torch.distributed.get_rank()}: _ReduceFromModelParallelRegion Forward")
return _reduce(ctx, input_)
@staticmethod @staticmethod
def backward(ctx, grad_output): # type: ignore def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _ReduceFromModelParallelRegion Backward")
return grad_output return grad_output
...@@ -111,10 +121,12 @@ class _ScatterToModelParallelRegion(torch.autograd.Function): ...@@ -111,10 +121,12 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input_): # type: ignore def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _ScatterToModelParallelRegion Forward")
return _split(input_) return _split(input_)
@staticmethod @staticmethod
def backward(ctx, grad_output): # type: ignore def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _ScatterToModelParallelRegion Backward")
return _gather(grad_output) return _gather(grad_output)
...@@ -123,10 +135,12 @@ class _GatherFromModelParallelRegion(torch.autograd.Function): ...@@ -123,10 +135,12 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input_): # type: ignore def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _GatherFromModelParallelRegion Forward")
return _gather(input_) return _gather(input_)
@staticmethod @staticmethod
def backward(ctx, grad_output): # type: ignore def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _GatherFromModelParallelRegion Backward")
return _split(grad_output) return _split(grad_output)
......
...@@ -59,7 +59,7 @@ def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float ...@@ -59,7 +59,7 @@ def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float
if any(p.grad is not None for p in module.parameters()): if any(p.grad is not None for p in module.parameters()):
raise ValueError("some parameter already has gradient") raise ValueError("some parameter already has gradient")
_batch = Batch(sample) _batch = Batch(sample, 0)
for i, x in enumerate(_batch): for i, x in enumerate(_batch):
_batch[i] = x.detach().to(device).requires_grad_(x.requires_grad) _batch[i] = x.detach().to(device).requires_grad_(x.requires_grad)
...@@ -101,7 +101,7 @@ def profile_sizes( ...@@ -101,7 +101,7 @@ def profile_sizes(
if device.type != "cuda": if device.type != "cuda":
raise ValueError("size profiler supports only CUDA device") raise ValueError("size profiler supports only CUDA device")
batch = Batch(input) batch = Batch(input, 0)
sizes: List[int] = [] sizes: List[int] = []
latent_scale = batch[0].size(0) / chunks latent_scale = batch[0].size(0) / chunks
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# Copyright 2019 Kakao Brain # Copyright 2019 Kakao Brain
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
...@@ -74,20 +74,6 @@ class Function(Protocol): ...@@ -74,20 +74,6 @@ class Function(Protocol):
... ...
def checkpoint(function: Function, input: TensorOrTensors) -> TensorOrTensors:
"""Makes a checkpoint with a simple interface like
:func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug
:class:`Checkpoint` and :class:`Recompute` without boilerplate.
"""
batch = Batch(input)
chk = Checkpointing(function, batch)
batch = chk.checkpoint()
chk.recompute(batch)
return batch.tensor_or_tensors
class Checkpointing: class Checkpointing:
"""Generates a pair of :class:`Checkpoint` and :class:`Recompute`.""" """Generates a pair of :class:`Checkpoint` and :class:`Recompute`."""
...@@ -116,7 +102,7 @@ class Checkpointing: ...@@ -116,7 +102,7 @@ class Checkpointing:
if isinstance(output, tuple): if isinstance(output, tuple):
output = tuple([x if x.is_floating_point() else x.detach() for x in output]) output = tuple([x if x.is_floating_point() else x.detach() for x in output])
return Batch(output) return Batch(output, self.batch.index)
def recompute(self, batch: Batch) -> None: def recompute(self, batch: Batch) -> None:
"""Applies :class:`Recompute` to the batch in place.""" """Applies :class:`Recompute` to the batch in place."""
...@@ -226,6 +212,7 @@ def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None ...@@ -226,6 +212,7 @@ def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None
else: else:
gpu_rng_state = None gpu_rng_state = None
rng_states.clear()
rng_states.append((cpu_rng_state, gpu_rng_state)) rng_states.append((cpu_rng_state, gpu_rng_state))
...@@ -237,7 +224,7 @@ def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> G ...@@ -237,7 +224,7 @@ def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> G
.. seealso:: :ref:`Referential Transparency` .. seealso:: :ref:`Referential Transparency`
""" """
cpu_rng_state, gpu_rng_state = rng_states.pop() cpu_rng_state, gpu_rng_state = rng_states[0]
gpu_devices: List[torch.device] = [] gpu_devices: List[torch.device] = []
if device.type == "cuda": if device.type == "cuda":
......
...@@ -53,9 +53,14 @@ class Batch: ...@@ -53,9 +53,14 @@ class Batch:
""" """
def __init__(self, value: TensorOrTensors) -> None: def __init__(self, value: TensorOrTensors, index: int) -> None:
self.value = value self.value = value
self.atomic = torch.is_tensor(value) self.atomic = torch.is_tensor(value)
self.__index = index
@property
def index(self) -> int:
return self.__index
@property @property
def tensor(self) -> Tensor: def tensor(self) -> Tensor:
...@@ -80,7 +85,7 @@ class Batch: ...@@ -80,7 +85,7 @@ class Batch:
"""Calls a function by the underlying tensor or tensors. It also wraps """Calls a function by the underlying tensor or tensors. It also wraps
the output with :class:`Batch`. the output with :class:`Batch`.
""" """
return Batch(function(self.value)) return Batch(function(self.value), self.index)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Batch[atomic={self.atomic!r}]({self.value!r})" return f"Batch[atomic={self.atomic!r}]({self.value!r})"
...@@ -176,7 +181,7 @@ def scatter(input: TensorOrTensors, chunks: int) -> List[Batch]: ...@@ -176,7 +181,7 @@ def scatter(input: TensorOrTensors, chunks: int) -> List[Batch]:
inputs = zip(*rotated) inputs = zip(*rotated)
return [Batch(x) for x in inputs] return [Batch(x, i) for i, x in enumerate(inputs)]
def gather(outputs: List[Batch]) -> TensorOrTensors: def gather(outputs: List[Batch]) -> TensorOrTensors:
......
This diff is collapsed.
This diff is collapsed.
...@@ -36,19 +36,41 @@ class SkipLayout: ...@@ -36,19 +36,41 @@ class SkipLayout:
# Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...] # Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...]
by_partition: List[List[Tuple[int, Namespace, str]]] by_partition: List[List[Tuple[int, Namespace, str]]]
# Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...]
by_src_partition: List[List[Tuple[int, Namespace, str]]]
def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None: def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None:
# The skip routes are already indexed by 'ns, name'. # The skip routes are already indexed by 'ns, name'.
self.by_ns_name = skip_routes self.by_ns_name = skip_routes
# Index skip routes by partition number 'j'. # Index skip routes by partition number 'j'.
self.by_partition = [[] for _ in range(num_partitions)] self.by_partition = [[] for _ in range(num_partitions)]
self.by_src_partition = [[] for _ in range(num_partitions)]
for (ns, name), (prev_j, next_j) in skip_routes.items(): for (ns, name), (prev_j, next_j) in skip_routes.items():
self.by_partition[next_j].append((prev_j, ns, name)) self.by_partition[next_j].append((prev_j, ns, name))
self.by_src_partition[prev_j].append((next_j, ns, name))
for p in self.by_partition: for p in self.by_partition:
p.sort() p.sort()
def copy_policy_by_src(self, prev_j: int) -> Iterable[Tuple[int, Namespace, str]]:
"""Generates skip routes for the given destination partition number.
The skip routes are sorted by source partition number in ascending
order.
Yields:
Each tuple of (source partition number, namespace, name).
"""
for next_j, ns, name in self.by_src_partition[prev_j]:
if prev_j == next_j:
# This skip tensor will be popped at the same partition where
# it is stashed. In this case, copy is not required.
continue
yield (next_j, ns, name)
def copy_policy(self, next_j: int) -> Iterable[Tuple[int, Namespace, str]]: def copy_policy(self, next_j: int) -> Iterable[Tuple[int, Namespace, str]]:
"""Generates skip routes for the given destination partition number. """Generates skip routes for the given destination partition number.
The skip routes are sorted by source partition number in ascending The skip routes are sorted by source partition number in ascending
......
...@@ -25,11 +25,12 @@ one of the most important feature of :mod:`torchpipe.skip`. ...@@ -25,11 +25,12 @@ one of the most important feature of :mod:`torchpipe.skip`.
The metaphor is inspired by Portal™ from Valve. The metaphor is inspired by Portal™ from Valve.
""" """
from typing import List, Optional, Tuple from typing import Any, List, Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
from . import Namespace
from ..copy import Context as CopyContext from ..copy import Context as CopyContext
from ..copy import Copy from ..copy import Copy
from ..phony import get_phony from ..phony import get_phony
...@@ -41,9 +42,16 @@ __all__: List[str] = [] ...@@ -41,9 +42,16 @@ __all__: List[str] = []
class Portal: class Portal:
"""A portal for a tensor.""" """A portal for a tensor."""
def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None: def __init__(self, tensor: Optional[Tensor], tensor_life: int, index: int) -> None:
self.put_tensor(tensor, tensor_life) self.put_tensor(tensor, tensor_life)
self.grad: Optional[Tensor] = None self.grad: Optional[Tensor] = None
self.__index = index
self.ns_name: Optional[Tuple[Namespace, str]]
self.pipeline: Any
@property
def index(self) -> int:
return self.__index
def blue(self) -> Tensor: def blue(self) -> Tensor:
"""Creates a :class:`PortalBlue` which hides the underlying tensor from """Creates a :class:`PortalBlue` which hides the underlying tensor from
...@@ -151,12 +159,17 @@ class Portal: ...@@ -151,12 +159,17 @@ class Portal:
def put_grad(self, grad: Tensor) -> None: def put_grad(self, grad: Tensor) -> None:
"""Stores a gradient into this portal.""" """Stores a gradient into this portal."""
if hasattr(self, "pipeline"):
self.pipeline.send_portal_grad(self.ns_name, self.index, grad)
self.grad = grad self.grad = grad
def use_grad(self) -> Tensor: def use_grad(self) -> Tensor:
"""Retrieves and removes the underlying gradient. The gradient is """Retrieves and removes the underlying gradient. The gradient is
always ephemeral. always ephemeral.
""" """
if self.grad is None and hasattr(self, "pipeline"):
self.grad = self.pipeline.recv_portal_grad(self.ns_name, self.index)
if self.grad is None: if self.grad is None:
raise RuntimeError("grad in portal has been removed or never set") raise RuntimeError("grad in portal has been removed or never set")
......
...@@ -204,7 +204,7 @@ class Skippable(nn.Module): ...@@ -204,7 +204,7 @@ class Skippable(nn.Module):
# Load skip tensors that might be popped. # Load skip tensors that might be popped.
poppable_tensors = {} poppable_tensors = {}
batch = Batch(input) batch = Batch(input, skip_tracker.index)
for ns, name in self.poppable(): for ns, name in self.poppable():
try: try:
poppable_tensors[name] = skip_tracker.load(batch, ns, name) poppable_tensors[name] = skip_tracker.load(batch, ns, name)
...@@ -237,7 +237,7 @@ class Skippable(nn.Module): ...@@ -237,7 +237,7 @@ class Skippable(nn.Module):
raise RuntimeError(f"{comma_names} must be popped but have not") raise RuntimeError(f"{comma_names} must be popped but have not")
# Save stashed skip tensors. # Save stashed skip tensors.
batch = Batch(output) batch = Batch(output, skip_tracker.index)
for ns, name in self.stashable(): for ns, name in self.stashable():
tensor = stashed_tensors[name] tensor = stashed_tensors[name]
skip_tracker.save(batch, ns, name, tensor) skip_tracker.save(batch, ns, name, tensor)
......
...@@ -61,6 +61,10 @@ class SkipTracker: ...@@ -61,6 +61,10 @@ class SkipTracker:
) -> None: ) -> None:
raise TypeError("copy is not supported for non-portal skip tensors") raise TypeError("copy is not supported for non-portal skip tensors")
@property
def index(self) -> int:
return 0
class SkipTrackerThroughPotals(SkipTracker): class SkipTrackerThroughPotals(SkipTracker):
"""Tracks saved skip tensors through portals. The skip tensors will be """Tracks saved skip tensors through portals. The skip tensors will be
...@@ -71,10 +75,15 @@ class SkipTrackerThroughPotals(SkipTracker): ...@@ -71,10 +75,15 @@ class SkipTrackerThroughPotals(SkipTracker):
""" """
def __init__(self, skip_layout: SkipLayout) -> None: def __init__(self, skip_layout: SkipLayout, index: int) -> None:
super().__init__() super().__init__()
self.skip_layout = skip_layout self.skip_layout = skip_layout
self.portals: Dict[Tuple[Namespace, str], Portal] = {} self.portals: Dict[Tuple[Namespace, str], Portal] = {}
self.__index = index
@property
def index(self) -> int:
return self.__index
def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None: def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None:
"""Saves the stashed skip tensor in a portal. The portal is then """Saves the stashed skip tensor in a portal. The portal is then
...@@ -106,7 +115,9 @@ class SkipTrackerThroughPotals(SkipTracker): ...@@ -106,7 +115,9 @@ class SkipTrackerThroughPotals(SkipTracker):
else: else:
tensor_life = 2 # Delete at [6. PortalOrange.forward] tensor_life = 2 # Delete at [6. PortalOrange.forward]
portal = Portal(tensor, tensor_life) assert batch.index == self.index
portal = Portal(tensor, tensor_life, batch.index)
portal.ns_name = (ns, name)
self.portals[(ns, name)] = portal self.portals[(ns, name)] = portal
else: else:
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
CPU device. CPU device.
""" """
from contextlib import contextmanager from contextlib import contextmanager
from typing import Generator, List, Union, cast from typing import Generator, List, Optional, Union, cast
import torch import torch
...@@ -72,8 +72,12 @@ def use_device(device: torch.device) -> Generator[None, None, None]: ...@@ -72,8 +72,12 @@ def use_device(device: torch.device) -> Generator[None, None, None]:
@contextmanager @contextmanager
def use_stream(stream: AbstractStream) -> Generator[None, None, None]: def use_stream(stream: Optional[AbstractStream]) -> Generator[None, None, None]:
""":func:`torch.cuda.stream` for either CPU or CUDA stream.""" """:func:`torch.cuda.stream` for either CPU or CUDA stream."""
if not stream:
yield
return
if not is_cuda(stream): if not is_cuda(stream):
yield yield
return return
...@@ -120,7 +124,7 @@ def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None: ...@@ -120,7 +124,7 @@ def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
tensor.record_stream(as_cuda(stream)) tensor.record_stream(as_cuda(stream))
def is_cuda(stream: AbstractStream) -> bool: def is_cuda(stream: Optional[AbstractStream]) -> bool:
"""Returns ``True`` if the given stream is a valid CUDA stream.""" """Returns ``True`` if the given stream is a valid CUDA stream."""
return stream is not CPUStream return stream is not CPUStream
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment