Unverified Commit 5d4f50fb authored by Tom Birch's avatar Tom Birch Committed by GitHub
Browse files

Single-process control via PipeRPCWrapper (#156)

Adds support for:
* Reused layers (e.g. for weight sharing)
* Lazily-constructed layers
* Single-process control via PipeRPCWrapper
* PipelineStyle.AsyncScheudle, which lays the foundation for asynchronous pipeline work by introducing an event loop for each rank/worker to process either activations or gradients as they arrive

Also added examples for multi-process and PipeRPCWrapper
parent 543d5693
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import argparse
import logging
import math
import os
import time
......@@ -11,14 +12,17 @@ import torch
from torch.distributed import rpc
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
import torchtext
from torchtext.data.utils import get_tokenizer
from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.pipe import pipe
from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule, pipe
from fairscale.optim import GradScaler
from fairscale.optim.oss import OSS
from tests.nn.model_parallel.commons import dist_init, get_worker_map
try:
......@@ -164,13 +168,13 @@ def make_model(args, device, ntokens):
if args.lazy_construction:
layers = [
lambda: EmbeddingLayer(ntokens, ninp, initrange),
lambda: PositionalEncodingLayer(ninp, dropout),
LazyModule(lambda: EmbeddingLayer(ntokens, ninp, initrange)),
LazyModule(lambda: PositionalEncodingLayer(ninp, dropout)),
]
for _ in range(ndecoder):
layers.append(lambda: TransformerDecoderLayer(ninp, nhead, nhid, dropout))
layers.append(LazyModule(lambda: TransformerDecoderLayer(ninp, nhead, nhid, dropout)))
layers.append(lambda: LinearLayer(ninp, ntokens, initrange))
layers.append(LazyModule(lambda: LinearLayer(ninp, ntokens, initrange)))
model = layers
else:
model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
......@@ -179,7 +183,10 @@ def make_model(args, device, ntokens):
lr = 0.01 # learning rate
def make_adam(model):
return Adam(model.parameters(), lr=lr)
if args.ddp_zero:
return OSS(params=model.parameters(), optim=Adam, group=get_data_parallel_group(), lr=lr)
else:
return Adam(model.parameters(), lr=lr)
optimizer = make_adam
scaler = GradScaler()
......@@ -276,9 +283,17 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
if model.group:
print(f"training model, #prams = {num_params}, group: {model.group.rank()}, sizes {model.group.size()}")
total = torch.Tensor([num_params]).cuda()
torch.distributed.all_reduce(total, group=model.group)
logging.info(
f"training model, #prams = {num_params}, group: {model.group.rank()}, grank:"
f" {torch.distributed.get_rank()}, sizes {model.group.size()}"
)
torch.distributed.barrier()
if model.group.rank() == 0:
logging.info(f"total #prams = {total.item()}")
else:
print(f"training model, #prams = {num_params}")
logging.info(f"training model, #prams = {num_params}")
vocab_size = 10000 # FIXME
total_loss = 0.0
start_time = time.time()
......@@ -287,37 +302,81 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
optimizer = optimizer(model)
def get_first_device(model):
if isinstance(model, DDP):
model = model.module
if model.devices:
return model.devices[0]
else:
return torch.cuda.current_device()
def get_last_device(model):
if isinstance(model, DDP):
model = model.module
if model.devices:
return model.devices[-1]
else:
return torch.cuda.current_device()
pipe_group = model.group
if args.ddp_zero:
model = DDP(
model,
device_ids=[torch.cuda.current_device()],
process_group=get_data_parallel_group(),
find_unused_parameters=False,
)
if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
thing = {"input": torch.zeros(args.batch_size)}
class FakeDataset:
def __getitem__(self, index):
return thing
def __len__(self):
return len(lm_dataloader)
lm_dataloader = FakeDataset()
for i, batch in enumerate(lm_dataloader):
bi = batch["input"]
if args.max_batch and i > args.max_batch:
break
optimizer.zero_grad()
output = model(batch["input"].to(get_first_device(model)))
if model.group is None or model.group.rank() == model.group.size() - 1:
try:
if (pipe_group is None or pipe_group.rank() == 0) and not args.ddp_zero:
tmp = batch["input"].to(get_first_device(model))
output = model(tmp)
else:
output = model(batch["input"])
except Exception as e:
raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e
if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
target = batch["target"].to(get_last_device(model))
output = output.to(target.device)
loss = criterion(output.view(-1, vocab_size), target.view(-1))
if args.ddp_zero:
ddp_group = get_data_parallel_group()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM, group=ddp_group)
loss /= ddp_group.size()
loss.backward()
del target
else:
model.back_helper(output)
if args.ddp_zero:
model.module.back_helper(output)
else:
model.back_helper(output)
del output
torch.nn.utils.clip_grad_value_(model.parameters(), 0.05)
optimizer.step()
if model.group is None or model.group.rank() == model.group.size() - 1:
if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
total_loss += loss.item()
log_interval = 1
word_counter += batch["ntokens"]
......@@ -406,6 +465,17 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
print("No regression detected")
def generate_balance_weighted(num_devices, num_layers, fraction=0.5):
balance = []
layers_assigned = 0
average_count = num_layers / num_devices
last_layers = int(average_count * fraction)
balance = generate_balance(num_devices - 1, num_layers - last_layers)
balance.append(last_layers)
return balance
def generate_balance(num_devices, num_layers):
balance = []
layers_assigned = 0
......@@ -460,7 +530,7 @@ def bench_single_process(args):
blob = make_model_and_data(args, None, new_data=new_data)
model = blob["model"]
balance = generate_balance(min(num_devices, 8), len(model))
balance = generate_balance(min(num_devices, 4), len(model))
p = pipe.Pipe(
model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint
)
......@@ -480,16 +550,17 @@ def run_mp_worker(args, available_workers):
blob = make_model_and_data(args, None, new_data=new_data)
model = blob["model"]
balance = generate_balance(min(available_workers, 8), len(model))
balance = generate_balance_weighted(get_pipeline_parallel_group().size(), len(model), 0.8)
p = pipe.Pipe(
model,
balance,
style=Pipe.MultiProcess,
style=Pipe.AsyncSchedule,
chunks=args.chunks,
worker_map=get_worker_map(),
input_device=torch.cuda.current_device(),
pipelined_backward=args.pipelined_backward,
checkpoint=args.checkpoint,
# loss_fn=blob["criterion"],
).cuda()
if args.all_at_once and p.pipeline:
......@@ -537,18 +608,24 @@ best_device_map = {
def bench_mpi(args):
guess_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
os.environ["UCX_NET_DEVICES"] = best_device_map[guess_rank]
world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
os.environ["UCX_NET_DEVICES"] = best_device_map[local_rank]
torch.distributed.init_process_group(backend="mpi")
os.environ["MASTER_ADDR"] = args.host
os.environ["MASTER_PORT"] = "10639"
os.environ["MASTER_PORT"] = "10638"
if args.socket_name:
os.environ["GLOO_SOCKET_IFNAME"] = args.socket_name
os.environ["TP_SOCKET_IFNAME"] = args.socket_name
torch.distributed.init_process_group(backend="gloo", rank=guess_rank, world_size=world_size)
os.environ["MASTER_ADDR"] = args.host
os.environ["MASTER_PORT"] = "10639"
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(rank % torch.cuda.device_count())
torch.cuda.set_device(local_rank % torch.cuda.device_count())
rpc.init_rpc(
f"Test{rank}",
......@@ -558,7 +635,12 @@ def bench_mpi(args):
rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(rpc_timeout=20, init_method=init_method),
)
initialize_model_parallel(1, world_size)
backends = {"model_parallel_backend": "nccl", "pipeline_backend": "mpi", "ddp_backend": "nccl"}
if args.ddp_zero:
initialize_model_parallel(1, 4, **backends)
else:
initialize_model_parallel(1, world_size, **backends)
init_random_seed(0)
run_mp_worker(args, world_size)
......@@ -579,6 +661,7 @@ parser.add_argument("--all-at-once", action="store_true", default=False, help="d
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
parser.add_argument("--socket-name", type=str, default=None, help="socket ifname for gloo/tp")
parser.add_argument("--num-decoder-layers", type=int, default=10, help="Number of decoder layers in the model")
parser.add_argument("--ddp-zero", action="store_true", default=False, help="enable ddp")
parser.add_argument(
"--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
)
......
......@@ -12,7 +12,10 @@ BUILDDIR = build
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
setup:
pip install -r requirements.txt
.PHONY: help Makefile setup
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
......
import os
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import fairscale
from fairscale.nn.model_parallel import initialize_model_parallel
def run(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "10638"
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)
os.environ["MASTER_PORT"] = "10639"
torch.distributed.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size)
initialize_model_parallel(1, world_size)
model = nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 5))
target = torch.randint(0, 2, size=(20, 1)).squeeze()
data = torch.randn(20, 10)
loss_fn = F.nll_loss
device = torch.device("cuda", rank)
model = fairscale.nn.Pipe(
model,
balance=[2, 1],
style=fairscale.nn.Pipe.MultiProcess,
worker_map={0: "worker0", 1: "worker1"}, # Needed to convert ranks to RPC worker names
input_device=device,
).to(device)
# define optimizer and loss function
optimizer = optim.SGD(model.parameters(), lr=0.001)
# zero the parameter gradients
optimizer.zero_grad()
# outputs and target need to be on the same device
# forward step
outputs = model(data.to(device))
# compute loss
if rank == 1:
loss = loss_fn(outputs.to(device), target.to(device))
# backward + optimize
loss.backward()
optimizer.step()
else:
model.back_helper(outputs)
print(f"Finished Training Step on {rank}")
del model
if __name__ == "__main__":
world_size = 2
mp.spawn(run, args=(world_size,), nprocs=world_size, join=True)
# run with:
# mpirun -np 2 --host localhost:2 -x PYTHONPATH=$PWD python # examples/tutorial_pipe_rpc.py
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_pg
import fairscale
from fairscale.nn.model_parallel import initialize_model_parallel
def register_optimizer(ctx, model):
# Set the optimizer as an attribute on the model so we can access it later
model.optimizer = optim.SGD(model.parameters(), **ctx)
# zero the parameter gradients
model.optimizer.zero_grad()
def run_optimizer(ctx, model):
model.optimizer.step()
def run(rank, world_size):
torch_pg.init_mpi()
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "10638"
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)
os.environ["MASTER_PORT"] = "10639"
torch.distributed.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size)
initialize_model_parallel(1, world_size, pipeline_backend="mpi")
if rank == 1:
# For RPC, all ranks other than 0 just need to call rpc.shutdown()
torch.distributed.rpc.shutdown()
return
model = nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 5))
target = torch.randint(0, 2, size=(20, 1)).squeeze()
data = torch.randn(20, 10)
loss_fn = F.nll_loss
device = torch.device("cuda", rank)
model = fairscale.nn.PipeRPCWrapper(
model,
balance=[2, 1],
worker_map={0: "worker0", 1: "worker1"}, # Needed to convert ranks to RPC worker names
input_device=device,
).to(device)
# We can't directly access the model on each worker, so we need to call
# foreach_worker with a callback to setup the optimizer
model.foreach_worker(register_optimizer, {"lr": 0.001}, include_self=True)
outputs = model(data.to(device))
loss = loss_fn(outputs.to(device), target.to(device))
loss.backward()
# Same as earlier, use foreach_worker to step the optimizer on each rank
model.foreach_worker(run_optimizer, include_self=True)
print(f"Finished Training Step on {rank}")
torch.distributed.rpc.shutdown()
del model
if __name__ == "__main__":
rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
run(rank, world_size)
......@@ -4,6 +4,6 @@
# LICENSE file in the root directory of this source tree.
from .moe import MOELayer, Top2Gate
from .pipe import Pipe
from .pipe import LazyModule, Pipe, PipeRPCWrapper
__all__ = ["Pipe", "Top2Gate"]
__all__ = ["Pipe", "PipeRPCWrapper", "Top2Gate", "LazyModule"]
......@@ -22,7 +22,7 @@
"""Model and data parallel groups."""
from typing import List
from typing import List, Optional
import torch
......@@ -38,7 +38,14 @@ _PIPELINE_PARALLEL_GROUP = None
_PIPELINE_PARALLEL_RANKS = None
def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = 1) -> None:
def initialize_model_parallel(
model_parallel_size_: int,
pipeline_length: int = 1,
*,
model_parallel_backend: Optional[str] = None,
pipeline_backend: Optional[str] = None,
ddp_backend: Optional[str] = None
) -> None:
"""
Initialize model data parallel groups.
......@@ -57,8 +64,6 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int =
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
if torch.distributed.get_rank() == 0:
print("> initializing model parallel with size {}".format(model_parallel_size_))
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
......@@ -69,6 +74,11 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int =
data_parallel_size = int(world_size / (model_parallel_size * pipeline_length))
if torch.distributed.get_rank() == 0:
print("> initializing model parallel with size {}".format(model_parallel_size_))
print("> initializing ddp with size {}".format(data_parallel_size))
print("> initializing pipeline with size {}".format(pipeline_length))
groups = torch.LongTensor(range(world_size)).reshape(data_parallel_size, pipeline_length, model_parallel_size)
found = torch.where(groups == rank)
......@@ -80,7 +90,7 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int =
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
for j in range(pipeline_length):
for k in range(model_parallel_size):
group = torch.distributed.new_group(groups[:, j, k].tolist())
group = torch.distributed.new_group(groups[:, j, k].tolist(), backend=ddp_backend)
if j == found[1] and k == found[2]:
_DATA_PARALLEL_GROUP = group
......@@ -89,7 +99,7 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int =
assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
for i in range(data_parallel_size):
for j in range(pipeline_length):
group = torch.distributed.new_group(groups[i, j, :].tolist())
group = torch.distributed.new_group(groups[i, j, :].tolist(), backend=model_parallel_backend)
if i == found[0] and j == found[1]:
_MODEL_PARALLEL_GROUP = group
......@@ -100,7 +110,7 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int =
for i in range(data_parallel_size):
for k in range(model_parallel_size):
ranks = groups[i, :, k].tolist()
group = torch.distributed.new_group(ranks)
group = torch.distributed.new_group(ranks, backend=pipeline_backend)
if i == found[0] and k == found[2]:
_PIPELINE_PARALLEL_GROUP = group
_PIPELINE_PARALLEL_RANKS = ranks
......
......@@ -39,7 +39,6 @@ def _reduce(ctx: Any, input_: torch.Tensor) -> torch.Tensor:
return input_
# All-reduce.
print(f"doing all_reduce on {torch.distributed.get_rank()}")
torch.distributed.all_reduce(input_, group=group)
return input_
......@@ -93,12 +92,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _CopyToModelParallelRegion Forward")
return input_
@staticmethod
def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _CopyToModelParallelRegion Backward")
return _reduce(None, grad_output)
......@@ -107,12 +104,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _ReduceFromModelParallelRegion Forward")
return _reduce(ctx, input_)
@staticmethod
def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _ReduceFromModelParallelRegion Backward")
return grad_output
......@@ -121,12 +116,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _ScatterToModelParallelRegion Forward")
return _split(input_)
@staticmethod
def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _ScatterToModelParallelRegion Backward")
return _gather(grad_output)
......@@ -135,12 +128,10 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _GatherFromModelParallelRegion Forward")
return _gather(input_)
@staticmethod
def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _GatherFromModelParallelRegion Backward")
return _split(grad_output)
......
......@@ -182,11 +182,12 @@ def model_parallel_cuda_manual_seed(seed: int) -> None:
),
flush=True,
)
_CUDA_RNG_STATE_TRACKER.reset()
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed)
if torch.cuda.is_available():
_CUDA_RNG_STATE_TRACKER.reset()
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed)
class CheckpointFunction(torch.autograd.Function):
......
......@@ -19,6 +19,7 @@
"""A Pipe implementation in PyTorch."""
from .checkpoint import is_checkpointing, is_recomputing
from .pipe import Pipe
from .pipe import LazyModule, Pipe
from .rpc import PipeRPCWrapper
__all__ = ["Pipe", "is_checkpointing", "is_recomputing"]
__all__ = ["Pipe", "is_checkpointing", "is_recomputing", "LazyModule"]
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
from enum import Enum, auto
from threading import Event
from typing import Dict, Iterable, List, Optional, Tuple
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from torch.distributed import ProcessGroup
from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from .messages import Transport
from .microbatch import Batch
from .skip.tracker import SkipTrackerThroughPotals
from .types import EVENT_LOOP_QUEUE, PipelineStyle, PipeMessage, Tensors
@dataclass(frozen=True)
class Location:
stage: int
index: int
def __repr__(self) -> str:
return f"{self.stage}@{self.index}"
@dataclass(frozen=True)
class Invocation:
order: int
this: Location
source: Optional[Location]
dest: Optional[Location]
Activations = Dict[int, Dict[int, Dict[int, Batch]]]
Invocations = Dict[int, Invocation]
@dataclass(frozen=True)
class TailBackwardContext:
activations: Activations
invocations: Invocations
count_per_order: Dict[int, int]
expected_gradients: int
class ModuleWrapper:
def __init__(self, module: nn.Sequential, location: Location, invocations: Optional[List[Invocation]] = None):
self.module: nn.Sequential = module
self.location: Location = location
self.invocations: List[Invocation] = invocations or []
def __repr__(self) -> str:
return f"{self.location}:\n" + "\n".join(map(str, self.invocations)) + "\n\t" + str(self.module)
def __len__(self) -> int:
return len(self.module)
def __iter__(self) -> Iterable:
yield from self.module
class AsyncMessageType(Enum):
Activations = auto()
Gradients = auto()
@dataclass(frozen=True)
class AsyncMessageBody:
message_type: AsyncMessageType
microbatch_index: int
source: Location
dest: Location
order: int
class AutogradWithoutActivations(torch.autograd.Function):
"""A helper class to add another edge in the autograd graph which allows us
to delete the potentially large activations and still perform a backward
pass. Returns return a phony tensor which is connected to the graph."""
@staticmethod
# type: ignore
def forward(ctx, *x):
return torch.tensor(1.0)
@staticmethod
# type: ignore
def backward(ctx, grad):
assert ctx.grad_from_pipeline is not None
return ctx.grad_from_pipeline
class AsyncRecvOperator(torch.autograd.Function):
"""Receive activations to the previous pipeline stage"""
@staticmethod
# type: ignore
def forward(ctx, phony: Tensor, transport: Transport, message: PipeMessage) -> Tensors:
ctx.transport = transport
ctx.index = message.args.microbatch_index
result = transport.recv_message_tensors(message)
ctx.args = result.args
def maybe_requires_grad(t: Tensor) -> Tensor:
if t.dtype.is_floating_point:
return t.requires_grad_()
return t
return tuple(maybe_requires_grad(r) for r in result.tensors)
@staticmethod
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
body = AsyncMessageBody(
AsyncMessageType.Gradients, ctx.index, source=ctx.args.dest, dest=ctx.args.source, order=ctx.args.order - 1
)
ctx.transport.send_message(
PipeMessage(
this_rank, ranks[ctx.args.source.stage], queue_name=EVENT_LOOP_QUEUE, args=body, tensors=tuple(grad),
),
sync=True,
)
tail_ctx = getattr(ctx, "tail_ctx", None)
if tail_ctx:
expected_gradients = tail_ctx.expected_gradients
while expected_gradients > 0:
message = ctx.transport.recv_message_header(EVENT_LOOP_QUEUE)
args: AsyncMessageBody = message.args
assert args.message_type is AsyncMessageType.Gradients
invocation = tail_ctx.invocations[args.order]
expected_gradients -= tail_ctx.count_per_order[invocation.order]
AsyncEventLoop.perform_backward_for_invocation(ctx.transport, message, tail_ctx.activations, invocation)
return (None, None, None, None, None)
class AsyncEventLoop:
def __init__(
self,
partitions: List[ModuleWrapper],
group: ProcessGroup,
transport: Transport,
training: bool,
checkpoint_stop: int,
):
self.training = training
self.checkpoint_stop = checkpoint_stop
self.transport = transport
self.group = group
self.partitions: List[ModuleWrapper] = partitions
def send_async_message(self, dst_rank: int, result: Batch, invocation: Invocation) -> Batch:
"""Send batch to dst_rank, and use AutogradWithoutActivations to delete
the activations since we no longer need them"""
assert invocation.dest
src_rank = torch.distributed.get_rank()
body = AsyncMessageBody(
AsyncMessageType.Activations, result.index, invocation.this, invocation.dest, invocation.order + 1
)
self.transport.send_message(
PipeMessage(src_rank, dst_rank, queue_name=EVENT_LOOP_QUEUE, args=body, tensors=tuple([*result])),
sync=True,
)
phony = AutogradWithoutActivations.apply(*result)
return Batch(phony, result.index)
def run_invocation(
self,
batch: Batch,
partition: ModuleWrapper,
skip_trackers: List[SkipTrackerThroughPotals],
invocation: Invocation,
) -> Batch:
"""Actually run the forward pass for a given module, and send the result
to the next stage in the pipeline if needed."""
assert self.group
from .pipeline import create_task
task = create_task(
PipelineStyle.AsyncSchedule,
self.checkpoint_stop,
batch.index,
self.group.rank(),
batch,
partition.module,
skip_trackers,
[],
)
result = task.compute()
task.finalize(result)
if invocation.dest and invocation.dest.stage != invocation.this.stage:
ranks = get_pipeline_parallel_ranks()
dst_rank = ranks[invocation.dest.stage]
result = self.send_async_message(dst_rank, result, invocation)
return result
@staticmethod
def perform_backward_for_invocation(
transport: Transport, message: PipeMessage, activations: Activations, invocation: Invocation
) -> None:
"""Perform the backward pass by looking up the appropriate `Batch` and
then calling `backward` on the tensor"""
recvd_grads = transport.recv_message_tensors(message)
batch: Batch = activations[invocation.this.index][invocation.order][message.args.microbatch_index]
# All batches saved in `activations` are generated by AutogradWithoutActivations,
# so we store the gradients in `grad_from_pipeline` so it will be used
# during the backward pass
batch.tensor.grad_fn.grad_from_pipeline = tuple(recvd_grads.tensors) # type: ignore
batch.tensor.backward(retain_graph=True)
def run_invocations_on_batch(
self,
batch: Batch,
invocations: Invocations,
order: int,
skip_trackers: List[SkipTrackerThroughPotals],
activations: Activations,
) -> Tuple[int, int]:
"""Run invocations on the batch until we hit one that receives its input
from a different stage (i.e. another process)"""
invocations_handled = 0
last_order = 0
for invocation in invocations.values():
if invocation.order < order:
continue
pi = invocation.this.index
partition = self.partitions[pi]
if invocation.order == order:
invocations_handled += 1
last_order = invocation.order
activations[pi][invocation.order][batch.index] = self.run_invocation(
batch, partition, skip_trackers, invocation
)
elif invocation.source and invocation.source.stage == self.group.rank():
invocations_handled += 1
last_order = invocation.order
batch = activations[invocation.source.index][invocation.order - 1][batch.index]
activations[pi][invocation.order][batch.index] = self.run_invocation(
batch, partition, skip_trackers, invocation
)
del activations[invocation.source.index][invocation.order - 1][batch.index]
elif invocation.source and invocation.source.stage != self.group.rank():
break
return (invocations_handled, last_order)
def event_loop_head(
self, batches: List[Batch], skip_trackers: List[SkipTrackerThroughPotals], event: Optional[Event]
) -> None:
"""The event loop for the "head", which first performs the forward pass
on any applicable layers for this stage, and then enters the common
`event_loop_inner`"""
invocations, activations = self.get_invocations_and_activations()
expected_invocations = len(invocations) * len(batches)
actual_invocations = 0
count_per_order = dict()
for batch in batches:
inv_count, last_order = self.run_invocations_on_batch(batch, invocations, 0, skip_trackers, activations)
actual_invocations += inv_count
count_per_order[last_order] = inv_count
if actual_invocations < expected_invocations or self.training:
self.event_loop_inner(
expected_invocations,
skip_trackers,
activations,
invocations,
count_per_order,
already_received=actual_invocations,
event=event,
)
def get_batch_from_message(self, message: PipeMessage) -> Batch:
"""Get the tensor(s) wrapped in a `Batch` from a `PipeMessage`, applying
AsyncRecvOperator so we can intercept the backward pass"""
microbatch_index = message.args.microbatch_index
phony = torch.empty(0, device=self.transport.input_device, requires_grad=True)
result = AsyncRecvOperator.apply(phony, self.transport, message)
if len(result) == 1:
batch = Batch(result[0], microbatch_index)
else:
batch = Batch(result, microbatch_index)
return batch
def event_loop_tail(self, batches: List[Batch], skip_trackers: List[SkipTrackerThroughPotals]) -> None:
"""The event loop for the "tail", or final stage which only processes
activations and then returns to the caller so that the loss can be
calculated. This also handles the first/only stage for the special
case of a 1-stage pipeline."""
assert self.group
invocations, activations = self.get_invocations_and_activations()
expected_invocations = len(invocations) * len(batches)
actual_invocations = 0
rank = self.group.rank()
count_per_order = dict()
for batch in batches:
if rank == 0:
order = 0
else:
message = self.transport.recv_message_header(EVENT_LOOP_QUEUE)
args: AsyncMessageBody = message.args
batch = self.get_batch_from_message(message)
order = args.order
inv_count, last_order = self.run_invocations_on_batch(batch, invocations, order, skip_trackers, activations)
actual_invocations += inv_count
count_per_order[last_order] = inv_count
if invocations[last_order].dest is None:
self.prepare_tail_backward(
batch, activations, invocations, count_per_order, len(invocations) - inv_count
)
if actual_invocations < expected_invocations:
expected_gradients = 0 # (len(invocations) - 1) * len(batches)
self.event_loop_inner(
expected_invocations,
skip_trackers,
activations,
invocations,
count_per_order,
already_received=actual_invocations,
ignore_gradients=True,
tail=True,
)
_, last_invocation = invocations.popitem()
for index, batch in activations[len(self.partitions) - 1][last_invocation.order].items():
batches[index] = batch
def get_invocations_and_activations(self) -> Tuple[Invocations, Activations]:
activations: Activations = dict()
invocations: Invocations = OrderedDict()
for pi, partition in enumerate(self.partitions):
activations[pi] = dict()
for invocation in partition.invocations:
activations[pi][invocation.order] = dict()
invocations[invocation.order] = invocation
invocations = OrderedDict(sorted(invocations.items(), key=lambda entry: entry[0]))
return (invocations, activations)
def event_loop(self, num_microbatch: int, skip_trackers: List[SkipTrackerThroughPotals]) -> None:
"""The event loop for the "middle", i.e. neither the head nor the tail"""
assert self.group
invocations, activations = self.get_invocations_and_activations()
expected_invocations = len(invocations) * num_microbatch
self.event_loop_inner(expected_invocations, skip_trackers, activations, invocations, dict())
def event_loop_inner(
self,
expected_invocations: int,
skip_trackers: List[SkipTrackerThroughPotals],
activations: Activations,
invocations: Invocations,
count_per_order: Dict[int, int],
*,
already_received: int = 0,
ignore_gradients: bool = False,
event: Optional[Event] = None,
tail: bool = False,
) -> None:
"""The common event loop shared by all stages. This processses
activations for the forward pass, and if `self.training` is true,
processes gradients for the backward pass."""
num_activations = already_received
if self.training and not ignore_gradients:
num_gradients = 0
else:
num_gradients = expected_invocations
while num_activations < expected_invocations or num_gradients < expected_invocations:
if num_activations == expected_invocations and num_gradients == 0 and event is not None:
# We are ready to do the backward pass, but must wait for
# PipeRPCWrapper to signal that it is safe to proceed, otherwise
# deadlock
event.wait()
message = self.transport.recv_message_header(EVENT_LOOP_QUEUE)
args: AsyncMessageBody = message.args
invocation = invocations[args.order]
# FIXME(tom) for combining pipeline with megatron, I currently don't
# control the order of received activations or gradients, so it is
# possible for a reused ColumnParallelLinear for example to receive
# a different order of activations w.r.t. the sending stage, which
# would result in incorrect values being used for the all_gather
if args.message_type is AsyncMessageType.Activations:
batch = self.get_batch_from_message(message)
inv_count, last_order = self.run_invocations_on_batch(
batch, invocations, args.order, skip_trackers, activations
)
count_per_order[last_order] = inv_count
num_activations += inv_count
if tail and invocations[last_order].dest is None:
self.prepare_tail_backward(
batch, activations, invocations, count_per_order, len(invocations) - inv_count
)
assert num_activations <= expected_invocations
elif args.message_type is AsyncMessageType.Gradients:
num_gradients += count_per_order[invocation.order]
self.perform_backward_for_invocation(self.transport, message, activations, invocation)
@staticmethod
def prepare_tail_backward(
batch: Batch,
activations: Activations,
invocations: Invocations,
count_per_order: Dict[int, int],
expected_gradients: int,
) -> None:
if expected_gradients > 0:
grad_fn = next(b.grad_fn for b in batch if b.requires_grad)
assert grad_fn
grad_fn.tail_ctx = TailBackwardContext(activations, invocations, count_per_order, expected_gradients)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC
from queue import Empty as QueueEmpty
from queue import Queue
from typing import Dict, List, Optional
from dataclasses import dataclass
import torch
from fairscale.nn.model_parallel import get_pipeline_parallel_group
from fairscale.utils.object import pyobject_to_tensor, tensor_to_pyobject
from .types import MESSAGE_GENERATION_START, InputDevice, PipeMessage, Tensors
MESSAGE_TENSOR_SIZE = 1024
MessageQueues: List[Queue] = [Queue() for _ in range(MESSAGE_GENERATION_START)]
def to_input_device(tensors: Tensors, input_device: InputDevice) -> Tensors:
if input_device is None:
return tensors
else:
return tuple(t.to(input_device) for t in tensors)
def rpc_push_queue(message: PipeMessage) -> None:
globals()["MessageQueues"][message.queue_name].put(message)
@dataclass(frozen=True)
class Transport(ABC):
worker_map: Optional[Dict[int, str]]
input_device: InputDevice
def recv_message(self, queue_name: int, *, nowait: bool = False) -> PipeMessage:
message = self.recv_message_header(queue_name, nowait)
return self.recv_message_tensors(message)
def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage:
...
def recv_message_tensors(self, message: PipeMessage) -> PipeMessage:
...
def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None:
...
def get_out_of_order(self, queue_name: int, index: int) -> Tensors:
...
def MakeTransport(use_rpc: bool, worker_map: Optional[Dict[int, str]], input_device: InputDevice) -> Transport:
if use_rpc:
if worker_map is None:
raise ValueError("'RpcTransport' requires 'worker_map' to be set")
return RpcTransport(worker_map, input_device)
else:
return SendRecvTransport(worker_map, input_device)
class RpcTransport(Transport):
def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None:
message.tensors = tuple(t.cpu() for t in message.tensors)
assert self.worker_map
name = self.worker_map[message.dest]
if sync:
torch.distributed.rpc.rpc_sync(name, rpc_push_queue, args=(message,))
else:
torch.distributed.rpc.rpc_async(name, rpc_push_queue, args=(message,))
def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage:
queue = MessageQueues[queue_name]
if nowait:
result = queue.get_nowait()
else:
result = queue.get()
result.tensors = to_input_device(result.tensors, self.input_device)
return result
def recv_message_tensors(self, message: PipeMessage) -> PipeMessage:
# Tensors already contained within message
message.tensors = to_input_device(message.tensors, self.input_device)
return message
def get_out_of_order(self, queue_name: int, index: int) -> Tensors:
"""Receive a message with a known microbatch index, and handle out-of-order
messages by placing them back on the queue"""
queue = globals()["MessageQueues"][queue_name]
out_of_order: List[PipeMessage] = []
while True:
message = self.recv_message(queue_name)
got_index = message.args
value = message.tensors
if got_index == index:
for b in out_of_order:
queue.put(b)
return value
else:
out_of_order.append(message)
class SendRecvTransport(Transport):
def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None:
tensors = message.tensors
message.tensors = tuple()
torch.cuda.current_stream().synchronize()
if not skip_header:
message.tensor_shapes = [t.size() for t in tensors]
message.tensor_dtypes = [t.dtype for t in tensors]
torch.distributed.send(
pyobject_to_tensor(message, MESSAGE_TENSOR_SIZE).cuda(),
message.dest,
tag=message.queue_name,
group=get_pipeline_parallel_group(),
)
for index, t in enumerate(tensors):
if t.device.type == "cpu":
t = t.cuda()
torch.distributed.send(
t.contiguous(), message.dest, tag=message.tag + index, group=get_pipeline_parallel_group()
)
def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage:
# FIXME(handle nowait)
if nowait:
raise QueueEmpty
tensor = torch.empty(MESSAGE_TENSOR_SIZE, dtype=torch.uint8, device=self.input_device)
torch.cuda.current_stream().synchronize()
torch.distributed.recv(tensor, src=None, tag=queue_name, group=get_pipeline_parallel_group())
torch.cuda.current_stream().synchronize()
return tensor_to_pyobject(tensor)
def recv_message_tensors(self, message: PipeMessage) -> PipeMessage:
torch.cuda.current_stream().synchronize()
message_tensors = []
for index, (shape, dtype) in enumerate(zip(message.tensor_shapes, message.tensor_dtypes)):
t = torch.empty(*shape, dtype=dtype, device=self.input_device)
torch.distributed.recv(t, message.src, tag=message.tag + index, group=get_pipeline_parallel_group())
message_tensors.append(t)
message.tensors = tuple(message_tensors)
torch.cuda.current_stream().synchronize()
return message
def get_out_of_order(self, queue_name: int, index: int) -> Tensors:
"""Receive a message with a known microbatch index, and handle out-of-order
messages by placing them back on the queue"""
message = self.recv_message(queue_name)
assert message.args == index
return message.tensors
......@@ -19,24 +19,29 @@
"""The Pipe interface."""
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast
import itertools
import threading
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union, cast
import warnings
from dataclasses import dataclass, field
import torch
from torch import Tensor, nn
import torch.autograd
import torch.cuda
from fairscale.nn.model_parallel import get_model_parallel_group, get_pipeline_parallel_group
from fairscale.nn.model_parallel import get_model_parallel_world_size, get_pipeline_parallel_group
from . import microbatch
from .async_schedule import Invocation, Location, ModuleWrapper
from .batchnorm import DeferredBatchNorm
from .pipeline import Pipeline, PipelineStyle
from .pipeline import Pipeline
from .skip.layout import SkipLayout, inspect_skip_layout
from .skip.skippable import Skippable, verify_skippables
from .stream import AbstractStream, new_stream
from .types import LazyModule, PipelineStyle
__all__ = ["Pipe"]
__all__ = ["Pipe", "LazyModule"]
Device = Union[torch.device, int, str]
......@@ -45,7 +50,7 @@ Devices = Union[Iterable[Device], List[Device]]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
ListOfLazyModules = List[Callable[[], nn.Module]]
ListOfLazyModules = List[LazyModule]
if TYPE_CHECKING:
Module = nn.Module[TensorOrTensors]
......@@ -79,10 +84,10 @@ def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None:
for layer in module:
if isinstance(layer, nn.Module):
pass
elif callable(layer):
elif isinstance(layer, LazyModule):
pass
else:
raise TypeError(f"layer {type(layer)} must be nn.Module or callable to be partitioned")
raise TypeError(f"layer {type(layer)} must be nn.Module or LazyModule to be partitioned")
def verify_module(module: Union[nn.Sequential, ListOfLazyModules]) -> None:
......@@ -124,8 +129,14 @@ class BalanceError(ValueError):
pass
def check_balance(module: Any, balance: Iterable[int]) -> None:
if len(module) != sum(balance):
def check_balance(module: Any, balance: Iterable[int], filter_unique: bool = False) -> None:
if filter_unique:
module_len = len(set(map(id, module)))
else:
module_len = len(module)
if module_len != sum(balance):
raise BalanceError(
f"module and sum of balance have different length (module: {len(module)}, sum of balance: {sum(balance)})"
)
......@@ -134,16 +145,27 @@ def check_balance(module: Any, balance: Iterable[int]) -> None:
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
@dataclass
class PartitionInfo:
location: Location
modules: "OrderedDict[str, nn.Module]"
invocations: List[Invocation] = field(default_factory=list)
def __len__(self) -> int:
return len(self.modules)
def instantiate_partition(
module: Union[nn.Sequential, ListOfLazyModules], balance: Iterable[int], group: torch.distributed.ProcessGroup
) -> nn.Sequential:
module: Union[nn.Sequential, ListOfLazyModules],
balance: Iterable[int],
group: torch.distributed.ProcessGroup,
style: PipelineStyle,
) -> List[ModuleWrapper]:
balance = list(balance)
check_balance(module, balance)
check_balance(module, balance, True)
layers: NamedModules = OrderedDict()
j = 0
def maybe_realize(layer: Any) -> nn.Module:
if isinstance(layer, nn.Module):
return layer
......@@ -156,7 +178,85 @@ def instantiate_partition(
if isinstance(module, nn.Sequential):
yield from module.named_children()
else:
yield from enumerate(module)
yield from ((str(k), v) for k, v in enumerate(module))
if style == PipelineStyle.AsyncSchedule:
module_ids = list(map(id, module))
index_of_first_use = [module_ids.index(x) for x in module_ids]
locations: List[Location] = []
module_iter = enumerate(iterate_module(module))
partitions: List[List[PartitionInfo]] = []
for bi, b in enumerate(balance):
modules_for_rank: List[PartitionInfo] = []
current_module: OrderedDict[str, nn.Module] = OrderedDict()
def current_location() -> Location:
return Location(bi, len(modules_for_rank))
def append_module(mod: "OrderedDict[str, nn.Module]") -> None:
modules_for_rank.append(PartitionInfo(current_location(), mod))
while sum(map(len, modules_for_rank)) + len(current_module) < b:
module_index, (name, layer) = next(module_iter)
if index_of_first_use[module_index] != module_index:
# Subsequent reuse of a module
locations.append(locations[index_of_first_use[module_index]])
continue
is_reused = index_of_first_use.count(index_of_first_use[module_index]) > 1
if is_reused and len(current_module) > 0:
append_module(current_module)
current_module = OrderedDict()
current_module[str(name)] = layer
locations.append(current_location())
if is_reused:
append_module(current_module)
current_module = OrderedDict()
if len(current_module) > 0:
append_module(current_module)
partitions.append(modules_for_rank)
filtered_locations: List[Optional[Location]] = [loc for loc, _ in itertools.groupby(locations)]
filtered_locations.append(None)
for i in range(len(filtered_locations) - 1):
loc = filtered_locations[i]
assert loc
if i == 0:
inv = Invocation(i, loc, None, filtered_locations[i + 1])
else:
inv = Invocation(i, loc, filtered_locations[i - 1], filtered_locations[i + 1])
partitions[loc.stage][loc.index].invocations.append(inv)
invocations = enumerate(iterate_module(module))
partition = partitions[group.rank()]
result: List[ModuleWrapper] = []
for partition_info in partition:
wrapper = ModuleWrapper(
nn.Sequential(OrderedDict((k, maybe_realize(m)) for k, m in partition_info.modules.items())),
partition_info.location,
partition_info.invocations,
)
if not isinstance(module, nn.Sequential):
for layer in wrapper.module:
if isinstance(layer, Skippable):
raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction")
result.append(wrapper)
return result
j = 0
for name, layer in iterate_module(module):
layers[name] = layer
......@@ -170,8 +270,7 @@ def instantiate_partition(
if isinstance(layer, Skippable):
raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction")
partition = nn.Sequential(*layers.values())
return partition
return [ModuleWrapper(nn.Sequential(layers), Location(j, 0))]
# Prepare for the next partition.
layers.clear()
......@@ -297,7 +396,7 @@ class Pipe(Module):
backward pass (instead of once for the whole batch). This works
around a potential deadlock in pytorch when using tensor parallelism
at the same time. Defaults to `True` if
`get_model_parallel_group.size() > 1`
`get_model_parallel_world_size() > 1`
(default: `None`)
retain_graph (bool):
The value passed to `torch.autograd.backwards(..., retain_graph=<value>)
......@@ -315,6 +414,7 @@ class Pipe(Module):
SingleProcess: PipelineStyle = PipelineStyle.SingleProcess
MultiProcess: PipelineStyle = PipelineStyle.MultiProcess
AsyncSchedule: PipelineStyle = PipelineStyle.AsyncSchedule
#: The number of layers in each partition.
balance: List[int] = []
......@@ -359,6 +459,7 @@ class Pipe(Module):
deferred_batch_norm: bool = False,
pipelined_backward: bool = None,
retain_graph: bool = False,
loss_fn: Optional[nn.Module] = None,
) -> None:
super().__init__()
......@@ -384,6 +485,17 @@ class Pipe(Module):
self.pipelined_backward = pipelined_backward
self.retain_graph = retain_graph
self.pipeline: Optional[Pipeline]
self.loss_fn = loss_fn
self.lock = threading.Lock()
self.group = group
self.worker_map = worker_map
self.input_device = input_device
self._copy_streams: List[List[AbstractStream]] = []
# The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
if style is PipelineStyle.SingleProcess:
module = cast(nn.Sequential, module)
......@@ -407,29 +519,42 @@ class Pipe(Module):
self._skip_layout = inspect_skip_layout(self.partitions)
elif style is PipelineStyle.MultiProcess:
if group is None:
group = get_pipeline_parallel_group()
# Separate CUDA streams for copy.
copy_streams = self._ensure_copy_streams()
if self.pipelined_backward is None:
self.pipelined_backward = False
self.pipeline = Pipeline(
self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop, style=style,
)
elif style in [PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule]:
if self.group is None:
self.group = get_pipeline_parallel_group()
assert self.group
if devices is not None:
raise ValueError("'devices' argument only applies to 'PipelineStyle.SingleProcess'")
self.balance = list(balance)
if group.size() < len(self.balance):
if self.group.size() < len(self.balance):
raise IndexError(
f"too few ranks to hold given partitions (ranks: {group.size()}, partitions: {len(self.balance)})"
f"too few ranks to hold given partitions (ranks: {self.group.size()}, partitions:"
f" {len(self.balance)})"
)
try:
rank = torch.distributed.get_rank(group)
rank = self.group.rank()
if rank >= len(self.balance):
warnings.warn("More ranks than partitions, some ranks unused")
self.partitions = cast(List[nn.Sequential], nn.ModuleList([nn.Sequential()]))
self.mp_partitions: List[ModuleWrapper] = []
else:
partition = instantiate_partition(module, balance, group)
self.mp_partitions = instantiate_partition(module, balance, self.group, style)
if deferred_batch_norm:
partition = DeferredBatchNorm.convert_deferred_batch_norm(partition, chunks)
self.partitions = cast(List[nn.Sequential], nn.ModuleList([partition]))
for part in self.mp_partitions:
part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks)
for name, part in enumerate(self.mp_partitions):
self.add_module(str(name), part.module)
self.devices = None
if isinstance(module, nn.Sequential):
local_partitions, _, _ = split_module(module, balance, None)
......@@ -440,31 +565,16 @@ class Pipe(Module):
except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc)))
self.group = group
self.worker_map = worker_map
self.input_device = input_device
self._copy_streams: List[List[AbstractStream]] = []
# The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
if style is PipelineStyle.SingleProcess:
# Separate CUDA streams for copy.
copy_streams = self._ensure_copy_streams()
if self.pipelined_backward is None:
self.pipelined_backward = False
self.pipeline = Pipeline(
self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop, style=style
)
elif style is PipelineStyle.MultiProcess:
rank = torch.distributed.get_rank(group)
rank = self.group.rank()
if rank >= len(self.balance):
self.pipeline = None
self.final_stage = False
else:
self.final_stage = rank == len(self.balance) - 1
assert loss_fn is None or self.final_stage
self.pipeline = Pipeline(
self.partitions,
cast(List[nn.Sequential], self.mp_partitions),
None,
None,
self._skip_layout,
......@@ -473,27 +583,39 @@ class Pipe(Module):
group=self.group,
worker_map=self.worker_map,
input_device=self.input_device,
final_stage=self.final_stage,
)
del module
if self.pipelined_backward is None:
if get_model_parallel_group().size() > 1:
if get_model_parallel_world_size() > 1:
self.pipelined_backward = True
else:
self.pipelined_backward = False
def __len__(self) -> int:
"""Counts the length of the underlying sequential module."""
return sum(len(p) for p in self.partitions)
if hasattr(self, "partitions"):
return sum(len(p) for p in self.partitions)
else:
return sum(len(p) for p in self.mp_partitions)
def __getitem__(self, index: int) -> nn.Module:
"""Gets a layer in the underlying sequential module."""
partitions = self.partitions
partitions: List[Any]
if hasattr(self, "partitions"):
partitions = self.partitions
else:
partitions = self.mp_partitions
if index < 0:
partitions = partitions[::-1]
for partition in partitions:
try:
return partition[index]
if isinstance(partition, ModuleWrapper):
return partition.module[index]
else:
return partition[index]
except IndexError:
pass
......@@ -508,8 +630,12 @@ class Pipe(Module):
def __iter__(self) -> Iterable[nn.Module]:
"""Iterates over children of the underlying sequential module."""
for partition in self.partitions:
yield from partition
if hasattr(self, "partitions"):
for partition in self.partitions:
yield from partition
else:
for mp_partition in self.mp_partitions:
yield from mp_partition.module
# Pipe should manage the device of each partition.
# Deny cuda(), cpu(), and to() with device, by TypeError.
......@@ -527,7 +653,7 @@ class Pipe(Module):
return super().cpu()
def to(self, *args: Any, **kwargs: Any) -> "Pipe":
""" Restrict .to() options.
"""Restrict .to() options.
Deny these usages:
- to(device[, dtype, non_blocking])
......@@ -563,7 +689,7 @@ class Pipe(Module):
return self._copy_streams
def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore
def forward(self, input: TensorOrTensors, *, event=None) -> TensorOrTensors: # type: ignore
""":class:`Pipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But
there's type restriction. Input and output have to be a
......@@ -594,25 +720,26 @@ class Pipe(Module):
batches = microbatch.scatter(input, self.chunks)
# Run pipeline parallelism.
self.pipeline.run(batches)
if self.group and not self.final_stage:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
return batches # type: ignore
else:
# Merge the micro-batches into one mini-batch.
if self.pipelined_backward:
with torch.no_grad():
output = microbatch.gather(batches)
with self.lock:
self.pipeline.run(self.training, batches, event)
if self.group and not self.final_stage:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
return batches # type: ignore
else:
# Merge the micro-batches into one mini-batch.
if self.pipelined_backward:
with torch.no_grad():
output = microbatch.gather(batches)
from .phony import get_phony
from .phony import get_phony
phony = get_phony(torch.device(torch.cuda.current_device()), requires_grad=True)
output = PipelinedBackwardPass.apply(output, batches, phony, True) # self.retain_graph)
else:
output = microbatch.gather(batches)
phony = get_phony(torch.device(torch.cuda.current_device()), requires_grad=True)
output = PipelinedBackwardPass.apply(output, batches, phony, True) # self.retain_graph)
else:
output = microbatch.gather(batches)
return output
......
This diff is collapsed.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from threading import Event, Lock, Thread
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
import torch
from torch import nn
from torch.distributed import ProcessGroup, rpc
from torch.distributed.distributed_c10d import _get_global_rank
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from . import Pipe
from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
PipeModel: Pipe
PipeResult: TensorOrTensors
SizeOrSizes = Union[torch.Size, List[torch.Size]]
DtypeOrDtypes = Union[torch.dtype, List[torch.dtype]]
def set_device_based_on_group(group: ProcessGroup) -> None:
# torch.cuda.set_device(group.rank() % torch.cuda.device_count())
torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count())
def get_shapes(tensor: TensorOrTensors) -> SizeOrSizes:
if isinstance(tensor, torch.Tensor):
return tensor.shape
else:
return [t.shape for t in tensor]
def get_dtype(tensor: TensorOrTensors) -> DtypeOrDtypes:
if isinstance(tensor, torch.Tensor):
return tensor.dtype
else:
return [t.dtype for t in tensor]
def get_global_ranks_from_group(group: ProcessGroup) -> List[int]:
return [_get_global_rank(group, r) for r in range(group.size())]
class PipeBackRedirect(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx, inputs, dest, event, message, transport, futures):
ctx.dest = dest
ctx.event = event
ctx.message = message
ctx.transport = transport
ctx.futures = futures
return inputs
@staticmethod
# type: ignore
def backward(ctx, *grad):
ctx.message.tensors = tuple(grad)
ctx.transport.send_message(ctx.message, sync=False, skip_header=True)
ctx.event.set()
# torch.futures.wait_all(ctx.futures)
return (None, None, None, None, None, None)
def callback_with_model(callback: Callable[[Any, Pipe], None], ctx: Any) -> None:
try:
group = get_pipeline_parallel_group() # FIXME(tom) handle dynamic group
set_device_based_on_group(group)
with PipeModel.lock:
callback(ctx, PipeModel)
except Exception as e:
print(f"callback_with_model got {e}")
class PipeRPCWrapper(nn.Module):
"""A wrapper for Pipe to control the entire pipeline from a single process.
Typical usecase would have rank 0 construct `PipeRPCWrapper` and run the
training loop as normal, and all other ranks would call
`torch.distributed.rpc.shutdown()`
To run code on each worker, e.g. to run the optimizer, use `foreach_worker`
"""
def __init__(self, *args: Any, **kwargs: Any):
super().__init__()
self.group = cast(ProcessGroup, kwargs.get("group")) or get_pipeline_parallel_group()
assert self.group.rank() == 0
self.lock = Lock()
if True:
assert (
self.group == get_pipeline_parallel_group()
), "Can't pickle groups, so group must be `get_pipeline_parallel_group()`"
kwargs["group"] = None
else:
kwargs["group"] = self.group
kwargs["style"] = Pipe.AsyncSchedule
kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device())
self.model = Pipe(*args, **kwargs)
self.worker_map = kwargs["worker_map"]
self._foreach_worker(self._register_remote_model, args=(args, kwargs))
self.model.cuda()
def _get_rpc_name(self, rank: int) -> str:
return self.worker_map[_get_global_rank(self.group, rank)]
def _foreach_worker(self, callback: Callable, args: Any = None) -> None:
futures = [rpc.rpc_async(self._get_rpc_name(rank), callback, args=args) for rank in range(1, self.group.size())]
futures = [f.wait() for f in futures]
def foreach_worker(
self, callback: Callable[[Any, Pipe], None], ctx: Any = None, *, include_self: bool = False
) -> None:
"""Call `callback` on each worker with the `ctx` and model local to that
worker. e.g.
def register_optimizer(ctx, model):
args, kwargs = ctx
model.optimizer = torch.optim.SGD(model.parameters(), *args, **kwargs)
pipe_model = PipeRPCWrapper( ... )
pipe_model.foreach_worker(
register_optimizer,
([], {"lr" : 0.01, "momentum" : 0.9})
)
"""
self._foreach_worker(callback_with_model, args=(callback, ctx))
if include_self:
with self.model.lock:
callback(ctx, self.model)
def forward(self, tensor: TensorOrTensors) -> TensorOrTensors: # type: ignore
shape = get_shapes(tensor)
dtype = get_dtype(tensor)
if isinstance(tensor, torch.Tensor):
num_tensors = 1
else:
num_tensors = len(tensor)
futures = [
rpc.rpc_async(self._get_rpc_name(rank), self._model_forward, args=(self.model.training, shape, dtype))
for rank in range(1, self.group.size())
]
if self.model.final_stage:
return self.model(tensor)
else:
event = Event()
t = Thread(target=self._model_forward_first_stage, args=(tensor, event))
t.start()
shape, dtype = futures.pop().wait()
dest_rank = self.group.size() - 1
dest = self._get_rpc_name(dest_rank)
dest_global_rank = _get_global_rank(self.group, dest_rank)
src_global_rank = torch.distributed.get_rank()
queue = EVENT_LOOP_QUEUE
activations = PipeMessage(dest_global_rank, src_global_rank, queue_name=queue, tensor_count=num_tensors)
grads = PipeMessage(src_global_rank, dest_global_rank, queue_name=queue, tensor_count=num_tensors)
back_fut = rpc.rpc_async(
dest, self._send_result_and_do_backwards, args=(self.model.training, activations, grads)
)
futures.append(back_fut)
result = self._recv_result(self.model, shape, dtype, activations)
if isinstance(result, torch.Tensor):
result.requires_grad_()
else:
for r in result:
r.requires_grad_()
assert self.model.pipeline
return PipeBackRedirect.apply(
result, dest_global_rank, event, grads, self.model.pipeline.transport, futures
)
@property
def final_stage(self) -> bool:
return self.model.final_stage
@staticmethod
def _recv_result(model: Pipe, shapes: SizeOrSizes, dtypes: DtypeOrDtypes, message: PipeMessage) -> TensorOrTensors:
group = get_pipeline_parallel_group()
set_device_based_on_group(group)
assert model.pipeline
transport = model.pipeline.transport
if isinstance(shapes, torch.Size):
message.tensor_shapes = [cast(torch.Size, shapes)]
message.tensor_dtypes = [cast(torch.dtype, dtypes)]
message = transport.recv_message_tensors(message)
return message.tensors[0]
else:
message.tensor_shapes = cast(List[torch.Size], shapes)
message.tensor_dtypes = cast(List[torch.dtype], dtypes)
message = transport.recv_message_tensors(message)
return message.tensors
@staticmethod
def _send_result_and_do_backwards(training: bool, message: PipeMessage, grads_message: PipeMessage) -> None:
group = get_pipeline_parallel_group()
set_device_based_on_group(group)
result = PipeResult
model = PipeModel
if isinstance(result, torch.Tensor):
result = tuple([result])
message.tensors = tuple(result)
assert model.pipeline
transport = model.pipeline.transport
transport.send_message(message, sync=False, skip_header=True)
if training:
grads_message.tensor_shapes = [r.shape for r in result]
grads_message.tensor_dtypes = [r.dtype for r in result]
grads_message = transport.recv_message_tensors(grads_message)
with model.lock:
torch.autograd.backward(result, grads_message.tensors, retain_graph=True)
@staticmethod
def _register_remote_model(args: List[Any], kwargs: Dict[str, Any]) -> None:
group = get_pipeline_parallel_group() # FIXME(tom) handle dynamic group
set_device_based_on_group(group)
kwargs["group"] = group
kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device())
model = Pipe(*args, **kwargs)
model.cuda()
global PipeModel
PipeModel = model
@staticmethod
def _model_forward(
training: bool, shape: torch.Size, dtype: torch.dtype
) -> Optional[Tuple[SizeOrSizes, DtypeOrDtypes]]:
try:
if isinstance(shape, torch.Size):
tensor = torch.empty(shape, dtype=dtype)
else:
tensor = tuple([torch.empty(s, dtype=d) for s, d in zip(shape, dtype)])
model = PipeModel
assert model.group
set_device_based_on_group(model.group)
model.train(training)
result = model(tensor)
if model.final_stage:
global PipeResult
PipeResult = result
return (get_shapes(result), get_dtype(result))
return None
except Exception as e:
print(f"_model_forward got {e}")
raise e
def _model_forward_first_stage(self, tensor: TensorOrTensors, event: Event) -> None:
try:
assert self.model.group
set_device_based_on_group(self.model.group)
self.model(tensor, event=event)
except Exception as e:
print(f"_model_forward got {e}")
raise e
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum, auto
from typing import Any, Callable, List, Optional, Tuple, Union
from dataclasses import dataclass
import torch
from torch import Tensor, nn
ACTIVATIONS_GRADS_QUEUE = 0
SKIP_TENSOR_QUEUE = 1
PORTAL_QUEUE = 2
EVENT_LOOP_QUEUE = 3
MESSAGE_GENERATION_START = 4
MessageGeneration = MESSAGE_GENERATION_START
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
InputDevice = Union[None, int, str, torch.device]
Schedule = List[Tuple[int, int]]
class LazyModule:
def __init__(self, function: Callable[[], nn.Module]):
self.function = function
def __call__(self) -> nn.Module:
return self.function()
class PipelineStyle(Enum):
SingleProcess = auto()
MultiProcess = auto()
AsyncSchedule = auto()
@dataclass(init=False)
class PipeMessage:
src: int
dest: int
queue_name: int
args: Any
tensors: Tensors
tensor_shapes: List[torch.Size]
tensor_dtypes: List[torch.dtype]
tag: int = 0
def __init__(
self,
src: int,
dest: int,
queue_name: int,
args: Any = None,
tensors: Optional[Tensors] = None,
tensor_count: int = 0,
):
self.src = src
self.dest = dest
self.queue_name = queue_name
self.args = args
self.tensors = tensors or tuple()
self.tensor_shapes = []
self.tensor_dtypes = []
global MessageGeneration
self.tag = MessageGeneration
if tensors is None:
MessageGeneration += tensor_count
else:
MessageGeneration += len(self.tensors)
......@@ -422,7 +422,7 @@ class OSS(Optimizer):
if group is dist.group.WORLD:
return rank
else:
global_rank = dist.distributed_c10d._get_global_rank(group, rank) # type: ignore
global_rank = dist.distributed_c10d._get_global_rank(group, rank)
return global_rank
def _broadcast_params(self, buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]]) -> None:
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pickle
from typing import Any
import torch
def pyobject_to_tensor(obj: Any, fixed_buffer_size: int = 0) -> torch.Tensor:
pickled = pickle.dumps(obj)
result: torch.Tensor = torch.ByteTensor(bytearray(pickled))
if fixed_buffer_size:
delta = fixed_buffer_size - len(result)
if delta < 0:
raise ValueError(
f"message too big to send, increase `fixed_buffer_size`? - {len(result)} > {fixed_buffer_size}"
)
elif delta > 0:
result = torch.cat((result, torch.zeros(delta, dtype=torch.uint8)))
return result
def tensor_to_pyobject(tensor: torch.Tensor) -> Any:
nparray = tensor.cpu().numpy()
return pickle.loads(nparray.tobytes())
......@@ -28,4 +28,4 @@ use_parentheses = true
skip_glob = ["build/*", "stubs/*"]
# Don't split "import" and "from".
force_sort_within_sections = true
known_third_party = ["benchmark_dataset", "dataclasses", "numpy", "packaging", "pytest", "recommonmark", "setuptools", "torch", "torchtext", "torchvision"]
known_third_party = ["benchmark_dataset", "dataclasses", "numpy", "packaging", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"]
#!/bin/bash
set -e
for WORKERS in {1..5}; do
mpirun -n $WORKERS python -m pytest tests/nn/pipe_process
rpc_tests=$(pytest --collect-only | grep 'Function.*rpc' | cut -d' ' -f 6 | tr -d '>')
for WORKERS in {1..6}; do
mpirun -n $WORKERS -mca orte_base_help_aggregate 0 python -m pytest tests/nn/pipe_process -k "not rpc"
for test_name in $rpc_tests; do
mpirun -n $WORKERS -mca orte_base_help_aggregate 0 python -m pytest tests/nn/pipe_process -k $test_name
done
done
......@@ -35,7 +35,7 @@ from . import version
#END
class dtype:
is_floating_point: bool
is_floating_point: builtins.bool
class layout: ...
......@@ -277,7 +277,7 @@ class Tensor:
def atan2(self, other: Tensor) -> Tensor: ...
def atan2_(self, other: Tensor) -> Tensor: ...
def atan_(self) -> Tensor: ...
def backward(self, gradient: Optional[Tensor]=None, keep_graph: _bool=False, create_graph: _bool=False) -> None: ...
def backward(self, gradient: Optional[Tensor]=None, retain_graph: _bool=False, create_graph: _bool=False) -> None: ...
def baddbmm(self, batch1: Tensor, batch2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ...
def baddbmm_(self, batch1: Tensor, batch2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ...
@overload
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment