Unverified Commit 63f7796a authored by Tom Birch's avatar Tom Birch Committed by GitHub
Browse files

Multi-process pipe (#90)

Adds support for distributing pipeline stages across multiple processes (and therefore multiple machines)
* Adds a style argument to the Pipe constructor, defaulting to PipelineStyle.SingleProcess, but also supporting PipelineStyle.MultiProcess
* Added support for lazy construction of modules (see lazy_construction for an example)
* Added two implementations of inter-process communication: one based on rpc with globally visible queues, one based on send/recv
* Copied all the relevant tests from tests/pipe to tests/pipe_process and modified them to exercise PipelineStyle.MultiProcess
parent 49a198c9
......@@ -149,7 +149,7 @@ jobs:
- run:
name: Run type-checking (mypy)
command: |
mypy --pretty .
mypy --ignore-missing-imports --scripts-are-modules --pretty .
- <<: *run_flake8
......
[settings]
known_third_party =numpy,pytest,recommonmark,setuptools,torch,torchtext,torchvision
known_third_party =benchmark_dataset,dataclasses,numpy,packaging,pytest,recommonmark,setuptools,torch,torchtext,torchvision
......@@ -37,6 +37,7 @@ repos:
rev: 4.3.20
hooks:
- id: isort
exclude: README.md
additional_dependencies: [toml]
- repo: https://github.com/pre-commit/mirrors-mypy
......
import torch
from torch.utils.data import Dataset
def collate_sentences_lm(samples):
if len(samples) == 0:
return {}
id = torch.LongTensor([s["id"] for s in samples])
src_tokens = torch.stack([s["source"] for s in samples], 0)
tgt_tokens = torch.stack([s["target"] for s in samples], 0)
ntokens = len(samples) * len(samples[0]["target"])
src_lengths = torch.LongTensor([len(samples[0]["source"])] * len(samples))
batch = {
"id": id,
"nsentences": len(samples),
"ntokens": ntokens,
"input": src_tokens,
"target": tgt_tokens,
}
return batch
class BenchmarkLMDataset(Dataset):
"""
Dataset to benchmark a translation like seq2seq task.
Args:
vocab_size (int, optional): size of the vocabulary (default 10000).
max_source_positions (int, optional): max number of tokens in the
source sentence (default: 1024).
total_samples (int, optional): the total number of rows in the
dataset (default: 10000).
"""
def __init__(
self, vocab_size=10000, max_source_positions=1024, total_samples=10000,
):
self.vocab_size = vocab_size
self.max_source_positions = max_source_positions
self.total_samples = total_samples
self.sizes = [self.max_source_positions] * self.total_samples
def __getitem__(self, index):
length = self.sizes[index]
source = torch.randint(1, self.vocab_size, (length,))
target = source.clone()
return {
"id": index,
"source": source,
"target": target,
}
def __len__(self):
return self.total_samples
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import argparse
import math
import os
import time
import warnings
from benchmark_dataset import BenchmarkLMDataset, collate_sentences_lm
import torch
from torch.distributed import rpc
import torch.multiprocessing as mp
import torch.nn as nn
from torch.utils.data import DataLoader
import torchtext
from torchtext.data.utils import get_tokenizer
from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.pipe import pipe
from fairscale.optim import GradScaler
from tests.nn.model_parallel.commons import dist_init, get_worker_map
try:
from fairscale.optim import Adam, Precision # type: ignore
from fairscale.optim import Adam # type: ignore
can_benchmark = True
except ImportError:
......@@ -21,6 +31,18 @@ except ImportError:
can_benchmark = False
def init_random_seed(seed: int):
import numpy
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
numpy.random.seed(seed)
PIPE_CHUNKS = 2
iteration_count = 0
class EmbeddingLayer(nn.Embedding):
def __init__(self, ntoken, ninp, initrange):
super().__init__(ntoken, ninp)
......@@ -51,7 +73,7 @@ class PositionalEncodingLayer(nn.Module):
class TransformerDecoderLayer(nn.TransformerEncoderLayer):
"""Though this class inherits from torch.nn.TransformerEncoderLayer,
it functions as a decoder in this model"""
it functions as a decoder in this model"""
def __init__(self, ninp, nhead, nhid, droupout):
super().__init__(ninp, nhead, nhid, droupout)
......@@ -63,6 +85,11 @@ class TransformerDecoderLayer(nn.TransformerEncoderLayer):
return mask
def forward(self, src):
global iteration_count
iteration_count += 1
# if iteration_count == 196:
# dump_cuda_tensors()
if self.src_mask is None or self.src_mask.size(0) != len(src):
device = src.device
mask = self._generate_square_subsequent_mask(len(src)).to(device)
......@@ -80,32 +107,36 @@ class LinearLayer(nn.Linear):
class TransformerLMSequntial(nn.Sequential):
"""A small language model based on the design of GPT-2 using nn.Sequeitnal
for compatability with Pipe"""
for compatability with Pipe"""
def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange):
super(TransformerLMSequntial, self).__init__(
def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder):
layers = [
EmbeddingLayer(ntokens, ninp, initrange),
PositionalEncodingLayer(ninp, dropout),
TransformerDecoderLayer(ninp, nhead, nhid, dropout),
LinearLayer(ninp, ntokens, initrange),
)
]
for _ in range(ndecoder):
layers.append(TransformerDecoderLayer(ninp, nhead, nhid, dropout))
layers.append(LinearLayer(ninp, ntokens, initrange))
super(TransformerLMSequntial, self).__init__(*layers)
def get_data(device):
TEXT = torchtext.data.Field(
tokenize=get_tokenizer("basic_english"), init_token="<sos>", eos_token="<eos>", lower=True
)
train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)
TEXT.build_vocab(train_txt)
ntokens = len(TEXT.vocab.stoi)
with warnings.catch_warnings(record=True) as fjldska:
TEXT = torchtext.data.Field(
tokenize=get_tokenizer("basic_english"), init_token="<sos>", eos_token="<eos>", lower=True
)
train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)
TEXT.build_vocab(train_txt)
ntokens = len(TEXT.vocab.stoi)
batch_size = 500
eval_batch_size = 200
train_data = batchify(train_txt, batch_size, TEXT, device)
val_data = batchify(val_txt, eval_batch_size, TEXT, device)
test_data = batchify(test_txt, eval_batch_size, TEXT, device)
batch_size = 20
eval_batch_size = 10
train_data = batchify(train_txt, batch_size, TEXT, device)
val_data = batchify(val_txt, eval_batch_size, TEXT, device)
test_data = batchify(test_txt, eval_batch_size, TEXT, device)
return ntokens, train_data, val_data, test_data
return ntokens, train_data, val_data, test_data
def batchify(data, bsz, TEXT, device):
......@@ -123,71 +154,188 @@ def get_batch(source, i, bptt):
return data, target
def make_model(device, ntokens):
ninp = 50 # embedding dimension
nhid = 50 # the dimension of the feedforward network model in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models
def make_model(args, device, ntokens):
ninp = 2048 # embedding dimension
nhid = 2048 # the dimension of the feedforward network model in nn.TransformerEncoder
nhead = 32 # the number of heads in the multiheadattention models
dropout = 0
initrange = 0.1
ndecoder = args.num_decoder_layers
model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange).half().to(device)
balance = generate_balance(min(num_devices, 4), len(model))
p = Pipe(model, balance, chunks=len(balance))
if args.lazy_construction:
layers = [
lambda: EmbeddingLayer(ntokens, ninp, initrange),
lambda: PositionalEncodingLayer(ninp, dropout),
]
for _ in range(ndecoder):
layers.append(lambda: TransformerDecoderLayer(ninp, nhead, nhid, dropout))
layers.append(lambda: LinearLayer(ninp, ntokens, initrange))
model = layers
else:
model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
criterion = nn.CrossEntropyLoss()
lr = 0.001 # learning rate
lr = 0.01 # learning rate
try:
optimizer = Adam(p.parameters(), lr=lr, precision=Precision.PURE_FP16)
except NameError:
optimizer = Adam(p.parameters(), lr=lr)
def make_adam(model):
return Adam(model.parameters(), lr=lr)
optimizer = make_adam
scaler = GradScaler()
return p, criterion, optimizer, scaler
return model, criterion, optimizer, scaler
def train(train_data, model, criterion, optimizer, scaler, bptt, ntokens):
def get_tensors_by_size_bucket():
import gc
from collections import defaultdict
size_buckets = defaultdict(int)
for obj in gc.get_objects():
if not isinstance(obj, torch.Tensor):
continue
if obj.device.type == "cuda":
size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1
return size_buckets
def dump_size_buckets(size_buckets, prefix=""):
import operator
from functools import reduce
total = 0
for key, value in size_buckets.items():
this = reduce(operator.mul, key) * value
total += this
print(prefix + f"{key} : {value}, {this}")
print(prefix + f"total = {total}")
last_size_buckets = None
once = True
def safe_rank():
try:
return torch.distributed.get_rank()
except AssertionError:
return 0
def check_size_buckets():
global last_size_buckets
global once
size_buckets = get_tensors_by_size_bucket()
if last_size_buckets is not None:
if size_buckets != last_size_buckets:
print(f"difference is oustanding tensors: {safe-rank()}")
dump_size_buckets(last_size_buckets, "old: ")
dump_size_buckets(size_buckets, "new: ")
if once:
print(f"dumping buckets for: {safe_rank()}")
dump_size_buckets(last_size_buckets, "old: ")
dump_size_buckets(size_buckets, "new: ")
once = False
else:
print(f"size buckets none on {safe_rank()}")
last_size_buckets = size_buckets
def dump_cuda_tensors():
print(f"dumping cuda tensors...")
from functools import reduce
import operator
import gc
for obj in gc.get_objects():
if not isinstance(obj, torch.Tensor):
continue
if obj.device.type == "cuda":
size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1
print(f"outstanding cuda tensors:")
total = 0
for key, value in size_buckets.items():
this = reduce(operator.mul, key) * value
total += this
print(f"{key} : {value}, {this}")
print(f"total size = {total}")
import pprint
pprint.pprint(torch.cuda.memory_stats())
def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
model.train()
from functools import reduce
import operator
num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
if model.group:
print(f"training model, #prams = {num_params}, group: {model.group.rank()}, sizes {model.group.size()}")
else:
print(f"training model, #prams = {num_params}")
vocab_size = 10000 # FIXME
total_loss = 0.0
start_time = time.time()
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
data, targets = get_batch(train_data, i, bptt)
word_counter = 0
optimizer = optimizer(model)
def get_first_device(model):
if model.devices:
return model.devices[0]
else:
return torch.cuda.current_device()
def get_last_device(model):
if model.devices:
return model.devices[-1]
else:
return torch.cuda.current_device()
for i, batch in enumerate(lm_dataloader):
if args.max_batch and i > args.max_batch:
break
optimizer.zero_grad()
output = model(data)
output = output.to(targets.device)
loss = criterion(output.view(-1, ntokens), targets)
scaler.scale(loss).backward()
scaler.step(optimizer) # scaler.step automatically unscale if unscale has not yet been performed
scaler.update()
total_loss += loss.item()
log_interval = 50
if batch % log_interval == 0 and batch > 0:
cur_loss = total_loss / log_interval
elapsed = time.time() - start_time
try:
print(
"| {:5d}/{:5d} batches | ms/batch {:5.2f} | "
"loss {:5.2f} | ppl {:8.2f} | grad scale {:3d} | optim scale {:3d}".format(
batch,
len(train_data) // bptt,
elapsed * 1000 / log_interval,
cur_loss,
math.exp(cur_loss),
int(scaler.get_scale()),
int(optimizer._optim_scale),
)
)
except AttributeError:
output = model(batch["input"].to(get_first_device(model)))
if model.group is None or model.group.rank() == model.group.size() - 1:
target = batch["target"].to(get_last_device(model))
output = output.to(target.device)
loss = criterion(output.view(-1, vocab_size), target.view(-1))
loss.backward()
else:
model.back_helper(output)
del output
torch.nn.utils.clip_grad_value_(model.parameters(), 0.05)
optimizer.step()
if model.group is None or model.group.rank() == model.group.size() - 1:
total_loss += loss.item()
log_interval = 1
word_counter += batch["ntokens"]
if i % log_interval == 0 and i > 0:
cur_loss = total_loss / log_interval
elapsed = time.time() - start_time
print(
"| {:5d}/{:5d} batches | ms/batch {:5.2f} | "
"loss {:5.2f} | ppl {:8.2f}".format(
batch, len(train_data) // bptt, elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)
"| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
i, word_counter / elapsed, cur_loss, math.exp(cur_loss)
)
)
total_loss = 0
start_time = time.time()
word_counter = 0
total_loss = 0
start_time = time.time()
# if i >= 10:
# break
# torch.cuda.empty_cache()
# check_size_buckets()
def evaluate(eval_model, data_source, criterion, bptt, ntokens):
......@@ -207,7 +355,7 @@ def get_number_of_words(data):
return data.size()[0] * data.size()[1]
def benchmark_language_model(train_data, val_data, test_data, model, criterion, optimizer, scaler, ntokens):
def benchmark_language_model(train_data, val_data, test_data, model, criterion, optimizer, ntokens, args):
epoch = 1
bptt = 35
start_time = time.time()
......@@ -216,9 +364,9 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
print("| start of epoch {:1d}".format(epoch))
print("-" * 110)
epoch_start_time = time.time()
train(train_data, model, criterion, optimizer, scaler, bptt, ntokens)
val_loss = evaluate(model, val_data, criterion, bptt, ntokens)
print("-" * 110)
train(train_data, model, criterion, optimizer, bptt, ntokens, args)
val_loss = 1 # evaluate(model, val_data, criterion, bptt, ntokens)
print("-" * 89)
print(
"| end of epoch {:1d} | time: {:5.2f}s | valid loss {:5.2f} ".format(
epoch, (time.time() - epoch_start_time), val_loss
......@@ -230,8 +378,8 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
nwords = get_number_of_words(train_data) + get_number_of_words(val_data)
wps = nwords / elapsed_time
test_loss = evaluate(model, test_data, criterion, bptt, ntokens)
print("=" * 110)
test_loss = 1 # evaluate(model, test_data, criterion, bptt, ntokens)
print("=" * 89)
print(
"| end of training | test loss {:5.2f} \n| time: {:5.2f}s | words: {:3d} | wps: {:5.2f}".format(
test_loss, elapsed_time, nwords, wps
......@@ -272,13 +420,186 @@ def generate_balance(num_devices, num_layers):
return balance
if __name__ == "__main__":
def make_model_and_data(args, device, new_data: bool = True):
if new_data:
device = torch.device("cuda")
vocab_size = 10000
model, criterion, optimizer, scaler = make_model(args, device, vocab_size)
lm_dataset = BenchmarkLMDataset()
lm_dataloader = DataLoader(
lm_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=collate_sentences_lm
)
return {
"model": model,
"criterion": criterion,
"optimizer": optimizer,
"data": lm_dataloader,
"vocab_size": vocab_size,
}
else:
device = torch.device("cuda")
data = get_data(device)
ntokens, train_data, val_data, test_data = data
model, criterion, optimizer, scaler = make_model(args, device, ntokens)
return {
"model": model,
"criterion": criterion,
"optimizer": optimizer,
"data": data,
}
def bench_single_process(args):
num_devices = torch.cuda.device_count()
assert num_devices > 0
torch.manual_seed(0)
init_random_seed(0)
device = torch.device("cuda")
ntokens, train_data, val_data, test_data = get_data(device)
model, criterion, optimizer, scaler = make_model(device, ntokens)
benchmark_language_model(train_data, val_data, test_data, model, criterion, optimizer, scaler, ntokens)
new_data = True
blob = make_model_and_data(args, None, new_data=new_data)
model = blob["model"]
balance = generate_balance(min(num_devices, 8), len(model))
p = pipe.Pipe(
model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint
)
del model
del blob["model"]
if new_data:
train(blob["data"], p, blob["criterion"], blob["optimizer"], blob["vocab_size"], args)
else:
ntokens, train_data, val_data, test_data = blob["data"]
benchmark_language_model(train_data, val_data, test_data, p, criterion, optimizer, ntokens, args)
def run_mp_worker(args, available_workers):
new_data = True
blob = make_model_and_data(args, None, new_data=new_data)
model = blob["model"]
balance = generate_balance(min(available_workers, 8), len(model))
p = pipe.Pipe(
model,
balance,
style=Pipe.MultiProcess,
chunks=args.chunks,
worker_map=get_worker_map(),
input_device=torch.cuda.current_device(),
pipelined_backward=args.pipelined_backward,
checkpoint=args.checkpoint,
).cuda()
if args.all_at_once and p.pipeline:
print(f"running all at once")
p.pipeline.all_at_once = True
if new_data:
train(blob["data"], p, blob["criterion"], blob["optimizer"], blob["vocab_size"], args)
else:
ntokens, train_data, val_data, test_data = blob["data"]
benchmark_language_model(train_data, val_data, test_data, p, criterion, optimizer, ntokens, args)
def run_worker(rank, world_size, args):
if args.world_size != 0:
world_size = args.world_size
dist_init(rank + args.rank_base, world_size, hostname=args.host)
initialize_model_parallel(1, world_size)
init_random_seed(0)
run_mp_worker(args, world_size)
rpc.shutdown()
torch.distributed.destroy_process_group()
def bench_multi_process(args, all_at_once=False):
if args.local_world_size != 0:
world_size = args.local_world_size
else:
world_size = min(torch.cuda.device_count(), 2)
mp.spawn(run_worker, args=(world_size, args), nprocs=world_size, join=True)
best_device_map = {
0: "mlx5_0:1",
1: "mlx5_0:1",
2: "mlx5_1:1",
3: "mlx5_1:1",
4: "mlx5_2:1",
5: "mlx5_2:1",
6: "mlx5_3:1",
7: "mlx5_3:1",
}
def bench_mpi(args):
guess_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
os.environ["UCX_NET_DEVICES"] = best_device_map[guess_rank]
torch.distributed.init_process_group(backend="mpi")
os.environ["MASTER_ADDR"] = args.host
os.environ["MASTER_PORT"] = "10639"
if args.socket_name:
os.environ["GLOO_SOCKET_IFNAME"] = args.socket_name
os.environ["TP_SOCKET_IFNAME"] = args.socket_name
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(rank % torch.cuda.device_count())
rpc.init_rpc(
f"Test{rank}",
rank=rank,
world_size=world_size,
backend=rpc.BackendType.PROCESS_GROUP,
rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(rpc_timeout=20, init_method=init_method),
)
initialize_model_parallel(1, world_size)
init_random_seed(0)
run_mp_worker(args, world_size)
rpc.shutdown()
torch.distributed.destroy_process_group()
parser = argparse.ArgumentParser(description="benchmark")
parser.add_argument("--local-world-size", "-l", type=int, default=0, help="local world size")
parser.add_argument("--world-size", "-w", type=int, default=0, help="world size")
parser.add_argument("--rank-base", "-r", type=int, help="rank base", default=0)
parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname")
parser.add_argument("--no-mpi", action="store_true", default=False, help="disable mpi")
parser.add_argument("--chunks", type=int, default=1, help="number of microbatches per batch")
parser.add_argument("--batch-size", type=int, default=8, help="size of a batch")
parser.add_argument("--all-at-once", action="store_true", default=False, help="do backward pass on whole batch at once")
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
parser.add_argument("--socket-name", type=str, default=None, help="socket ifname for gloo/tp")
parser.add_argument("--num-decoder-layers", type=int, default=10, help="Number of decoder layers in the model")
parser.add_argument(
"--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
)
parser.add_argument(
"--checkpoint", default="never", choices=["always", "except_last", "never"], help="Checkpointing strategy for pipe"
)
parser.add_argument(
"--pipelined-backward", dest="pipelined_backward", action="store_true", help="Pipelined backward pass"
)
parser.add_argument(
"--no-pipelined-backward", dest="pipelined_backward", action="store_false", help="Pipelined backward pass"
)
parser.set_defaults(pipelined_backward=True)
if __name__ == "__main__":
args = parser.parse_args()
# bench_multi_process(args, all_at_once=True)
if args.no_mpi or "OMPI_COMM_WORLD_RANK" not in os.environ:
print(f"Running benchmark with args: {args}")
bench_single_process(args)
else:
if os.environ["OMPI_COMM_WORLD_RANK"] == "0":
print(f"Running benchmark with args: {args}")
bench_mpi(args)
......@@ -13,6 +13,7 @@
#
import os
import sys
from typing import Any, List
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
......@@ -46,7 +47,7 @@ templates_path = ["_templates"]
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
exclude_patterns: List[Any] = []
# -- Options for HTML output -------------------------------------------------
......
......@@ -5,6 +5,7 @@
from .cross_entropy import vocab_parallel_cross_entropy
from .initialize import (
destroy_model_parallel,
get_data_parallel_group,
get_data_parallel_rank,
get_data_parallel_world_size,
......@@ -12,6 +13,8 @@ from .initialize import (
get_model_parallel_rank,
get_model_parallel_src_rank,
get_model_parallel_world_size,
get_pipeline_parallel_group,
get_pipeline_parallel_ranks,
initialize_model_parallel,
)
from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
......
......@@ -35,6 +35,8 @@ _DATA_PARALLEL_GROUP = None
# Pipeline parallel group that the current rank belongs to.
_PIPELINE_PARALLEL_GROUP = None
_PIPELINE_PARALLEL_RANKS = None
def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = 1) -> None:
"""
......@@ -93,7 +95,15 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int =
global _PIPELINE_PARALLEL_GROUP
assert _PIPELINE_PARALLEL_GROUP is None, "model parallel group is already initialized"
_PIPELINE_PARALLEL_GROUP = groups[found[0], :, found[2]].tolist()
global _PIPELINE_PARALLEL_RANKS
assert _PIPELINE_PARALLEL_RANKS is None, "model parallel group is already initialized"
for i in range(data_parallel_size):
for k in range(model_parallel_size):
ranks = groups[i, :, k].tolist()
group = torch.distributed.new_group(ranks)
if i == found[0] and k == found[2]:
_PIPELINE_PARALLEL_GROUP = group
_PIPELINE_PARALLEL_RANKS = ranks
def model_parallel_is_initialized() -> bool:
......@@ -115,12 +125,18 @@ def get_data_parallel_group() -> torch.distributed.ProcessGroup:
return _DATA_PARALLEL_GROUP
def get_pipeline_parallel_group() -> List[int]:
def get_pipeline_parallel_group() -> torch.distributed.ProcessGroup:
"""Get the pipeline parallel group the caller rank belongs to."""
assert _PIPELINE_PARALLEL_GROUP is not None, "pipeline parallel group is not initialized"
return _PIPELINE_PARALLEL_GROUP
def get_pipeline_parallel_ranks() -> List[int]:
"""Get the pipeline parallel group the caller rank belongs to."""
assert _PIPELINE_PARALLEL_RANKS is not None, "pipeline parallel group is not initialized"
return _PIPELINE_PARALLEL_RANKS
def get_model_parallel_world_size() -> int:
"""Return world size for the model parallel group."""
return torch.distributed.get_world_size(group=get_model_parallel_group())
......@@ -157,3 +173,6 @@ def destroy_model_parallel() -> None:
_DATA_PARALLEL_GROUP = None
global _PIPELINE_PARALLEL_GROUP
_PIPELINE_PARALLEL_GROUP = None
global _PIPELINE_PARALLEL_RANKS
_PIPELINE_PARALLEL_RANKS = None
......@@ -280,6 +280,9 @@ class ColumnParallelLinear(torch.nn.Module):
return_master_weight=keep_master_weight_for_test,
)
def get_master_weight(self) -> torch.Tensor:
return gather_from_model_parallel_region(self.weight.data.transpose(0, 1)).transpose_(0, 1)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_)
......@@ -364,6 +367,9 @@ class RowParallelLinear(torch.nn.Module):
return_master_weight=keep_master_weight_for_test,
)
def get_master_weight(self) -> torch.Tensor:
return gather_from_model_parallel_region(self.weight.data)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore
# Set up backprop all-reduce.
if self.input_is_parallel:
......
......@@ -19,21 +19,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from .initialize import get_model_parallel_group
from .utils import split_tensor_along_last_dim
def _reduce(input_: torch.Tensor) -> torch.Tensor:
def _reduce(ctx: Any, input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the the input tensor across model parallel group."""
group = get_model_parallel_group()
if ctx:
ctx.mark_dirty(input_)
# Bypass the function if we are using only 1 GPU.
if torch.distributed.get_world_size(group=group) == 1:
return input_
# All-reduce.
print(f"doing all_reduce on {torch.distributed.get_rank()}")
torch.distributed.all_reduce(input_, group=group)
return input_
......@@ -87,11 +93,13 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _CopyToModelParallelRegion Forward")
return input_
@staticmethod
def backward(ctx, grad_output): # type: ignore
return _reduce(grad_output)
print(f"{torch.distributed.get_rank()}: _CopyToModelParallelRegion Backward")
return _reduce(None, grad_output)
class _ReduceFromModelParallelRegion(torch.autograd.Function):
......@@ -99,10 +107,12 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_): # type: ignore
return _reduce(input_)
print(f"{torch.distributed.get_rank()}: _ReduceFromModelParallelRegion Forward")
return _reduce(ctx, input_)
@staticmethod
def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _ReduceFromModelParallelRegion Backward")
return grad_output
......@@ -111,10 +121,12 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _ScatterToModelParallelRegion Forward")
return _split(input_)
@staticmethod
def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _ScatterToModelParallelRegion Backward")
return _gather(grad_output)
......@@ -123,10 +135,12 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _GatherFromModelParallelRegion Forward")
return _gather(input_)
@staticmethod
def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _GatherFromModelParallelRegion Backward")
return _split(grad_output)
......
......@@ -59,7 +59,7 @@ def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float
if any(p.grad is not None for p in module.parameters()):
raise ValueError("some parameter already has gradient")
_batch = Batch(sample)
_batch = Batch(sample, 0)
for i, x in enumerate(_batch):
_batch[i] = x.detach().to(device).requires_grad_(x.requires_grad)
......@@ -101,7 +101,7 @@ def profile_sizes(
if device.type != "cuda":
raise ValueError("size profiler supports only CUDA device")
batch = Batch(input)
batch = Batch(input, 0)
sizes: List[int] = []
latent_scale = batch[0].size(0) / chunks
......
......@@ -5,7 +5,7 @@
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
......@@ -74,20 +74,6 @@ class Function(Protocol):
...
def checkpoint(function: Function, input: TensorOrTensors) -> TensorOrTensors:
"""Makes a checkpoint with a simple interface like
:func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug
:class:`Checkpoint` and :class:`Recompute` without boilerplate.
"""
batch = Batch(input)
chk = Checkpointing(function, batch)
batch = chk.checkpoint()
chk.recompute(batch)
return batch.tensor_or_tensors
class Checkpointing:
"""Generates a pair of :class:`Checkpoint` and :class:`Recompute`."""
......@@ -116,7 +102,7 @@ class Checkpointing:
if isinstance(output, tuple):
output = tuple([x if x.is_floating_point() else x.detach() for x in output])
return Batch(output)
return Batch(output, self.batch.index)
def recompute(self, batch: Batch) -> None:
"""Applies :class:`Recompute` to the batch in place."""
......@@ -226,6 +212,7 @@ def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None
else:
gpu_rng_state = None
rng_states.clear()
rng_states.append((cpu_rng_state, gpu_rng_state))
......@@ -237,7 +224,7 @@ def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> G
.. seealso:: :ref:`Referential Transparency`
"""
cpu_rng_state, gpu_rng_state = rng_states.pop()
cpu_rng_state, gpu_rng_state = rng_states[0]
gpu_devices: List[torch.device] = []
if device.type == "cuda":
......
......@@ -53,9 +53,14 @@ class Batch:
"""
def __init__(self, value: TensorOrTensors) -> None:
def __init__(self, value: TensorOrTensors, index: int) -> None:
self.value = value
self.atomic = torch.is_tensor(value)
self.__index = index
@property
def index(self) -> int:
return self.__index
@property
def tensor(self) -> Tensor:
......@@ -80,7 +85,7 @@ class Batch:
"""Calls a function by the underlying tensor or tensors. It also wraps
the output with :class:`Batch`.
"""
return Batch(function(self.value))
return Batch(function(self.value), self.index)
def __repr__(self) -> str:
return f"Batch[atomic={self.atomic!r}]({self.value!r})"
......@@ -176,7 +181,7 @@ def scatter(input: TensorOrTensors, chunks: int) -> List[Batch]:
inputs = zip(*rotated)
return [Batch(x) for x in inputs]
return [Batch(x, i) for i, x in enumerate(inputs)]
def gather(outputs: List[Batch]) -> TensorOrTensors:
......
......@@ -19,18 +19,21 @@
"""The Pipe interface."""
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast
import warnings
import torch
from torch import Tensor, nn
import torch.autograd
import torch.cuda
from fairscale.nn.model_parallel import get_model_parallel_group, get_pipeline_parallel_group
from . import microbatch
from .batchnorm import DeferredBatchNorm
from .pipeline import Pipeline
from .skip.layout import inspect_skip_layout
from .skip.skippable import verify_skippables
from .pipeline import Pipeline, PipelineStyle
from .skip.layout import SkipLayout, inspect_skip_layout
from .skip.skippable import Skippable, verify_skippables
from .stream import AbstractStream, new_stream
__all__ = ["Pipe"]
......@@ -42,6 +45,8 @@ Devices = Union[Iterable[Device], List[Device]]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
ListOfLazyModules = List[Callable[[], nn.Module]]
if TYPE_CHECKING:
Module = nn.Module[TensorOrTensors]
NamedModules = OrderedDict[str, Module]
......@@ -69,17 +74,34 @@ naive automatic balancing:
"""
def verify_module(module: nn.Sequential) -> None:
if not isinstance(module, nn.Sequential):
raise TypeError("module must be nn.Sequential to be partitioned")
# FIXME(tom) make this a valid way to call
def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None:
for layer in module:
if isinstance(layer, nn.Module):
pass
elif callable(layer):
pass
else:
raise TypeError(f"layer {type(layer)} must be nn.Module or callable to be partitioned")
def verify_module(module: Union[nn.Sequential, ListOfLazyModules]) -> None:
if isinstance(module, Iterable) and not isinstance(module, nn.Sequential):
verify_list_of_callable(module)
else:
if not isinstance(module, nn.Sequential):
raise TypeError("module must be nn.Sequential to be partitioned")
named_children = list(module.named_children())
if len(named_children) != len(module):
raise ValueError("module with duplicate children is not supported")
named_children = list(module.named_children())
if len(named_children) != len(module):
raise ValueError("module with duplicate children is not supported")
def verify_splitting(
module: nn.Sequential, partitions: List[nn.Sequential], balance: Iterable[int], devices: List[torch.device]
module: nn.Sequential,
partitions: List[nn.Sequential],
balance: Iterable[int],
devices: Optional[List[torch.device]],
) -> None:
num_parameters = len(list(module.parameters()))
num_child_parameters = sum(len(list(child.parameters())) for child in module.children())
......@@ -90,7 +112,7 @@ def verify_splitting(
for j in range(i + 1, len(partitions)):
parti = partitions[i]
partj = partitions[j]
if devices[i] == devices[j]:
if devices and devices[i] == devices[j]:
continue
for p in parti.parameters():
for q in partj.parameters():
......@@ -102,9 +124,65 @@ class BalanceError(ValueError):
pass
def check_balance(module: Any, balance: Iterable[int]) -> None:
if len(module) != sum(balance):
raise BalanceError(
f"module and sum of balance have different length (module: {len(module)}, sum of balance: {sum(balance)})"
)
if any(x <= 0 for x in balance):
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
def instantiate_partition(
module: Union[nn.Sequential, ListOfLazyModules], balance: Iterable[int], group: torch.distributed.ProcessGroup
) -> nn.Sequential:
balance = list(balance)
check_balance(module, balance)
layers: NamedModules = OrderedDict()
j = 0
def maybe_realize(layer: Any) -> nn.Module:
if isinstance(layer, nn.Module):
return layer
elif callable(layer):
return layer()
else:
raise TypeError(f"layer must be nn.Module or callable, is {type(layer)}")
def iterate_module(module: Union[nn.Sequential, list]) -> Iterable[Tuple[Any, nn.Module]]:
if isinstance(module, nn.Sequential):
yield from module.named_children()
else:
yield from enumerate(module)
for name, layer in iterate_module(module):
layers[name] = layer
if len(layers) == balance[j]:
if j == group.rank():
for key in layers:
layers[key] = maybe_realize(layers[key])
if not isinstance(module, nn.Sequential):
for layer in layers.values():
if isinstance(layer, Skippable):
raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction")
partition = nn.Sequential(*layers.values())
return partition
# Prepare for the next partition.
layers.clear()
j += 1
raise ValueError("Souldn't get here, more ranks than partitions")
def split_module(
module: nn.Sequential, balance: Iterable[int], devices: List[torch.device],
) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]:
module: nn.Sequential, balance: Iterable[int], devices: Optional[List[torch.device]],
) -> Tuple[List[nn.Sequential], List[int], Optional[List[torch.device]]]:
"""Splits a module into multiple partitions.
Returns:
......@@ -123,18 +201,11 @@ def split_module(
"""
balance = list(balance)
if len(module) != sum(balance):
raise BalanceError(
"module and sum of balance have different length "
f"(module: {len(module)}, sum of balance: {sum(balance)})"
)
if any(x <= 0 for x in balance):
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
check_balance(module, balance)
if len(balance) > len(devices):
if devices and len(balance) > len(devices):
raise IndexError(
"too few devices to hold given partitions " f"(devices: {len(devices)}, partitions: {len(balance)})"
f"too few devices to hold given partitions (devices: {len(devices)}, partitions: {len(balance)})"
)
j = 0
......@@ -148,8 +219,9 @@ def split_module(
# Group buffered layers as a partition.
partition = nn.Sequential(layers)
device = devices[j]
partition.to(device)
if devices:
device = devices[j]
partition.to(device)
partitions.append(partition)
......@@ -158,12 +230,13 @@ def split_module(
j += 1
partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))
del devices[j:]
if devices:
del devices[j:]
return partitions, balance, devices
MOVING_DENIED = TypeError("denied to move parameters and buffers, " "because Pipe should manage device placement")
MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement")
class Pipe(Module):
......@@ -193,8 +266,23 @@ class Pipe(Module):
list of number of layers in each partition
Keyword Args:
style (PipelineStyle):
whether to use a single process for all pipeline stages or to assign
one stage per process
devices (iterable of devices):
devices to use (default: all CUDA devices)
group (ProcessGroup):
specific to `style=MultiProcess`, the process group that all
pipeline stages are a member of. Defaults to
`get_pipeline_parallel_group()`
worker_map (Dict[int, str]):
a map from worker name (the first argument to
`torch.distributed.rpc.init_rpc`) to global rank (i.e.
`torch.distributed.get_rank()`) needed in order for pipeline stages
to communicate with each other
input_device (device):
the device on which tensors should be located before being passed to
the first module in a given pipeline stage
chunks (int):
number of micro-batches (default: ``1``)
checkpoint (str):
......@@ -204,6 +292,16 @@ class Pipe(Module):
whether to use deferred BatchNorm moving statistics (default:
:data:`False`, see :ref:`Deferred Batch Normalization` for more
details)
pipelined_backward (bool, optional):
if True, call torch.autograd.backward once per microbatch on the
backward pass (instead of once for the whole batch). This works
around a potential deadlock in pytorch when using tensor parallelism
at the same time. Defaults to `True` if
`get_model_parallel_group.size() > 1`
(default: `None`)
retain_graph (bool):
The value passed to `torch.autograd.backwards(..., retain_graph=<value>)
(default: = `True`)
Raises:
TypeError:
......@@ -215,6 +313,9 @@ class Pipe(Module):
"""
SingleProcess: PipelineStyle = PipelineStyle.SingleProcess
MultiProcess: PipelineStyle = PipelineStyle.MultiProcess
#: The number of layers in each partition.
balance: List[int] = []
# ^^
......@@ -234,7 +335,7 @@ class Pipe(Module):
#: output = pipe(input)
#: loss = F.cross_entropy(output, target)
#:
devices: List[torch.device] = []
devices: Optional[List[torch.device]] = None
#: The number of micro-batches.
chunks: int = 1
......@@ -245,13 +346,19 @@ class Pipe(Module):
def __init__(
self,
module: nn.Sequential,
module: Union[nn.Sequential, ListOfLazyModules],
balance: Optional[Iterable[int]] = None,
*,
style: PipelineStyle = PipelineStyle.SingleProcess,
devices: Optional[Devices] = None,
group: Optional[torch.distributed.ProcessGroup] = None,
worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None,
chunks: int = chunks,
checkpoint: str = checkpoint,
deferred_batch_norm: bool = False,
pipelined_backward: bool = None,
retain_graph: bool = False,
) -> None:
super().__init__()
......@@ -269,36 +376,110 @@ class Pipe(Module):
# Verify if the underlying skippable modules satisfy integrity. The
# integrity can be verified before forward() because it is static.
verify_skippables(module)
if isinstance(module, nn.Sequential):
verify_skippables(module)
self.chunks = chunks
self.checkpoint = checkpoint
self.pipelined_backward = pipelined_backward
self.retain_graph = retain_graph
self.pipeline: Optional[Pipeline]
if deferred_batch_norm:
module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks)
if style is PipelineStyle.SingleProcess:
module = cast(nn.Sequential, module)
if deferred_batch_norm:
module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks)
if devices is None:
devices = range(torch.cuda.device_count())
devices = [torch.device(d) for d in devices]
devices = cast(List[torch.device], devices)
if input_device is not None:
raise ValueError("'input_device' argument only applies to 'PipelineStyle.MultiProcess'")
try:
self.partitions, self.balance, self.devices = split_module(module, balance, devices)
except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc)))
if devices is None:
devices = range(torch.cuda.device_count())
verify_splitting(module, self.partitions, self.balance, self.devices)
devices = [torch.device(d) for d in devices]
devices = cast(List[torch.device], devices)
self._copy_streams: List[List[AbstractStream]] = []
self._skip_layout = inspect_skip_layout(self.partitions)
try:
self.partitions, self.balance, self.devices = split_module(module, balance, devices)
except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc)))
verify_splitting(module, self.partitions, self.balance, self.devices)
# Separate CUDA streams for copy.
copy_streams = self._ensure_copy_streams()
self._skip_layout = inspect_skip_layout(self.partitions)
elif style is PipelineStyle.MultiProcess:
if group is None:
group = get_pipeline_parallel_group()
if devices is not None:
raise ValueError("'devices' argument only applies to 'PipelineStyle.SingleProcess'")
self.balance = list(balance)
if group.size() < len(self.balance):
raise IndexError(
f"too few ranks to hold given partitions (ranks: {group.size()}, partitions: {len(self.balance)})"
)
try:
rank = torch.distributed.get_rank(group)
if rank >= len(self.balance):
warnings.warn("More ranks than partitions, some ranks unused")
self.partitions = cast(List[nn.Sequential], nn.ModuleList([nn.Sequential()]))
else:
partition = instantiate_partition(module, balance, group)
if deferred_batch_norm:
partition = DeferredBatchNorm.convert_deferred_batch_norm(partition, chunks)
self.partitions = cast(List[nn.Sequential], nn.ModuleList([partition]))
self.devices = None
if isinstance(module, nn.Sequential):
local_partitions, _, _ = split_module(module, balance, None)
self._skip_layout = inspect_skip_layout(local_partitions)
else:
self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom)
except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc)))
self.group = group
self.worker_map = worker_map
self.input_device = input_device
self._copy_streams: List[List[AbstractStream]] = []
# The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop)
if style is PipelineStyle.SingleProcess:
# Separate CUDA streams for copy.
copy_streams = self._ensure_copy_streams()
if self.pipelined_backward is None:
self.pipelined_backward = False
self.pipeline = Pipeline(
self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop, style=style
)
elif style is PipelineStyle.MultiProcess:
rank = torch.distributed.get_rank(group)
if rank >= len(self.balance):
self.pipeline = None
else:
self.final_stage = rank == len(self.balance) - 1
self.pipeline = Pipeline(
self.partitions,
None,
None,
self._skip_layout,
checkpoint_stop,
style=style,
group=self.group,
worker_map=self.worker_map,
input_device=self.input_device,
)
del module
if self.pipelined_backward is None:
if get_model_parallel_group().size() > 1:
self.pipelined_backward = True
else:
self.pipelined_backward = False
def __len__(self) -> int:
"""Counts the length of the underlying sequential module."""
......@@ -333,10 +514,17 @@ class Pipe(Module):
# Pipe should manage the device of each partition.
# Deny cuda(), cpu(), and to() with device, by TypeError.
def cuda(self, device: Optional[Device] = None) -> "Pipe":
raise MOVING_DENIED
if self.devices:
raise MOVING_DENIED
if device:
return super().cuda(device=device)
else:
return super().cuda()
def cpu(self) -> "Pipe":
raise MOVING_DENIED
if self.devices:
raise MOVING_DENIED
return super().cpu()
def to(self, *args: Any, **kwargs: Any) -> "Pipe":
# Deny these usages:
......@@ -348,15 +536,16 @@ class Pipe(Module):
#
# - to(dtype[, non_blocking])
#
if "device" in kwargs or "tensor" in kwargs:
raise MOVING_DENIED
if args:
if isinstance(args[0], (torch.device, int, str)):
raise MOVING_DENIED
if torch.is_tensor(args[0]):
if self.devices:
if "device" in kwargs or "tensor" in kwargs:
raise MOVING_DENIED
if args:
if isinstance(args[0], (torch.device, int, str)):
raise MOVING_DENIED
if torch.is_tensor(args[0]):
raise MOVING_DENIED
return super().to(*args, **kwargs)
def _ensure_copy_streams(self) -> List[List[AbstractStream]]:
......@@ -368,6 +557,7 @@ class Pipe(Module):
"""
if not self._copy_streams:
assert self.devices is not None
for device in self.devices:
self._copy_streams.append([new_stream(device) for _ in range(self.chunks)])
......@@ -392,16 +582,78 @@ class Pipe(Module):
"""
microbatch.check(input)
if not self.devices:
if not self.group and not self.devices:
# Empty sequential module is not illegal.
return input
if not self.pipeline:
# No pipeline is not illegal, more ranks than partitions
return input
# Divide a mini-batch into micro-batches.
batches = microbatch.scatter(input, self.chunks)
# Run pipeline parallelism.
self.pipeline.run(batches)
# Merge the micro-batches into one mini-batch.
output = microbatch.gather(batches)
return output
if self.group and not self.final_stage:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
return batches # type: ignore
else:
# Merge the micro-batches into one mini-batch.
if self.pipelined_backward:
with torch.no_grad():
output = microbatch.gather(batches)
from .phony import get_phony
phony = get_phony(torch.device(torch.cuda.current_device()), requires_grad=True)
output = PipelinedBackwardPass.apply(output, batches, phony, True) # self.retain_graph)
else:
output = microbatch.gather(batches)
return output
def back_helper(self, output: List[microbatch.Batch]) -> None:
if self.final_stage:
raise ValueError("back_helper should only be called on non-final stages")
if self.pipeline:
self.pipeline.back_helper(list(reversed(output)))
class PipelinedBackwardPass(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx, input: TensorOrTensors, batches, phony, retain_graph) -> TensorOrTensors:
ctx.batches = batches
ctx.retain_graph = retain_graph
return input
@staticmethod
# type: ignore
def backward(ctx, *grads) -> Tuple:
with torch.no_grad():
grad_batches = microbatch.scatter(grads, len(ctx.batches))
for grad, batch in reversed(list(zip(grad_batches, ctx.batches))):
for t in batch:
t.retain_grad()
torch.autograd.backward(batch.tensor_or_tensors, grad_tensors=(*grad,), retain_graph=ctx.retain_graph)
with torch.no_grad():
if ctx.batches[0].atomic:
tensors = tuple(b.tensor.grad for b in ctx.batches)
output: TensorOrTensors = torch.cat(tensors)
else:
rotated = [[t.grad for t in b.tensors] for b in ctx.batches]
output_buf = []
for tensors in zip(*rotated):
output_buf.append(torch.cat(tensors))
output = tuple(output_buf)
del ctx.batches
return (output, None, None, None)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
......@@ -18,18 +17,27 @@
# limitations under the License.
"""The pipeline parallelism of Pipe."""
from enum import Enum, auto
import os
import pickle
from queue import Empty as QueueEmpty
from queue import Queue
from types import TracebackType
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type, Union, cast
from dataclasses import dataclass
import numpy as np
import torch
from torch import Tensor, nn
from torch.autograd.profiler import record_function
from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from .checkpoint import Checkpointing
from .copy import Copy, Wait
from .dependency import fork, join
from .microbatch import Batch
from .skip import Namespace
from .skip.layout import SkipLayout
from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from .stream import AbstractStream, current_stream, use_device
......@@ -41,8 +49,229 @@ __all__: List[str] = []
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
InputDevice = Union[None, int, str, torch.device]
Schedule = List[Tuple[int, int]]
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
MessageQueues: List[Queue] = [Queue(), Queue(), Queue()]
ACTIVATIONS_GRADS_QUEUE = 0
SKIP_TENSOR_QUEUE = 1
PORTAL_QUEUE = 2
MESSAGE_GENERATION_START = 3
# FIXME Why is 256 ok for training but not for tests?
MESSAGE_TENSOR_SIZE = 512 # 256
MessageGeneration = MESSAGE_GENERATION_START
class PipelineStyle(Enum):
SingleProcess = auto()
MultiProcess = auto()
@dataclass(frozen=True)
class TransportConfig:
use_rpc: bool
worker_map: Optional[Dict[int, str]]
@dataclass
class PipeMessage:
src: int
dest: int
queue_name: int
args: Any
tensors: Tensors
tensor_shapes: List[torch.Size]
tensor_dtypes: List[torch.dtype]
tag: int = 0
def __init__(self, src: int, dest: int, queue_name: int, args: Any, tensors: Tensors):
self.src = src
self.dest = dest
self.queue_name = queue_name
self.args = args
self.tensors = tensors
global MessageGeneration
self.tag = MessageGeneration
MessageGeneration += len(tensors)
def rpc_push_queue(message: PipeMessage) -> None:
globals()["MessageQueues"][message.queue_name].put(message)
def pyobject_to_tensor(obj: Any) -> Tensor:
pickled = pickle.dumps(obj)
nparray = np.frombuffer(pickled, dtype=np.uint8).copy()
nparray.setflags(write=True)
result = torch.from_numpy(nparray)
delta = MESSAGE_TENSOR_SIZE - len(result)
if delta < 0:
raise ValueError(
f"message too big to send, increase MESSAGE_TENSOR_SIZE? - {len(result)} > {MESSAGE_TENSOR_SIZE}"
)
elif delta > 0:
result = torch.cat((result, torch.zeros(delta, dtype=torch.uint8)))
return result.cuda()
def tensor_to_pyobject(tensor: Tensor) -> Any:
nparray = tensor.numpy()
return pickle.loads(nparray.tobytes())
def send_message(config: TransportConfig, message: PipeMessage, sync: bool = False) -> None:
if config.use_rpc:
message.tensors = tuple(t.cpu() for t in message.tensors)
assert config.worker_map
name = config.worker_map[message.dest]
if sync:
torch.distributed.rpc.rpc_sync(name, rpc_push_queue, args=(message,))
else:
torch.distributed.rpc.rpc_async(name, rpc_push_queue, args=(message,))
else:
tensors = message.tensors
message.tensors = tuple()
message.tensor_shapes = [t.size() for t in tensors]
message.tensor_dtypes = [t.dtype for t in tensors]
torch.cuda.current_stream().synchronize()
torch.distributed.send(pyobject_to_tensor(message), message.dest, tag=0)
for index, t in enumerate(tensors):
if t.device.type == "cpu":
t = t.cuda()
torch.distributed.send(t, message.dest, tag=message.tag + index)
def recv_message(
config: TransportConfig, queue_name: int, *, nowait: bool = False, input_device: InputDevice = None
) -> PipeMessage:
if config.use_rpc:
queue = globals()["MessageQueues"][queue_name]
if nowait:
result = queue.get_nowait()
else:
result = queue.get()
result.tensors = to_input_device(result.tensors, input_device)
return result
else:
# FIXME(handle nowait)
if nowait:
raise QueueEmpty
tensor = torch.empty(MESSAGE_TENSOR_SIZE, dtype=torch.uint8, device=input_device)
torch.distributed.recv(tensor, src=-1, tag=queue_name)
message = tensor_to_pyobject(tensor.cpu())
torch.cuda.current_stream().synchronize()
message_tensors = []
for index, (shape, dtype) in enumerate(zip(message.tensor_shapes, message.tensor_dtypes)):
t = torch.empty(*shape, dtype=dtype, device=input_device)
torch.distributed.recv(t, message.src, tag=message.tag + index)
message_tensors.append(t)
message.tensors = tuple(message_tensors)
torch.cuda.current_stream().synchronize()
return message
def get_out_of_order(config: TransportConfig, queue_name: int, index: int, *, input_device: InputDevice) -> Tensors:
"""Receive a message with a known microbatch index, and handle out-of-order
messages by placing them back on the queue"""
if config.use_rpc:
queue = globals()["MessageQueues"][queue_name]
out_of_order: List[PipeMessage] = []
while True:
message = recv_message(config, queue_name, input_device=input_device)
got_index = message.args
value = message.tensors
if got_index == index:
for b in out_of_order:
queue.put(b)
return value
else:
out_of_order.append(message)
else:
message = recv_message(config, queue_name, input_device=input_device)
assert message.args == index
return message.tensors
def to_input_device(tensors: TensorOrTensors, input_device: InputDevice) -> TensorOrTensors:
if input_device is None:
return tensors
else:
if isinstance(tensors, Tensor):
return tensors.to(input_device)
else:
return tuple(t.to(input_device) for t in tensors)
class SendOperator(torch.autograd.Function):
"""Send activations to the next pipeline stage"""
@staticmethod
# type: ignore
def forward(ctx, src_rank, dst_rank, config: TransportConfig, input: List[Tensor], index: int) -> Tensors:
assert src_rank == torch.distributed.get_rank()
send_message(
config,
PipeMessage(src_rank, dst_rank, queue_name=ACTIVATIONS_GRADS_QUEUE, args=index, tensors=tuple(input)),
)
return ()
@staticmethod
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tensors:
return tuple(grad)
class RecvOperator(torch.autograd.Function):
"""Receive activations to the previous pipeline stage"""
@staticmethod
# type: ignore
def forward(ctx, dst_rank: int, tensor: Tensor, input_device, config: TransportConfig, index: int) -> Tensors:
assert dst_rank == torch.distributed.get_rank()
ctx.config = config
ctx.index = index
result = get_out_of_order(config, ACTIVATIONS_GRADS_QUEUE, index, input_device=input_device)
def maybe_requires_grad(t: Tensor) -> Tensor:
if t.dtype.is_floating_point:
return t.requires_grad_()
return t
return tuple(maybe_requires_grad(r) for r in result)
@staticmethod
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
send_message(
ctx.config,
PipeMessage(
this_rank,
ranks[ranks.index(this_rank) - 1],
queue_name=ACTIVATIONS_GRADS_QUEUE,
args=ctx.index,
tensors=tuple(grad),
),
)
return (None, None, None, None, None)
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if TYPE_CHECKING:
......@@ -70,7 +299,7 @@ def wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream)
batch[:] = tuple([x if x.is_floating_point() else x.detach() for x in batch])
def clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]:
def clock_cycles(m: int, n: int) -> Iterable[Schedule]:
"""Generates schedules for each clock cycle."""
# m: number of micro-batches
# n: number of partitions
......@@ -95,22 +324,46 @@ class Pipeline:
def __init__(
self,
partitions: List[nn.Sequential],
devices: List[torch.device],
copy_streams: List[List[AbstractStream]],
devices: Optional[List[torch.device]],
copy_streams: Optional[List[List[AbstractStream]]],
skip_layout: SkipLayout,
checkpoint_stop: int,
style: PipelineStyle,
group: Optional[torch.distributed.ProcessGroup] = None,
worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None,
) -> None:
self.partitions = partitions
self.devices = devices
self.copy_streams = copy_streams
self.skip_layout = skip_layout
self.checkpoint_stop = checkpoint_stop
(self.in_queues, self.out_queues) = create_workers(devices)
self.style = style
self.group = group
self.transport_config = TransportConfig(
use_rpc=("OMPI_COMM_WORLD_RANK" not in os.environ), worker_map=worker_map
)
self.input_device = input_device
self.all_at_once = False
self.callcount = 0
if self.style is PipelineStyle.SingleProcess:
assert self.devices is not None
(self.in_queues, self.out_queues) = create_workers(self.devices)
if (
self.style is PipelineStyle.MultiProcess
and self.transport_config.worker_map is None
and self.transport_config.use_rpc is True
):
raise ValueError("'PipelineStyle.MultiProcess' requires 'worker_map' to be set")
def __del__(self) -> None:
join_workers(self.in_queues, self.out_queues)
if self.style is PipelineStyle.SingleProcess:
join_workers(self.in_queues, self.out_queues)
def run(self, batches: List[Batch]) -> None:
"""Runs pipeline parallelism.
It modifies the given batches in place.
......@@ -118,17 +371,23 @@ class Pipeline:
"""
partitions = self.partitions
devices = self.devices
skip_layout = self.skip_layout
m = len(batches)
n = len(partitions)
skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches]
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))]
for schedule in clock_cycles(m, n):
self.fence(batches, schedule, skip_trackers)
if self.style is PipelineStyle.SingleProcess:
for schedule in clock_cycles(m, n):
self.fence(batches, schedule, skip_trackers)
self.compute(batches, schedule, skip_trackers)
elif self.style is PipelineStyle.MultiProcess:
assert self.group
schedule = [(i, self.group.rank()) for i in range(m)]
self.compute(batches, schedule, skip_trackers)
self.callcount += 1
def fence(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
) -> None:
......@@ -138,6 +397,9 @@ class Pipeline:
copy_streams = self.copy_streams
skip_layout = self.skip_layout
assert copy_streams
assert skip_layout
for i, j in schedule:
# Ensure that batches[i-1] is executed after batches[i] in
# backpropagation by an explicit dependency.
......@@ -154,6 +416,174 @@ class Pipeline:
prev_stream = copy_streams[j - 1][i]
copy(batches[i], prev_stream, next_stream)
def get_batch_from_previous_stage(
self, i: int, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]
) -> Batch:
phony = torch.empty(0, device=self.input_device, requires_grad=True)
result = RecvOperator.apply(torch.distributed.get_rank(), phony, self.input_device, self.transport_config, i)
if len(result) == 1:
batch = Batch(result[0], i)
else:
batch = Batch(result, i)
self.recv_skip_tensors(skip_trackers, batches)
return batch
def send_skip_tensors(
self, this_rank: int, ranks: List[int], batch: Batch, i: int, skip_trackers: List[SkipTrackerThroughPotals]
) -> None:
assert self.group
for next_j, ns, name in self.skip_layout.copy_policy_by_src(self.group.rank()):
life = skip_trackers[i].portals[(ns, name)].tensor_life
loaded = skip_trackers[i].load(batch, ns, name)
if loaded is not None:
tensors = tuple([loaded])
else:
tensors = tuple()
send_message(
self.transport_config,
PipeMessage(
this_rank, ranks[next_j], queue_name=SKIP_TENSOR_QUEUE, args=(i, ns, name, life), tensors=tensors,
),
sync=True,
)
def recv_skip_tensors(self, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]) -> None:
while True:
try:
message = recv_message(
self.transport_config, SKIP_TENSOR_QUEUE, nowait=True, input_device=self.input_device
)
(si, ns, name, life) = message.args
value: Optional[TensorOrTensors] = message.tensors
assert isinstance(value, tuple)
if len(value) == 0:
value = None
else:
assert len(value) == 1
value = value[0]
skip_trackers[si].save(batches[si], ns, name, value)
old_life = skip_trackers[si].portals[(ns, name)].tensor_life
if life != 0:
skip_trackers[si].portals[(ns, name)].tensor_life = life
except QueueEmpty:
break
def execute_task(self, task: Task, i: int, skip_trackers: List[SkipTrackerThroughPotals]) -> Batch:
batch = task.compute()
assert self.group
rank = self.group.rank()
if rank != self.group.size() - 1:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
self.send_skip_tensors(this_rank, ranks, batch, i, skip_trackers)
SendOperator.apply(this_rank, ranks[ranks.index(this_rank) + 1], self.transport_config, [*batch], i)
for portal in skip_trackers[i].portals.values():
portal.pipeline = self
task.finalize(batch)
return batch
def finalize_tasks(
self,
n: int,
schedule: Schedule,
streams: List[AbstractStream],
copy_streams: List[List[AbstractStream]],
batches: List[Batch],
) -> None:
exc_info: Optional[ExcInfo] = None
for i, j in schedule:
ok, payload = self.out_queues[j].get()
# Hold the first exception.
if exc_info is not None:
continue
elif not ok:
exc_info = cast(ExcInfo, payload)
continue
task, batch = cast(Tuple[Task, Batch], payload)
# The copy stream synchronizes to copy the output. ([3] in the
# diagram)
if j != n - 1:
wait(batch, streams[j], copy_streams[j][i])
# Finalize tasks. If checkpointing is enabled, here the
# recomputation is scheduled at backpropagation. ([4] in the
# diagram)
assert self.devices
with use_device(self.devices[j]):
task.finalize(batch)
batches[i] = batch
# Fail at the first exception.
if exc_info is not None:
raise exc_info[0].with_traceback(exc_info[1], exc_info[2])
def create_task(
self,
i: int,
j: int,
batch: Batch,
checkpoint_stop: int,
partition: nn.Sequential,
skip_trackers: List[SkipTrackerThroughPotals],
streams: List[AbstractStream],
) -> Task:
# Determine whether checkpointing or not.
if i < checkpoint_stop:
def function(
input: TensorOrTensors,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return partition(input)
chk = Checkpointing(function, batch)
if self.style is PipelineStyle.SingleProcess:
task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
elif self.style is PipelineStyle.MultiProcess:
task = Task(None, compute=chk.checkpoint, finalize=chk.recompute)
del function, chk # TODO(tom) maybe remove
else:
def compute(
batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> Batch:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition)
if self.style is PipelineStyle.SingleProcess:
task = Task(streams[j], compute=compute, finalize=None)
elif self.style is PipelineStyle.MultiProcess:
task = Task(None, compute=compute, finalize=None)
del compute # TODO(tom) maybe remove
return task
def compute(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
) -> None:
......@@ -167,9 +597,14 @@ class Pipeline:
if not self.partitions[0].training:
checkpoint_stop = 0
n = len(partitions)
streams = [current_stream(d) for d in devices]
exc_info: Optional[ExcInfo] = None
if self.style is PipelineStyle.SingleProcess:
assert devices is not None
n = len(partitions)
streams = [current_stream(d) for d in devices]
elif self.style is PipelineStyle.MultiProcess:
assert self.group
n = self.group.size()
streams = []
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
......@@ -198,73 +633,103 @@ class Pipeline:
# └─────┰──────┘
for i, j in schedule:
batch = batches[i]
partition = partitions[j]
# Synchronize with the copied input. ([1] in the diagram)
if j != 0:
wait(batch, copy_streams[j][i], streams[j])
# Determine whether checkpointing or not.
checkpoint = i < checkpoint_stop
if checkpoint:
def function(
input: TensorOrTensors,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return partition(input)
chk = Checkpointing(function, batch)
task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
del function, chk
else:
def compute(
batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> Batch:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition)
task = Task(streams[j], compute=compute, finalize=None)
del compute
# Compute tasks in parallel. ([2] in the diagram)
self.in_queues[j].put(task)
for i, j in schedule:
ok, payload = self.out_queues[j].get()
# Hold the first exception.
if exc_info is not None:
continue
elif not ok:
exc_info = cast(ExcInfo, payload)
continue
task, batch = cast(Tuple[Task, Batch], payload)
# The copy stream synchronizes to copy the output. ([3] in the
# diagram)
if j != n - 1:
wait(batch, streams[j], copy_streams[j][i])
# Finalize tasks. If checkpointing is enabled, here the
# recomputation is scheduled at backpropagation. ([4] in the
# diagram)
with use_device(devices[j]):
task.finalize(batch)
batches[i] = batch
# Fail at the first exception.
if exc_info is not None:
raise exc_info[0].with_traceback(exc_info[1], exc_info[2])
if self.style is PipelineStyle.SingleProcess:
partition = partitions[j]
# Synchronize with the copied input. ([1] in the diagram)
assert copy_streams
if j != 0:
wait(batch, copy_streams[j][i], streams[j])
elif self.style is PipelineStyle.MultiProcess:
assert len(self.partitions) == 1
partition = self.partitions[0]
assert self.group
if self.group.rank() != 0:
batch = self.get_batch_from_previous_stage(i, skip_trackers, batches)
task = self.create_task(i, j, batch, checkpoint_stop, partition, skip_trackers, streams)
if self.style is PipelineStyle.SingleProcess:
# Compute tasks in parallel. ([2] in the diagram)
self.in_queues[j].put(task)
elif self.style is PipelineStyle.MultiProcess:
batches[i] = self.execute_task(task, i, skip_trackers)
if self.style is PipelineStyle.SingleProcess:
assert copy_streams
self.finalize_tasks(n, schedule, streams, copy_streams, batches)
def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: TensorOrTensors) -> None:
dest, src = self.skip_layout.by_ns_name.get(ns_name, (-1, -1))
if dest == src:
return
ranks = get_pipeline_parallel_ranks()
dst_rank = ranks[dest]
if dst_rank == torch.distributed.get_rank():
return
if isinstance(grad, Tensor):
grad = tuple([grad])
send_message(
self.transport_config,
PipeMessage(ranks[src], dst_rank, queue_name=PORTAL_QUEUE, args=(ns_name, index), tensors=grad),
sync=True,
)
def recv_portal_grad(self, expected_ns_name: Tuple[Namespace, str], expected_index: int) -> Tensor:
message = recv_message(self.transport_config, PORTAL_QUEUE, input_device=self.input_device)
(ns_name, index) = message.args
grad = message.tensors
assert len(grad) == 1
result = grad[0]
assert index == expected_index and ns_name == expected_ns_name
return result
def back_helper(self, output: List[Batch]) -> None:
o = list(output)
tensors: Tensors
if self.all_at_once:
# FIXME(tom) allow specifying this branch when constructing Pipe(), add a test
grads = []
for i, batch in enumerate(o):
rank = torch.distributed.get_rank()
found = get_out_of_order(
self.transport_config, ACTIVATIONS_GRADS_QUEUE, i, input_device=self.input_device
)
assert len(found) == 1
grads.append(found[0])
tensors = tuple(x.tensor_or_tensors for x in o) # type: ignore
try:
torch.autograd.backward(tensors, grad_tensors=grads, retain_graph=True)
except Exception as e:
raise RuntimeError("Autograd failed") from e
else:
rank = torch.distributed.get_rank()
for batch in o:
found = get_out_of_order(
self.transport_config, ACTIVATIONS_GRADS_QUEUE, batch.index, input_device=self.input_device
)
if batch.atomic:
tensors = tuple([batch.tensor])
else:
tensors = batch.tensors
if len(found) != len(tensors):
raise RuntimeError("different number of tensors and gradients")
grads = []
final_tensors = []
for i, tensor in enumerate(tensors):
if tensor.requires_grad or getattr(tensor, "grad_fn", None) is not None:
grads.append(found[i])
final_tensors.append(tensor)
try:
torch.autograd.backward(final_tensors, grad_tensors=grads, retain_graph=True)
except Exception as e:
raise RuntimeError("Autograd failed") from e
......@@ -36,19 +36,41 @@ class SkipLayout:
# Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...]
by_partition: List[List[Tuple[int, Namespace, str]]]
# Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...]
by_src_partition: List[List[Tuple[int, Namespace, str]]]
def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None:
# The skip routes are already indexed by 'ns, name'.
self.by_ns_name = skip_routes
# Index skip routes by partition number 'j'.
self.by_partition = [[] for _ in range(num_partitions)]
self.by_src_partition = [[] for _ in range(num_partitions)]
for (ns, name), (prev_j, next_j) in skip_routes.items():
self.by_partition[next_j].append((prev_j, ns, name))
self.by_src_partition[prev_j].append((next_j, ns, name))
for p in self.by_partition:
p.sort()
def copy_policy_by_src(self, prev_j: int) -> Iterable[Tuple[int, Namespace, str]]:
"""Generates skip routes for the given destination partition number.
The skip routes are sorted by source partition number in ascending
order.
Yields:
Each tuple of (source partition number, namespace, name).
"""
for next_j, ns, name in self.by_src_partition[prev_j]:
if prev_j == next_j:
# This skip tensor will be popped at the same partition where
# it is stashed. In this case, copy is not required.
continue
yield (next_j, ns, name)
def copy_policy(self, next_j: int) -> Iterable[Tuple[int, Namespace, str]]:
"""Generates skip routes for the given destination partition number.
The skip routes are sorted by source partition number in ascending
......
......@@ -25,11 +25,12 @@ one of the most important feature of :mod:`torchpipe.skip`.
The metaphor is inspired by Portal™ from Valve.
"""
from typing import List, Optional, Tuple
from typing import Any, List, Optional, Tuple
import torch
from torch import Tensor
from . import Namespace
from ..copy import Context as CopyContext
from ..copy import Copy
from ..phony import get_phony
......@@ -41,9 +42,16 @@ __all__: List[str] = []
class Portal:
"""A portal for a tensor."""
def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None:
def __init__(self, tensor: Optional[Tensor], tensor_life: int, index: int) -> None:
self.put_tensor(tensor, tensor_life)
self.grad: Optional[Tensor] = None
self.__index = index
self.ns_name: Optional[Tuple[Namespace, str]]
self.pipeline: Any
@property
def index(self) -> int:
return self.__index
def blue(self) -> Tensor:
"""Creates a :class:`PortalBlue` which hides the underlying tensor from
......@@ -151,12 +159,17 @@ class Portal:
def put_grad(self, grad: Tensor) -> None:
"""Stores a gradient into this portal."""
if hasattr(self, "pipeline"):
self.pipeline.send_portal_grad(self.ns_name, self.index, grad)
self.grad = grad
def use_grad(self) -> Tensor:
"""Retrieves and removes the underlying gradient. The gradient is
always ephemeral.
"""
if self.grad is None and hasattr(self, "pipeline"):
self.grad = self.pipeline.recv_portal_grad(self.ns_name, self.index)
if self.grad is None:
raise RuntimeError("grad in portal has been removed or never set")
......
......@@ -204,7 +204,7 @@ class Skippable(nn.Module):
# Load skip tensors that might be popped.
poppable_tensors = {}
batch = Batch(input)
batch = Batch(input, skip_tracker.index)
for ns, name in self.poppable():
try:
poppable_tensors[name] = skip_tracker.load(batch, ns, name)
......@@ -237,7 +237,7 @@ class Skippable(nn.Module):
raise RuntimeError(f"{comma_names} must be popped but have not")
# Save stashed skip tensors.
batch = Batch(output)
batch = Batch(output, skip_tracker.index)
for ns, name in self.stashable():
tensor = stashed_tensors[name]
skip_tracker.save(batch, ns, name, tensor)
......
......@@ -61,6 +61,10 @@ class SkipTracker:
) -> None:
raise TypeError("copy is not supported for non-portal skip tensors")
@property
def index(self) -> int:
return 0
class SkipTrackerThroughPotals(SkipTracker):
"""Tracks saved skip tensors through portals. The skip tensors will be
......@@ -71,10 +75,15 @@ class SkipTrackerThroughPotals(SkipTracker):
"""
def __init__(self, skip_layout: SkipLayout) -> None:
def __init__(self, skip_layout: SkipLayout, index: int) -> None:
super().__init__()
self.skip_layout = skip_layout
self.portals: Dict[Tuple[Namespace, str], Portal] = {}
self.__index = index
@property
def index(self) -> int:
return self.__index
def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None:
"""Saves the stashed skip tensor in a portal. The portal is then
......@@ -106,7 +115,9 @@ class SkipTrackerThroughPotals(SkipTracker):
else:
tensor_life = 2 # Delete at [6. PortalOrange.forward]
portal = Portal(tensor, tensor_life)
assert batch.index == self.index
portal = Portal(tensor, tensor_life, batch.index)
portal.ns_name = (ns, name)
self.portals[(ns, name)] = portal
else:
......
......@@ -21,7 +21,7 @@
CPU device.
"""
from contextlib import contextmanager
from typing import Generator, List, Union, cast
from typing import Generator, List, Optional, Union, cast
import torch
......@@ -72,8 +72,12 @@ def use_device(device: torch.device) -> Generator[None, None, None]:
@contextmanager
def use_stream(stream: AbstractStream) -> Generator[None, None, None]:
def use_stream(stream: Optional[AbstractStream]) -> Generator[None, None, None]:
""":func:`torch.cuda.stream` for either CPU or CUDA stream."""
if not stream:
yield
return
if not is_cuda(stream):
yield
return
......@@ -120,7 +124,7 @@ def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
tensor.record_stream(as_cuda(stream))
def is_cuda(stream: AbstractStream) -> bool:
def is_cuda(stream: Optional[AbstractStream]) -> bool:
"""Returns ``True`` if the given stream is a valid CUDA stream."""
return stream is not CPUStream
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment