Unverified Commit 63f7796a authored by Tom Birch's avatar Tom Birch Committed by GitHub
Browse files

Multi-process pipe (#90)

Adds support for distributing pipeline stages across multiple processes (and therefore multiple machines)
* Adds a style argument to the Pipe constructor, defaulting to PipelineStyle.SingleProcess, but also supporting PipelineStyle.MultiProcess
* Added support for lazy construction of modules (see lazy_construction for an example)
* Added two implementations of inter-process communication: one based on rpc with globally visible queues, one based on send/recv
* Copied all the relevant tests from tests/pipe to tests/pipe_process and modified them to exercise PipelineStyle.MultiProcess
parent 49a198c9
......@@ -149,7 +149,7 @@ jobs:
- run:
name: Run type-checking (mypy)
command: |
mypy --pretty .
mypy --ignore-missing-imports --scripts-are-modules --pretty .
- <<: *run_flake8
......
[settings]
known_third_party =numpy,pytest,recommonmark,setuptools,torch,torchtext,torchvision
known_third_party =benchmark_dataset,dataclasses,numpy,packaging,pytest,recommonmark,setuptools,torch,torchtext,torchvision
......@@ -37,6 +37,7 @@ repos:
rev: 4.3.20
hooks:
- id: isort
exclude: README.md
additional_dependencies: [toml]
- repo: https://github.com/pre-commit/mirrors-mypy
......
import torch
from torch.utils.data import Dataset
def collate_sentences_lm(samples):
if len(samples) == 0:
return {}
id = torch.LongTensor([s["id"] for s in samples])
src_tokens = torch.stack([s["source"] for s in samples], 0)
tgt_tokens = torch.stack([s["target"] for s in samples], 0)
ntokens = len(samples) * len(samples[0]["target"])
src_lengths = torch.LongTensor([len(samples[0]["source"])] * len(samples))
batch = {
"id": id,
"nsentences": len(samples),
"ntokens": ntokens,
"input": src_tokens,
"target": tgt_tokens,
}
return batch
class BenchmarkLMDataset(Dataset):
"""
Dataset to benchmark a translation like seq2seq task.
Args:
vocab_size (int, optional): size of the vocabulary (default 10000).
max_source_positions (int, optional): max number of tokens in the
source sentence (default: 1024).
total_samples (int, optional): the total number of rows in the
dataset (default: 10000).
"""
def __init__(
self, vocab_size=10000, max_source_positions=1024, total_samples=10000,
):
self.vocab_size = vocab_size
self.max_source_positions = max_source_positions
self.total_samples = total_samples
self.sizes = [self.max_source_positions] * self.total_samples
def __getitem__(self, index):
length = self.sizes[index]
source = torch.randint(1, self.vocab_size, (length,))
target = source.clone()
return {
"id": index,
"source": source,
"target": target,
}
def __len__(self):
return self.total_samples
This diff is collapsed.
......@@ -13,6 +13,7 @@
#
import os
import sys
from typing import Any, List
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
......@@ -46,7 +47,7 @@ templates_path = ["_templates"]
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
exclude_patterns: List[Any] = []
# -- Options for HTML output -------------------------------------------------
......
......@@ -5,6 +5,7 @@
from .cross_entropy import vocab_parallel_cross_entropy
from .initialize import (
destroy_model_parallel,
get_data_parallel_group,
get_data_parallel_rank,
get_data_parallel_world_size,
......@@ -12,6 +13,8 @@ from .initialize import (
get_model_parallel_rank,
get_model_parallel_src_rank,
get_model_parallel_world_size,
get_pipeline_parallel_group,
get_pipeline_parallel_ranks,
initialize_model_parallel,
)
from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
......
......@@ -35,6 +35,8 @@ _DATA_PARALLEL_GROUP = None
# Pipeline parallel group that the current rank belongs to.
_PIPELINE_PARALLEL_GROUP = None
_PIPELINE_PARALLEL_RANKS = None
def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = 1) -> None:
"""
......@@ -93,7 +95,15 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int =
global _PIPELINE_PARALLEL_GROUP
assert _PIPELINE_PARALLEL_GROUP is None, "model parallel group is already initialized"
_PIPELINE_PARALLEL_GROUP = groups[found[0], :, found[2]].tolist()
global _PIPELINE_PARALLEL_RANKS
assert _PIPELINE_PARALLEL_RANKS is None, "model parallel group is already initialized"
for i in range(data_parallel_size):
for k in range(model_parallel_size):
ranks = groups[i, :, k].tolist()
group = torch.distributed.new_group(ranks)
if i == found[0] and k == found[2]:
_PIPELINE_PARALLEL_GROUP = group
_PIPELINE_PARALLEL_RANKS = ranks
def model_parallel_is_initialized() -> bool:
......@@ -115,12 +125,18 @@ def get_data_parallel_group() -> torch.distributed.ProcessGroup:
return _DATA_PARALLEL_GROUP
def get_pipeline_parallel_group() -> List[int]:
def get_pipeline_parallel_group() -> torch.distributed.ProcessGroup:
"""Get the pipeline parallel group the caller rank belongs to."""
assert _PIPELINE_PARALLEL_GROUP is not None, "pipeline parallel group is not initialized"
return _PIPELINE_PARALLEL_GROUP
def get_pipeline_parallel_ranks() -> List[int]:
"""Get the pipeline parallel group the caller rank belongs to."""
assert _PIPELINE_PARALLEL_RANKS is not None, "pipeline parallel group is not initialized"
return _PIPELINE_PARALLEL_RANKS
def get_model_parallel_world_size() -> int:
"""Return world size for the model parallel group."""
return torch.distributed.get_world_size(group=get_model_parallel_group())
......@@ -157,3 +173,6 @@ def destroy_model_parallel() -> None:
_DATA_PARALLEL_GROUP = None
global _PIPELINE_PARALLEL_GROUP
_PIPELINE_PARALLEL_GROUP = None
global _PIPELINE_PARALLEL_RANKS
_PIPELINE_PARALLEL_RANKS = None
......@@ -280,6 +280,9 @@ class ColumnParallelLinear(torch.nn.Module):
return_master_weight=keep_master_weight_for_test,
)
def get_master_weight(self) -> torch.Tensor:
return gather_from_model_parallel_region(self.weight.data.transpose(0, 1)).transpose_(0, 1)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_)
......@@ -364,6 +367,9 @@ class RowParallelLinear(torch.nn.Module):
return_master_weight=keep_master_weight_for_test,
)
def get_master_weight(self) -> torch.Tensor:
return gather_from_model_parallel_region(self.weight.data)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore
# Set up backprop all-reduce.
if self.input_is_parallel:
......
......@@ -19,21 +19,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from .initialize import get_model_parallel_group
from .utils import split_tensor_along_last_dim
def _reduce(input_: torch.Tensor) -> torch.Tensor:
def _reduce(ctx: Any, input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the the input tensor across model parallel group."""
group = get_model_parallel_group()
if ctx:
ctx.mark_dirty(input_)
# Bypass the function if we are using only 1 GPU.
if torch.distributed.get_world_size(group=group) == 1:
return input_
# All-reduce.
print(f"doing all_reduce on {torch.distributed.get_rank()}")
torch.distributed.all_reduce(input_, group=group)
return input_
......@@ -87,11 +93,13 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _CopyToModelParallelRegion Forward")
return input_
@staticmethod
def backward(ctx, grad_output): # type: ignore
return _reduce(grad_output)
print(f"{torch.distributed.get_rank()}: _CopyToModelParallelRegion Backward")
return _reduce(None, grad_output)
class _ReduceFromModelParallelRegion(torch.autograd.Function):
......@@ -99,10 +107,12 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_): # type: ignore
return _reduce(input_)
print(f"{torch.distributed.get_rank()}: _ReduceFromModelParallelRegion Forward")
return _reduce(ctx, input_)
@staticmethod
def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _ReduceFromModelParallelRegion Backward")
return grad_output
......@@ -111,10 +121,12 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _ScatterToModelParallelRegion Forward")
return _split(input_)
@staticmethod
def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _ScatterToModelParallelRegion Backward")
return _gather(grad_output)
......@@ -123,10 +135,12 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _GatherFromModelParallelRegion Forward")
return _gather(input_)
@staticmethod
def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _GatherFromModelParallelRegion Backward")
return _split(grad_output)
......
......@@ -59,7 +59,7 @@ def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float
if any(p.grad is not None for p in module.parameters()):
raise ValueError("some parameter already has gradient")
_batch = Batch(sample)
_batch = Batch(sample, 0)
for i, x in enumerate(_batch):
_batch[i] = x.detach().to(device).requires_grad_(x.requires_grad)
......@@ -101,7 +101,7 @@ def profile_sizes(
if device.type != "cuda":
raise ValueError("size profiler supports only CUDA device")
batch = Batch(input)
batch = Batch(input, 0)
sizes: List[int] = []
latent_scale = batch[0].size(0) / chunks
......
......@@ -5,7 +5,7 @@
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
......@@ -74,20 +74,6 @@ class Function(Protocol):
...
def checkpoint(function: Function, input: TensorOrTensors) -> TensorOrTensors:
"""Makes a checkpoint with a simple interface like
:func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug
:class:`Checkpoint` and :class:`Recompute` without boilerplate.
"""
batch = Batch(input)
chk = Checkpointing(function, batch)
batch = chk.checkpoint()
chk.recompute(batch)
return batch.tensor_or_tensors
class Checkpointing:
"""Generates a pair of :class:`Checkpoint` and :class:`Recompute`."""
......@@ -116,7 +102,7 @@ class Checkpointing:
if isinstance(output, tuple):
output = tuple([x if x.is_floating_point() else x.detach() for x in output])
return Batch(output)
return Batch(output, self.batch.index)
def recompute(self, batch: Batch) -> None:
"""Applies :class:`Recompute` to the batch in place."""
......@@ -226,6 +212,7 @@ def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None
else:
gpu_rng_state = None
rng_states.clear()
rng_states.append((cpu_rng_state, gpu_rng_state))
......@@ -237,7 +224,7 @@ def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> G
.. seealso:: :ref:`Referential Transparency`
"""
cpu_rng_state, gpu_rng_state = rng_states.pop()
cpu_rng_state, gpu_rng_state = rng_states[0]
gpu_devices: List[torch.device] = []
if device.type == "cuda":
......
......@@ -53,9 +53,14 @@ class Batch:
"""
def __init__(self, value: TensorOrTensors) -> None:
def __init__(self, value: TensorOrTensors, index: int) -> None:
self.value = value
self.atomic = torch.is_tensor(value)
self.__index = index
@property
def index(self) -> int:
return self.__index
@property
def tensor(self) -> Tensor:
......@@ -80,7 +85,7 @@ class Batch:
"""Calls a function by the underlying tensor or tensors. It also wraps
the output with :class:`Batch`.
"""
return Batch(function(self.value))
return Batch(function(self.value), self.index)
def __repr__(self) -> str:
return f"Batch[atomic={self.atomic!r}]({self.value!r})"
......@@ -176,7 +181,7 @@ def scatter(input: TensorOrTensors, chunks: int) -> List[Batch]:
inputs = zip(*rotated)
return [Batch(x) for x in inputs]
return [Batch(x, i) for i, x in enumerate(inputs)]
def gather(outputs: List[Batch]) -> TensorOrTensors:
......
This diff is collapsed.
This diff is collapsed.
......@@ -36,19 +36,41 @@ class SkipLayout:
# Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...]
by_partition: List[List[Tuple[int, Namespace, str]]]
# Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...]
by_src_partition: List[List[Tuple[int, Namespace, str]]]
def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None:
# The skip routes are already indexed by 'ns, name'.
self.by_ns_name = skip_routes
# Index skip routes by partition number 'j'.
self.by_partition = [[] for _ in range(num_partitions)]
self.by_src_partition = [[] for _ in range(num_partitions)]
for (ns, name), (prev_j, next_j) in skip_routes.items():
self.by_partition[next_j].append((prev_j, ns, name))
self.by_src_partition[prev_j].append((next_j, ns, name))
for p in self.by_partition:
p.sort()
def copy_policy_by_src(self, prev_j: int) -> Iterable[Tuple[int, Namespace, str]]:
"""Generates skip routes for the given destination partition number.
The skip routes are sorted by source partition number in ascending
order.
Yields:
Each tuple of (source partition number, namespace, name).
"""
for next_j, ns, name in self.by_src_partition[prev_j]:
if prev_j == next_j:
# This skip tensor will be popped at the same partition where
# it is stashed. In this case, copy is not required.
continue
yield (next_j, ns, name)
def copy_policy(self, next_j: int) -> Iterable[Tuple[int, Namespace, str]]:
"""Generates skip routes for the given destination partition number.
The skip routes are sorted by source partition number in ascending
......
......@@ -25,11 +25,12 @@ one of the most important feature of :mod:`torchpipe.skip`.
The metaphor is inspired by Portal™ from Valve.
"""
from typing import List, Optional, Tuple
from typing import Any, List, Optional, Tuple
import torch
from torch import Tensor
from . import Namespace
from ..copy import Context as CopyContext
from ..copy import Copy
from ..phony import get_phony
......@@ -41,9 +42,16 @@ __all__: List[str] = []
class Portal:
"""A portal for a tensor."""
def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None:
def __init__(self, tensor: Optional[Tensor], tensor_life: int, index: int) -> None:
self.put_tensor(tensor, tensor_life)
self.grad: Optional[Tensor] = None
self.__index = index
self.ns_name: Optional[Tuple[Namespace, str]]
self.pipeline: Any
@property
def index(self) -> int:
return self.__index
def blue(self) -> Tensor:
"""Creates a :class:`PortalBlue` which hides the underlying tensor from
......@@ -151,12 +159,17 @@ class Portal:
def put_grad(self, grad: Tensor) -> None:
"""Stores a gradient into this portal."""
if hasattr(self, "pipeline"):
self.pipeline.send_portal_grad(self.ns_name, self.index, grad)
self.grad = grad
def use_grad(self) -> Tensor:
"""Retrieves and removes the underlying gradient. The gradient is
always ephemeral.
"""
if self.grad is None and hasattr(self, "pipeline"):
self.grad = self.pipeline.recv_portal_grad(self.ns_name, self.index)
if self.grad is None:
raise RuntimeError("grad in portal has been removed or never set")
......
......@@ -204,7 +204,7 @@ class Skippable(nn.Module):
# Load skip tensors that might be popped.
poppable_tensors = {}
batch = Batch(input)
batch = Batch(input, skip_tracker.index)
for ns, name in self.poppable():
try:
poppable_tensors[name] = skip_tracker.load(batch, ns, name)
......@@ -237,7 +237,7 @@ class Skippable(nn.Module):
raise RuntimeError(f"{comma_names} must be popped but have not")
# Save stashed skip tensors.
batch = Batch(output)
batch = Batch(output, skip_tracker.index)
for ns, name in self.stashable():
tensor = stashed_tensors[name]
skip_tracker.save(batch, ns, name, tensor)
......
......@@ -61,6 +61,10 @@ class SkipTracker:
) -> None:
raise TypeError("copy is not supported for non-portal skip tensors")
@property
def index(self) -> int:
return 0
class SkipTrackerThroughPotals(SkipTracker):
"""Tracks saved skip tensors through portals. The skip tensors will be
......@@ -71,10 +75,15 @@ class SkipTrackerThroughPotals(SkipTracker):
"""
def __init__(self, skip_layout: SkipLayout) -> None:
def __init__(self, skip_layout: SkipLayout, index: int) -> None:
super().__init__()
self.skip_layout = skip_layout
self.portals: Dict[Tuple[Namespace, str], Portal] = {}
self.__index = index
@property
def index(self) -> int:
return self.__index
def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None:
"""Saves the stashed skip tensor in a portal. The portal is then
......@@ -106,7 +115,9 @@ class SkipTrackerThroughPotals(SkipTracker):
else:
tensor_life = 2 # Delete at [6. PortalOrange.forward]
portal = Portal(tensor, tensor_life)
assert batch.index == self.index
portal = Portal(tensor, tensor_life, batch.index)
portal.ns_name = (ns, name)
self.portals[(ns, name)] = portal
else:
......
......@@ -21,7 +21,7 @@
CPU device.
"""
from contextlib import contextmanager
from typing import Generator, List, Union, cast
from typing import Generator, List, Optional, Union, cast
import torch
......@@ -72,8 +72,12 @@ def use_device(device: torch.device) -> Generator[None, None, None]:
@contextmanager
def use_stream(stream: AbstractStream) -> Generator[None, None, None]:
def use_stream(stream: Optional[AbstractStream]) -> Generator[None, None, None]:
""":func:`torch.cuda.stream` for either CPU or CUDA stream."""
if not stream:
yield
return
if not is_cuda(stream):
yield
return
......@@ -120,7 +124,7 @@ def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
tensor.record_stream(as_cuda(stream))
def is_cuda(stream: AbstractStream) -> bool:
def is_cuda(stream: Optional[AbstractStream]) -> bool:
"""Returns ``True`` if the given stream is a valid CUDA stream."""
return stream is not CPUStream
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment