Unverified Commit 63f7796a authored by Tom Birch's avatar Tom Birch Committed by GitHub
Browse files

Multi-process pipe (#90)

Adds support for distributing pipeline stages across multiple processes (and therefore multiple machines)
* Adds a style argument to the Pipe constructor, defaulting to PipelineStyle.SingleProcess, but also supporting PipelineStyle.MultiProcess
* Added support for lazy construction of modules (see lazy_construction for an example)
* Added two implementations of inter-process communication: one based on rpc with globally visible queues, one based on send/recv
* Copied all the relevant tests from tests/pipe to tests/pipe_process and modified them to exercise PipelineStyle.MultiProcess
parent 49a198c9
......@@ -149,7 +149,7 @@ jobs:
- run:
name: Run type-checking (mypy)
command: |
mypy --pretty .
mypy --ignore-missing-imports --scripts-are-modules --pretty .
- <<: *run_flake8
......
[settings]
known_third_party =numpy,pytest,recommonmark,setuptools,torch,torchtext,torchvision
known_third_party =benchmark_dataset,dataclasses,numpy,packaging,pytest,recommonmark,setuptools,torch,torchtext,torchvision
......@@ -37,6 +37,7 @@ repos:
rev: 4.3.20
hooks:
- id: isort
exclude: README.md
additional_dependencies: [toml]
- repo: https://github.com/pre-commit/mirrors-mypy
......
import torch
from torch.utils.data import Dataset
def collate_sentences_lm(samples):
if len(samples) == 0:
return {}
id = torch.LongTensor([s["id"] for s in samples])
src_tokens = torch.stack([s["source"] for s in samples], 0)
tgt_tokens = torch.stack([s["target"] for s in samples], 0)
ntokens = len(samples) * len(samples[0]["target"])
src_lengths = torch.LongTensor([len(samples[0]["source"])] * len(samples))
batch = {
"id": id,
"nsentences": len(samples),
"ntokens": ntokens,
"input": src_tokens,
"target": tgt_tokens,
}
return batch
class BenchmarkLMDataset(Dataset):
"""
Dataset to benchmark a translation like seq2seq task.
Args:
vocab_size (int, optional): size of the vocabulary (default 10000).
max_source_positions (int, optional): max number of tokens in the
source sentence (default: 1024).
total_samples (int, optional): the total number of rows in the
dataset (default: 10000).
"""
def __init__(
self, vocab_size=10000, max_source_positions=1024, total_samples=10000,
):
self.vocab_size = vocab_size
self.max_source_positions = max_source_positions
self.total_samples = total_samples
self.sizes = [self.max_source_positions] * self.total_samples
def __getitem__(self, index):
length = self.sizes[index]
source = torch.randint(1, self.vocab_size, (length,))
target = source.clone()
return {
"id": index,
"source": source,
"target": target,
}
def __len__(self):
return self.total_samples
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import argparse
import math
import os
import time
import warnings
from benchmark_dataset import BenchmarkLMDataset, collate_sentences_lm
import torch
from torch.distributed import rpc
import torch.multiprocessing as mp
import torch.nn as nn
from torch.utils.data import DataLoader
import torchtext
from torchtext.data.utils import get_tokenizer
from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.pipe import pipe
from fairscale.optim import GradScaler
from tests.nn.model_parallel.commons import dist_init, get_worker_map
try:
from fairscale.optim import Adam, Precision # type: ignore
from fairscale.optim import Adam # type: ignore
can_benchmark = True
except ImportError:
......@@ -21,6 +31,18 @@ except ImportError:
can_benchmark = False
def init_random_seed(seed: int):
import numpy
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
numpy.random.seed(seed)
PIPE_CHUNKS = 2
iteration_count = 0
class EmbeddingLayer(nn.Embedding):
def __init__(self, ntoken, ninp, initrange):
super().__init__(ntoken, ninp)
......@@ -63,6 +85,11 @@ class TransformerDecoderLayer(nn.TransformerEncoderLayer):
return mask
def forward(self, src):
global iteration_count
iteration_count += 1
# if iteration_count == 196:
# dump_cuda_tensors()
if self.src_mask is None or self.src_mask.size(0) != len(src):
device = src.device
mask = self._generate_square_subsequent_mask(len(src)).to(device)
......@@ -82,16 +109,20 @@ class TransformerLMSequntial(nn.Sequential):
"""A small language model based on the design of GPT-2 using nn.Sequeitnal
for compatability with Pipe"""
def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange):
super(TransformerLMSequntial, self).__init__(
def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder):
layers = [
EmbeddingLayer(ntokens, ninp, initrange),
PositionalEncodingLayer(ninp, dropout),
TransformerDecoderLayer(ninp, nhead, nhid, dropout),
LinearLayer(ninp, ntokens, initrange),
)
]
for _ in range(ndecoder):
layers.append(TransformerDecoderLayer(ninp, nhead, nhid, dropout))
layers.append(LinearLayer(ninp, ntokens, initrange))
super(TransformerLMSequntial, self).__init__(*layers)
def get_data(device):
with warnings.catch_warnings(record=True) as fjldska:
TEXT = torchtext.data.Field(
tokenize=get_tokenizer("basic_english"), init_token="<sos>", eos_token="<eos>", lower=True
)
......@@ -99,8 +130,8 @@ def get_data(device):
TEXT.build_vocab(train_txt)
ntokens = len(TEXT.vocab.stoi)
batch_size = 500
eval_batch_size = 200
batch_size = 20
eval_batch_size = 10
train_data = batchify(train_txt, batch_size, TEXT, device)
val_data = batchify(val_txt, eval_batch_size, TEXT, device)
test_data = batchify(test_txt, eval_batch_size, TEXT, device)
......@@ -123,71 +154,188 @@ def get_batch(source, i, bptt):
return data, target
def make_model(device, ntokens):
ninp = 50 # embedding dimension
nhid = 50 # the dimension of the feedforward network model in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models
def make_model(args, device, ntokens):
ninp = 2048 # embedding dimension
nhid = 2048 # the dimension of the feedforward network model in nn.TransformerEncoder
nhead = 32 # the number of heads in the multiheadattention models
dropout = 0
initrange = 0.1
model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange).half().to(device)
balance = generate_balance(min(num_devices, 4), len(model))
p = Pipe(model, balance, chunks=len(balance))
ndecoder = args.num_decoder_layers
if args.lazy_construction:
layers = [
lambda: EmbeddingLayer(ntokens, ninp, initrange),
lambda: PositionalEncodingLayer(ninp, dropout),
]
for _ in range(ndecoder):
layers.append(lambda: TransformerDecoderLayer(ninp, nhead, nhid, dropout))
layers.append(lambda: LinearLayer(ninp, ntokens, initrange))
model = layers
else:
model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
criterion = nn.CrossEntropyLoss()
lr = 0.001 # learning rate
lr = 0.01 # learning rate
try:
optimizer = Adam(p.parameters(), lr=lr, precision=Precision.PURE_FP16)
except NameError:
optimizer = Adam(p.parameters(), lr=lr)
def make_adam(model):
return Adam(model.parameters(), lr=lr)
optimizer = make_adam
scaler = GradScaler()
return p, criterion, optimizer, scaler
return model, criterion, optimizer, scaler
def get_tensors_by_size_bucket():
import gc
from collections import defaultdict
size_buckets = defaultdict(int)
for obj in gc.get_objects():
if not isinstance(obj, torch.Tensor):
continue
if obj.device.type == "cuda":
size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1
return size_buckets
def dump_size_buckets(size_buckets, prefix=""):
import operator
from functools import reduce
total = 0
for key, value in size_buckets.items():
this = reduce(operator.mul, key) * value
total += this
print(prefix + f"{key} : {value}, {this}")
print(prefix + f"total = {total}")
def train(train_data, model, criterion, optimizer, scaler, bptt, ntokens):
last_size_buckets = None
once = True
def safe_rank():
try:
return torch.distributed.get_rank()
except AssertionError:
return 0
def check_size_buckets():
global last_size_buckets
global once
size_buckets = get_tensors_by_size_bucket()
if last_size_buckets is not None:
if size_buckets != last_size_buckets:
print(f"difference is oustanding tensors: {safe-rank()}")
dump_size_buckets(last_size_buckets, "old: ")
dump_size_buckets(size_buckets, "new: ")
if once:
print(f"dumping buckets for: {safe_rank()}")
dump_size_buckets(last_size_buckets, "old: ")
dump_size_buckets(size_buckets, "new: ")
once = False
else:
print(f"size buckets none on {safe_rank()}")
last_size_buckets = size_buckets
def dump_cuda_tensors():
print(f"dumping cuda tensors...")
from functools import reduce
import operator
import gc
for obj in gc.get_objects():
if not isinstance(obj, torch.Tensor):
continue
if obj.device.type == "cuda":
size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1
print(f"outstanding cuda tensors:")
total = 0
for key, value in size_buckets.items():
this = reduce(operator.mul, key) * value
total += this
print(f"{key} : {value}, {this}")
print(f"total size = {total}")
import pprint
pprint.pprint(torch.cuda.memory_stats())
def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
model.train()
from functools import reduce
import operator
num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
if model.group:
print(f"training model, #prams = {num_params}, group: {model.group.rank()}, sizes {model.group.size()}")
else:
print(f"training model, #prams = {num_params}")
vocab_size = 10000 # FIXME
total_loss = 0.0
start_time = time.time()
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
data, targets = get_batch(train_data, i, bptt)
word_counter = 0
optimizer = optimizer(model)
def get_first_device(model):
if model.devices:
return model.devices[0]
else:
return torch.cuda.current_device()
def get_last_device(model):
if model.devices:
return model.devices[-1]
else:
return torch.cuda.current_device()
for i, batch in enumerate(lm_dataloader):
if args.max_batch and i > args.max_batch:
break
optimizer.zero_grad()
output = model(data)
output = output.to(targets.device)
output = model(batch["input"].to(get_first_device(model)))
if model.group is None or model.group.rank() == model.group.size() - 1:
target = batch["target"].to(get_last_device(model))
output = output.to(target.device)
loss = criterion(output.view(-1, vocab_size), target.view(-1))
loss.backward()
else:
model.back_helper(output)
del output
loss = criterion(output.view(-1, ntokens), targets)
scaler.scale(loss).backward()
scaler.step(optimizer) # scaler.step automatically unscale if unscale has not yet been performed
scaler.update()
torch.nn.utils.clip_grad_value_(model.parameters(), 0.05)
optimizer.step()
if model.group is None or model.group.rank() == model.group.size() - 1:
total_loss += loss.item()
log_interval = 50
if batch % log_interval == 0 and batch > 0:
log_interval = 1
word_counter += batch["ntokens"]
if i % log_interval == 0 and i > 0:
cur_loss = total_loss / log_interval
elapsed = time.time() - start_time
try:
print(
"| {:5d}/{:5d} batches | ms/batch {:5.2f} | "
"loss {:5.2f} | ppl {:8.2f} | grad scale {:3d} | optim scale {:3d}".format(
batch,
len(train_data) // bptt,
elapsed * 1000 / log_interval,
cur_loss,
math.exp(cur_loss),
int(scaler.get_scale()),
int(optimizer._optim_scale),
)
)
except AttributeError:
print(
"| {:5d}/{:5d} batches | ms/batch {:5.2f} | "
"loss {:5.2f} | ppl {:8.2f}".format(
batch, len(train_data) // bptt, elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)
"| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
i, word_counter / elapsed, cur_loss, math.exp(cur_loss)
)
)
word_counter = 0
total_loss = 0
start_time = time.time()
# if i >= 10:
# break
# torch.cuda.empty_cache()
# check_size_buckets()
def evaluate(eval_model, data_source, criterion, bptt, ntokens):
......@@ -207,7 +355,7 @@ def get_number_of_words(data):
return data.size()[0] * data.size()[1]
def benchmark_language_model(train_data, val_data, test_data, model, criterion, optimizer, scaler, ntokens):
def benchmark_language_model(train_data, val_data, test_data, model, criterion, optimizer, ntokens, args):
epoch = 1
bptt = 35
start_time = time.time()
......@@ -216,9 +364,9 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
print("| start of epoch {:1d}".format(epoch))
print("-" * 110)
epoch_start_time = time.time()
train(train_data, model, criterion, optimizer, scaler, bptt, ntokens)
val_loss = evaluate(model, val_data, criterion, bptt, ntokens)
print("-" * 110)
train(train_data, model, criterion, optimizer, bptt, ntokens, args)
val_loss = 1 # evaluate(model, val_data, criterion, bptt, ntokens)
print("-" * 89)
print(
"| end of epoch {:1d} | time: {:5.2f}s | valid loss {:5.2f} ".format(
epoch, (time.time() - epoch_start_time), val_loss
......@@ -230,8 +378,8 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
nwords = get_number_of_words(train_data) + get_number_of_words(val_data)
wps = nwords / elapsed_time
test_loss = evaluate(model, test_data, criterion, bptt, ntokens)
print("=" * 110)
test_loss = 1 # evaluate(model, test_data, criterion, bptt, ntokens)
print("=" * 89)
print(
"| end of training | test loss {:5.2f} \n| time: {:5.2f}s | words: {:3d} | wps: {:5.2f}".format(
test_loss, elapsed_time, nwords, wps
......@@ -272,13 +420,186 @@ def generate_balance(num_devices, num_layers):
return balance
if __name__ == "__main__":
def make_model_and_data(args, device, new_data: bool = True):
if new_data:
device = torch.device("cuda")
vocab_size = 10000
model, criterion, optimizer, scaler = make_model(args, device, vocab_size)
lm_dataset = BenchmarkLMDataset()
lm_dataloader = DataLoader(
lm_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=collate_sentences_lm
)
return {
"model": model,
"criterion": criterion,
"optimizer": optimizer,
"data": lm_dataloader,
"vocab_size": vocab_size,
}
else:
device = torch.device("cuda")
data = get_data(device)
ntokens, train_data, val_data, test_data = data
model, criterion, optimizer, scaler = make_model(args, device, ntokens)
return {
"model": model,
"criterion": criterion,
"optimizer": optimizer,
"data": data,
}
def bench_single_process(args):
num_devices = torch.cuda.device_count()
assert num_devices > 0
torch.manual_seed(0)
init_random_seed(0)
device = torch.device("cuda")
ntokens, train_data, val_data, test_data = get_data(device)
model, criterion, optimizer, scaler = make_model(device, ntokens)
benchmark_language_model(train_data, val_data, test_data, model, criterion, optimizer, scaler, ntokens)
new_data = True
blob = make_model_and_data(args, None, new_data=new_data)
model = blob["model"]
balance = generate_balance(min(num_devices, 8), len(model))
p = pipe.Pipe(
model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint
)
del model
del blob["model"]
if new_data:
train(blob["data"], p, blob["criterion"], blob["optimizer"], blob["vocab_size"], args)
else:
ntokens, train_data, val_data, test_data = blob["data"]
benchmark_language_model(train_data, val_data, test_data, p, criterion, optimizer, ntokens, args)
def run_mp_worker(args, available_workers):
new_data = True
blob = make_model_and_data(args, None, new_data=new_data)
model = blob["model"]
balance = generate_balance(min(available_workers, 8), len(model))
p = pipe.Pipe(
model,
balance,
style=Pipe.MultiProcess,
chunks=args.chunks,
worker_map=get_worker_map(),
input_device=torch.cuda.current_device(),
pipelined_backward=args.pipelined_backward,
checkpoint=args.checkpoint,
).cuda()
if args.all_at_once and p.pipeline:
print(f"running all at once")
p.pipeline.all_at_once = True
if new_data:
train(blob["data"], p, blob["criterion"], blob["optimizer"], blob["vocab_size"], args)
else:
ntokens, train_data, val_data, test_data = blob["data"]
benchmark_language_model(train_data, val_data, test_data, p, criterion, optimizer, ntokens, args)
def run_worker(rank, world_size, args):
if args.world_size != 0:
world_size = args.world_size
dist_init(rank + args.rank_base, world_size, hostname=args.host)
initialize_model_parallel(1, world_size)
init_random_seed(0)
run_mp_worker(args, world_size)
rpc.shutdown()
torch.distributed.destroy_process_group()
def bench_multi_process(args, all_at_once=False):
if args.local_world_size != 0:
world_size = args.local_world_size
else:
world_size = min(torch.cuda.device_count(), 2)
mp.spawn(run_worker, args=(world_size, args), nprocs=world_size, join=True)
best_device_map = {
0: "mlx5_0:1",
1: "mlx5_0:1",
2: "mlx5_1:1",
3: "mlx5_1:1",
4: "mlx5_2:1",
5: "mlx5_2:1",
6: "mlx5_3:1",
7: "mlx5_3:1",
}
def bench_mpi(args):
guess_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
os.environ["UCX_NET_DEVICES"] = best_device_map[guess_rank]
torch.distributed.init_process_group(backend="mpi")
os.environ["MASTER_ADDR"] = args.host
os.environ["MASTER_PORT"] = "10639"
if args.socket_name:
os.environ["GLOO_SOCKET_IFNAME"] = args.socket_name
os.environ["TP_SOCKET_IFNAME"] = args.socket_name
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(rank % torch.cuda.device_count())
rpc.init_rpc(
f"Test{rank}",
rank=rank,
world_size=world_size,
backend=rpc.BackendType.PROCESS_GROUP,
rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(rpc_timeout=20, init_method=init_method),
)
initialize_model_parallel(1, world_size)
init_random_seed(0)
run_mp_worker(args, world_size)
rpc.shutdown()
torch.distributed.destroy_process_group()
parser = argparse.ArgumentParser(description="benchmark")
parser.add_argument("--local-world-size", "-l", type=int, default=0, help="local world size")
parser.add_argument("--world-size", "-w", type=int, default=0, help="world size")
parser.add_argument("--rank-base", "-r", type=int, help="rank base", default=0)
parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname")
parser.add_argument("--no-mpi", action="store_true", default=False, help="disable mpi")
parser.add_argument("--chunks", type=int, default=1, help="number of microbatches per batch")
parser.add_argument("--batch-size", type=int, default=8, help="size of a batch")
parser.add_argument("--all-at-once", action="store_true", default=False, help="do backward pass on whole batch at once")
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
parser.add_argument("--socket-name", type=str, default=None, help="socket ifname for gloo/tp")
parser.add_argument("--num-decoder-layers", type=int, default=10, help="Number of decoder layers in the model")
parser.add_argument(
"--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
)
parser.add_argument(
"--checkpoint", default="never", choices=["always", "except_last", "never"], help="Checkpointing strategy for pipe"
)
parser.add_argument(
"--pipelined-backward", dest="pipelined_backward", action="store_true", help="Pipelined backward pass"
)
parser.add_argument(
"--no-pipelined-backward", dest="pipelined_backward", action="store_false", help="Pipelined backward pass"
)
parser.set_defaults(pipelined_backward=True)
if __name__ == "__main__":
args = parser.parse_args()
# bench_multi_process(args, all_at_once=True)
if args.no_mpi or "OMPI_COMM_WORLD_RANK" not in os.environ:
print(f"Running benchmark with args: {args}")
bench_single_process(args)
else:
if os.environ["OMPI_COMM_WORLD_RANK"] == "0":
print(f"Running benchmark with args: {args}")
bench_mpi(args)
......@@ -13,6 +13,7 @@
#
import os
import sys
from typing import Any, List
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
......@@ -46,7 +47,7 @@ templates_path = ["_templates"]
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
exclude_patterns: List[Any] = []
# -- Options for HTML output -------------------------------------------------
......
......@@ -5,6 +5,7 @@
from .cross_entropy import vocab_parallel_cross_entropy
from .initialize import (
destroy_model_parallel,
get_data_parallel_group,
get_data_parallel_rank,
get_data_parallel_world_size,
......@@ -12,6 +13,8 @@ from .initialize import (
get_model_parallel_rank,
get_model_parallel_src_rank,
get_model_parallel_world_size,
get_pipeline_parallel_group,
get_pipeline_parallel_ranks,
initialize_model_parallel,
)
from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
......
......@@ -35,6 +35,8 @@ _DATA_PARALLEL_GROUP = None
# Pipeline parallel group that the current rank belongs to.
_PIPELINE_PARALLEL_GROUP = None
_PIPELINE_PARALLEL_RANKS = None
def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = 1) -> None:
"""
......@@ -93,7 +95,15 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int =
global _PIPELINE_PARALLEL_GROUP
assert _PIPELINE_PARALLEL_GROUP is None, "model parallel group is already initialized"
_PIPELINE_PARALLEL_GROUP = groups[found[0], :, found[2]].tolist()
global _PIPELINE_PARALLEL_RANKS
assert _PIPELINE_PARALLEL_RANKS is None, "model parallel group is already initialized"
for i in range(data_parallel_size):
for k in range(model_parallel_size):
ranks = groups[i, :, k].tolist()
group = torch.distributed.new_group(ranks)
if i == found[0] and k == found[2]:
_PIPELINE_PARALLEL_GROUP = group
_PIPELINE_PARALLEL_RANKS = ranks
def model_parallel_is_initialized() -> bool:
......@@ -115,12 +125,18 @@ def get_data_parallel_group() -> torch.distributed.ProcessGroup:
return _DATA_PARALLEL_GROUP
def get_pipeline_parallel_group() -> List[int]:
def get_pipeline_parallel_group() -> torch.distributed.ProcessGroup:
"""Get the pipeline parallel group the caller rank belongs to."""
assert _PIPELINE_PARALLEL_GROUP is not None, "pipeline parallel group is not initialized"
return _PIPELINE_PARALLEL_GROUP
def get_pipeline_parallel_ranks() -> List[int]:
"""Get the pipeline parallel group the caller rank belongs to."""
assert _PIPELINE_PARALLEL_RANKS is not None, "pipeline parallel group is not initialized"
return _PIPELINE_PARALLEL_RANKS
def get_model_parallel_world_size() -> int:
"""Return world size for the model parallel group."""
return torch.distributed.get_world_size(group=get_model_parallel_group())
......@@ -157,3 +173,6 @@ def destroy_model_parallel() -> None:
_DATA_PARALLEL_GROUP = None
global _PIPELINE_PARALLEL_GROUP
_PIPELINE_PARALLEL_GROUP = None
global _PIPELINE_PARALLEL_RANKS
_PIPELINE_PARALLEL_RANKS = None
......@@ -280,6 +280,9 @@ class ColumnParallelLinear(torch.nn.Module):
return_master_weight=keep_master_weight_for_test,
)
def get_master_weight(self) -> torch.Tensor:
return gather_from_model_parallel_region(self.weight.data.transpose(0, 1)).transpose_(0, 1)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_)
......@@ -364,6 +367,9 @@ class RowParallelLinear(torch.nn.Module):
return_master_weight=keep_master_weight_for_test,
)
def get_master_weight(self) -> torch.Tensor:
return gather_from_model_parallel_region(self.weight.data)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore
# Set up backprop all-reduce.
if self.input_is_parallel:
......
......@@ -19,21 +19,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from .initialize import get_model_parallel_group
from .utils import split_tensor_along_last_dim
def _reduce(input_: torch.Tensor) -> torch.Tensor:
def _reduce(ctx: Any, input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the the input tensor across model parallel group."""
group = get_model_parallel_group()
if ctx:
ctx.mark_dirty(input_)
# Bypass the function if we are using only 1 GPU.
if torch.distributed.get_world_size(group=group) == 1:
return input_
# All-reduce.
print(f"doing all_reduce on {torch.distributed.get_rank()}")
torch.distributed.all_reduce(input_, group=group)
return input_
......@@ -87,11 +93,13 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _CopyToModelParallelRegion Forward")
return input_
@staticmethod
def backward(ctx, grad_output): # type: ignore
return _reduce(grad_output)
print(f"{torch.distributed.get_rank()}: _CopyToModelParallelRegion Backward")
return _reduce(None, grad_output)
class _ReduceFromModelParallelRegion(torch.autograd.Function):
......@@ -99,10 +107,12 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_): # type: ignore
return _reduce(input_)
print(f"{torch.distributed.get_rank()}: _ReduceFromModelParallelRegion Forward")
return _reduce(ctx, input_)
@staticmethod
def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _ReduceFromModelParallelRegion Backward")
return grad_output
......@@ -111,10 +121,12 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _ScatterToModelParallelRegion Forward")
return _split(input_)
@staticmethod
def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _ScatterToModelParallelRegion Backward")
return _gather(grad_output)
......@@ -123,10 +135,12 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _GatherFromModelParallelRegion Forward")
return _gather(input_)
@staticmethod
def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _GatherFromModelParallelRegion Backward")
return _split(grad_output)
......
......@@ -59,7 +59,7 @@ def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float
if any(p.grad is not None for p in module.parameters()):
raise ValueError("some parameter already has gradient")
_batch = Batch(sample)
_batch = Batch(sample, 0)
for i, x in enumerate(_batch):
_batch[i] = x.detach().to(device).requires_grad_(x.requires_grad)
......@@ -101,7 +101,7 @@ def profile_sizes(
if device.type != "cuda":
raise ValueError("size profiler supports only CUDA device")
batch = Batch(input)
batch = Batch(input, 0)
sizes: List[int] = []
latent_scale = batch[0].size(0) / chunks
......
......@@ -5,7 +5,7 @@
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
......@@ -74,20 +74,6 @@ class Function(Protocol):
...
def checkpoint(function: Function, input: TensorOrTensors) -> TensorOrTensors:
"""Makes a checkpoint with a simple interface like
:func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug
:class:`Checkpoint` and :class:`Recompute` without boilerplate.
"""
batch = Batch(input)
chk = Checkpointing(function, batch)
batch = chk.checkpoint()
chk.recompute(batch)
return batch.tensor_or_tensors
class Checkpointing:
"""Generates a pair of :class:`Checkpoint` and :class:`Recompute`."""
......@@ -116,7 +102,7 @@ class Checkpointing:
if isinstance(output, tuple):
output = tuple([x if x.is_floating_point() else x.detach() for x in output])
return Batch(output)
return Batch(output, self.batch.index)
def recompute(self, batch: Batch) -> None:
"""Applies :class:`Recompute` to the batch in place."""
......@@ -226,6 +212,7 @@ def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None
else:
gpu_rng_state = None
rng_states.clear()
rng_states.append((cpu_rng_state, gpu_rng_state))
......@@ -237,7 +224,7 @@ def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> G
.. seealso:: :ref:`Referential Transparency`
"""
cpu_rng_state, gpu_rng_state = rng_states.pop()
cpu_rng_state, gpu_rng_state = rng_states[0]
gpu_devices: List[torch.device] = []
if device.type == "cuda":
......
......@@ -53,9 +53,14 @@ class Batch:
"""
def __init__(self, value: TensorOrTensors) -> None:
def __init__(self, value: TensorOrTensors, index: int) -> None:
self.value = value
self.atomic = torch.is_tensor(value)
self.__index = index
@property
def index(self) -> int:
return self.__index
@property
def tensor(self) -> Tensor:
......@@ -80,7 +85,7 @@ class Batch:
"""Calls a function by the underlying tensor or tensors. It also wraps
the output with :class:`Batch`.
"""
return Batch(function(self.value))
return Batch(function(self.value), self.index)
def __repr__(self) -> str:
return f"Batch[atomic={self.atomic!r}]({self.value!r})"
......@@ -176,7 +181,7 @@ def scatter(input: TensorOrTensors, chunks: int) -> List[Batch]:
inputs = zip(*rotated)
return [Batch(x) for x in inputs]
return [Batch(x, i) for i, x in enumerate(inputs)]
def gather(outputs: List[Batch]) -> TensorOrTensors:
......
......@@ -19,18 +19,21 @@
"""The Pipe interface."""
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast
import warnings
import torch
from torch import Tensor, nn
import torch.autograd
import torch.cuda
from fairscale.nn.model_parallel import get_model_parallel_group, get_pipeline_parallel_group
from . import microbatch
from .batchnorm import DeferredBatchNorm
from .pipeline import Pipeline
from .skip.layout import inspect_skip_layout
from .skip.skippable import verify_skippables
from .pipeline import Pipeline, PipelineStyle
from .skip.layout import SkipLayout, inspect_skip_layout
from .skip.skippable import Skippable, verify_skippables
from .stream import AbstractStream, new_stream
__all__ = ["Pipe"]
......@@ -42,6 +45,8 @@ Devices = Union[Iterable[Device], List[Device]]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
ListOfLazyModules = List[Callable[[], nn.Module]]
if TYPE_CHECKING:
Module = nn.Module[TensorOrTensors]
NamedModules = OrderedDict[str, Module]
......@@ -69,7 +74,21 @@ naive automatic balancing:
"""
def verify_module(module: nn.Sequential) -> None:
# FIXME(tom) make this a valid way to call
def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None:
for layer in module:
if isinstance(layer, nn.Module):
pass
elif callable(layer):
pass
else:
raise TypeError(f"layer {type(layer)} must be nn.Module or callable to be partitioned")
def verify_module(module: Union[nn.Sequential, ListOfLazyModules]) -> None:
if isinstance(module, Iterable) and not isinstance(module, nn.Sequential):
verify_list_of_callable(module)
else:
if not isinstance(module, nn.Sequential):
raise TypeError("module must be nn.Sequential to be partitioned")
......@@ -79,7 +98,10 @@ def verify_module(module: nn.Sequential) -> None:
def verify_splitting(
module: nn.Sequential, partitions: List[nn.Sequential], balance: Iterable[int], devices: List[torch.device]
module: nn.Sequential,
partitions: List[nn.Sequential],
balance: Iterable[int],
devices: Optional[List[torch.device]],
) -> None:
num_parameters = len(list(module.parameters()))
num_child_parameters = sum(len(list(child.parameters())) for child in module.children())
......@@ -90,7 +112,7 @@ def verify_splitting(
for j in range(i + 1, len(partitions)):
parti = partitions[i]
partj = partitions[j]
if devices[i] == devices[j]:
if devices and devices[i] == devices[j]:
continue
for p in parti.parameters():
for q in partj.parameters():
......@@ -102,9 +124,65 @@ class BalanceError(ValueError):
pass
def check_balance(module: Any, balance: Iterable[int]) -> None:
if len(module) != sum(balance):
raise BalanceError(
f"module and sum of balance have different length (module: {len(module)}, sum of balance: {sum(balance)})"
)
if any(x <= 0 for x in balance):
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
def instantiate_partition(
module: Union[nn.Sequential, ListOfLazyModules], balance: Iterable[int], group: torch.distributed.ProcessGroup
) -> nn.Sequential:
balance = list(balance)
check_balance(module, balance)
layers: NamedModules = OrderedDict()
j = 0
def maybe_realize(layer: Any) -> nn.Module:
if isinstance(layer, nn.Module):
return layer
elif callable(layer):
return layer()
else:
raise TypeError(f"layer must be nn.Module or callable, is {type(layer)}")
def iterate_module(module: Union[nn.Sequential, list]) -> Iterable[Tuple[Any, nn.Module]]:
if isinstance(module, nn.Sequential):
yield from module.named_children()
else:
yield from enumerate(module)
for name, layer in iterate_module(module):
layers[name] = layer
if len(layers) == balance[j]:
if j == group.rank():
for key in layers:
layers[key] = maybe_realize(layers[key])
if not isinstance(module, nn.Sequential):
for layer in layers.values():
if isinstance(layer, Skippable):
raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction")
partition = nn.Sequential(*layers.values())
return partition
# Prepare for the next partition.
layers.clear()
j += 1
raise ValueError("Souldn't get here, more ranks than partitions")
def split_module(
module: nn.Sequential, balance: Iterable[int], devices: List[torch.device],
) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]:
module: nn.Sequential, balance: Iterable[int], devices: Optional[List[torch.device]],
) -> Tuple[List[nn.Sequential], List[int], Optional[List[torch.device]]]:
"""Splits a module into multiple partitions.
Returns:
......@@ -123,18 +201,11 @@ def split_module(
"""
balance = list(balance)
if len(module) != sum(balance):
raise BalanceError(
"module and sum of balance have different length "
f"(module: {len(module)}, sum of balance: {sum(balance)})"
)
check_balance(module, balance)
if any(x <= 0 for x in balance):
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
if len(balance) > len(devices):
if devices and len(balance) > len(devices):
raise IndexError(
"too few devices to hold given partitions " f"(devices: {len(devices)}, partitions: {len(balance)})"
f"too few devices to hold given partitions (devices: {len(devices)}, partitions: {len(balance)})"
)
j = 0
......@@ -148,6 +219,7 @@ def split_module(
# Group buffered layers as a partition.
partition = nn.Sequential(layers)
if devices:
device = devices[j]
partition.to(device)
......@@ -158,12 +230,13 @@ def split_module(
j += 1
partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))
if devices:
del devices[j:]
return partitions, balance, devices
MOVING_DENIED = TypeError("denied to move parameters and buffers, " "because Pipe should manage device placement")
MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement")
class Pipe(Module):
......@@ -193,8 +266,23 @@ class Pipe(Module):
list of number of layers in each partition
Keyword Args:
style (PipelineStyle):
whether to use a single process for all pipeline stages or to assign
one stage per process
devices (iterable of devices):
devices to use (default: all CUDA devices)
group (ProcessGroup):
specific to `style=MultiProcess`, the process group that all
pipeline stages are a member of. Defaults to
`get_pipeline_parallel_group()`
worker_map (Dict[int, str]):
a map from worker name (the first argument to
`torch.distributed.rpc.init_rpc`) to global rank (i.e.
`torch.distributed.get_rank()`) needed in order for pipeline stages
to communicate with each other
input_device (device):
the device on which tensors should be located before being passed to
the first module in a given pipeline stage
chunks (int):
number of micro-batches (default: ``1``)
checkpoint (str):
......@@ -204,6 +292,16 @@ class Pipe(Module):
whether to use deferred BatchNorm moving statistics (default:
:data:`False`, see :ref:`Deferred Batch Normalization` for more
details)
pipelined_backward (bool, optional):
if True, call torch.autograd.backward once per microbatch on the
backward pass (instead of once for the whole batch). This works
around a potential deadlock in pytorch when using tensor parallelism
at the same time. Defaults to `True` if
`get_model_parallel_group.size() > 1`
(default: `None`)
retain_graph (bool):
The value passed to `torch.autograd.backwards(..., retain_graph=<value>)
(default: = `True`)
Raises:
TypeError:
......@@ -215,6 +313,9 @@ class Pipe(Module):
"""
SingleProcess: PipelineStyle = PipelineStyle.SingleProcess
MultiProcess: PipelineStyle = PipelineStyle.MultiProcess
#: The number of layers in each partition.
balance: List[int] = []
# ^^
......@@ -234,7 +335,7 @@ class Pipe(Module):
#: output = pipe(input)
#: loss = F.cross_entropy(output, target)
#:
devices: List[torch.device] = []
devices: Optional[List[torch.device]] = None
#: The number of micro-batches.
chunks: int = 1
......@@ -245,13 +346,19 @@ class Pipe(Module):
def __init__(
self,
module: nn.Sequential,
module: Union[nn.Sequential, ListOfLazyModules],
balance: Optional[Iterable[int]] = None,
*,
style: PipelineStyle = PipelineStyle.SingleProcess,
devices: Optional[Devices] = None,
group: Optional[torch.distributed.ProcessGroup] = None,
worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None,
chunks: int = chunks,
checkpoint: str = checkpoint,
deferred_batch_norm: bool = False,
pipelined_backward: bool = None,
retain_graph: bool = False,
) -> None:
super().__init__()
......@@ -269,16 +376,26 @@ class Pipe(Module):
# Verify if the underlying skippable modules satisfy integrity. The
# integrity can be verified before forward() because it is static.
if isinstance(module, nn.Sequential):
verify_skippables(module)
self.chunks = chunks
self.checkpoint = checkpoint
self.pipelined_backward = pipelined_backward
self.retain_graph = retain_graph
self.pipeline: Optional[Pipeline]
if style is PipelineStyle.SingleProcess:
module = cast(nn.Sequential, module)
if deferred_batch_norm:
module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks)
if input_device is not None:
raise ValueError("'input_device' argument only applies to 'PipelineStyle.MultiProcess'")
if devices is None:
devices = range(torch.cuda.device_count())
devices = [torch.device(d) for d in devices]
devices = cast(List[torch.device], devices)
......@@ -286,19 +403,83 @@ class Pipe(Module):
self.partitions, self.balance, self.devices = split_module(module, balance, devices)
except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc)))
verify_splitting(module, self.partitions, self.balance, self.devices)
self._copy_streams: List[List[AbstractStream]] = []
self._skip_layout = inspect_skip_layout(self.partitions)
# Separate CUDA streams for copy.
copy_streams = self._ensure_copy_streams()
elif style is PipelineStyle.MultiProcess:
if group is None:
group = get_pipeline_parallel_group()
if devices is not None:
raise ValueError("'devices' argument only applies to 'PipelineStyle.SingleProcess'")
self.balance = list(balance)
if group.size() < len(self.balance):
raise IndexError(
f"too few ranks to hold given partitions (ranks: {group.size()}, partitions: {len(self.balance)})"
)
try:
rank = torch.distributed.get_rank(group)
if rank >= len(self.balance):
warnings.warn("More ranks than partitions, some ranks unused")
self.partitions = cast(List[nn.Sequential], nn.ModuleList([nn.Sequential()]))
else:
partition = instantiate_partition(module, balance, group)
if deferred_batch_norm:
partition = DeferredBatchNorm.convert_deferred_batch_norm(partition, chunks)
self.partitions = cast(List[nn.Sequential], nn.ModuleList([partition]))
self.devices = None
if isinstance(module, nn.Sequential):
local_partitions, _, _ = split_module(module, balance, None)
self._skip_layout = inspect_skip_layout(local_partitions)
else:
self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom)
except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc)))
self.group = group
self.worker_map = worker_map
self.input_device = input_device
self._copy_streams: List[List[AbstractStream]] = []
# The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop)
if style is PipelineStyle.SingleProcess:
# Separate CUDA streams for copy.
copy_streams = self._ensure_copy_streams()
if self.pipelined_backward is None:
self.pipelined_backward = False
self.pipeline = Pipeline(
self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop, style=style
)
elif style is PipelineStyle.MultiProcess:
rank = torch.distributed.get_rank(group)
if rank >= len(self.balance):
self.pipeline = None
else:
self.final_stage = rank == len(self.balance) - 1
self.pipeline = Pipeline(
self.partitions,
None,
None,
self._skip_layout,
checkpoint_stop,
style=style,
group=self.group,
worker_map=self.worker_map,
input_device=self.input_device,
)
del module
if self.pipelined_backward is None:
if get_model_parallel_group().size() > 1:
self.pipelined_backward = True
else:
self.pipelined_backward = False
def __len__(self) -> int:
"""Counts the length of the underlying sequential module."""
......@@ -333,10 +514,17 @@ class Pipe(Module):
# Pipe should manage the device of each partition.
# Deny cuda(), cpu(), and to() with device, by TypeError.
def cuda(self, device: Optional[Device] = None) -> "Pipe":
if self.devices:
raise MOVING_DENIED
if device:
return super().cuda(device=device)
else:
return super().cuda()
def cpu(self) -> "Pipe":
if self.devices:
raise MOVING_DENIED
return super().cpu()
def to(self, *args: Any, **kwargs: Any) -> "Pipe":
# Deny these usages:
......@@ -348,6 +536,7 @@ class Pipe(Module):
#
# - to(dtype[, non_blocking])
#
if self.devices:
if "device" in kwargs or "tensor" in kwargs:
raise MOVING_DENIED
......@@ -368,6 +557,7 @@ class Pipe(Module):
"""
if not self._copy_streams:
assert self.devices is not None
for device in self.devices:
self._copy_streams.append([new_stream(device) for _ in range(self.chunks)])
......@@ -392,16 +582,78 @@ class Pipe(Module):
"""
microbatch.check(input)
if not self.devices:
if not self.group and not self.devices:
# Empty sequential module is not illegal.
return input
if not self.pipeline:
# No pipeline is not illegal, more ranks than partitions
return input
# Divide a mini-batch into micro-batches.
batches = microbatch.scatter(input, self.chunks)
# Run pipeline parallelism.
self.pipeline.run(batches)
if self.group and not self.final_stage:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
return batches # type: ignore
else:
# Merge the micro-batches into one mini-batch.
if self.pipelined_backward:
with torch.no_grad():
output = microbatch.gather(batches)
from .phony import get_phony
phony = get_phony(torch.device(torch.cuda.current_device()), requires_grad=True)
output = PipelinedBackwardPass.apply(output, batches, phony, True) # self.retain_graph)
else:
output = microbatch.gather(batches)
return output
def back_helper(self, output: List[microbatch.Batch]) -> None:
if self.final_stage:
raise ValueError("back_helper should only be called on non-final stages")
if self.pipeline:
self.pipeline.back_helper(list(reversed(output)))
class PipelinedBackwardPass(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx, input: TensorOrTensors, batches, phony, retain_graph) -> TensorOrTensors:
ctx.batches = batches
ctx.retain_graph = retain_graph
return input
@staticmethod
# type: ignore
def backward(ctx, *grads) -> Tuple:
with torch.no_grad():
grad_batches = microbatch.scatter(grads, len(ctx.batches))
for grad, batch in reversed(list(zip(grad_batches, ctx.batches))):
for t in batch:
t.retain_grad()
torch.autograd.backward(batch.tensor_or_tensors, grad_tensors=(*grad,), retain_graph=ctx.retain_graph)
with torch.no_grad():
if ctx.batches[0].atomic:
tensors = tuple(b.tensor.grad for b in ctx.batches)
output: TensorOrTensors = torch.cat(tensors)
else:
rotated = [[t.grad for t in b.tensors] for b in ctx.batches]
output_buf = []
for tensors in zip(*rotated):
output_buf.append(torch.cat(tensors))
output = tuple(output_buf)
del ctx.batches
return (output, None, None, None)
This diff is collapsed.
......@@ -36,19 +36,41 @@ class SkipLayout:
# Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...]
by_partition: List[List[Tuple[int, Namespace, str]]]
# Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...]
by_src_partition: List[List[Tuple[int, Namespace, str]]]
def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None:
# The skip routes are already indexed by 'ns, name'.
self.by_ns_name = skip_routes
# Index skip routes by partition number 'j'.
self.by_partition = [[] for _ in range(num_partitions)]
self.by_src_partition = [[] for _ in range(num_partitions)]
for (ns, name), (prev_j, next_j) in skip_routes.items():
self.by_partition[next_j].append((prev_j, ns, name))
self.by_src_partition[prev_j].append((next_j, ns, name))
for p in self.by_partition:
p.sort()
def copy_policy_by_src(self, prev_j: int) -> Iterable[Tuple[int, Namespace, str]]:
"""Generates skip routes for the given destination partition number.
The skip routes are sorted by source partition number in ascending
order.
Yields:
Each tuple of (source partition number, namespace, name).
"""
for next_j, ns, name in self.by_src_partition[prev_j]:
if prev_j == next_j:
# This skip tensor will be popped at the same partition where
# it is stashed. In this case, copy is not required.
continue
yield (next_j, ns, name)
def copy_policy(self, next_j: int) -> Iterable[Tuple[int, Namespace, str]]:
"""Generates skip routes for the given destination partition number.
The skip routes are sorted by source partition number in ascending
......
......@@ -25,11 +25,12 @@ one of the most important feature of :mod:`torchpipe.skip`.
The metaphor is inspired by Portal™ from Valve.
"""
from typing import List, Optional, Tuple
from typing import Any, List, Optional, Tuple
import torch
from torch import Tensor
from . import Namespace
from ..copy import Context as CopyContext
from ..copy import Copy
from ..phony import get_phony
......@@ -41,9 +42,16 @@ __all__: List[str] = []
class Portal:
"""A portal for a tensor."""
def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None:
def __init__(self, tensor: Optional[Tensor], tensor_life: int, index: int) -> None:
self.put_tensor(tensor, tensor_life)
self.grad: Optional[Tensor] = None
self.__index = index
self.ns_name: Optional[Tuple[Namespace, str]]
self.pipeline: Any
@property
def index(self) -> int:
return self.__index
def blue(self) -> Tensor:
"""Creates a :class:`PortalBlue` which hides the underlying tensor from
......@@ -151,12 +159,17 @@ class Portal:
def put_grad(self, grad: Tensor) -> None:
"""Stores a gradient into this portal."""
if hasattr(self, "pipeline"):
self.pipeline.send_portal_grad(self.ns_name, self.index, grad)
self.grad = grad
def use_grad(self) -> Tensor:
"""Retrieves and removes the underlying gradient. The gradient is
always ephemeral.
"""
if self.grad is None and hasattr(self, "pipeline"):
self.grad = self.pipeline.recv_portal_grad(self.ns_name, self.index)
if self.grad is None:
raise RuntimeError("grad in portal has been removed or never set")
......
......@@ -204,7 +204,7 @@ class Skippable(nn.Module):
# Load skip tensors that might be popped.
poppable_tensors = {}
batch = Batch(input)
batch = Batch(input, skip_tracker.index)
for ns, name in self.poppable():
try:
poppable_tensors[name] = skip_tracker.load(batch, ns, name)
......@@ -237,7 +237,7 @@ class Skippable(nn.Module):
raise RuntimeError(f"{comma_names} must be popped but have not")
# Save stashed skip tensors.
batch = Batch(output)
batch = Batch(output, skip_tracker.index)
for ns, name in self.stashable():
tensor = stashed_tensors[name]
skip_tracker.save(batch, ns, name, tensor)
......
......@@ -61,6 +61,10 @@ class SkipTracker:
) -> None:
raise TypeError("copy is not supported for non-portal skip tensors")
@property
def index(self) -> int:
return 0
class SkipTrackerThroughPotals(SkipTracker):
"""Tracks saved skip tensors through portals. The skip tensors will be
......@@ -71,10 +75,15 @@ class SkipTrackerThroughPotals(SkipTracker):
"""
def __init__(self, skip_layout: SkipLayout) -> None:
def __init__(self, skip_layout: SkipLayout, index: int) -> None:
super().__init__()
self.skip_layout = skip_layout
self.portals: Dict[Tuple[Namespace, str], Portal] = {}
self.__index = index
@property
def index(self) -> int:
return self.__index
def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None:
"""Saves the stashed skip tensor in a portal. The portal is then
......@@ -106,7 +115,9 @@ class SkipTrackerThroughPotals(SkipTracker):
else:
tensor_life = 2 # Delete at [6. PortalOrange.forward]
portal = Portal(tensor, tensor_life)
assert batch.index == self.index
portal = Portal(tensor, tensor_life, batch.index)
portal.ns_name = (ns, name)
self.portals[(ns, name)] = portal
else:
......
......@@ -21,7 +21,7 @@
CPU device.
"""
from contextlib import contextmanager
from typing import Generator, List, Union, cast
from typing import Generator, List, Optional, Union, cast
import torch
......@@ -72,8 +72,12 @@ def use_device(device: torch.device) -> Generator[None, None, None]:
@contextmanager
def use_stream(stream: AbstractStream) -> Generator[None, None, None]:
def use_stream(stream: Optional[AbstractStream]) -> Generator[None, None, None]:
""":func:`torch.cuda.stream` for either CPU or CUDA stream."""
if not stream:
yield
return
if not is_cuda(stream):
yield
return
......@@ -120,7 +124,7 @@ def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
tensor.record_stream(as_cuda(stream))
def is_cuda(stream: AbstractStream) -> bool:
def is_cuda(stream: Optional[AbstractStream]) -> bool:
"""Returns ``True`` if the given stream is a valid CUDA stream."""
return stream is not CPUStream
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment