Unverified Commit 63f7796a authored by Tom Birch's avatar Tom Birch Committed by GitHub
Browse files

Multi-process pipe (#90)

Adds support for distributing pipeline stages across multiple processes (and therefore multiple machines)
* Adds a style argument to the Pipe constructor, defaulting to PipelineStyle.SingleProcess, but also supporting PipelineStyle.MultiProcess
* Added support for lazy construction of modules (see lazy_construction for an example)
* Added two implementations of inter-process communication: one based on rpc with globally visible queues, one based on send/recv
* Copied all the relevant tests from tests/pipe to tests/pipe_process and modified them to exercise PipelineStyle.MultiProcess
parent 49a198c9
...@@ -149,7 +149,7 @@ jobs: ...@@ -149,7 +149,7 @@ jobs:
- run: - run:
name: Run type-checking (mypy) name: Run type-checking (mypy)
command: | command: |
mypy --pretty . mypy --ignore-missing-imports --scripts-are-modules --pretty .
- <<: *run_flake8 - <<: *run_flake8
......
[settings] [settings]
known_third_party =numpy,pytest,recommonmark,setuptools,torch,torchtext,torchvision known_third_party =benchmark_dataset,dataclasses,numpy,packaging,pytest,recommonmark,setuptools,torch,torchtext,torchvision
...@@ -37,6 +37,7 @@ repos: ...@@ -37,6 +37,7 @@ repos:
rev: 4.3.20 rev: 4.3.20
hooks: hooks:
- id: isort - id: isort
exclude: README.md
additional_dependencies: [toml] additional_dependencies: [toml]
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
......
import torch
from torch.utils.data import Dataset
def collate_sentences_lm(samples):
if len(samples) == 0:
return {}
id = torch.LongTensor([s["id"] for s in samples])
src_tokens = torch.stack([s["source"] for s in samples], 0)
tgt_tokens = torch.stack([s["target"] for s in samples], 0)
ntokens = len(samples) * len(samples[0]["target"])
src_lengths = torch.LongTensor([len(samples[0]["source"])] * len(samples))
batch = {
"id": id,
"nsentences": len(samples),
"ntokens": ntokens,
"input": src_tokens,
"target": tgt_tokens,
}
return batch
class BenchmarkLMDataset(Dataset):
"""
Dataset to benchmark a translation like seq2seq task.
Args:
vocab_size (int, optional): size of the vocabulary (default 10000).
max_source_positions (int, optional): max number of tokens in the
source sentence (default: 1024).
total_samples (int, optional): the total number of rows in the
dataset (default: 10000).
"""
def __init__(
self, vocab_size=10000, max_source_positions=1024, total_samples=10000,
):
self.vocab_size = vocab_size
self.max_source_positions = max_source_positions
self.total_samples = total_samples
self.sizes = [self.max_source_positions] * self.total_samples
def __getitem__(self, index):
length = self.sizes[index]
source = torch.randint(1, self.vocab_size, (length,))
target = source.clone()
return {
"id": index,
"source": source,
"target": target,
}
def __len__(self):
return self.total_samples
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import argparse
import math import math
import os
import time import time
import warnings
from benchmark_dataset import BenchmarkLMDataset, collate_sentences_lm
import torch import torch
from torch.distributed import rpc
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader
import torchtext import torchtext
from torchtext.data.utils import get_tokenizer from torchtext.data.utils import get_tokenizer
from fairscale.nn import Pipe from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.pipe import pipe
from fairscale.optim import GradScaler from fairscale.optim import GradScaler
from tests.nn.model_parallel.commons import dist_init, get_worker_map
try: try:
from fairscale.optim import Adam, Precision # type: ignore from fairscale.optim import Adam # type: ignore
can_benchmark = True can_benchmark = True
except ImportError: except ImportError:
...@@ -21,6 +31,18 @@ except ImportError: ...@@ -21,6 +31,18 @@ except ImportError:
can_benchmark = False can_benchmark = False
def init_random_seed(seed: int):
import numpy
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
numpy.random.seed(seed)
PIPE_CHUNKS = 2
iteration_count = 0
class EmbeddingLayer(nn.Embedding): class EmbeddingLayer(nn.Embedding):
def __init__(self, ntoken, ninp, initrange): def __init__(self, ntoken, ninp, initrange):
super().__init__(ntoken, ninp) super().__init__(ntoken, ninp)
...@@ -63,6 +85,11 @@ class TransformerDecoderLayer(nn.TransformerEncoderLayer): ...@@ -63,6 +85,11 @@ class TransformerDecoderLayer(nn.TransformerEncoderLayer):
return mask return mask
def forward(self, src): def forward(self, src):
global iteration_count
iteration_count += 1
# if iteration_count == 196:
# dump_cuda_tensors()
if self.src_mask is None or self.src_mask.size(0) != len(src): if self.src_mask is None or self.src_mask.size(0) != len(src):
device = src.device device = src.device
mask = self._generate_square_subsequent_mask(len(src)).to(device) mask = self._generate_square_subsequent_mask(len(src)).to(device)
...@@ -82,16 +109,20 @@ class TransformerLMSequntial(nn.Sequential): ...@@ -82,16 +109,20 @@ class TransformerLMSequntial(nn.Sequential):
"""A small language model based on the design of GPT-2 using nn.Sequeitnal """A small language model based on the design of GPT-2 using nn.Sequeitnal
for compatability with Pipe""" for compatability with Pipe"""
def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange): def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder):
super(TransformerLMSequntial, self).__init__( layers = [
EmbeddingLayer(ntokens, ninp, initrange), EmbeddingLayer(ntokens, ninp, initrange),
PositionalEncodingLayer(ninp, dropout), PositionalEncodingLayer(ninp, dropout),
TransformerDecoderLayer(ninp, nhead, nhid, dropout), ]
LinearLayer(ninp, ntokens, initrange), for _ in range(ndecoder):
) layers.append(TransformerDecoderLayer(ninp, nhead, nhid, dropout))
layers.append(LinearLayer(ninp, ntokens, initrange))
super(TransformerLMSequntial, self).__init__(*layers)
def get_data(device): def get_data(device):
with warnings.catch_warnings(record=True) as fjldska:
TEXT = torchtext.data.Field( TEXT = torchtext.data.Field(
tokenize=get_tokenizer("basic_english"), init_token="<sos>", eos_token="<eos>", lower=True tokenize=get_tokenizer("basic_english"), init_token="<sos>", eos_token="<eos>", lower=True
) )
...@@ -99,8 +130,8 @@ def get_data(device): ...@@ -99,8 +130,8 @@ def get_data(device):
TEXT.build_vocab(train_txt) TEXT.build_vocab(train_txt)
ntokens = len(TEXT.vocab.stoi) ntokens = len(TEXT.vocab.stoi)
batch_size = 500 batch_size = 20
eval_batch_size = 200 eval_batch_size = 10
train_data = batchify(train_txt, batch_size, TEXT, device) train_data = batchify(train_txt, batch_size, TEXT, device)
val_data = batchify(val_txt, eval_batch_size, TEXT, device) val_data = batchify(val_txt, eval_batch_size, TEXT, device)
test_data = batchify(test_txt, eval_batch_size, TEXT, device) test_data = batchify(test_txt, eval_batch_size, TEXT, device)
...@@ -123,71 +154,188 @@ def get_batch(source, i, bptt): ...@@ -123,71 +154,188 @@ def get_batch(source, i, bptt):
return data, target return data, target
def make_model(device, ntokens): def make_model(args, device, ntokens):
ninp = 50 # embedding dimension ninp = 2048 # embedding dimension
nhid = 50 # the dimension of the feedforward network model in nn.TransformerEncoder nhid = 2048 # the dimension of the feedforward network model in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models nhead = 32 # the number of heads in the multiheadattention models
dropout = 0 dropout = 0
initrange = 0.1 initrange = 0.1
ndecoder = args.num_decoder_layers
model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange).half().to(device)
balance = generate_balance(min(num_devices, 4), len(model)) if args.lazy_construction:
p = Pipe(model, balance, chunks=len(balance)) layers = [
lambda: EmbeddingLayer(ntokens, ninp, initrange),
lambda: PositionalEncodingLayer(ninp, dropout),
]
for _ in range(ndecoder):
layers.append(lambda: TransformerDecoderLayer(ninp, nhead, nhid, dropout))
layers.append(lambda: LinearLayer(ninp, ntokens, initrange))
model = layers
else:
model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
lr = 0.001 # learning rate lr = 0.01 # learning rate
try: def make_adam(model):
optimizer = Adam(p.parameters(), lr=lr, precision=Precision.PURE_FP16) return Adam(model.parameters(), lr=lr)
except NameError:
optimizer = Adam(p.parameters(), lr=lr) optimizer = make_adam
scaler = GradScaler() scaler = GradScaler()
return p, criterion, optimizer, scaler return model, criterion, optimizer, scaler
def get_tensors_by_size_bucket():
import gc
from collections import defaultdict
size_buckets = defaultdict(int)
for obj in gc.get_objects():
if not isinstance(obj, torch.Tensor):
continue
if obj.device.type == "cuda":
size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1
return size_buckets
def dump_size_buckets(size_buckets, prefix=""):
import operator
from functools import reduce
total = 0
for key, value in size_buckets.items():
this = reduce(operator.mul, key) * value
total += this
print(prefix + f"{key} : {value}, {this}")
print(prefix + f"total = {total}")
def train(train_data, model, criterion, optimizer, scaler, bptt, ntokens):
last_size_buckets = None
once = True
def safe_rank():
try:
return torch.distributed.get_rank()
except AssertionError:
return 0
def check_size_buckets():
global last_size_buckets
global once
size_buckets = get_tensors_by_size_bucket()
if last_size_buckets is not None:
if size_buckets != last_size_buckets:
print(f"difference is oustanding tensors: {safe-rank()}")
dump_size_buckets(last_size_buckets, "old: ")
dump_size_buckets(size_buckets, "new: ")
if once:
print(f"dumping buckets for: {safe_rank()}")
dump_size_buckets(last_size_buckets, "old: ")
dump_size_buckets(size_buckets, "new: ")
once = False
else:
print(f"size buckets none on {safe_rank()}")
last_size_buckets = size_buckets
def dump_cuda_tensors():
print(f"dumping cuda tensors...")
from functools import reduce
import operator
import gc
for obj in gc.get_objects():
if not isinstance(obj, torch.Tensor):
continue
if obj.device.type == "cuda":
size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1
print(f"outstanding cuda tensors:")
total = 0
for key, value in size_buckets.items():
this = reduce(operator.mul, key) * value
total += this
print(f"{key} : {value}, {this}")
print(f"total size = {total}")
import pprint
pprint.pprint(torch.cuda.memory_stats())
def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
model.train() model.train()
from functools import reduce
import operator
num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
if model.group:
print(f"training model, #prams = {num_params}, group: {model.group.rank()}, sizes {model.group.size()}")
else:
print(f"training model, #prams = {num_params}")
vocab_size = 10000 # FIXME
total_loss = 0.0 total_loss = 0.0
start_time = time.time() start_time = time.time()
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)): word_counter = 0
data, targets = get_batch(train_data, i, bptt)
optimizer = optimizer(model)
def get_first_device(model):
if model.devices:
return model.devices[0]
else:
return torch.cuda.current_device()
def get_last_device(model):
if model.devices:
return model.devices[-1]
else:
return torch.cuda.current_device()
for i, batch in enumerate(lm_dataloader):
if args.max_batch and i > args.max_batch:
break
optimizer.zero_grad() optimizer.zero_grad()
output = model(data) output = model(batch["input"].to(get_first_device(model)))
output = output.to(targets.device)
if model.group is None or model.group.rank() == model.group.size() - 1:
target = batch["target"].to(get_last_device(model))
output = output.to(target.device)
loss = criterion(output.view(-1, vocab_size), target.view(-1))
loss.backward()
else:
model.back_helper(output)
del output
loss = criterion(output.view(-1, ntokens), targets) torch.nn.utils.clip_grad_value_(model.parameters(), 0.05)
scaler.scale(loss).backward() optimizer.step()
scaler.step(optimizer) # scaler.step automatically unscale if unscale has not yet been performed
scaler.update()
if model.group is None or model.group.rank() == model.group.size() - 1:
total_loss += loss.item() total_loss += loss.item()
log_interval = 50 log_interval = 1
if batch % log_interval == 0 and batch > 0: word_counter += batch["ntokens"]
if i % log_interval == 0 and i > 0:
cur_loss = total_loss / log_interval cur_loss = total_loss / log_interval
elapsed = time.time() - start_time elapsed = time.time() - start_time
try:
print( print(
"| {:5d}/{:5d} batches | ms/batch {:5.2f} | " "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
"loss {:5.2f} | ppl {:8.2f} | grad scale {:3d} | optim scale {:3d}".format( i, word_counter / elapsed, cur_loss, math.exp(cur_loss)
batch,
len(train_data) // bptt,
elapsed * 1000 / log_interval,
cur_loss,
math.exp(cur_loss),
int(scaler.get_scale()),
int(optimizer._optim_scale),
)
)
except AttributeError:
print(
"| {:5d}/{:5d} batches | ms/batch {:5.2f} | "
"loss {:5.2f} | ppl {:8.2f}".format(
batch, len(train_data) // bptt, elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)
) )
) )
word_counter = 0
total_loss = 0 total_loss = 0
start_time = time.time() start_time = time.time()
# if i >= 10:
# break
# torch.cuda.empty_cache()
# check_size_buckets()
def evaluate(eval_model, data_source, criterion, bptt, ntokens): def evaluate(eval_model, data_source, criterion, bptt, ntokens):
...@@ -207,7 +355,7 @@ def get_number_of_words(data): ...@@ -207,7 +355,7 @@ def get_number_of_words(data):
return data.size()[0] * data.size()[1] return data.size()[0] * data.size()[1]
def benchmark_language_model(train_data, val_data, test_data, model, criterion, optimizer, scaler, ntokens): def benchmark_language_model(train_data, val_data, test_data, model, criterion, optimizer, ntokens, args):
epoch = 1 epoch = 1
bptt = 35 bptt = 35
start_time = time.time() start_time = time.time()
...@@ -216,9 +364,9 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion, ...@@ -216,9 +364,9 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
print("| start of epoch {:1d}".format(epoch)) print("| start of epoch {:1d}".format(epoch))
print("-" * 110) print("-" * 110)
epoch_start_time = time.time() epoch_start_time = time.time()
train(train_data, model, criterion, optimizer, scaler, bptt, ntokens) train(train_data, model, criterion, optimizer, bptt, ntokens, args)
val_loss = evaluate(model, val_data, criterion, bptt, ntokens) val_loss = 1 # evaluate(model, val_data, criterion, bptt, ntokens)
print("-" * 110) print("-" * 89)
print( print(
"| end of epoch {:1d} | time: {:5.2f}s | valid loss {:5.2f} ".format( "| end of epoch {:1d} | time: {:5.2f}s | valid loss {:5.2f} ".format(
epoch, (time.time() - epoch_start_time), val_loss epoch, (time.time() - epoch_start_time), val_loss
...@@ -230,8 +378,8 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion, ...@@ -230,8 +378,8 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
nwords = get_number_of_words(train_data) + get_number_of_words(val_data) nwords = get_number_of_words(train_data) + get_number_of_words(val_data)
wps = nwords / elapsed_time wps = nwords / elapsed_time
test_loss = evaluate(model, test_data, criterion, bptt, ntokens) test_loss = 1 # evaluate(model, test_data, criterion, bptt, ntokens)
print("=" * 110) print("=" * 89)
print( print(
"| end of training | test loss {:5.2f} \n| time: {:5.2f}s | words: {:3d} | wps: {:5.2f}".format( "| end of training | test loss {:5.2f} \n| time: {:5.2f}s | words: {:3d} | wps: {:5.2f}".format(
test_loss, elapsed_time, nwords, wps test_loss, elapsed_time, nwords, wps
...@@ -272,13 +420,186 @@ def generate_balance(num_devices, num_layers): ...@@ -272,13 +420,186 @@ def generate_balance(num_devices, num_layers):
return balance return balance
if __name__ == "__main__": def make_model_and_data(args, device, new_data: bool = True):
if new_data:
device = torch.device("cuda")
vocab_size = 10000
model, criterion, optimizer, scaler = make_model(args, device, vocab_size)
lm_dataset = BenchmarkLMDataset()
lm_dataloader = DataLoader(
lm_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=collate_sentences_lm
)
return {
"model": model,
"criterion": criterion,
"optimizer": optimizer,
"data": lm_dataloader,
"vocab_size": vocab_size,
}
else:
device = torch.device("cuda")
data = get_data(device)
ntokens, train_data, val_data, test_data = data
model, criterion, optimizer, scaler = make_model(args, device, ntokens)
return {
"model": model,
"criterion": criterion,
"optimizer": optimizer,
"data": data,
}
def bench_single_process(args):
num_devices = torch.cuda.device_count() num_devices = torch.cuda.device_count()
assert num_devices > 0 assert num_devices > 0
init_random_seed(0)
torch.manual_seed(0)
device = torch.device("cuda") device = torch.device("cuda")
ntokens, train_data, val_data, test_data = get_data(device)
model, criterion, optimizer, scaler = make_model(device, ntokens) new_data = True
benchmark_language_model(train_data, val_data, test_data, model, criterion, optimizer, scaler, ntokens)
blob = make_model_and_data(args, None, new_data=new_data)
model = blob["model"]
balance = generate_balance(min(num_devices, 8), len(model))
p = pipe.Pipe(
model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint
)
del model del model
del blob["model"]
if new_data:
train(blob["data"], p, blob["criterion"], blob["optimizer"], blob["vocab_size"], args)
else:
ntokens, train_data, val_data, test_data = blob["data"]
benchmark_language_model(train_data, val_data, test_data, p, criterion, optimizer, ntokens, args)
def run_mp_worker(args, available_workers):
new_data = True
blob = make_model_and_data(args, None, new_data=new_data)
model = blob["model"]
balance = generate_balance(min(available_workers, 8), len(model))
p = pipe.Pipe(
model,
balance,
style=Pipe.MultiProcess,
chunks=args.chunks,
worker_map=get_worker_map(),
input_device=torch.cuda.current_device(),
pipelined_backward=args.pipelined_backward,
checkpoint=args.checkpoint,
).cuda()
if args.all_at_once and p.pipeline:
print(f"running all at once")
p.pipeline.all_at_once = True
if new_data:
train(blob["data"], p, blob["criterion"], blob["optimizer"], blob["vocab_size"], args)
else:
ntokens, train_data, val_data, test_data = blob["data"]
benchmark_language_model(train_data, val_data, test_data, p, criterion, optimizer, ntokens, args)
def run_worker(rank, world_size, args):
if args.world_size != 0:
world_size = args.world_size
dist_init(rank + args.rank_base, world_size, hostname=args.host)
initialize_model_parallel(1, world_size)
init_random_seed(0)
run_mp_worker(args, world_size)
rpc.shutdown()
torch.distributed.destroy_process_group()
def bench_multi_process(args, all_at_once=False):
if args.local_world_size != 0:
world_size = args.local_world_size
else:
world_size = min(torch.cuda.device_count(), 2)
mp.spawn(run_worker, args=(world_size, args), nprocs=world_size, join=True)
best_device_map = {
0: "mlx5_0:1",
1: "mlx5_0:1",
2: "mlx5_1:1",
3: "mlx5_1:1",
4: "mlx5_2:1",
5: "mlx5_2:1",
6: "mlx5_3:1",
7: "mlx5_3:1",
}
def bench_mpi(args):
guess_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
os.environ["UCX_NET_DEVICES"] = best_device_map[guess_rank]
torch.distributed.init_process_group(backend="mpi")
os.environ["MASTER_ADDR"] = args.host
os.environ["MASTER_PORT"] = "10639"
if args.socket_name:
os.environ["GLOO_SOCKET_IFNAME"] = args.socket_name
os.environ["TP_SOCKET_IFNAME"] = args.socket_name
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(rank % torch.cuda.device_count())
rpc.init_rpc(
f"Test{rank}",
rank=rank,
world_size=world_size,
backend=rpc.BackendType.PROCESS_GROUP,
rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(rpc_timeout=20, init_method=init_method),
)
initialize_model_parallel(1, world_size)
init_random_seed(0)
run_mp_worker(args, world_size)
rpc.shutdown()
torch.distributed.destroy_process_group()
parser = argparse.ArgumentParser(description="benchmark")
parser.add_argument("--local-world-size", "-l", type=int, default=0, help="local world size")
parser.add_argument("--world-size", "-w", type=int, default=0, help="world size")
parser.add_argument("--rank-base", "-r", type=int, help="rank base", default=0)
parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname")
parser.add_argument("--no-mpi", action="store_true", default=False, help="disable mpi")
parser.add_argument("--chunks", type=int, default=1, help="number of microbatches per batch")
parser.add_argument("--batch-size", type=int, default=8, help="size of a batch")
parser.add_argument("--all-at-once", action="store_true", default=False, help="do backward pass on whole batch at once")
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
parser.add_argument("--socket-name", type=str, default=None, help="socket ifname for gloo/tp")
parser.add_argument("--num-decoder-layers", type=int, default=10, help="Number of decoder layers in the model")
parser.add_argument(
"--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
)
parser.add_argument(
"--checkpoint", default="never", choices=["always", "except_last", "never"], help="Checkpointing strategy for pipe"
)
parser.add_argument(
"--pipelined-backward", dest="pipelined_backward", action="store_true", help="Pipelined backward pass"
)
parser.add_argument(
"--no-pipelined-backward", dest="pipelined_backward", action="store_false", help="Pipelined backward pass"
)
parser.set_defaults(pipelined_backward=True)
if __name__ == "__main__":
args = parser.parse_args()
# bench_multi_process(args, all_at_once=True)
if args.no_mpi or "OMPI_COMM_WORLD_RANK" not in os.environ:
print(f"Running benchmark with args: {args}")
bench_single_process(args)
else:
if os.environ["OMPI_COMM_WORLD_RANK"] == "0":
print(f"Running benchmark with args: {args}")
bench_mpi(args)
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# #
import os import os
import sys import sys
from typing import Any, List
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes. # a list of builtin themes.
...@@ -46,7 +47,7 @@ templates_path = ["_templates"] ...@@ -46,7 +47,7 @@ templates_path = ["_templates"]
# List of patterns, relative to source directory, that match files and # List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files. # directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path. # This pattern also affects html_static_path and html_extra_path.
exclude_patterns = [] exclude_patterns: List[Any] = []
# -- Options for HTML output ------------------------------------------------- # -- Options for HTML output -------------------------------------------------
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from .cross_entropy import vocab_parallel_cross_entropy from .cross_entropy import vocab_parallel_cross_entropy
from .initialize import ( from .initialize import (
destroy_model_parallel,
get_data_parallel_group, get_data_parallel_group,
get_data_parallel_rank, get_data_parallel_rank,
get_data_parallel_world_size, get_data_parallel_world_size,
...@@ -12,6 +13,8 @@ from .initialize import ( ...@@ -12,6 +13,8 @@ from .initialize import (
get_model_parallel_rank, get_model_parallel_rank,
get_model_parallel_src_rank, get_model_parallel_src_rank,
get_model_parallel_world_size, get_model_parallel_world_size,
get_pipeline_parallel_group,
get_pipeline_parallel_ranks,
initialize_model_parallel, initialize_model_parallel,
) )
from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
......
...@@ -35,6 +35,8 @@ _DATA_PARALLEL_GROUP = None ...@@ -35,6 +35,8 @@ _DATA_PARALLEL_GROUP = None
# Pipeline parallel group that the current rank belongs to. # Pipeline parallel group that the current rank belongs to.
_PIPELINE_PARALLEL_GROUP = None _PIPELINE_PARALLEL_GROUP = None
_PIPELINE_PARALLEL_RANKS = None
def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = 1) -> None: def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = 1) -> None:
""" """
...@@ -93,7 +95,15 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = ...@@ -93,7 +95,15 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int =
global _PIPELINE_PARALLEL_GROUP global _PIPELINE_PARALLEL_GROUP
assert _PIPELINE_PARALLEL_GROUP is None, "model parallel group is already initialized" assert _PIPELINE_PARALLEL_GROUP is None, "model parallel group is already initialized"
_PIPELINE_PARALLEL_GROUP = groups[found[0], :, found[2]].tolist() global _PIPELINE_PARALLEL_RANKS
assert _PIPELINE_PARALLEL_RANKS is None, "model parallel group is already initialized"
for i in range(data_parallel_size):
for k in range(model_parallel_size):
ranks = groups[i, :, k].tolist()
group = torch.distributed.new_group(ranks)
if i == found[0] and k == found[2]:
_PIPELINE_PARALLEL_GROUP = group
_PIPELINE_PARALLEL_RANKS = ranks
def model_parallel_is_initialized() -> bool: def model_parallel_is_initialized() -> bool:
...@@ -115,12 +125,18 @@ def get_data_parallel_group() -> torch.distributed.ProcessGroup: ...@@ -115,12 +125,18 @@ def get_data_parallel_group() -> torch.distributed.ProcessGroup:
return _DATA_PARALLEL_GROUP return _DATA_PARALLEL_GROUP
def get_pipeline_parallel_group() -> List[int]: def get_pipeline_parallel_group() -> torch.distributed.ProcessGroup:
"""Get the pipeline parallel group the caller rank belongs to.""" """Get the pipeline parallel group the caller rank belongs to."""
assert _PIPELINE_PARALLEL_GROUP is not None, "pipeline parallel group is not initialized" assert _PIPELINE_PARALLEL_GROUP is not None, "pipeline parallel group is not initialized"
return _PIPELINE_PARALLEL_GROUP return _PIPELINE_PARALLEL_GROUP
def get_pipeline_parallel_ranks() -> List[int]:
"""Get the pipeline parallel group the caller rank belongs to."""
assert _PIPELINE_PARALLEL_RANKS is not None, "pipeline parallel group is not initialized"
return _PIPELINE_PARALLEL_RANKS
def get_model_parallel_world_size() -> int: def get_model_parallel_world_size() -> int:
"""Return world size for the model parallel group.""" """Return world size for the model parallel group."""
return torch.distributed.get_world_size(group=get_model_parallel_group()) return torch.distributed.get_world_size(group=get_model_parallel_group())
...@@ -157,3 +173,6 @@ def destroy_model_parallel() -> None: ...@@ -157,3 +173,6 @@ def destroy_model_parallel() -> None:
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
global _PIPELINE_PARALLEL_GROUP global _PIPELINE_PARALLEL_GROUP
_PIPELINE_PARALLEL_GROUP = None _PIPELINE_PARALLEL_GROUP = None
global _PIPELINE_PARALLEL_RANKS
_PIPELINE_PARALLEL_RANKS = None
...@@ -280,6 +280,9 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -280,6 +280,9 @@ class ColumnParallelLinear(torch.nn.Module):
return_master_weight=keep_master_weight_for_test, return_master_weight=keep_master_weight_for_test,
) )
def get_master_weight(self) -> torch.Tensor:
return gather_from_model_parallel_region(self.weight.data.transpose(0, 1)).transpose_(0, 1)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
# Set up backprop all-reduce. # Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_) input_parallel = copy_to_model_parallel_region(input_)
...@@ -364,6 +367,9 @@ class RowParallelLinear(torch.nn.Module): ...@@ -364,6 +367,9 @@ class RowParallelLinear(torch.nn.Module):
return_master_weight=keep_master_weight_for_test, return_master_weight=keep_master_weight_for_test,
) )
def get_master_weight(self) -> torch.Tensor:
return gather_from_model_parallel_region(self.weight.data)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore
# Set up backprop all-reduce. # Set up backprop all-reduce.
if self.input_is_parallel: if self.input_is_parallel:
......
...@@ -19,21 +19,27 @@ ...@@ -19,21 +19,27 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any
import torch import torch
from .initialize import get_model_parallel_group from .initialize import get_model_parallel_group
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
def _reduce(input_: torch.Tensor) -> torch.Tensor: def _reduce(ctx: Any, input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the the input tensor across model parallel group.""" """All-reduce the the input tensor across model parallel group."""
group = get_model_parallel_group() group = get_model_parallel_group()
if ctx:
ctx.mark_dirty(input_)
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if torch.distributed.get_world_size(group=group) == 1: if torch.distributed.get_world_size(group=group) == 1:
return input_ return input_
# All-reduce. # All-reduce.
print(f"doing all_reduce on {torch.distributed.get_rank()}")
torch.distributed.all_reduce(input_, group=group) torch.distributed.all_reduce(input_, group=group)
return input_ return input_
...@@ -87,11 +93,13 @@ class _CopyToModelParallelRegion(torch.autograd.Function): ...@@ -87,11 +93,13 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input_): # type: ignore def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _CopyToModelParallelRegion Forward")
return input_ return input_
@staticmethod @staticmethod
def backward(ctx, grad_output): # type: ignore def backward(ctx, grad_output): # type: ignore
return _reduce(grad_output) print(f"{torch.distributed.get_rank()}: _CopyToModelParallelRegion Backward")
return _reduce(None, grad_output)
class _ReduceFromModelParallelRegion(torch.autograd.Function): class _ReduceFromModelParallelRegion(torch.autograd.Function):
...@@ -99,10 +107,12 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function): ...@@ -99,10 +107,12 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input_): # type: ignore def forward(ctx, input_): # type: ignore
return _reduce(input_) print(f"{torch.distributed.get_rank()}: _ReduceFromModelParallelRegion Forward")
return _reduce(ctx, input_)
@staticmethod @staticmethod
def backward(ctx, grad_output): # type: ignore def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _ReduceFromModelParallelRegion Backward")
return grad_output return grad_output
...@@ -111,10 +121,12 @@ class _ScatterToModelParallelRegion(torch.autograd.Function): ...@@ -111,10 +121,12 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input_): # type: ignore def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _ScatterToModelParallelRegion Forward")
return _split(input_) return _split(input_)
@staticmethod @staticmethod
def backward(ctx, grad_output): # type: ignore def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _ScatterToModelParallelRegion Backward")
return _gather(grad_output) return _gather(grad_output)
...@@ -123,10 +135,12 @@ class _GatherFromModelParallelRegion(torch.autograd.Function): ...@@ -123,10 +135,12 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input_): # type: ignore def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _GatherFromModelParallelRegion Forward")
return _gather(input_) return _gather(input_)
@staticmethod @staticmethod
def backward(ctx, grad_output): # type: ignore def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _GatherFromModelParallelRegion Backward")
return _split(grad_output) return _split(grad_output)
......
...@@ -59,7 +59,7 @@ def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float ...@@ -59,7 +59,7 @@ def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float
if any(p.grad is not None for p in module.parameters()): if any(p.grad is not None for p in module.parameters()):
raise ValueError("some parameter already has gradient") raise ValueError("some parameter already has gradient")
_batch = Batch(sample) _batch = Batch(sample, 0)
for i, x in enumerate(_batch): for i, x in enumerate(_batch):
_batch[i] = x.detach().to(device).requires_grad_(x.requires_grad) _batch[i] = x.detach().to(device).requires_grad_(x.requires_grad)
...@@ -101,7 +101,7 @@ def profile_sizes( ...@@ -101,7 +101,7 @@ def profile_sizes(
if device.type != "cuda": if device.type != "cuda":
raise ValueError("size profiler supports only CUDA device") raise ValueError("size profiler supports only CUDA device")
batch = Batch(input) batch = Batch(input, 0)
sizes: List[int] = [] sizes: List[int] = []
latent_scale = batch[0].size(0) / chunks latent_scale = batch[0].size(0) / chunks
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# Copyright 2019 Kakao Brain # Copyright 2019 Kakao Brain
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
...@@ -74,20 +74,6 @@ class Function(Protocol): ...@@ -74,20 +74,6 @@ class Function(Protocol):
... ...
def checkpoint(function: Function, input: TensorOrTensors) -> TensorOrTensors:
"""Makes a checkpoint with a simple interface like
:func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug
:class:`Checkpoint` and :class:`Recompute` without boilerplate.
"""
batch = Batch(input)
chk = Checkpointing(function, batch)
batch = chk.checkpoint()
chk.recompute(batch)
return batch.tensor_or_tensors
class Checkpointing: class Checkpointing:
"""Generates a pair of :class:`Checkpoint` and :class:`Recompute`.""" """Generates a pair of :class:`Checkpoint` and :class:`Recompute`."""
...@@ -116,7 +102,7 @@ class Checkpointing: ...@@ -116,7 +102,7 @@ class Checkpointing:
if isinstance(output, tuple): if isinstance(output, tuple):
output = tuple([x if x.is_floating_point() else x.detach() for x in output]) output = tuple([x if x.is_floating_point() else x.detach() for x in output])
return Batch(output) return Batch(output, self.batch.index)
def recompute(self, batch: Batch) -> None: def recompute(self, batch: Batch) -> None:
"""Applies :class:`Recompute` to the batch in place.""" """Applies :class:`Recompute` to the batch in place."""
...@@ -226,6 +212,7 @@ def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None ...@@ -226,6 +212,7 @@ def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None
else: else:
gpu_rng_state = None gpu_rng_state = None
rng_states.clear()
rng_states.append((cpu_rng_state, gpu_rng_state)) rng_states.append((cpu_rng_state, gpu_rng_state))
...@@ -237,7 +224,7 @@ def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> G ...@@ -237,7 +224,7 @@ def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> G
.. seealso:: :ref:`Referential Transparency` .. seealso:: :ref:`Referential Transparency`
""" """
cpu_rng_state, gpu_rng_state = rng_states.pop() cpu_rng_state, gpu_rng_state = rng_states[0]
gpu_devices: List[torch.device] = [] gpu_devices: List[torch.device] = []
if device.type == "cuda": if device.type == "cuda":
......
...@@ -53,9 +53,14 @@ class Batch: ...@@ -53,9 +53,14 @@ class Batch:
""" """
def __init__(self, value: TensorOrTensors) -> None: def __init__(self, value: TensorOrTensors, index: int) -> None:
self.value = value self.value = value
self.atomic = torch.is_tensor(value) self.atomic = torch.is_tensor(value)
self.__index = index
@property
def index(self) -> int:
return self.__index
@property @property
def tensor(self) -> Tensor: def tensor(self) -> Tensor:
...@@ -80,7 +85,7 @@ class Batch: ...@@ -80,7 +85,7 @@ class Batch:
"""Calls a function by the underlying tensor or tensors. It also wraps """Calls a function by the underlying tensor or tensors. It also wraps
the output with :class:`Batch`. the output with :class:`Batch`.
""" """
return Batch(function(self.value)) return Batch(function(self.value), self.index)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Batch[atomic={self.atomic!r}]({self.value!r})" return f"Batch[atomic={self.atomic!r}]({self.value!r})"
...@@ -176,7 +181,7 @@ def scatter(input: TensorOrTensors, chunks: int) -> List[Batch]: ...@@ -176,7 +181,7 @@ def scatter(input: TensorOrTensors, chunks: int) -> List[Batch]:
inputs = zip(*rotated) inputs = zip(*rotated)
return [Batch(x) for x in inputs] return [Batch(x, i) for i, x in enumerate(inputs)]
def gather(outputs: List[Batch]) -> TensorOrTensors: def gather(outputs: List[Batch]) -> TensorOrTensors:
......
...@@ -19,18 +19,21 @@ ...@@ -19,18 +19,21 @@
"""The Pipe interface.""" """The Pipe interface."""
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast
import warnings
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
import torch.autograd import torch.autograd
import torch.cuda import torch.cuda
from fairscale.nn.model_parallel import get_model_parallel_group, get_pipeline_parallel_group
from . import microbatch from . import microbatch
from .batchnorm import DeferredBatchNorm from .batchnorm import DeferredBatchNorm
from .pipeline import Pipeline from .pipeline import Pipeline, PipelineStyle
from .skip.layout import inspect_skip_layout from .skip.layout import SkipLayout, inspect_skip_layout
from .skip.skippable import verify_skippables from .skip.skippable import Skippable, verify_skippables
from .stream import AbstractStream, new_stream from .stream import AbstractStream, new_stream
__all__ = ["Pipe"] __all__ = ["Pipe"]
...@@ -42,6 +45,8 @@ Devices = Union[Iterable[Device], List[Device]] ...@@ -42,6 +45,8 @@ Devices = Union[Iterable[Device], List[Device]]
Tensors = Tuple[Tensor, ...] Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors] TensorOrTensors = Union[Tensor, Tensors]
ListOfLazyModules = List[Callable[[], nn.Module]]
if TYPE_CHECKING: if TYPE_CHECKING:
Module = nn.Module[TensorOrTensors] Module = nn.Module[TensorOrTensors]
NamedModules = OrderedDict[str, Module] NamedModules = OrderedDict[str, Module]
...@@ -69,7 +74,21 @@ naive automatic balancing: ...@@ -69,7 +74,21 @@ naive automatic balancing:
""" """
def verify_module(module: nn.Sequential) -> None: # FIXME(tom) make this a valid way to call
def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None:
for layer in module:
if isinstance(layer, nn.Module):
pass
elif callable(layer):
pass
else:
raise TypeError(f"layer {type(layer)} must be nn.Module or callable to be partitioned")
def verify_module(module: Union[nn.Sequential, ListOfLazyModules]) -> None:
if isinstance(module, Iterable) and not isinstance(module, nn.Sequential):
verify_list_of_callable(module)
else:
if not isinstance(module, nn.Sequential): if not isinstance(module, nn.Sequential):
raise TypeError("module must be nn.Sequential to be partitioned") raise TypeError("module must be nn.Sequential to be partitioned")
...@@ -79,7 +98,10 @@ def verify_module(module: nn.Sequential) -> None: ...@@ -79,7 +98,10 @@ def verify_module(module: nn.Sequential) -> None:
def verify_splitting( def verify_splitting(
module: nn.Sequential, partitions: List[nn.Sequential], balance: Iterable[int], devices: List[torch.device] module: nn.Sequential,
partitions: List[nn.Sequential],
balance: Iterable[int],
devices: Optional[List[torch.device]],
) -> None: ) -> None:
num_parameters = len(list(module.parameters())) num_parameters = len(list(module.parameters()))
num_child_parameters = sum(len(list(child.parameters())) for child in module.children()) num_child_parameters = sum(len(list(child.parameters())) for child in module.children())
...@@ -90,7 +112,7 @@ def verify_splitting( ...@@ -90,7 +112,7 @@ def verify_splitting(
for j in range(i + 1, len(partitions)): for j in range(i + 1, len(partitions)):
parti = partitions[i] parti = partitions[i]
partj = partitions[j] partj = partitions[j]
if devices[i] == devices[j]: if devices and devices[i] == devices[j]:
continue continue
for p in parti.parameters(): for p in parti.parameters():
for q in partj.parameters(): for q in partj.parameters():
...@@ -102,9 +124,65 @@ class BalanceError(ValueError): ...@@ -102,9 +124,65 @@ class BalanceError(ValueError):
pass pass
def check_balance(module: Any, balance: Iterable[int]) -> None:
if len(module) != sum(balance):
raise BalanceError(
f"module and sum of balance have different length (module: {len(module)}, sum of balance: {sum(balance)})"
)
if any(x <= 0 for x in balance):
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
def instantiate_partition(
module: Union[nn.Sequential, ListOfLazyModules], balance: Iterable[int], group: torch.distributed.ProcessGroup
) -> nn.Sequential:
balance = list(balance)
check_balance(module, balance)
layers: NamedModules = OrderedDict()
j = 0
def maybe_realize(layer: Any) -> nn.Module:
if isinstance(layer, nn.Module):
return layer
elif callable(layer):
return layer()
else:
raise TypeError(f"layer must be nn.Module or callable, is {type(layer)}")
def iterate_module(module: Union[nn.Sequential, list]) -> Iterable[Tuple[Any, nn.Module]]:
if isinstance(module, nn.Sequential):
yield from module.named_children()
else:
yield from enumerate(module)
for name, layer in iterate_module(module):
layers[name] = layer
if len(layers) == balance[j]:
if j == group.rank():
for key in layers:
layers[key] = maybe_realize(layers[key])
if not isinstance(module, nn.Sequential):
for layer in layers.values():
if isinstance(layer, Skippable):
raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction")
partition = nn.Sequential(*layers.values())
return partition
# Prepare for the next partition.
layers.clear()
j += 1
raise ValueError("Souldn't get here, more ranks than partitions")
def split_module( def split_module(
module: nn.Sequential, balance: Iterable[int], devices: List[torch.device], module: nn.Sequential, balance: Iterable[int], devices: Optional[List[torch.device]],
) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]: ) -> Tuple[List[nn.Sequential], List[int], Optional[List[torch.device]]]:
"""Splits a module into multiple partitions. """Splits a module into multiple partitions.
Returns: Returns:
...@@ -123,18 +201,11 @@ def split_module( ...@@ -123,18 +201,11 @@ def split_module(
""" """
balance = list(balance) balance = list(balance)
if len(module) != sum(balance): check_balance(module, balance)
raise BalanceError(
"module and sum of balance have different length "
f"(module: {len(module)}, sum of balance: {sum(balance)})"
)
if any(x <= 0 for x in balance): if devices and len(balance) > len(devices):
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
if len(balance) > len(devices):
raise IndexError( raise IndexError(
"too few devices to hold given partitions " f"(devices: {len(devices)}, partitions: {len(balance)})" f"too few devices to hold given partitions (devices: {len(devices)}, partitions: {len(balance)})"
) )
j = 0 j = 0
...@@ -148,6 +219,7 @@ def split_module( ...@@ -148,6 +219,7 @@ def split_module(
# Group buffered layers as a partition. # Group buffered layers as a partition.
partition = nn.Sequential(layers) partition = nn.Sequential(layers)
if devices:
device = devices[j] device = devices[j]
partition.to(device) partition.to(device)
...@@ -158,12 +230,13 @@ def split_module( ...@@ -158,12 +230,13 @@ def split_module(
j += 1 j += 1
partitions = cast(List[nn.Sequential], nn.ModuleList(partitions)) partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))
if devices:
del devices[j:] del devices[j:]
return partitions, balance, devices return partitions, balance, devices
MOVING_DENIED = TypeError("denied to move parameters and buffers, " "because Pipe should manage device placement") MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement")
class Pipe(Module): class Pipe(Module):
...@@ -193,8 +266,23 @@ class Pipe(Module): ...@@ -193,8 +266,23 @@ class Pipe(Module):
list of number of layers in each partition list of number of layers in each partition
Keyword Args: Keyword Args:
style (PipelineStyle):
whether to use a single process for all pipeline stages or to assign
one stage per process
devices (iterable of devices): devices (iterable of devices):
devices to use (default: all CUDA devices) devices to use (default: all CUDA devices)
group (ProcessGroup):
specific to `style=MultiProcess`, the process group that all
pipeline stages are a member of. Defaults to
`get_pipeline_parallel_group()`
worker_map (Dict[int, str]):
a map from worker name (the first argument to
`torch.distributed.rpc.init_rpc`) to global rank (i.e.
`torch.distributed.get_rank()`) needed in order for pipeline stages
to communicate with each other
input_device (device):
the device on which tensors should be located before being passed to
the first module in a given pipeline stage
chunks (int): chunks (int):
number of micro-batches (default: ``1``) number of micro-batches (default: ``1``)
checkpoint (str): checkpoint (str):
...@@ -204,6 +292,16 @@ class Pipe(Module): ...@@ -204,6 +292,16 @@ class Pipe(Module):
whether to use deferred BatchNorm moving statistics (default: whether to use deferred BatchNorm moving statistics (default:
:data:`False`, see :ref:`Deferred Batch Normalization` for more :data:`False`, see :ref:`Deferred Batch Normalization` for more
details) details)
pipelined_backward (bool, optional):
if True, call torch.autograd.backward once per microbatch on the
backward pass (instead of once for the whole batch). This works
around a potential deadlock in pytorch when using tensor parallelism
at the same time. Defaults to `True` if
`get_model_parallel_group.size() > 1`
(default: `None`)
retain_graph (bool):
The value passed to `torch.autograd.backwards(..., retain_graph=<value>)
(default: = `True`)
Raises: Raises:
TypeError: TypeError:
...@@ -215,6 +313,9 @@ class Pipe(Module): ...@@ -215,6 +313,9 @@ class Pipe(Module):
""" """
SingleProcess: PipelineStyle = PipelineStyle.SingleProcess
MultiProcess: PipelineStyle = PipelineStyle.MultiProcess
#: The number of layers in each partition. #: The number of layers in each partition.
balance: List[int] = [] balance: List[int] = []
# ^^ # ^^
...@@ -234,7 +335,7 @@ class Pipe(Module): ...@@ -234,7 +335,7 @@ class Pipe(Module):
#: output = pipe(input) #: output = pipe(input)
#: loss = F.cross_entropy(output, target) #: loss = F.cross_entropy(output, target)
#: #:
devices: List[torch.device] = [] devices: Optional[List[torch.device]] = None
#: The number of micro-batches. #: The number of micro-batches.
chunks: int = 1 chunks: int = 1
...@@ -245,13 +346,19 @@ class Pipe(Module): ...@@ -245,13 +346,19 @@ class Pipe(Module):
def __init__( def __init__(
self, self,
module: nn.Sequential, module: Union[nn.Sequential, ListOfLazyModules],
balance: Optional[Iterable[int]] = None, balance: Optional[Iterable[int]] = None,
*, *,
style: PipelineStyle = PipelineStyle.SingleProcess,
devices: Optional[Devices] = None, devices: Optional[Devices] = None,
group: Optional[torch.distributed.ProcessGroup] = None,
worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None,
chunks: int = chunks, chunks: int = chunks,
checkpoint: str = checkpoint, checkpoint: str = checkpoint,
deferred_batch_norm: bool = False, deferred_batch_norm: bool = False,
pipelined_backward: bool = None,
retain_graph: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -269,16 +376,26 @@ class Pipe(Module): ...@@ -269,16 +376,26 @@ class Pipe(Module):
# Verify if the underlying skippable modules satisfy integrity. The # Verify if the underlying skippable modules satisfy integrity. The
# integrity can be verified before forward() because it is static. # integrity can be verified before forward() because it is static.
if isinstance(module, nn.Sequential):
verify_skippables(module) verify_skippables(module)
self.chunks = chunks self.chunks = chunks
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.pipelined_backward = pipelined_backward
self.retain_graph = retain_graph
self.pipeline: Optional[Pipeline]
if style is PipelineStyle.SingleProcess:
module = cast(nn.Sequential, module)
if deferred_batch_norm: if deferred_batch_norm:
module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks) module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks)
if input_device is not None:
raise ValueError("'input_device' argument only applies to 'PipelineStyle.MultiProcess'")
if devices is None: if devices is None:
devices = range(torch.cuda.device_count()) devices = range(torch.cuda.device_count())
devices = [torch.device(d) for d in devices] devices = [torch.device(d) for d in devices]
devices = cast(List[torch.device], devices) devices = cast(List[torch.device], devices)
...@@ -286,19 +403,83 @@ class Pipe(Module): ...@@ -286,19 +403,83 @@ class Pipe(Module):
self.partitions, self.balance, self.devices = split_module(module, balance, devices) self.partitions, self.balance, self.devices = split_module(module, balance, devices)
except BalanceError as exc: except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc))) raise ValueError(recommend_auto_balance(str(exc)))
verify_splitting(module, self.partitions, self.balance, self.devices) verify_splitting(module, self.partitions, self.balance, self.devices)
self._copy_streams: List[List[AbstractStream]] = []
self._skip_layout = inspect_skip_layout(self.partitions) self._skip_layout = inspect_skip_layout(self.partitions)
# Separate CUDA streams for copy. elif style is PipelineStyle.MultiProcess:
copy_streams = self._ensure_copy_streams() if group is None:
group = get_pipeline_parallel_group()
if devices is not None:
raise ValueError("'devices' argument only applies to 'PipelineStyle.SingleProcess'")
self.balance = list(balance)
if group.size() < len(self.balance):
raise IndexError(
f"too few ranks to hold given partitions (ranks: {group.size()}, partitions: {len(self.balance)})"
)
try:
rank = torch.distributed.get_rank(group)
if rank >= len(self.balance):
warnings.warn("More ranks than partitions, some ranks unused")
self.partitions = cast(List[nn.Sequential], nn.ModuleList([nn.Sequential()]))
else:
partition = instantiate_partition(module, balance, group)
if deferred_batch_norm:
partition = DeferredBatchNorm.convert_deferred_batch_norm(partition, chunks)
self.partitions = cast(List[nn.Sequential], nn.ModuleList([partition]))
self.devices = None
if isinstance(module, nn.Sequential):
local_partitions, _, _ = split_module(module, balance, None)
self._skip_layout = inspect_skip_layout(local_partitions)
else:
self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom)
except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc)))
self.group = group
self.worker_map = worker_map
self.input_device = input_device
self._copy_streams: List[List[AbstractStream]] = []
# The micro-batch index where the checkpointing stops. # The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop) if style is PipelineStyle.SingleProcess:
# Separate CUDA streams for copy.
copy_streams = self._ensure_copy_streams()
if self.pipelined_backward is None:
self.pipelined_backward = False
self.pipeline = Pipeline(
self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop, style=style
)
elif style is PipelineStyle.MultiProcess:
rank = torch.distributed.get_rank(group)
if rank >= len(self.balance):
self.pipeline = None
else:
self.final_stage = rank == len(self.balance) - 1
self.pipeline = Pipeline(
self.partitions,
None,
None,
self._skip_layout,
checkpoint_stop,
style=style,
group=self.group,
worker_map=self.worker_map,
input_device=self.input_device,
)
del module
if self.pipelined_backward is None:
if get_model_parallel_group().size() > 1:
self.pipelined_backward = True
else:
self.pipelined_backward = False
def __len__(self) -> int: def __len__(self) -> int:
"""Counts the length of the underlying sequential module.""" """Counts the length of the underlying sequential module."""
...@@ -333,10 +514,17 @@ class Pipe(Module): ...@@ -333,10 +514,17 @@ class Pipe(Module):
# Pipe should manage the device of each partition. # Pipe should manage the device of each partition.
# Deny cuda(), cpu(), and to() with device, by TypeError. # Deny cuda(), cpu(), and to() with device, by TypeError.
def cuda(self, device: Optional[Device] = None) -> "Pipe": def cuda(self, device: Optional[Device] = None) -> "Pipe":
if self.devices:
raise MOVING_DENIED raise MOVING_DENIED
if device:
return super().cuda(device=device)
else:
return super().cuda()
def cpu(self) -> "Pipe": def cpu(self) -> "Pipe":
if self.devices:
raise MOVING_DENIED raise MOVING_DENIED
return super().cpu()
def to(self, *args: Any, **kwargs: Any) -> "Pipe": def to(self, *args: Any, **kwargs: Any) -> "Pipe":
# Deny these usages: # Deny these usages:
...@@ -348,6 +536,7 @@ class Pipe(Module): ...@@ -348,6 +536,7 @@ class Pipe(Module):
# #
# - to(dtype[, non_blocking]) # - to(dtype[, non_blocking])
# #
if self.devices:
if "device" in kwargs or "tensor" in kwargs: if "device" in kwargs or "tensor" in kwargs:
raise MOVING_DENIED raise MOVING_DENIED
...@@ -368,6 +557,7 @@ class Pipe(Module): ...@@ -368,6 +557,7 @@ class Pipe(Module):
""" """
if not self._copy_streams: if not self._copy_streams:
assert self.devices is not None
for device in self.devices: for device in self.devices:
self._copy_streams.append([new_stream(device) for _ in range(self.chunks)]) self._copy_streams.append([new_stream(device) for _ in range(self.chunks)])
...@@ -392,16 +582,78 @@ class Pipe(Module): ...@@ -392,16 +582,78 @@ class Pipe(Module):
""" """
microbatch.check(input) microbatch.check(input)
if not self.devices: if not self.group and not self.devices:
# Empty sequential module is not illegal. # Empty sequential module is not illegal.
return input return input
if not self.pipeline:
# No pipeline is not illegal, more ranks than partitions
return input
# Divide a mini-batch into micro-batches. # Divide a mini-batch into micro-batches.
batches = microbatch.scatter(input, self.chunks) batches = microbatch.scatter(input, self.chunks)
# Run pipeline parallelism. # Run pipeline parallelism.
self.pipeline.run(batches) self.pipeline.run(batches)
if self.group and not self.final_stage:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
return batches # type: ignore
else:
# Merge the micro-batches into one mini-batch. # Merge the micro-batches into one mini-batch.
if self.pipelined_backward:
with torch.no_grad():
output = microbatch.gather(batches)
from .phony import get_phony
phony = get_phony(torch.device(torch.cuda.current_device()), requires_grad=True)
output = PipelinedBackwardPass.apply(output, batches, phony, True) # self.retain_graph)
else:
output = microbatch.gather(batches) output = microbatch.gather(batches)
return output return output
def back_helper(self, output: List[microbatch.Batch]) -> None:
if self.final_stage:
raise ValueError("back_helper should only be called on non-final stages")
if self.pipeline:
self.pipeline.back_helper(list(reversed(output)))
class PipelinedBackwardPass(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx, input: TensorOrTensors, batches, phony, retain_graph) -> TensorOrTensors:
ctx.batches = batches
ctx.retain_graph = retain_graph
return input
@staticmethod
# type: ignore
def backward(ctx, *grads) -> Tuple:
with torch.no_grad():
grad_batches = microbatch.scatter(grads, len(ctx.batches))
for grad, batch in reversed(list(zip(grad_batches, ctx.batches))):
for t in batch:
t.retain_grad()
torch.autograd.backward(batch.tensor_or_tensors, grad_tensors=(*grad,), retain_graph=ctx.retain_graph)
with torch.no_grad():
if ctx.batches[0].atomic:
tensors = tuple(b.tensor.grad for b in ctx.batches)
output: TensorOrTensors = torch.cat(tensors)
else:
rotated = [[t.grad for t in b.tensors] for b in ctx.batches]
output_buf = []
for tensors in zip(*rotated):
output_buf.append(torch.cat(tensors))
output = tuple(output_buf)
del ctx.batches
return (output, None, None, None)
This diff is collapsed.
...@@ -36,19 +36,41 @@ class SkipLayout: ...@@ -36,19 +36,41 @@ class SkipLayout:
# Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...] # Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...]
by_partition: List[List[Tuple[int, Namespace, str]]] by_partition: List[List[Tuple[int, Namespace, str]]]
# Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...]
by_src_partition: List[List[Tuple[int, Namespace, str]]]
def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None: def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None:
# The skip routes are already indexed by 'ns, name'. # The skip routes are already indexed by 'ns, name'.
self.by_ns_name = skip_routes self.by_ns_name = skip_routes
# Index skip routes by partition number 'j'. # Index skip routes by partition number 'j'.
self.by_partition = [[] for _ in range(num_partitions)] self.by_partition = [[] for _ in range(num_partitions)]
self.by_src_partition = [[] for _ in range(num_partitions)]
for (ns, name), (prev_j, next_j) in skip_routes.items(): for (ns, name), (prev_j, next_j) in skip_routes.items():
self.by_partition[next_j].append((prev_j, ns, name)) self.by_partition[next_j].append((prev_j, ns, name))
self.by_src_partition[prev_j].append((next_j, ns, name))
for p in self.by_partition: for p in self.by_partition:
p.sort() p.sort()
def copy_policy_by_src(self, prev_j: int) -> Iterable[Tuple[int, Namespace, str]]:
"""Generates skip routes for the given destination partition number.
The skip routes are sorted by source partition number in ascending
order.
Yields:
Each tuple of (source partition number, namespace, name).
"""
for next_j, ns, name in self.by_src_partition[prev_j]:
if prev_j == next_j:
# This skip tensor will be popped at the same partition where
# it is stashed. In this case, copy is not required.
continue
yield (next_j, ns, name)
def copy_policy(self, next_j: int) -> Iterable[Tuple[int, Namespace, str]]: def copy_policy(self, next_j: int) -> Iterable[Tuple[int, Namespace, str]]:
"""Generates skip routes for the given destination partition number. """Generates skip routes for the given destination partition number.
The skip routes are sorted by source partition number in ascending The skip routes are sorted by source partition number in ascending
......
...@@ -25,11 +25,12 @@ one of the most important feature of :mod:`torchpipe.skip`. ...@@ -25,11 +25,12 @@ one of the most important feature of :mod:`torchpipe.skip`.
The metaphor is inspired by Portal™ from Valve. The metaphor is inspired by Portal™ from Valve.
""" """
from typing import List, Optional, Tuple from typing import Any, List, Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
from . import Namespace
from ..copy import Context as CopyContext from ..copy import Context as CopyContext
from ..copy import Copy from ..copy import Copy
from ..phony import get_phony from ..phony import get_phony
...@@ -41,9 +42,16 @@ __all__: List[str] = [] ...@@ -41,9 +42,16 @@ __all__: List[str] = []
class Portal: class Portal:
"""A portal for a tensor.""" """A portal for a tensor."""
def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None: def __init__(self, tensor: Optional[Tensor], tensor_life: int, index: int) -> None:
self.put_tensor(tensor, tensor_life) self.put_tensor(tensor, tensor_life)
self.grad: Optional[Tensor] = None self.grad: Optional[Tensor] = None
self.__index = index
self.ns_name: Optional[Tuple[Namespace, str]]
self.pipeline: Any
@property
def index(self) -> int:
return self.__index
def blue(self) -> Tensor: def blue(self) -> Tensor:
"""Creates a :class:`PortalBlue` which hides the underlying tensor from """Creates a :class:`PortalBlue` which hides the underlying tensor from
...@@ -151,12 +159,17 @@ class Portal: ...@@ -151,12 +159,17 @@ class Portal:
def put_grad(self, grad: Tensor) -> None: def put_grad(self, grad: Tensor) -> None:
"""Stores a gradient into this portal.""" """Stores a gradient into this portal."""
if hasattr(self, "pipeline"):
self.pipeline.send_portal_grad(self.ns_name, self.index, grad)
self.grad = grad self.grad = grad
def use_grad(self) -> Tensor: def use_grad(self) -> Tensor:
"""Retrieves and removes the underlying gradient. The gradient is """Retrieves and removes the underlying gradient. The gradient is
always ephemeral. always ephemeral.
""" """
if self.grad is None and hasattr(self, "pipeline"):
self.grad = self.pipeline.recv_portal_grad(self.ns_name, self.index)
if self.grad is None: if self.grad is None:
raise RuntimeError("grad in portal has been removed or never set") raise RuntimeError("grad in portal has been removed or never set")
......
...@@ -204,7 +204,7 @@ class Skippable(nn.Module): ...@@ -204,7 +204,7 @@ class Skippable(nn.Module):
# Load skip tensors that might be popped. # Load skip tensors that might be popped.
poppable_tensors = {} poppable_tensors = {}
batch = Batch(input) batch = Batch(input, skip_tracker.index)
for ns, name in self.poppable(): for ns, name in self.poppable():
try: try:
poppable_tensors[name] = skip_tracker.load(batch, ns, name) poppable_tensors[name] = skip_tracker.load(batch, ns, name)
...@@ -237,7 +237,7 @@ class Skippable(nn.Module): ...@@ -237,7 +237,7 @@ class Skippable(nn.Module):
raise RuntimeError(f"{comma_names} must be popped but have not") raise RuntimeError(f"{comma_names} must be popped but have not")
# Save stashed skip tensors. # Save stashed skip tensors.
batch = Batch(output) batch = Batch(output, skip_tracker.index)
for ns, name in self.stashable(): for ns, name in self.stashable():
tensor = stashed_tensors[name] tensor = stashed_tensors[name]
skip_tracker.save(batch, ns, name, tensor) skip_tracker.save(batch, ns, name, tensor)
......
...@@ -61,6 +61,10 @@ class SkipTracker: ...@@ -61,6 +61,10 @@ class SkipTracker:
) -> None: ) -> None:
raise TypeError("copy is not supported for non-portal skip tensors") raise TypeError("copy is not supported for non-portal skip tensors")
@property
def index(self) -> int:
return 0
class SkipTrackerThroughPotals(SkipTracker): class SkipTrackerThroughPotals(SkipTracker):
"""Tracks saved skip tensors through portals. The skip tensors will be """Tracks saved skip tensors through portals. The skip tensors will be
...@@ -71,10 +75,15 @@ class SkipTrackerThroughPotals(SkipTracker): ...@@ -71,10 +75,15 @@ class SkipTrackerThroughPotals(SkipTracker):
""" """
def __init__(self, skip_layout: SkipLayout) -> None: def __init__(self, skip_layout: SkipLayout, index: int) -> None:
super().__init__() super().__init__()
self.skip_layout = skip_layout self.skip_layout = skip_layout
self.portals: Dict[Tuple[Namespace, str], Portal] = {} self.portals: Dict[Tuple[Namespace, str], Portal] = {}
self.__index = index
@property
def index(self) -> int:
return self.__index
def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None: def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None:
"""Saves the stashed skip tensor in a portal. The portal is then """Saves the stashed skip tensor in a portal. The portal is then
...@@ -106,7 +115,9 @@ class SkipTrackerThroughPotals(SkipTracker): ...@@ -106,7 +115,9 @@ class SkipTrackerThroughPotals(SkipTracker):
else: else:
tensor_life = 2 # Delete at [6. PortalOrange.forward] tensor_life = 2 # Delete at [6. PortalOrange.forward]
portal = Portal(tensor, tensor_life) assert batch.index == self.index
portal = Portal(tensor, tensor_life, batch.index)
portal.ns_name = (ns, name)
self.portals[(ns, name)] = portal self.portals[(ns, name)] = portal
else: else:
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
CPU device. CPU device.
""" """
from contextlib import contextmanager from contextlib import contextmanager
from typing import Generator, List, Union, cast from typing import Generator, List, Optional, Union, cast
import torch import torch
...@@ -72,8 +72,12 @@ def use_device(device: torch.device) -> Generator[None, None, None]: ...@@ -72,8 +72,12 @@ def use_device(device: torch.device) -> Generator[None, None, None]:
@contextmanager @contextmanager
def use_stream(stream: AbstractStream) -> Generator[None, None, None]: def use_stream(stream: Optional[AbstractStream]) -> Generator[None, None, None]:
""":func:`torch.cuda.stream` for either CPU or CUDA stream.""" """:func:`torch.cuda.stream` for either CPU or CUDA stream."""
if not stream:
yield
return
if not is_cuda(stream): if not is_cuda(stream):
yield yield
return return
...@@ -120,7 +124,7 @@ def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None: ...@@ -120,7 +124,7 @@ def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
tensor.record_stream(as_cuda(stream)) tensor.record_stream(as_cuda(stream))
def is_cuda(stream: AbstractStream) -> bool: def is_cuda(stream: Optional[AbstractStream]) -> bool:
"""Returns ``True`` if the given stream is a valid CUDA stream.""" """Returns ``True`` if the given stream is a valid CUDA stream."""
return stream is not CPUStream return stream is not CPUStream
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment