Unverified Commit 63f7796a authored by Tom Birch's avatar Tom Birch Committed by GitHub
Browse files

Multi-process pipe (#90)

Adds support for distributing pipeline stages across multiple processes (and therefore multiple machines)
* Adds a style argument to the Pipe constructor, defaulting to PipelineStyle.SingleProcess, but also supporting PipelineStyle.MultiProcess
* Added support for lazy construction of modules (see lazy_construction for an example)
* Added two implementations of inter-process communication: one based on rpc with globally visible queues, one based on send/recv
* Copied all the relevant tests from tests/pipe to tests/pipe_process and modified them to exercise PipelineStyle.MultiProcess
parent 49a198c9
......@@ -61,7 +61,11 @@ class Task:
"""
def __init__(
self, stream: AbstractStream, *, compute: Callable[[], Batch], finalize: Optional[Callable[[Batch], None]],
self,
stream: Optional[AbstractStream],
*,
compute: Callable[[], Batch],
finalize: Optional[Callable[[Batch], None]],
) -> None:
self.stream = stream
self._compute = compute
......
#!/bin/bash
set -e
for WORKERS in {1..5}; do
mpirun -n $WORKERS python -m pytest tests/nn/pipe_process
done
This diff is collapsed.
......@@ -69,8 +69,8 @@ class stream:
def __enter__(self) -> None: ...
def __exit__(self, *args: Any) -> None: ...
def current_stream(device: Optional[_device_t]) -> Stream: ...
def default_stream(device: Optional[_device_t]) -> Stream: ...
def current_stream(device: Optional[_device_t] = None) -> Stream: ...
def default_stream(device: Optional[_device_t] = None) -> Stream: ...
#END
#
default_generators: Tuple[Any]
......@@ -4,8 +4,13 @@ from typing import Any, List, Union, Optional
from torch import Tensor
import datetime
from . import rpc as rpc
class Backend: ...
class ProcessGroup: ...
class ProcessGroup:
def size(self) -> int: ...
def rank(self) -> int: ...
class ReduceOp:
SUM: ReduceOp
......@@ -29,5 +34,12 @@ def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def send(tensor: Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> None: ...
def isend(tensor: Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> None: ...
def recv(tensor: Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> int: ...
def irecv(tensor: Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> int: ...
class group(object):
WORLD: Any
class RRef: ...
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Union, Callable, Optional
class RRef:
...
class WorkerInfo:
...
def rpc_async(
to: Union[str, WorkerInfo],
func: Callable,
args: Optional[tuple] = None,
kwargs: Optional[dict] = None,
timeout=-1.0,
) -> None:
...
def rpc_sync(
to: Union[str, WorkerInfo],
func: Callable,
args: Optional[tuple] = None,
kwargs: Optional[dict] = None,
timeout=-1.0,
) -> None:
...
......@@ -33,7 +33,7 @@ class Module(Generic[T_co]):
def apply(self: T, fn: Callable[['Module'], None]) -> T: ...
def cuda(self: T, device: Optional[Union[int, device]] = ...) -> T: ...
def cuda(self: T, device: Optional[Union[int, str, device]] = ...) -> T: ...
def cpu(self: T) -> T: ...
......
......@@ -19,14 +19,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
import os
import random
import numpy
from packaging import version
import pytest
import torch
import torch.distributed as dist
from torch.distributed import rpc
import torch.multiprocessing as mp
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed
......@@ -39,7 +45,7 @@ class IdentityLayer(torch.nn.Module):
return self.weight
def set_random_seed(seed):
def set_random_seed(seed: int) -> None:
"""Set random seed for reproducability."""
random.seed(seed)
numpy.random.seed(seed)
......@@ -47,11 +53,40 @@ def set_random_seed(seed):
model_parallel_cuda_manual_seed(seed)
def dist_init(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def dist_init(rank, world_size, hostname=None):
if hostname is None:
hostname = "localhost"
print(f"dist init r={rank}, world={world_size}, host={hostname}")
os.environ["MASTER_ADDR"] = hostname
os.environ["MASTER_PORT"] = "10638"
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(rank)
if version.parse(torch.__version__).release >= (1, 6, 0):
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size, init_method=init_method)
os.environ["MASTER_ADDR"] = hostname
os.environ["MASTER_PORT"] = "10639"
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
rpc.init_rpc(
f"Test{rank}",
rank=rank,
world_size=world_size,
backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(init_method=init_method),
)
else:
if world_size > 1:
rpc.init_rpc(f"Test{rank}", rank=rank, world_size=world_size)
else:
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)
if torch.cuda.is_available() and torch.cuda.device_count():
torch.cuda.set_device(rank % torch.cuda.device_count())
def get_worker_map():
return {rank: f"Test{rank}" for rank in range(dist.get_world_size())}
def get_world_sizes():
......@@ -59,6 +94,54 @@ def get_world_sizes():
return [x for x in [1, 2, 4, 8] if x <= limit]
def spawn_for_all_world_sizes(test_func, world_sizes=get_world_sizes()):
def spawn_for_all_world_sizes(test_func, world_sizes=get_world_sizes(), args=[]):
for world_size in world_sizes:
mp.spawn(test_func, args=(world_size,), nprocs=world_size, join=True)
mp.spawn(test_func, args=(world_size, *args), nprocs=world_size, join=True)
def helper(rank, world_size, func, args):
dist_init(rank, world_size)
initialize_model_parallel(1, world_size)
func(*args)
def torch_spawn(world_sizes=None):
if world_sizes is None:
world_sizes = get_world_sizes()
def fixer(func):
name = func.__name__
parameters = inspect.signature(func).parameters
if name.startswith("test"):
raise ValueError(
f"Tests marked with @torch_spawn (i.e. '{name}') should not have names beginning in 'test' as they will"
" be picked up by pytest without running the spawn wrapper"
)
@functools.wraps(func)
def replacement(*args, **kwargs):
assert args == tuple()
args = tuple(
kwargs[p] for p in parameters if p != "rank"
) # converting named parameters to positional parameters to pass to `spawn`
if "OMPI_COMM_WORLD_RANK" in os.environ:
torch.distributed.init_process_group("mpi")
world_size = torch.distributed.get_world_size()
initialize_model_parallel(1, world_size)
torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count())
if world_size in world_sizes:
func(*args)
else:
pytest.skip(f"requested world size doesn't match current world size")
else:
spawn_for_all_world_sizes(helper, world_sizes, (func, args))
caller_module = inspect.getmodule(inspect.currentframe().f_back)
setattr(caller_module, f"test_{name}", replacement)
return func
return fixer
......@@ -125,10 +125,11 @@ def test_adjacency(monkeypatch):
for group in new_groups:
buckets[len(group)].append(group)
assert sorted(list(buckets.keys())) == [model_parallel_size, data_parallel_size]
assert sorted(list(buckets.keys())) == [model_parallel_size, pipeline_length, data_parallel_size]
assert len(buckets[model_parallel_size]) == pipeline_length * data_parallel_size
assert len(buckets[data_parallel_size]) == model_parallel_size * pipeline_length
assert len(buckets[pipeline_length]) == model_parallel_size * data_parallel_size
# Check that model_parallel groups are contiguous
for group in buckets[model_parallel_size]:
......
......@@ -19,15 +19,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
from torch import nn
from torch.distributed import rpc
import torch.nn.init as init
from torch.nn.parameter import Parameter
from fairscale.nn.model_parallel import initialize as mpu
from fairscale.nn.model_parallel import layers
from fairscale.nn.pipe import Pipe
from tests.nn.model_parallel.commons import dist_init, get_world_sizes, set_random_seed, spawn_for_all_world_sizes
from tests.nn.model_parallel.commons import (
dist_init,
get_world_sizes,
set_random_seed,
spawn_for_all_world_sizes,
torch_spawn,
)
def run_test_parallel_embedding(rank, model_parallel_size):
......@@ -297,33 +307,43 @@ def run_test_row_parallel_linear(rank, model_parallel_size):
print(" >> passed the test :-)")
def run_test_pipe(rank, model_parallel_size):
def run_test_pipe(rank, world_size, skip_dist_init=False):
pipe_world_size = 2
dist_init(rank, model_parallel_size)
mpu.initialize_model_parallel(model_parallel_size)
if world_size == 1:
return
if not skip_dist_init:
dist_init(rank, world_size)
else:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29502"
rpc.init_rpc(f"Test{rank}", rank=rank, world_size=world_size)
mpu.initialize_model_parallel(world_size / pipe_world_size, pipe_world_size)
model_parallel_size = mpu.get_model_parallel_world_size()
if torch.distributed.get_rank() == 0:
print(
"> testing Sequential + Pipe with model parallel size: {}, pipe: {}".format(
model_parallel_size, pipe_world_size
)
)
model_parallel_size = mpu.get_model_parallel_world_size()
chunk_size = 8
chunk_size = 4
seed = 12345
set_random_seed(seed)
input_size_coeff = 13
input_size_coeff = 3
input_size = input_size_coeff * model_parallel_size
output_size_coeff = 17
output_size_coeff = 7
output_size = output_size_coeff * model_parallel_size
batch_size = 7 * chunk_size
batch_size = 3 * chunk_size
target = torch.rand((batch_size, input_size), requires_grad=True).cuda()
print(f"target = {target}")
identity = IdentityLayer2D(batch_size, input_size).cuda()
pipeline_devices = mpu.get_pipeline_parallel_group()
if pipe_world_size == 2 and len(pipeline_devices) == 1:
pipeline_devices.append(pipeline_devices[0] + model_parallel_size)
set_random_seed(seed)
model = nn.Sequential(
......@@ -331,33 +351,196 @@ def run_test_pipe(rank, model_parallel_size):
nn.ReLU(),
layers.RowParallelLinear(output_size, input_size, keep_master_weight_for_test=True, bias=False).cuda(),
)
set_random_seed(seed)
reference = nn.Sequential(
reference = [
nn.Linear(input_size, output_size, bias=False).cuda(),
nn.ReLU(),
nn.Linear(output_size, input_size, bias=False).cuda(),
)
]
reference[0].weight.data = model[0].master_weight.cuda()
reference[-1].weight.data = model[-1].master_weight.cuda()
print(f"setup {reference[0].weight.size()}, {model[0].weight.size()}, {(input_size, output_size)}")
print(f"setup {reference[2].weight.size()}, {(output_size, input_size)}")
reference[0].weight = Parameter(model[0].get_master_weight().clone()).cuda()
reference[2].weight = Parameter(model[2].get_master_weight().clone()).cuda()
reference = nn.Sequential(*reference)
def grad_graph(depth, grad):
result = depth * " " + str(grad)
if grad:
for x in grad.next_functions:
result += "\n" + grad_graph(depth + 1, x[0])
return result
def check_weights(x, y, key: str, index=None):
for i in [2, 0]:
if index is not None and i != index:
continue
left = x[i].get_master_weight()
right = y[i].weight.data
if not torch.allclose(left, right, atol=1.0e-6) or index is not None:
print(f"check_weights {key}-{i}: left = {left}, \nright = {right}")
if not torch.equal(left, right):
print(f"check_weights NOT_EQUAL {key}-{i}: left = {left}, \nright = {right}")
assert torch.allclose(left, right, atol=1.0e-6)
def dump_opt_params(opt):
for i, group in enumerate(opt.param_groups):
for j, p in enumerate(group["params"]):
print(f"{torch.distributed.get_rank()}:param {(i,j)} = {p}")
print(f"{torch.distributed.get_rank()}:param.grad {(i,j)} = {p.grad}")
def forward_model(model_, target, step=False):
optimizer = torch.optim.SGD(model_.parameters(), lr=0.01, momentum=0.9)
optimizer.zero_grad()
model_.zero_grad()
output = model_(identity())
loss = nn.MSELoss()
model_.zero_grad()
if step:
loss(output, target).backward()
saved_weight_0 = model_[0].weight.data.clone()
saved_weight_2 = model_[2].weight.data.clone()
dump_opt_params(optimizer)
optimizer.step()
assert not torch.allclose(saved_weight_0, model_[0].weight.data, atol=1.0e-6)
assert not torch.allclose(saved_weight_2, model_[2].weight.data, atol=1.0e-6)
return output
output = forward_model(model, target)
reference_output = forward_model(reference, target)
loss_weight = torch.randn([batch_size, output_size]).cuda()
output = model(identity())
reference_output = reference(identity())
error = reference_output.sub(output).max()
torch.distributed.barrier()
assert error < 1.0e-6
output = forward_model(model, target)
error = reference_output.sub(output).max()
torch.distributed.barrier()
assert error < 1.0e-6
output = forward_model(model, target)
error = reference_output.sub(output).max()
torch.distributed.barrier()
assert error < 1.0e-6
check_weights(model, reference, "before")
saved_weight_0 = model[0].weight.data.clone()
saved_weight_2 = model[2].weight.data.clone()
output = forward_model(model, target, step=True)
error = reference_output.sub(output).max()
assert error < 1.0e-6
model[0].weight.data = saved_weight_0
model[2].weight.data = saved_weight_2
worker_map = {i: f"Test{i}" for i in range(torch.distributed.get_world_size())}
if pipe_world_size == 2:
pipe_model = Pipe(model, [2, 1], devices=pipeline_devices, chunks=chunk_size)
print(f"actually doing pipe stuff now")
assert torch.equal(saved_weight_0, model[0].weight.data)
assert torch.equal(saved_weight_2, model[2].weight.data)
pipe_model = Pipe(
model,
[2, 1],
style=Pipe.MultiProcess,
group=pipeline_devices,
worker_map=worker_map,
input_device=torch.cuda.current_device(),
chunks=chunk_size,
pipelined_backward=True,
).cuda()
torch.distributed.barrier()
pipe_rank = torch.distributed.get_rank(group=mpu.get_pipeline_parallel_group())
print(f"pipe rank is {pipe_rank}")
if pipe_rank == 0:
assert torch.equal(saved_weight_0, pipe_model[0].weight.data)
else:
if not torch.equal(saved_weight_2, pipe_model[0].weight.data):
print(f"ne {pipe_rank}: left\n{saved_weight_2}\nright:\n{pipe_model[0].weight.data}")
assert torch.equal(saved_weight_2, pipe_model[0].weight.data)
optimizer = torch.optim.SGD(pipe_model.parameters(), lr=0.01, momentum=0.9)
optimizer.zero_grad()
if pipe_rank == 0:
assert torch.equal(saved_weight_0, pipe_model[0].weight.data)
print(f"runner {rank}:\n{pipe_model[0].weight.data}")
else:
assert torch.equal(saved_weight_2, pipe_model[0].weight.data)
print(f"runner {rank}:\n{pipe_model[0].weight.data}")
if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1:
check_weights(model, reference, "pre-pipe", index=2)
else:
check_weights(model, reference, "pre-pipe", index=0)
pipe_output = pipe_model(identity())
print(f"exited pipe for {rank}")
forward_model(reference, target, step=True)
print(f"pipe_output {rank} = {pipe_output}")
print(f"reference_output {rank} = {reference_output}")
error = reference_output.sub(pipe_output.cuda()).max()
torch.distributed.barrier()
assert error < 1.0e-6
if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1:
error = reference_output.sub(pipe_output.cuda()).max()
if error >= 1.0e-6:
print(f"error bad {error}")
assert error < 1.0e-6
loss = nn.MSELoss()
failed = False
pipe_output.retain_grad()
with torch.autograd.profiler.profile() as prof:
try:
loss(pipe_output, target).backward()
except Exception as e:
failed = True
print(f"got {e} while doing backward, deadlock?")
if failed:
raise RuntimeError("failed somehow")
dump_opt_params(optimizer)
optimizer.step()
print(f"calling check_weights on master")
check_weights(model, reference, "pipe", index=2)
print(f"waiting for barrier on master, pid={os.getpid()}")
else:
print(f"calling backwards on slave, pid={os.getpid()}")
failed = False
with torch.autograd.profiler.profile() as prof:
try:
pipe_model.back_helper(pipe_output)
except Exception as e:
failed = True
print(f"got {e} while doing backward, deadlock?")
if failed:
raise RuntimeError("failed somehow")
dump_opt_params(optimizer)
print(f"calling step on slave")
optimizer.step()
print(f"calling check_weights on slave")
check_weights(model, reference, "pipe", index=0)
print(f"waiting for barrier on slave")
pipe_model.zero_grad()
torch.distributed.barrier()
pipe_output = pipe_model(identity())
updated_ref_output = forward_model(reference, target)
if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1:
error = updated_ref_output.sub(pipe_output.cuda()).max()
print(f"outputs are ref:\n{updated_ref_output}\npipe:\n{pipe_output}")
assert error < 1.0e-6
torch.distributed.barrier()
print(f"finished waiting for barrier on, pid={os.getpid()}")
print(f"really exited pipe for {rank}")
rpc.shutdown()
torch.distributed.destroy_process_group()
torch.backends.cudnn.deterministic = True
......@@ -376,11 +559,29 @@ def test_column_parallel():
spawn_for_all_world_sizes(run_test_column_parallel_linear)
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="only works on mpi")
def test_row_parallel():
spawn_for_all_world_sizes(run_test_row_parallel_linear)
def test_pipe():
@torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="only works on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def mpi_pipe():
mpu.destroy_model_parallel()
run_test_pipe(torch.distributed.get_rank(), torch.distributed.get_world_size(), skip_dist_init=True)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_pipe_layer():
world_sizes = [x for x in get_world_sizes() if x <= torch.cuda.device_count() / 2]
spawn_for_all_world_sizes(run_test_pipe, args=[False])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.skip(reason="potential deadlock in nccl with multiple processes using the same gpu")
def test_eight_pipe_layer():
world_sizes = [x for x in get_world_sizes() if x <= torch.cuda.device_count() / 2]
spawn_for_all_world_sizes(run_test_pipe, world_sizes)
spawn_for_all_world_sizes(run_test_pipe, [8])
......@@ -27,7 +27,7 @@ from fairscale.nn.pipe.stream import default_stream
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_copy_returns_on_next_device():
portal = Portal(torch.rand(1), tensor_life=1)
portal = Portal(torch.rand(1), tensor_life=1, index=0)
prev_stream = default_stream(torch.device("cpu"))
next_stream = default_stream(torch.device("cuda"))
......@@ -52,7 +52,7 @@ def test_blue_orange():
# tensor1 ------------ Join -- Fork --- Mul --- Add -- output
#
main = tensor1
portal = Portal(tensor2, tensor_life=2)
portal = Portal(tensor2, tensor_life=2, index=0)
phony = portal.blue()
main = join(main, phony)
main, phony = fork(main)
......@@ -78,7 +78,7 @@ def test_blue_orange_not_requires_grad():
# tensor1 ------------ Join -- Fork --- Mul --- Add -- output
#
main = tensor1
portal = Portal(tensor2, tensor_life=2)
portal = Portal(tensor2, tensor_life=2, index=0)
phony = portal.blue()
main = join(main, phony)
main, phony = fork(main)
......@@ -93,7 +93,7 @@ def test_blue_orange_not_requires_grad():
def test_use_grad():
tensor = torch.rand(1, requires_grad=True)
portal = Portal(tensor, tensor_life=1)
portal = Portal(tensor, tensor_life=1, index=0)
portal.put_grad(tensor)
assert portal.use_grad() is tensor
......@@ -111,7 +111,7 @@ class TestTensorLife:
def new_portal(tensor_life):
nonlocal portal
tensor = torch.rand(1, requires_grad=True)
portal = Portal(tensor, tensor_life)
portal = Portal(tensor, tensor_life, 0)
return portal, tensor
yield new_portal
......
......@@ -72,9 +72,9 @@ def test_default_skip_tracker_by_data_parallel():
def test_reuse_portal():
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)})
skip_tracker = SkipTrackerThroughPotals(skip_layout)
skip_tracker = SkipTrackerThroughPotals(skip_layout, 0)
batch = Batch(torch.tensor([1.0]))
batch = Batch(torch.tensor([1.0]), 0)
a = torch.tensor([2.0])
b = torch.tensor([2.0])
......@@ -87,9 +87,9 @@ def test_reuse_portal():
def test_no_copy_no_portal():
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "copy"): (0, 1), (None, "not_copy"): (0, 0)})
skip_tracker = SkipTrackerThroughPotals(skip_layout)
skip_tracker = SkipTrackerThroughPotals(skip_layout, 0)
batch = Batch(torch.tensor([1.0]))
batch = Batch(torch.tensor([1.0]), 0)
a = torch.tensor([2.0])
b = torch.tensor([2.0])
......@@ -104,9 +104,9 @@ def test_no_copy_no_portal():
def test_tensor_life_without_checkpointing():
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)})
skip_tracker = SkipTrackerThroughPotals(skip_layout)
skip_tracker = SkipTrackerThroughPotals(skip_layout, 0)
batch = Batch(torch.tensor([1.0]))
batch = Batch(torch.tensor([1.0]), 0)
tensor = torch.tensor([2.0])
skip_tracker.save(batch, None, "test", tensor)
......@@ -118,9 +118,9 @@ def test_tensor_life_without_checkpointing():
def test_tensor_life_with_checkpointing():
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)})
skip_tracker = SkipTrackerThroughPotals(skip_layout)
skip_tracker = SkipTrackerThroughPotals(skip_layout, 0)
batch = Batch(torch.tensor([1.0]))
batch = Batch(torch.tensor([1.0]), 0)
tensor = torch.tensor([2.0])
with enable_checkpointing():
......
......@@ -24,7 +24,7 @@ import torch
from torch import nn
import torch.cuda
from fairscale.nn.pipe.checkpoint import Checkpointing, checkpoint, is_checkpointing, is_recomputing
from fairscale.nn.pipe.checkpoint import Checkpointing, Function, TensorOrTensors, is_checkpointing, is_recomputing
from fairscale.nn.pipe.dependency import fork, join
from fairscale.nn.pipe.microbatch import Batch
......@@ -33,6 +33,20 @@ if torch.cuda.is_available():
devices.append("cuda")
def make_checkpoint(function: Function, input: TensorOrTensors, index: int) -> TensorOrTensors:
"""Makes a checkpoint with a simple interface like
:func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug
:class:`Checkpoint` and :class:`Recompute` without boilerplate.
"""
batch = Batch(input, index)
chk = Checkpointing(function, batch)
batch = chk.checkpoint()
chk.recompute(batch)
return batch.tensor_or_tensors
@pytest.mark.parametrize("device", devices)
def test_serial_checkpoints(device):
# Copied from https://github.com/pytorch/pytorch/pull/18568.
......@@ -57,12 +71,12 @@ def test_serial_checkpoints(device):
# Increase the next function sequence number.
_ = a + 1 + 2 + 3 + 4 + 5
a = checkpoint(partial(Log.apply, "a"), a)
a = make_checkpoint(partial(Log.apply, "a"), a, 0)
a, phony = fork(a)
b = join(b, phony)
b = checkpoint(partial(Log.apply, "b"), b)
b = make_checkpoint(partial(Log.apply, "b"), b, 0)
c = torch.cat((a, b))
......@@ -79,7 +93,7 @@ def test_serial_checkpoints(device):
def test_not_requires_grad():
x = Batch(torch.rand(1, requires_grad=False))
x = Batch(torch.rand(1, requires_grad=False), 0)
assert not x[0].requires_grad
def f(x):
......@@ -102,7 +116,7 @@ def test_not_requires_grad_with_parameter():
def f(x):
return x * a
y = checkpoint(f, x)
y = make_checkpoint(f, x, 0)
y.backward()
assert a.grad is not None
......@@ -119,7 +133,7 @@ def test_random_in_checkpoint(device):
torch.manual_seed(0)
chk_x = torch.randn(3, 3, device=device, requires_grad=True)
chk_y = checkpoint(dropout, chk_x)
chk_y = make_checkpoint(dropout, chk_x, 0)
chk_y.norm().backward()
assert torch.allclose(x.grad, chk_x.grad)
......@@ -136,7 +150,7 @@ def test_detect_checkpointing_recomputing():
model = Detect()
input = torch.rand(1, requires_grad=True)
output = checkpoint(model, input)
output = make_checkpoint(model, input, 0)
output.backward()
assert logs == [(True, False), (False, True)]
......@@ -167,5 +181,5 @@ def test_non_grad_output():
model = ForkNonGrad()
input = torch.rand(1, requires_grad=True)
output = checkpoint(model, input)
output = make_checkpoint(model, input, 0)
output[0].backward()
......@@ -26,7 +26,7 @@ from fairscale.nn.pipe.microbatch import Batch, check, gather, scatter
def test_batch_atomic():
x = torch.tensor(42)
b = Batch(x)
b = Batch(x, 0)
assert b.atomic
......@@ -41,7 +41,7 @@ def test_batch_atomic():
def test_batch_non_atomic():
x, y = torch.tensor(42), torch.tensor(21)
b = Batch((x, y))
b = Batch((x, y), 0)
assert not b.atomic
......@@ -56,8 +56,8 @@ def test_batch_non_atomic():
def test_batch_call():
a = Batch(torch.tensor(42))
b = Batch((torch.tensor(42), torch.tensor(21)))
a = Batch(torch.tensor(42), 0)
b = Batch((torch.tensor(42), torch.tensor(21)), 0)
def f(x):
return x
......@@ -67,8 +67,8 @@ def test_batch_call():
def test_batch_setitem_by_index():
a = Batch(torch.tensor(42))
b = Batch((torch.tensor(42), torch.tensor(21)))
a = Batch(torch.tensor(42), 0)
b = Batch((torch.tensor(42), torch.tensor(21)), 0)
a[0] = torch.tensor(0)
b[0] = torch.tensor(0)
......@@ -83,8 +83,8 @@ def test_batch_setitem_by_index():
def test_batch_setitem_by_slice():
a = Batch(torch.tensor(42))
b = Batch((torch.tensor(42), torch.tensor(21)))
a = Batch(torch.tensor(42), 0)
b = Batch((torch.tensor(42), torch.tensor(21)), 0)
a[:] = (torch.tensor(0),)
b[:] = (torch.tensor(0),)
......@@ -115,7 +115,7 @@ def test_gather_tensors():
a = torch.zeros(1, 1)
b = torch.zeros(1, 1)
ab = gather([Batch(a), Batch(b)])
ab = gather([Batch(a, 0), Batch(b, 0)])
assert ab.size() == (2, 1)
......@@ -124,7 +124,7 @@ def test_gather_tuples():
a = (torch.zeros(1, 1), torch.zeros(2, 2))
b = (torch.zeros(1, 1), torch.zeros(2, 2))
ab = gather([Batch(a), Batch(b)])
ab = gather([Batch(a, 0), Batch(b, 0)])
assert isinstance(ab, tuple)
assert ab[0].size() == (2, 1)
......
......@@ -44,7 +44,7 @@ def test_join_running_workers():
nonlocal count
time.sleep(0.1)
count += 1
return Batch(())
return Batch((), 0)
with spawn_workers([fake_device() for _ in range(10)]) as (in_queues, out_queues):
......@@ -70,7 +70,7 @@ def test_join_running_workers_with_exception():
nonlocal count
time.sleep(0.1)
count += 1
return Batch(())
return Batch((), 0)
with pytest.raises(ExpectedException):
with spawn_workers([fake_device() for _ in range(10)]) as (in_queues, out_queues):
......@@ -96,7 +96,7 @@ def test_compute_multithreading():
def log_thread_id():
thread_id = threading.current_thread().ident
thread_ids.add(thread_id)
return Batch(())
return Batch((), 0)
with spawn_workers([fake_device() for _ in range(2)]) as (in_queues, out_queues):
for i in range(2):
......@@ -112,7 +112,7 @@ def test_compute_success():
"""Task.compute returns (True, (task, batch)) on success."""
def _42():
return Batch(torch.tensor(42))
return Batch(torch.tensor(42), 0)
with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues):
t = Task(CPUStream, compute=_42, finalize=None)
......@@ -145,7 +145,7 @@ def test_compute_exception():
def test_grad_mode(grad_mode):
def detect_grad_enabled():
x = torch.rand(1, requires_grad=torch.is_grad_enabled())
return Batch(x)
return Batch(x, 0)
with torch.set_grad_enabled(grad_mode):
with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues):
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# tests/__init__.py makes pytest can import the application without custom sys.path or PYTHONPATH.
# See also: https://docs.pytest.org/en/latest/goodpractices.html
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import pytest
import torch
from fairscale.nn.model_parallel import destroy_model_parallel
@pytest.fixture(autouse=True)
def manual_seed_zero():
torch.manual_seed(0)
def cuda_sleep_impl(seconds, cycles_per_ms):
torch.cuda._sleep(int(seconds * cycles_per_ms * 1000))
@pytest.fixture(scope="session")
def cuda_sleep():
# Warm-up CUDA.
torch.empty(1, device="cuda")
# From test/test_cuda.py in PyTorch.
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
torch.cuda._sleep(1000000)
end.record()
end.synchronize()
cycles_per_ms = 1000000 / start.elapsed_time(end)
return functools.partial(cuda_sleep_impl, cycles_per_ms=cycles_per_ms)
def pytest_report_header():
return f"torch: {torch.__version__}"
def pytest_runtest_setup(item):
print(f"setup mpi function called")
def pytest_runtest_teardown(item):
if "OMPI_COMM_WORLD_RANK" in os.environ:
destroy_model_parallel()
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
from torch import nn
from fairscale.nn.pipe import Pipe
from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange
from tests.nn.model_parallel.commons import get_worker_map, torch_spawn
@torch_spawn([3])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
def x1to3(balance, checkpoint):
torch.manual_seed(0)
@skippable(stash=["1to3"])
class Layer1(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
def forward(self, input):
yield stash("1to3", input)
output = self.conv(input)
return output
class Layer2(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
def forward(self, input):
output = self.conv(input)
return output
@skippable(pop=["1to3"])
class Layer3(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
def forward(self, input):
skip_1to3 = yield pop("1to3")
output = self.conv(input) + skip_1to3
return output
model = nn.Sequential(Layer1(), Layer2(), Layer3())
model = Pipe(
model,
balance,
chunks=3,
checkpoint=checkpoint,
input_device=torch.cuda.current_device(),
style=Pipe.MultiProcess,
worker_map=get_worker_map(),
pipelined_backward=False,
).cuda()
input = torch.rand(30, 3, 224, 224, requires_grad=True).cuda()
input.retain_grad()
output = model(input)
if model.group.rank() == len(balance) - 1:
loss = output.mean()
loss.backward()
elif model.group.rank() < len(balance) - 1:
model.back_helper(output)
if model.group.rank() == len(balance) - 1:
# TODO(tom) the single-process test uses 2e-1 but for some reason
# mutli-process is more noisy, need to investigate why
assert torch.allclose(output.norm(), torch.tensor(1039.0).cuda(), atol=4e-1)
if model.group.rank() == 0:
assert torch.allclose(input.grad.norm(), torch.tensor(0.0004533053).cuda())
torch.distributed.barrier()
@torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def none_skip():
@skippable(stash=["none"])
class Stash(nn.Module):
def forward(self, input):
yield stash("none", None)
return input
@skippable(pop=["none"])
class Pop(nn.Module):
def forward(self, input):
none = yield pop("none")
assert none is None
return input
model = nn.Sequential(Stash(), Pop())
model = Pipe(
model,
[1, 1],
style=Pipe.MultiProcess,
worker_map=get_worker_map(),
input_device=torch.cuda.current_device(),
chunks=5,
).cuda()
input = torch.rand(10, requires_grad=True).cuda()
input.retain_grad()
output = model(input)
def assert_grad_fn_is_not_portal(grad_fn, visited=set()):
if grad_fn in visited or grad_fn is None:
return
assert not isinstance(grad_fn, PortalBlue._backward_cls)
assert not isinstance(grad_fn, PortalCopy._backward_cls)
assert not isinstance(grad_fn, PortalOrange._backward_cls)
visited.add(grad_fn)
for next_grad_fn, _ in grad_fn.next_functions:
assert_grad_fn_is_not_portal(next_grad_fn, visited)
if model.group.rank() == 1:
assert_grad_fn_is_not_portal(output.grad_fn)
output.sum().backward()
else:
model.back_helper(output)
assert input.grad.mean().item() == 1
@torch_spawn([2])
def lazy_skippable_error():
"""Using skippable layers in combination with lazy construction is currently
not supported, check that it raises an Exception"""
@skippable(stash=["1to3"])
class Layer1(nn.Linear):
pass
@skippable(pop=["1to3"])
class Layer3(nn.Linear):
pass
model = [lambda: Layer1(10, 10), lambda: nn.Linear(10, 10), lambda: Layer3(10, 10)]
with pytest.raises(ValueError, match="Can't use Skippable layers with multi-process pipe and lazy construction"):
Pipe(
model, [2, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(),
)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
from torch import nn
from fairscale.nn.pipe import Pipe, is_checkpointing, is_recomputing
from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.tracker import current_skip_tracker
from tests.nn.model_parallel.commons import get_worker_map, torch_spawn
@skippable(stash=["skip"])
class Stash(nn.Module):
def forward(self, input):
yield stash("skip", input)
return input
@skippable(pop=["skip"])
class Pop(nn.Module):
def forward(self, input):
skip = yield pop("skip")
return input + skip
@torch_spawn([2])
@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"])
@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def delete_portal_tensor(train, checkpoint):
# Without checkpointing:
# +- Stash --+ +--- Pop ----+ - - - layers
# | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function
# +----------+ +------------+
#
# With checkpointing:
# +- Stash --+ +--- Pop ----+ +--- Pop'----+ +- Stash'--+
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
# +----------+ +------------+ +------------+ +----------+
def portal_tensor_life_is(tensor_life, skip_tracker=None):
if skip_tracker is None:
skip_tracker = current_skip_tracker()
# Get the current portal.
portal = list(skip_tracker.portals.values())[0]
if tensor_life == 0:
return portal.tensor_life == 0 and portal.tensor is None
else:
return portal.tensor_life == tensor_life and portal.tensor is not None
# Check the portal tensor after 'Stash'.
stash_ = Stash()
@stash_.register_forward_hook
def check_portal_tensor_after_stash(*_):
if is_checkpointing():
assert portal_tensor_life_is(2)
elif is_recomputing():
assert portal_tensor_life_is(0)
else:
assert portal_tensor_life_is(1)
pop_ = Pop()
@pop_.register_forward_hook
def check_portal_tensor_after_pop(*_):
if is_checkpointing():
assert portal_tensor_life_is(1)
elif is_recomputing():
assert portal_tensor_life_is(0)
else:
assert portal_tensor_life_is(0)
class NoPortalTensorAtBackward(nn.Module):
class F(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.skip_tracker = current_skip_tracker()
return input.detach()
@staticmethod
def backward(ctx, grad):
assert portal_tensor_life_is(0, skip_tracker=ctx.skip_tracker)
return grad
def forward(self, input):
return self.F.apply(input)
model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_)
model = Pipe(
model, balance=[2, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint,
)
input = torch.rand(10, requires_grad=True)
if train:
model.train()
output = model(input)
if model.group.rank() == 1:
output.norm().backward()
else:
model.back_helper(output)
else:
model.eval()
with torch.no_grad():
model(input)
torch.distributed.barrier()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment