"python/vscode:/vscode.git/clone" did not exist on "d22dbec28b2dcc026b7c19a57ed71ce1ea9ed1b2"
Unverified Commit 5d4f50fb authored by Tom Birch's avatar Tom Birch Committed by GitHub
Browse files

Single-process control via PipeRPCWrapper (#156)

Adds support for:
* Reused layers (e.g. for weight sharing)
* Lazily-constructed layers
* Single-process control via PipeRPCWrapper
* PipelineStyle.AsyncScheudle, which lays the foundation for asynchronous pipeline work by introducing an event loop for each rank/worker to process either activations or gradients as they arrive

Also added examples for multi-process and PipeRPCWrapper
parent 543d5693
...@@ -29,7 +29,7 @@ _device_t = Union[_device, int, str] ...@@ -29,7 +29,7 @@ _device_t = Union[_device, int, str]
def check_error(res: int) -> None: ... def check_error(res: int) -> None: ...
def device_count() -> int: ... def device_count() -> int: ...
def empty_cache() -> None: ... def empty_cache() -> None: ...
def synchronize(device: _device_t) -> None: ... def synchronize(device: Optional[_device_t]=None) -> None: ...
def set_device(device: _device_t) -> None: ... def set_device(device: _device_t) -> None: ...
def get_device_capability(device: Optional[_device_t]=...) -> Tuple[int, int]: ... def get_device_capability(device: Optional[_device_t]=...) -> Tuple[int, int]: ...
def get_device_name(device: Optional[_device_t]=...) -> str: ... def get_device_name(device: Optional[_device_t]=...) -> str: ...
......
...@@ -5,6 +5,7 @@ from torch import Tensor ...@@ -5,6 +5,7 @@ from torch import Tensor
import datetime import datetime
from . import rpc as rpc from . import rpc as rpc
from . import distributed_c10d as distributed_c10d
class Backend: class Backend:
GLOO: str GLOO: str
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Any, List, Union, Optional
from . import ProcessGroup
def _get_global_rank(group: ProcessGroup, rank: int) -> int: ...
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Union, Callable, Optional from typing import Union, Callable, Optional
from torch.futures import Future
class RRef: class RRef:
...@@ -17,7 +18,7 @@ def rpc_async( ...@@ -17,7 +18,7 @@ def rpc_async(
args: Optional[tuple] = None, args: Optional[tuple] = None,
kwargs: Optional[dict] = None, kwargs: Optional[dict] = None,
timeout=-1.0, timeout=-1.0,
) -> None: ) -> Future:
... ...
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Any
class Future:
def wait(self) -> Any: ...
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .modules import * from .modules import *
from .parameter import Parameter as Parameter from .parameter import Parameter as Parameter
from .parallel import DataParallel as DataParallel from .parallel import DataParallel as DataParallel
from . import functional as functional from . import functional as functional
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
import functools import functools
import inspect import inspect
import multiprocessing
import os import os
import random import random
...@@ -100,17 +101,32 @@ def spawn_for_all_world_sizes(test_func, world_sizes=get_world_sizes(), args=[]) ...@@ -100,17 +101,32 @@ def spawn_for_all_world_sizes(test_func, world_sizes=get_world_sizes(), args=[])
mp.spawn(test_func, args=(world_size, *args), nprocs=world_size, join=True) mp.spawn(test_func, args=(world_size, *args), nprocs=world_size, join=True)
def helper(rank, world_size, func, args): def worker_process(rank, world_size, func, args, error_queue):
"""Main function for unit tests launced with torch_spawn"""
dist_init(rank, world_size) dist_init(rank, world_size)
initialize_model_parallel(1, world_size) kwargs = {}
func(*args) if "OMPI_COMM_WORLD_RANK" not in os.environ:
kwargs["pipeline_backend"] = "gloo"
initialize_model_parallel(1, world_size, **kwargs)
try:
func(*args)
except BaseException as e:
# If the function raises 'Skipped', this indicates pytest.skip(), so
# forward it to parent so we can call pytest.skip() there
if e.__class__.__name__ == "Skipped":
error_queue.put(str(e))
return
raise e
def torch_spawn(world_sizes=None): def torch_spawn(world_sizes=None):
if world_sizes is None: if world_sizes is None:
world_sizes = get_world_sizes() world_sizes = get_world_sizes()
def fixer(func): def prepare_test(func):
"""Function called with the test function as the argument. Generates a
replacement which serves as the actual test function."""
name = func.__name__ name = func.__name__
parameters = inspect.signature(func).parameters parameters = inspect.signature(func).parameters
...@@ -128,21 +144,39 @@ def torch_spawn(world_sizes=None): ...@@ -128,21 +144,39 @@ def torch_spawn(world_sizes=None):
kwargs[p] for p in parameters if p != "rank" kwargs[p] for p in parameters if p != "rank"
) # converting named parameters to positional parameters to pass to `spawn` ) # converting named parameters to positional parameters to pass to `spawn`
error_queue = multiprocessing.get_context("spawn").SimpleQueue()
if "OMPI_COMM_WORLD_RANK" in os.environ: if "OMPI_COMM_WORLD_RANK" in os.environ:
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "10638"
torch.distributed.init_process_group("mpi") torch.distributed.init_process_group("mpi")
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
initialize_model_parallel(1, world_size) initialize_model_parallel(1, world_size)
torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count()) torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count())
if world_size in world_sizes: if world_size in world_sizes:
func(*args) try:
func(*args)
except BaseException as e:
print(f"got exception {e} from test")
import traceback
print(f"{traceback.format_exc()}")
raise e
else: else:
pytest.skip(f"requested world size doesn't match current world size") pytest.skip(f"requested world size doesn't match current world size")
else: else:
spawn_for_all_world_sizes(helper, world_sizes, (func, args)) spawn_for_all_world_sizes(worker_process, world_sizes, (func, args, error_queue))
if not error_queue.empty():
msg = error_queue.get()
pytest.skip(msg)
# Register a function with the same name, prefixed with "test_" in the
# calling module, so it will be picked up by pytest
caller_module = inspect.getmodule(inspect.currentframe().f_back) caller_module = inspect.getmodule(inspect.currentframe().f_back)
setattr(caller_module, f"test_{name}", replacement) setattr(caller_module, f"test_{name}", replacement)
return func return func
return fixer return prepare_test
...@@ -110,7 +110,7 @@ def test_adjacency(monkeypatch): ...@@ -110,7 +110,7 @@ def test_adjacency(monkeypatch):
def get_world_size(self): def get_world_size(self):
return data_parallel_size * pipeline_length * model_parallel_size return data_parallel_size * pipeline_length * model_parallel_size
def new_group(self, args): def new_group(self, args, backend=None):
new_groups.append(args.copy()) new_groups.append(args.copy())
return () return ()
......
...@@ -436,6 +436,7 @@ def run_test_pipe(rank, world_size, skip_dist_init=False): ...@@ -436,6 +436,7 @@ def run_test_pipe(rank, world_size, skip_dist_init=False):
model[2].weight.data = saved_weight_2 model[2].weight.data = saved_weight_2
worker_map = {i: f"Test{i}" for i in range(torch.distributed.get_world_size())} worker_map = {i: f"Test{i}" for i in range(torch.distributed.get_world_size())}
style = Pipe.MultiProcess # Pipe.AsyncSchedule
if pipe_world_size == 2: if pipe_world_size == 2:
print(f"actually doing pipe stuff now") print(f"actually doing pipe stuff now")
...@@ -444,7 +445,7 @@ def run_test_pipe(rank, world_size, skip_dist_init=False): ...@@ -444,7 +445,7 @@ def run_test_pipe(rank, world_size, skip_dist_init=False):
pipe_model = Pipe( pipe_model = Pipe(
model, model,
[2, 1], [2, 1],
style=Pipe.MultiProcess, style=style,
group=pipeline_devices, group=pipeline_devices,
worker_map=worker_map, worker_map=worker_map,
input_device=torch.cuda.current_device(), input_device=torch.cuda.current_device(),
...@@ -511,7 +512,8 @@ def run_test_pipe(rank, world_size, skip_dist_init=False): ...@@ -511,7 +512,8 @@ def run_test_pipe(rank, world_size, skip_dist_init=False):
failed = False failed = False
with torch.autograd.profiler.profile() as prof: with torch.autograd.profiler.profile() as prof:
try: try:
pipe_model.back_helper(pipe_output) if style == Pipe.MultiProcess:
pipe_model.back_helper(pipe_output)
except Exception as e: except Exception as e:
failed = True failed = True
print(f"got {e} while doing backward, deadlock?") print(f"got {e} while doing backward, deadlock?")
...@@ -527,6 +529,7 @@ def run_test_pipe(rank, world_size, skip_dist_init=False): ...@@ -527,6 +529,7 @@ def run_test_pipe(rank, world_size, skip_dist_init=False):
pipe_model.zero_grad() pipe_model.zero_grad()
torch.distributed.barrier() torch.distributed.barrier()
pipe_model.eval()
pipe_output = pipe_model(identity()) pipe_output = pipe_model(identity())
updated_ref_output = forward_model(reference, target) updated_ref_output = forward_model(reference, target)
if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1: if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1:
......
...@@ -23,17 +23,18 @@ else: ...@@ -23,17 +23,18 @@ else:
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501" os.environ["MASTER_PORT"] = "29501"
if "OMPI_COMM_WORLD_SIZE" in os.environ: if "OMPI_COMM_WORLD_SIZE" in os.environ:
dist.init_process_group(backend=dist.Backend.MPI) pass # dist.init_process_group(backend=dist.Backend.MPI)
def setup_module(module): def setup_module(module):
if "OMPI_COMM_WORLD_SIZE" not in os.environ: if "OMPI_COMM_WORLD_SIZE" not in os.environ:
dist.init_process_group(backend=BACKEND, rank=0, world_size=1) dist.init_process_group(backend=BACKEND, rank=0, world_size=1)
else:
dist.init_process_group(backend=dist.Backend.MPI)
def teardown_module(module): def teardown_module(module):
if "OMPI_COMM_WORLD_SIZE" not in os.environ: torch.distributed.destroy_process_group()
torch.distributed.destroy_process_group()
@pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("device", devices)
......
...@@ -65,3 +65,7 @@ def pytest_runtest_teardown(item): ...@@ -65,3 +65,7 @@ def pytest_runtest_teardown(item):
destroy_model_parallel() destroy_model_parallel()
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
try:
torch.distributed.rpc.shutdown()
except Exception:
pass
...@@ -23,7 +23,7 @@ import pytest ...@@ -23,7 +23,7 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from fairscale.nn.pipe import Pipe from fairscale.nn.pipe import LazyModule, Pipe
from fairscale.nn.pipe.skip import pop, skippable, stash from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange
from tests.nn.model_parallel.commons import get_worker_map, torch_spawn from tests.nn.model_parallel.commons import get_worker_map, torch_spawn
...@@ -33,10 +33,15 @@ from tests.nn.model_parallel.commons import get_worker_map, torch_spawn ...@@ -33,10 +33,15 @@ from tests.nn.model_parallel.commons import get_worker_map, torch_spawn
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"]) @pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
def x1to3(balance, checkpoint): def x1to3(balance, checkpoint, pipeline_style):
torch.manual_seed(0) torch.manual_seed(0)
if pipeline_style == Pipe.AsyncSchedule and len(balance) > 1:
print(f"skipping yarg")
pytest.skip("Skip tensors NYI for AsyncSchedule")
@skippable(stash=["1to3"]) @skippable(stash=["1to3"])
class Layer1(nn.Module): class Layer1(nn.Module):
def __init__(self): def __init__(self):
...@@ -75,7 +80,7 @@ def x1to3(balance, checkpoint): ...@@ -75,7 +80,7 @@ def x1to3(balance, checkpoint):
chunks=3, chunks=3,
checkpoint=checkpoint, checkpoint=checkpoint,
input_device=torch.cuda.current_device(), input_device=torch.cuda.current_device(),
style=Pipe.MultiProcess, style=pipeline_style,
worker_map=get_worker_map(), worker_map=get_worker_map(),
pipelined_backward=False, pipelined_backward=False,
).cuda() ).cuda()
...@@ -101,7 +106,11 @@ def x1to3(balance, checkpoint): ...@@ -101,7 +106,11 @@ def x1to3(balance, checkpoint):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def none_skip(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def none_skip(pipeline_style):
if pipeline_style == Pipe.AsyncSchedule:
pytest.skip("Skip tensors NYI for AsyncSchedule")
@skippable(stash=["none"]) @skippable(stash=["none"])
class Stash(nn.Module): class Stash(nn.Module):
def forward(self, input): def forward(self, input):
...@@ -119,7 +128,7 @@ def none_skip(): ...@@ -119,7 +128,7 @@ def none_skip():
model = Pipe( model = Pipe(
model, model,
[1, 1], [1, 1],
style=Pipe.MultiProcess, style=pipeline_style,
worker_map=get_worker_map(), worker_map=get_worker_map(),
input_device=torch.cuda.current_device(), input_device=torch.cuda.current_device(),
chunks=5, chunks=5,
...@@ -151,7 +160,8 @@ def none_skip(): ...@@ -151,7 +160,8 @@ def none_skip():
@torch_spawn([2]) @torch_spawn([2])
def lazy_skippable_error(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def lazy_skippable_error(pipeline_style):
"""Using skippable layers in combination with lazy construction is currently """Using skippable layers in combination with lazy construction is currently
not supported, check that it raises an Exception""" not supported, check that it raises an Exception"""
...@@ -163,9 +173,13 @@ def lazy_skippable_error(): ...@@ -163,9 +173,13 @@ def lazy_skippable_error():
class Layer3(nn.Linear): class Layer3(nn.Linear):
pass pass
model = [lambda: Layer1(10, 10), lambda: nn.Linear(10, 10), lambda: Layer3(10, 10)] model = [
LazyModule(lambda: Layer1(10, 10)),
LazyModule(lambda: nn.Linear(10, 10)),
LazyModule(lambda: Layer3(10, 10)),
]
with pytest.raises(ValueError, match="Can't use Skippable layers with multi-process pipe and lazy construction"): with pytest.raises(ValueError, match="Can't use Skippable layers with multi-process pipe and lazy construction"):
Pipe( Pipe(
model, [2, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), model, [2, 1], style=pipeline_style, worker_map=get_worker_map(),
) )
...@@ -46,9 +46,10 @@ class Pop(nn.Module): ...@@ -46,9 +46,10 @@ class Pop(nn.Module):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) @pytest.mark.parametrize("train", [True, False], ids=["train", "eval"])
@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"]) @pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def delete_portal_tensor(train, checkpoint): def delete_portal_tensor(train, checkpoint, pipeline_style):
# Without checkpointing: # Without checkpointing:
# +- Stash --+ +--- Pop ----+ - - - layers # +- Stash --+ +--- Pop ----+ - - - layers
# | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function # | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function
...@@ -59,6 +60,9 @@ def delete_portal_tensor(train, checkpoint): ...@@ -59,6 +60,9 @@ def delete_portal_tensor(train, checkpoint):
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 | # | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
# +----------+ +------------+ +------------+ +----------+ # +----------+ +------------+ +------------+ +----------+
if pipeline_style == Pipe.AsyncSchedule:
pytest.skip("Skip tensors NYI for AsyncSchedule")
def portal_tensor_life_is(tensor_life, skip_tracker=None): def portal_tensor_life_is(tensor_life, skip_tracker=None):
if skip_tracker is None: if skip_tracker is None:
skip_tracker = current_skip_tracker() skip_tracker = current_skip_tracker()
...@@ -111,7 +115,7 @@ def delete_portal_tensor(train, checkpoint): ...@@ -111,7 +115,7 @@ def delete_portal_tensor(train, checkpoint):
model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_) model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_)
model = Pipe( model = Pipe(
model, balance=[2, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint, model, balance=[2, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint,
) )
input = torch.rand(10, requires_grad=True) input = torch.rand(10, requires_grad=True)
......
...@@ -28,7 +28,9 @@ from tests.nn.model_parallel.commons import get_worker_map, torch_spawn ...@@ -28,7 +28,9 @@ from tests.nn.model_parallel.commons import get_worker_map, torch_spawn
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def python_autograd_function(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def python_autograd_function(pipeline_style):
# FIXME deadlock with Pipe.AsyncSchedule?
# A Python autograd function might fail with this error: # A Python autograd function might fail with this error:
# #
# RuntimeError: Returning Variables sharing storage with other Variables # RuntimeError: Returning Variables sharing storage with other Variables
...@@ -55,7 +57,8 @@ def python_autograd_function(): ...@@ -55,7 +57,8 @@ def python_autograd_function():
return Identity.apply(input) return Identity.apply(input)
model = nn.Sequential(M(), M()) model = nn.Sequential(M(), M())
model = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), checkpoint="always").cuda() model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always").cuda()
model.eval()
x = torch.rand(42) x = torch.rand(42)
y = model(x) y = model(x)
...@@ -67,7 +70,8 @@ def python_autograd_function(): ...@@ -67,7 +70,8 @@ def python_autograd_function():
@torch_spawn([3]) @torch_spawn([3])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def exception_no_hang(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def exception_no_hang(pipeline_style):
# In v0.0.2, once a failed partition receives a normal message # In v0.0.2, once a failed partition receives a normal message
# (non-closing) for the next micro-batch, a hang occured. The reason was # (non-closing) for the next micro-batch, a hang occured. The reason was
# that a failed partition didn't call in_queue.task_done() on a normal # that a failed partition didn't call in_queue.task_done() on a normal
...@@ -85,7 +89,8 @@ def exception_no_hang(): ...@@ -85,7 +89,8 @@ def exception_no_hang():
raise ExpectedException() raise ExpectedException()
model = nn.Sequential(Pass(), Pass(), Raise()) model = nn.Sequential(Pass(), Pass(), Raise())
model = Pipe(model, [1, 1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=3) model = Pipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3)
model.eval()
if model.group.rank() == 2: if model.group.rank() == 2:
with pytest.raises(ExpectedException): with pytest.raises(ExpectedException):
...@@ -98,7 +103,8 @@ def exception_no_hang(): ...@@ -98,7 +103,8 @@ def exception_no_hang():
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required")
def tuple_wait(cuda_sleep): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def tuple_wait(cuda_sleep, pipeline_style):
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch. # In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# Under this behavior, if checkpointing was disabled, there's a possibility # Under this behavior, if checkpointing was disabled, there's a possibility
# that gradient accumulations on other tensors are not synchronized # that gradient accumulations on other tensors are not synchronized
...@@ -129,7 +135,7 @@ def tuple_wait(cuda_sleep): ...@@ -129,7 +135,7 @@ def tuple_wait(cuda_sleep):
model = Pipe( model = Pipe(
model, model,
[1, 1], [1, 1],
style=Pipe.MultiProcess, style=pipeline_style,
worker_map=get_worker_map(), worker_map=get_worker_map(),
input_device=torch.cuda.current_device(), input_device=torch.cuda.current_device(),
chunks=32, chunks=32,
...@@ -151,7 +157,8 @@ def tuple_wait(cuda_sleep): ...@@ -151,7 +157,8 @@ def tuple_wait(cuda_sleep):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def parallel_randoms(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def parallel_randoms(pipeline_style):
class Dropouts(nn.Module): class Dropouts(nn.Module):
def forward(self, x): def forward(self, x):
for _ in range(100): for _ in range(100):
...@@ -165,7 +172,7 @@ def parallel_randoms(): ...@@ -165,7 +172,7 @@ def parallel_randoms():
model = Pipe( model = Pipe(
model, model,
[1, 1], [1, 1],
style=Pipe.MultiProcess, style=pipeline_style,
input_device=torch.cuda.current_device(), input_device=torch.cuda.current_device(),
worker_map=get_worker_map(), worker_map=get_worker_map(),
chunks=10, chunks=10,
......
...@@ -27,11 +27,17 @@ from tests.nn.model_parallel.commons import get_worker_map, torch_spawn ...@@ -27,11 +27,17 @@ from tests.nn.model_parallel.commons import get_worker_map, torch_spawn
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def inplace_on_requires_grad(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def inplace_on_requires_grad(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True))
model = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), checkpoint="always") model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
x = torch.rand(1) x = torch.rand(1)
if pipeline_style == Pipe.AsyncSchedule and model.group.rank() == 0:
# With AsyncSchedule, model will wait forever for gradients if not eval
model.eval()
y = model(x) y = model(x)
message = r"a leaf Variable that requires grad .* used in an in-place operation." message = r"a leaf Variable that requires grad .* used in an in-place operation."
...@@ -44,11 +50,12 @@ def inplace_on_requires_grad(): ...@@ -44,11 +50,12 @@ def inplace_on_requires_grad():
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.xfail(strict=True) @pytest.mark.xfail(strict=True)
def inplace_on_not_requires_grad(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def inplace_on_not_requires_grad(pipeline_style):
# In-place operation on a tensor not requiring grad doesn't cause a # In-place operation on a tensor not requiring grad doesn't cause a
# RuntimeError. Currently, we cannot detect this case. # RuntimeError. Currently, we cannot detect this case.
model = nn.Sequential(nn.ReLU(inplace=True)) model = nn.Sequential(nn.ReLU(inplace=True))
model = Pipe(model, [1], style=Pipe.MultiProcess, worker_map=get_worker_map(), checkpoint="always") model = Pipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
x = torch.rand(1) x = torch.rand(1)
y = model(x) y = model(x)
...@@ -63,7 +70,8 @@ def inplace_on_not_requires_grad(): ...@@ -63,7 +70,8 @@ def inplace_on_not_requires_grad():
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.xfail(strict=True) @pytest.mark.xfail(strict=True)
def inplace_incorrect_grad(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def inplace_incorrect_grad(pipeline_style):
class M(nn.Module): class M(nn.Module):
def forward(self, foo_bar): def forward(self, foo_bar):
# 'foo' requires grad but 'bar' does not. In-place operation on # 'foo' requires grad but 'bar' does not. In-place operation on
...@@ -80,7 +88,7 @@ def inplace_incorrect_grad(): ...@@ -80,7 +88,7 @@ def inplace_incorrect_grad():
return foo * bar return foo * bar
model = nn.Sequential(M()) model = nn.Sequential(M())
model = Pipe(model, [1], style=Pipe.MultiProcess, worker_map=get_worker_map(), checkpoint="always") model = Pipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
foo = torch.tensor([1.0], requires_grad=True) foo = torch.tensor([1.0], requires_grad=True)
bar = torch.tensor([1.0]) bar = torch.tensor([1.0])
......
...@@ -21,21 +21,27 @@ from collections import OrderedDict ...@@ -21,21 +21,27 @@ from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
import os import os
import time import time
from typing import Tuple
from packaging import version from packaging import version
import pytest import pytest
import torch import torch
from torch import nn from torch import nn
from fairscale.nn.model_parallel.initialize import destroy_model_parallel, initialize_model_parallel from fairscale.nn.model_parallel.initialize import (
from fairscale.nn.pipe import Pipe destroy_model_parallel,
from tests.nn.model_parallel.commons import get_worker_map, torch_spawn get_pipeline_parallel_group,
initialize_model_parallel,
)
from fairscale.nn.pipe import LazyModule, Pipe
from tests.nn.model_parallel.commons import get_worker_map, set_random_seed, torch_spawn
@torch_spawn([2]) @torch_spawn([2])
def parameters(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def parameters(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
pipe = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=1) pipe = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
assert list(pipe.parameters()) != [] assert list(pipe.parameters()) != []
else: else:
...@@ -62,10 +68,10 @@ def infiniband(): ...@@ -62,10 +68,10 @@ def infiniband():
def infiniband2(): def infiniband2():
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
t = torch.Tensor(range(100)).cuda() t = torch.Tensor(range(100)).cuda()
torch.distributed.send(t, 1) torch.distributed.send(t, 1, group=get_pipeline_parallel_group())
else: else:
t = torch.empty(100).cuda() t = torch.empty(100).cuda()
torch.distributed.recv(t, 0) torch.distributed.recv(t, 0, group=get_pipeline_parallel_group())
assert torch.equal(t, torch.Tensor(range(100)).cuda()) assert torch.equal(t, torch.Tensor(range(100)).cuda())
print(f"t on {torch.distributed.get_rank()} is {t}") print(f"t on {torch.distributed.get_rank()} is {t}")
...@@ -87,7 +93,6 @@ def mpi(): ...@@ -87,7 +93,6 @@ def mpi():
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
torch.distributed.barrier() torch.distributed.barrier()
group = torch.distributed.new_group([0, 1])
tensor_size = (1024, 1024, 10) tensor_size = (1024, 1024, 10)
torch.cuda.set_device(torch.distributed.get_rank()) # need to pin device or ucx gets unhappy torch.cuda.set_device(torch.distributed.get_rank()) # need to pin device or ucx gets unhappy
...@@ -104,7 +109,8 @@ def mpi(): ...@@ -104,7 +109,8 @@ def mpi():
@torch_spawn([1]) @torch_spawn([1])
def public_attrs(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def public_attrs(pipeline_style):
class MyString: class MyString:
def __init__(self, value): def __init__(self, value):
self.value = value self.value = value
...@@ -117,7 +123,7 @@ def public_attrs(): ...@@ -117,7 +123,7 @@ def public_attrs():
pipe = Pipe( pipe = Pipe(
model, model,
balance=(1,), balance=(1,),
style=Pipe.MultiProcess, style=pipeline_style,
worker_map=get_worker_map(), worker_map=get_worker_map(),
chunks=42.000, chunks=42.000,
checkpoint=MyString("always"), checkpoint=MyString("always"),
...@@ -134,12 +140,13 @@ def public_attrs(): ...@@ -134,12 +140,13 @@ def public_attrs():
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("balance", [[2], [1, 1]]) @pytest.mark.parametrize("balance", [[2], [1, 1]])
def sequential_like(balance): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def sequential_like(balance, pipeline_style):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
model = nn.Sequential(a, b) model = nn.Sequential(a, b)
model = Pipe(model, balance, style=Pipe.MultiProcess, worker_map=get_worker_map()) model = Pipe(model, balance, style=pipeline_style, worker_map=get_worker_map())
if balance == [2]: if balance == [2]:
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
...@@ -172,57 +179,62 @@ def sequential_like(balance): ...@@ -172,57 +179,62 @@ def sequential_like(balance):
@torch_spawn([1]) @torch_spawn([1])
def balance_wrong_length(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def balance_wrong_length(pipeline_style):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
model = nn.Sequential(a, b) model = nn.Sequential(a, b)
with pytest.raises(ValueError): with pytest.raises(ValueError):
Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map()) Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map())
with pytest.raises(ValueError): with pytest.raises(ValueError):
Pipe(model, balance=[3], style=Pipe.MultiProcess, worker_map=get_worker_map()) Pipe(model, balance=[3], style=pipeline_style, worker_map=get_worker_map())
@torch_spawn([2]) @torch_spawn([2])
def balance_less_than_1(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def balance_less_than_1(pipeline_style):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
model = nn.Sequential(a, b) model = nn.Sequential(a, b)
with pytest.raises(ValueError): with pytest.raises(ValueError):
Pipe(model, balance=[0, 2], style=Pipe.MultiProcess, worker_map=get_worker_map()) Pipe(model, balance=[0, 2], style=pipeline_style, worker_map=get_worker_map())
with pytest.raises(ValueError): with pytest.raises(ValueError):
Pipe(model, balance=[-1, 3], style=Pipe.MultiProcess, worker_map=get_worker_map()) Pipe(model, balance=[-1, 3], style=pipeline_style, worker_map=get_worker_map())
@torch_spawn([1]) @torch_spawn([1])
def chunks_less_than_1(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def chunks_less_than_1(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
with pytest.raises(ValueError): with pytest.raises(ValueError):
Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=0) Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=0)
with pytest.raises(ValueError): with pytest.raises(ValueError):
Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=-1) Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=-1)
@torch_spawn([1]) @torch_spawn([1])
def too_few_devices(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def too_few_devices(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1))
with pytest.raises(IndexError): with pytest.raises(IndexError):
# len(balance) > len(group.size()) # len(balance) > len(group.size())
model = Pipe(model, balance=[1, 1, 1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) model = Pipe(model, balance=[1, 1, 1, 1], style=pipeline_style, worker_map=get_worker_map())
@torch_spawn([1]) @torch_spawn([1])
def batch_size_indivisible(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def batch_size_indivisible(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=4) model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=4)
with pytest.warns(None) as record: with pytest.warns(None) as record:
model(torch.rand(7, 1)) model(torch.rand(7, 1))
...@@ -232,9 +244,10 @@ def batch_size_indivisible(): ...@@ -232,9 +244,10 @@ def batch_size_indivisible():
@torch_spawn([1]) @torch_spawn([1])
def batch_size_small(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def batch_size_small(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=4) model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=4)
with pytest.warns(None) as record: with pytest.warns(None) as record:
model(torch.rand(2, 1)) model(torch.rand(2, 1))
...@@ -244,7 +257,8 @@ def batch_size_small(): ...@@ -244,7 +257,8 @@ def batch_size_small():
@torch_spawn([1]) @torch_spawn([1])
def checkpoint_mode(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def checkpoint_mode(pipeline_style):
def count_grad_fn(grad_fn, name, visited=set()): def count_grad_fn(grad_fn, name, visited=set()):
if grad_fn in visited: if grad_fn in visited:
return 0 return 0
...@@ -266,7 +280,7 @@ def checkpoint_mode(): ...@@ -266,7 +280,7 @@ def checkpoint_mode():
always = Pipe( always = Pipe(
model, model,
balance=[1], balance=[1],
style=Pipe.MultiProcess, style=pipeline_style,
worker_map=get_worker_map(), worker_map=get_worker_map(),
chunks=2, chunks=2,
checkpoint="always", checkpoint="always",
...@@ -275,7 +289,7 @@ def checkpoint_mode(): ...@@ -275,7 +289,7 @@ def checkpoint_mode():
except_last = Pipe( except_last = Pipe(
model, model,
balance=[1], balance=[1],
style=Pipe.MultiProcess, style=pipeline_style,
worker_map=get_worker_map(), worker_map=get_worker_map(),
chunks=2, chunks=2,
checkpoint="except_last", checkpoint="except_last",
...@@ -284,7 +298,7 @@ def checkpoint_mode(): ...@@ -284,7 +298,7 @@ def checkpoint_mode():
never = Pipe( never = Pipe(
model, model,
balance=[1], balance=[1],
style=Pipe.MultiProcess, style=pipeline_style,
worker_map=get_worker_map(), worker_map=get_worker_map(),
chunks=2, chunks=2,
checkpoint="never", checkpoint="never",
...@@ -301,14 +315,15 @@ def checkpoint_mode(): ...@@ -301,14 +315,15 @@ def checkpoint_mode():
@torch_spawn([1]) @torch_spawn([1])
def checkpoint_mode_invalid(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def checkpoint_mode_invalid(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"): with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"):
Pipe( Pipe(
model, model,
balance=[1], balance=[1],
style=Pipe.MultiProcess, style=pipeline_style,
worker_map=get_worker_map(), worker_map=get_worker_map(),
chunks=2, chunks=2,
checkpoint="INVALID_CHECKPOINT", checkpoint="INVALID_CHECKPOINT",
...@@ -316,22 +331,24 @@ def checkpoint_mode_invalid(): ...@@ -316,22 +331,24 @@ def checkpoint_mode_invalid():
@torch_spawn([1]) @torch_spawn([1])
def checkpoint_mode_when_chunks_1(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def checkpoint_mode_when_chunks_1(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
# All checkpoint modes are fine. # All checkpoint modes are fine.
Pipe( Pipe(
model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=1, checkpoint="except_last", model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="except_last",
) )
Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=1, checkpoint="always") Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="always")
Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=1, checkpoint="never") Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="never")
@torch_spawn([1]) @torch_spawn([1])
def checkpoint_eval(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def checkpoint_eval(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = Pipe( model = Pipe(
model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2, pipelined_backward=False, model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False,
) )
input = torch.rand(2, 1) input = torch.rand(2, 1)
...@@ -356,11 +373,16 @@ def checkpoint_eval(): ...@@ -356,11 +373,16 @@ def checkpoint_eval():
assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward") assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward")
def torch_version() -> Tuple[int, ...]:
result = version.parse(torch.__version__).release
assert result
return result
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.xfail( @pytest.mark.xfail(torch_version() < (1, 6, 0), reason="Doesn't work on torch < 1.6.0", strict=True)
version.parse(torch.__version__) < version.parse("1.6.0"), reason="Doesn't work on torch < 1.6.0", strict=True @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
) def checkpoint_non_float_input(pipeline_style):
def checkpoint_non_float_input():
class ForkNonFloat(nn.Module): class ForkNonFloat(nn.Module):
def forward(self, input): def forward(self, input):
return (input * 2, torch.tensor([False])) return (input * 2, torch.tensor([False]))
...@@ -373,7 +395,7 @@ def checkpoint_non_float_input(): ...@@ -373,7 +395,7 @@ def checkpoint_non_float_input():
model = Pipe( model = Pipe(
model, model,
balance=[1, 1], balance=[1, 1],
style=Pipe.MultiProcess, style=pipeline_style,
worker_map=get_worker_map(), worker_map=get_worker_map(),
chunks=1, chunks=1,
checkpoint="always", checkpoint="always",
...@@ -385,14 +407,17 @@ def checkpoint_non_float_input(): ...@@ -385,14 +407,17 @@ def checkpoint_non_float_input():
if model.group.rank() == 1: if model.group.rank() == 1:
# with torch.autograd.detect_anomaly(): # with torch.autograd.detect_anomaly():
output.backward() output.backward()
else: elif pipeline_style == Pipe.MultiProcess:
model.back_helper(output) model.back_helper(output)
torch.distributed.barrier()
@torch_spawn([1]) @torch_spawn([1])
def no_grad(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def no_grad(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2) model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2)
input = torch.rand(2, 1) input = torch.rand(2, 1)
latent = None latent = None
...@@ -404,8 +429,8 @@ def no_grad(): ...@@ -404,8 +429,8 @@ def no_grad():
nonlocal latent nonlocal latent
latent = output latent = output
partition = model.partitions[0] partition = model.mp_partitions[0]
partition.register_forward_hook(hook) partition.module.register_forward_hook(hook)
with torch.no_grad(): with torch.no_grad():
model(input) model(input)
...@@ -414,7 +439,8 @@ def no_grad(): ...@@ -414,7 +439,8 @@ def no_grad():
@torch_spawn([1]) @torch_spawn([1])
def exception(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def exception(pipeline_style):
class ExpectedException(Exception): class ExpectedException(Exception):
pass pass
...@@ -423,7 +449,7 @@ def exception(): ...@@ -423,7 +449,7 @@ def exception():
raise ExpectedException() raise ExpectedException()
model = nn.Sequential(Raise()) model = nn.Sequential(Raise())
model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=1) model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1)
with pytest.raises(ExpectedException): with pytest.raises(ExpectedException):
model(torch.rand(1)) model(torch.rand(1))
...@@ -432,7 +458,8 @@ def exception(): ...@@ -432,7 +458,8 @@ def exception():
# FIXME(tom) should probably signal to all hosts in group to stop # FIXME(tom) should probably signal to all hosts in group to stop
@torch_spawn([4]) @torch_spawn([4])
@pytest.mark.xfail(strict=True) @pytest.mark.xfail(strict=True)
def exception_early_stop_asap(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def exception_early_stop_asap(pipeline_style):
"""Even the first partitions have finished to process, the partition before """Even the first partitions have finished to process, the partition before
the failed partition hould be killed as soon as possible. the failed partition hould be killed as soon as possible.
""" """
...@@ -460,7 +487,7 @@ def exception_early_stop_asap(): ...@@ -460,7 +487,7 @@ def exception_early_stop_asap():
raise ExpectedException() raise ExpectedException()
model = nn.Sequential(Pass(), Pass(), Counter(), Raise()) model = nn.Sequential(Pass(), Pass(), Counter(), Raise())
model = Pipe(model, [1, 1, 1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=3) model = Pipe(model, [1, 1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3)
with pytest.raises(ExpectedException): with pytest.raises(ExpectedException):
model(torch.rand(3)) model(torch.rand(3))
...@@ -470,7 +497,8 @@ def exception_early_stop_asap(): ...@@ -470,7 +497,8 @@ def exception_early_stop_asap():
@torch_spawn([1]) @torch_spawn([1])
def input_pair(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def input_pair(pipeline_style):
class Two(nn.Module): class Two(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -483,7 +511,7 @@ def input_pair(): ...@@ -483,7 +511,7 @@ def input_pair():
model = nn.Sequential(Two()) model = nn.Sequential(Two())
model = Pipe( model = Pipe(
model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2, pipelined_backward=False, model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False,
) )
a = torch.rand(10, 1, requires_grad=True) a = torch.rand(10, 1, requires_grad=True)
...@@ -498,7 +526,8 @@ def input_pair(): ...@@ -498,7 +526,8 @@ def input_pair():
@torch_spawn([1]) @torch_spawn([1])
def input_singleton(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def input_singleton(pipeline_style):
class One(nn.Module): class One(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -510,7 +539,7 @@ def input_singleton(): ...@@ -510,7 +539,7 @@ def input_singleton():
model = nn.Sequential(One()) model = nn.Sequential(One())
model = Pipe( model = Pipe(
model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2, pipelined_backward=False, model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False,
) )
a = torch.rand(10, 1, requires_grad=True) a = torch.rand(10, 1, requires_grad=True)
...@@ -524,9 +553,10 @@ def input_singleton(): ...@@ -524,9 +553,10 @@ def input_singleton():
@torch_spawn([1]) @torch_spawn([1])
def input_varargs(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def input_varargs(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map()) model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map())
a = torch.rand(1) a = torch.rand(1)
b = torch.rand(1) b = torch.rand(1)
...@@ -537,13 +567,14 @@ def input_varargs(): ...@@ -537,13 +567,14 @@ def input_varargs():
@torch_spawn([1]) @torch_spawn([1])
def non_tensor(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def non_tensor(pipeline_style):
class NonTensor(nn.Module): class NonTensor(nn.Module):
def forward(self, _): def forward(self, _):
return "hello" return "hello"
model = nn.Sequential(NonTensor()) model = nn.Sequential(NonTensor())
model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map()) model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map())
x = torch.rand(1) x = torch.rand(1)
# TypeError: expected Tensor as element 0 in argument 0, but got str # TypeError: expected Tensor as element 0 in argument 0, but got str
...@@ -556,13 +587,14 @@ def non_tensor(): ...@@ -556,13 +587,14 @@ def non_tensor():
@torch_spawn([1]) @torch_spawn([1])
def non_tensor_tuple(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def non_tensor_tuple(pipeline_style):
class NonTensorTuple(nn.Module): class NonTensorTuple(nn.Module):
def forward(self, x): def forward(self, x):
return (x, "hello") return (x, "hello")
model = nn.Sequential(NonTensorTuple()) model = nn.Sequential(NonTensorTuple())
model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map()) model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map())
x = torch.rand(1) x = torch.rand(1)
# TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1 # TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1
...@@ -577,18 +609,19 @@ def non_tensor_tuple(): ...@@ -577,18 +609,19 @@ def non_tensor_tuple():
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
@pytest.mark.parametrize("lazy", [True, False]) @pytest.mark.parametrize("lazy", [True, False])
def deferred_batch_norm(checkpoint, lazy): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def deferred_batch_norm(checkpoint, lazy, pipeline_style):
bn = nn.BatchNorm2d(3) bn = nn.BatchNorm2d(3)
pipe_bn = deepcopy(bn) pipe_bn = deepcopy(bn)
pipe_fn = lambda: pipe_bn # noqa: E731 pipe_fn = lambda: pipe_bn # noqa: E731
if lazy: if lazy:
model = [pipe_fn] model = [LazyModule(pipe_fn)]
else: else:
model = nn.Sequential(pipe_bn) model = nn.Sequential(pipe_bn)
pipe = Pipe( pipe = Pipe(
model, model,
balance=[1], balance=[1],
style=Pipe.MultiProcess, style=pipeline_style,
worker_map=get_worker_map(), worker_map=get_worker_map(),
chunks=2, chunks=2,
checkpoint=checkpoint, checkpoint=checkpoint,
...@@ -606,18 +639,19 @@ def deferred_batch_norm(checkpoint, lazy): ...@@ -606,18 +639,19 @@ def deferred_batch_norm(checkpoint, lazy):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("checkpoint", ["never", "always"]) @pytest.mark.parametrize("checkpoint", ["never", "always"])
@pytest.mark.parametrize("lazy", [True, False]) @pytest.mark.parametrize("lazy", [True, False])
def deferred_batch_norm_params(checkpoint, lazy): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
bn = nn.BatchNorm2d(3) bn = nn.BatchNorm2d(3)
pipe_bn = deepcopy(bn) pipe_bn = deepcopy(bn)
pipe_fn = lambda: pipe_bn # noqa: E731 pipe_fn = lambda: pipe_bn # noqa: E731
if lazy: if lazy:
model = [pipe_fn] model = [LazyModule(pipe_fn)]
else: else:
model = nn.Sequential(pipe_bn) model = nn.Sequential(pipe_bn)
pipe = Pipe( pipe = Pipe(
model, model,
balance=[1], balance=[1],
style=Pipe.MultiProcess, style=pipeline_style,
worker_map=get_worker_map(), worker_map=get_worker_map(),
chunks=1, chunks=1,
checkpoint=checkpoint, checkpoint=checkpoint,
...@@ -636,14 +670,15 @@ def deferred_batch_norm_params(checkpoint, lazy): ...@@ -636,14 +670,15 @@ def deferred_batch_norm_params(checkpoint, lazy):
@torch_spawn([4]) @torch_spawn([4])
def devices(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def devices(pipeline_style):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
c = nn.Linear(1, 1) c = nn.Linear(1, 1)
# There are extra two ranks. # There are extra two ranks.
model = nn.Sequential(a, b, c) model = nn.Sequential(a, b, c)
model = Pipe(model, [1, 1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) model = Pipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map())
# Extra devices must be discarded. # Extra devices must be discarded.
if model.group.rank() == 3: if model.group.rank() == 3:
...@@ -651,28 +686,33 @@ def devices(): ...@@ -651,28 +686,33 @@ def devices():
@torch_spawn([2]) @torch_spawn([2])
def partitions(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def partitions(pipeline_style):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
model = nn.Sequential(a, b) model = nn.Sequential(a, b)
model = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
assert isinstance(model.partitions, nn.ModuleList) assert isinstance(model.mp_partitions, list)
assert len(model) == 1 assert len(model) == 1
assert isinstance(model.partitions[0], nn.Sequential) assert isinstance(model.mp_partitions[0].module, nn.Sequential)
assert "partitions.0.0.weight" in model.state_dict() if model.group.rank() == 0:
assert "0.0.weight" in model.state_dict()
else:
assert "0.1.weight" in model.state_dict()
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def deny_moving(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def deny_moving(pipeline_style):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
model = nn.Sequential(a, b) model = nn.Sequential(a, b)
model = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
model.cuda() model.cuda()
model.cpu() model.cpu()
...@@ -690,10 +730,11 @@ def deny_moving(): ...@@ -690,10 +730,11 @@ def deny_moving():
@torch_spawn([1]) @torch_spawn([1])
def empty_module(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def empty_module(pipeline_style):
# Empty sequential module is not illegal. # Empty sequential module is not illegal.
model = nn.Sequential() model = nn.Sequential()
model = Pipe(model, [], style=Pipe.MultiProcess, worker_map=get_worker_map()) model = Pipe(model, [], style=pipeline_style, worker_map=get_worker_map())
assert model(torch.tensor([42])) == torch.tensor([42]) assert model(torch.tensor([42])) == torch.tensor([42])
assert model((torch.tensor([42]),)) == (torch.tensor([42]),) assert model((torch.tensor([42]),)) == (torch.tensor([42]),)
...@@ -705,16 +746,19 @@ def empty_module(): ...@@ -705,16 +746,19 @@ def empty_module():
@torch_spawn([2]) @torch_spawn([2])
def named_children(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def named_children(pipeline_style):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
model = nn.Sequential(OrderedDict([("a", a), ("b", b)])) model = nn.Sequential(OrderedDict([("a", a), ("b", b)]))
model = Pipe(model, [1, 1], devices=["cpu", "cpu"]) model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
names = set(n for n, _ in model.named_modules()) names = set(n for n, _ in model.named_modules())
assert "partitions.0.a" in names if model.group.rank() == 0:
assert "partitions.1.b" in names assert "0.a" in names
else:
assert "0.b" in names
# Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires # Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires
# several methods in its namespace. # several methods in its namespace.
...@@ -723,7 +767,8 @@ def named_children(): ...@@ -723,7 +767,8 @@ def named_children():
@torch_spawn([1]) @torch_spawn([1])
def recommend_auto_balance(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def recommend_auto_balance(pipeline_style):
with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"): with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"):
# balance is required # balance is required
Pipe(nn.Sequential()) Pipe(nn.Sequential())
...@@ -737,23 +782,9 @@ def recommend_auto_balance(): ...@@ -737,23 +782,9 @@ def recommend_auto_balance():
Pipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1]) Pipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1])
@torch_spawn([1])
def verify_module_non_sequential():
with pytest.raises(TypeError, match="module must be nn.Sequential to be partitioned"):
Pipe(nn.Module(), [1])
@torch_spawn([1])
def verify_module_duplicate_children():
conv = nn.Conv2d(3, 3, 1)
model = nn.Sequential(conv, conv)
with pytest.raises(ValueError, match="module with duplicate children is not supported"):
Pipe(model, [1, 1])
@torch_spawn([2]) @torch_spawn([2])
def lazy_construction(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def lazy_construction(pipeline_style):
init_count = 0 init_count = 0
class Custom(nn.Module): class Custom(nn.Module):
...@@ -766,13 +797,13 @@ def lazy_construction(): ...@@ -766,13 +797,13 @@ def lazy_construction():
return x return x
model = [ model = [
lambda: Custom(), LazyModule(lambda: Custom()),
lambda: Custom(), LazyModule(lambda: Custom()),
lambda: Custom(), LazyModule(lambda: Custom()),
lambda: Custom(), LazyModule(lambda: Custom()),
] ]
pipe = Pipe(model, balance=[2, 2], style=Pipe.MultiProcess, worker_map=get_worker_map()) pipe = Pipe(model, balance=[2, 2], style=pipeline_style, worker_map=get_worker_map())
assert isinstance(pipe[0], Custom) assert isinstance(pipe[0], Custom)
assert isinstance(pipe[1], Custom) assert isinstance(pipe[1], Custom)
...@@ -780,18 +811,20 @@ def lazy_construction(): ...@@ -780,18 +811,20 @@ def lazy_construction():
assert init_count == 2 assert init_count == 2
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="doesn't apply to mpi")
@torch_spawn([2]) @torch_spawn([2])
def missing_worker_map(): @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="doesn't apply to mpi")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def missing_worker_map(pipeline_style):
model = nn.Sequential(nn.ReLU(), nn.ReLU()) model = nn.Sequential(nn.ReLU(), nn.ReLU())
with pytest.raises(ValueError, match="'PipelineStyle.MultiProcess' requires 'worker_map' to be set"): with pytest.raises(ValueError, match="'RpcTransport' requires 'worker_map' to be set"):
Pipe(model, [1, 1], style=Pipe.MultiProcess) Pipe(model, [1, 1], style=pipeline_style)
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skip(reason="currently broken") @pytest.mark.skip(reason="currently broken")
def verify_module_duplicate_parameters_on_distinct_partitions(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def verify_module_duplicate_parameters_on_distinct_partitions(pipeline_style):
class Surrogate(nn.Module): class Surrogate(nn.Module):
def __init__(self, module): def __init__(self, module):
super().__init__() super().__init__()
...@@ -802,21 +835,205 @@ def verify_module_duplicate_parameters_on_distinct_partitions(): ...@@ -802,21 +835,205 @@ def verify_module_duplicate_parameters_on_distinct_partitions():
# FIXME(tom) can't have duplicate params with separate processes # FIXME(tom) can't have duplicate params with separate processes
with pytest.raises(ValueError, match="module with duplicate parameters on distinct devices is not supported"): with pytest.raises(ValueError, match="module with duplicate parameters on distinct devices is not supported"):
Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
@torch_spawn([4]) @torch_spawn([4])
def pipelined_backward(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def pipelined_backward(pipeline_style):
model = nn.Sequential(nn.ReLU(), nn.ReLU()) model = nn.Sequential(nn.ReLU(), nn.ReLU())
destroy_model_parallel() destroy_model_parallel()
initialize_model_parallel(1, 4) initialize_model_parallel(1, 4)
pipe = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) pipe = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
assert pipe.pipelined_backward is False assert pipe.pipelined_backward is False
destroy_model_parallel() destroy_model_parallel()
initialize_model_parallel(2, 2) initialize_model_parallel(2, 2)
pipe = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) pipe = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map())
assert pipe.pipelined_backward is True assert pipe.pipelined_backward is True
@torch_spawn([4])
def async_event_loop():
model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU())
pipe = Pipe(model, [1, 1, 1, 1], style=Pipe.AsyncSchedule, worker_map=get_worker_map(), chunks=10)
inputs = torch.rand(100, 10)
output = pipe(inputs)
if pipe.final_stage:
loss = output.mean()
loss.backward()
@torch_spawn([4])
def reuse_lazy():
if False: # speed
reused = LazyModule(lambda: nn.Linear(10, 10))
model = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()]
# model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()]
pipe = Pipe(model, [3, 1, 1], style=Pipe.AsyncSchedule, worker_map=get_worker_map())
pipe.eval()
output = pipe(torch.rand(10))
print(f"output on {pipe.group.rank()}, {output}")
torch.distributed.barrier()
set_random_seed(1234)
# test both foward
reused = nn.Linear(10, 10)
layers = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()]
model = nn.Sequential(*layers)
model.eval()
set_random_seed(1234)
# ensure identical weights but no sharing between model and pipe
reused = nn.Linear(10, 10)
layers = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()]
pipe = Pipe(layers, [3, 1, 1], style=Pipe.AsyncSchedule, worker_map=get_worker_map())
pipe.eval()
model_optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
pipe_optimizer = torch.optim.SGD(pipe.parameters(), lr=0.01, momentum=0.9) if len(list(pipe.parameters())) else None
inputs = torch.rand(10)
if False: # speed
model_out = model(inputs)
pipe_out = pipe(inputs)
torch.distributed.barrier()
if pipe.final_stage:
assert torch.equal(model_out, pipe_out)
model.train()
pipe.train()
model_out = model(inputs)
pipe_out = pipe(inputs)
if pipe.final_stage:
pipe_loss = pipe_out.mean()
pipe_loss.backward()
model_loss = model_out.mean()
model_loss.backward()
model_optimizer.step()
if pipe_optimizer:
pipe_optimizer.step()
model.eval()
pipe.eval()
model_out = model(inputs)
pipe_out = pipe(inputs)
print(f"before barrier on {torch.distributed.get_rank()}")
torch.distributed.barrier()
print(f"after barrier on {torch.distributed.get_rank()}")
if pipe.final_stage:
assert torch.equal(model_out, pipe_out)
def test_instantiate_partition():
from fairscale.nn.pipe.async_schedule import Location
from fairscale.nn.pipe.pipe import instantiate_partition
class FakeGroup:
def __init__(self, rank, size):
self._rank = rank
self._size = size
def rank(self):
return self._rank
def size(self):
return self._size
def check_partitions(model, balance, expected_order, expected_ranks):
"""Check the instantiated model matches expectation of order and rank
model: a list of modules or an nn.Sequential
balance: the balance argument to Pipe
expected_order: the index of modules in `model` in the order they will
be executed, grouped by nn.Sequential
expected_rank: the rank that each module will be executed on
"""
invocations = []
invocation_wrapper = dict()
# Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from
# instantiated model
for rank in range(len(balance)):
instantiated = instantiate_partition(model, balance, FakeGroup(rank, len(balance)), Pipe.AsyncSchedule)
for part in instantiated:
assert isinstance(part.module, nn.Sequential)
for inv in part.invocations:
invocations.append(inv)
invocation_wrapper[inv] = part
modules = []
prev = None
current = Location(0, 0)
ranks = []
for order, inv in enumerate(sorted(invocations, key=lambda x: x.order)):
# Check integrity of Location chain
assert inv.order == order
assert inv.source == prev
assert inv.this == current
prev = inv.this
current = inv.dest
modules.append(list(invocation_wrapper[inv].module.children()))
ranks.append(inv.this.stage)
# assert len(modules) == len(expected_order)
for left, right in zip(modules, expected_order):
assert len(left) == len(right), f"{right}"
assert list(map(id, left)) == list(map(id, (model[e] for e in right))), f"{right}"
assert ranks == expected_ranks
reused = nn.Linear(20, 20)
model = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()]
balance = [3, 1, 1]
check_partitions(
model, balance, expected_order=[[0], [1, 2], [0], [4], [0], [6]], expected_ranks=[0, 0, 0, 1, 0, 2]
)
reused2 = nn.Linear(5, 5)
model = [reused, reused2, nn.Linear(10, 10), nn.ReLU(), reused, reused2, nn.ReLU(), reused, reused2, nn.ReLU()]
balance = [4, 1, 1]
check_partitions(
model,
balance,
expected_order=[[0], [1], [2, 3], [0], [1], [6], [0], [1], [9]],
expected_ranks=[0, 0, 0, 0, 0, 1, 0, 0, 2],
)
reused2 = nn.Linear(5, 5)
model = [
nn.Linear(10, 10),
reused,
nn.Linear(10, 10),
nn.ReLU(),
reused,
reused2,
nn.ReLU(),
reused,
reused2,
nn.ReLU(),
]
# 0 1 2 3 1 5 6 1 5 9
balance = [4, 2, 1]
check_partitions(
model,
balance,
expected_order=[[0], [1], [2, 3], [1], [5], [6], [1], [5], [9]],
expected_ranks=[0, 0, 0, 0, 1, 1, 0, 1, 2],
)
import copy
import os
import pytest
import torch
from torch import nn
from torch.distributed import rpc
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import PipeRPCWrapper
from tests.nn.model_parallel.commons import get_worker_map, torch_spawn
def init_rpc():
os.environ["MASTER_PORT"] = "10639"
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
rpc.init_rpc(
f"Test{torch.distributed.get_rank()}",
rank=torch.distributed.get_rank(),
world_size=torch.distributed.get_world_size(),
backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(init_method=init_method),
)
@torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="mpi required")
def basic_rpc():
init_rpc()
if torch.distributed.get_rank() != 0:
rpc.shutdown()
torch.distributed.barrier()
return
model = [nn.Linear(10, 10), nn.ReLU()]
pipe = PipeRPCWrapper(model, [1, 1], input_device=torch.cuda.current_device(), worker_map=get_worker_map())
pipe.foreach_worker(register_optimizer, include_self=True)
inputs = torch.rand(10).cuda()
output = pipe(inputs)
loss = output.mean()
loss.backward()
pipe.foreach_worker(step_optimizer, include_self=True)
pipe.eval()
rpc.shutdown()
torch.distributed.barrier()
def register_optimizer(ctx, model):
if len(list(model.parameters())) > 0:
model.optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
else:
model.optimizer = None
def step_optimizer(ctx, model):
if model.optimizer:
model.optimizer.step()
def check_pipe_against_reference(balance, model_constructor, checkpoint="except_last", custom_inputs=None):
model = model_constructor()
reference_model = model_constructor()
for src, dst in zip(model, reference_model):
dst.load_state_dict(copy.deepcopy(src.state_dict()))
reference_model = nn.Sequential(*reference_model).cuda()
pipe = PipeRPCWrapper(
model, balance, input_device=torch.cuda.current_device(), worker_map=get_worker_map(), checkpoint=checkpoint,
)
pipe.foreach_worker(register_optimizer, include_self=True)
register_optimizer(None, reference_model)
inputs = torch.rand(10).cuda()
target = torch.rand(10).cuda()
cloned = inputs.clone()
output = pipe(inputs)
ref_out = reference_model(inputs)
assert torch.equal(ref_out.cpu(), output.cpu())
for out in output, ref_out:
target = target.to(out.device)
loss = nn.MSELoss()(out, target)
loss.backward()
pipe.foreach_worker(step_optimizer, include_self=True)
step_optimizer(None, reference_model.cuda())
pipe.eval()
reference_model.eval()
final_output = pipe(inputs)
final_ref = reference_model(inputs.cuda())
assert torch.equal(final_output.cpu(), final_ref.cpu())
@torch_spawn([3])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="mpi required")
def rpc_optimizer():
init_rpc()
if torch.distributed.get_rank() != 0:
rpc.shutdown()
torch.distributed.barrier()
return
def model_with_reuse():
reused_1 = nn.Linear(10, 10)
return [reused_1, nn.ReLU(), reused_1, nn.ReLU(), reused_1, nn.ReLU()]
check_pipe_against_reference(
[2, 2, 2], lambda: [nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU()],
)
check_pipe_against_reference([2, 1, 1], model_with_reuse)
rpc.shutdown()
torch.distributed.barrier()
@torch_spawn([6])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="mpi required")
def rpc_megatron_reuse():
from fairscale.nn.model_parallel import layers
from fairscale.nn.model_parallel.initialize import destroy_model_parallel, initialize_model_parallel
def make_model_simple():
return [
layers.ColumnParallelLinear(10, 10),
nn.ReLU(),
layers.RowParallelLinear(10, 10),
nn.ReLU(),
layers.ColumnParallelLinear(10, 10),
nn.ReLU(),
layers.RowParallelLinear(10, 10),
nn.ReLU(),
nn.Linear(10, 10),
nn.ReLU(),
]
def make_model_with_reuse():
column = layers.ColumnParallelLinear(10, 10)
row = layers.RowParallelLinear(10, 10)
return [
column,
nn.ReLU(),
row,
nn.ReLU(),
column,
nn.ReLU(),
row,
nn.ReLU(),
nn.Linear(10, 10),
nn.ReLU(),
]
destroy_model_parallel()
torch.distributed.destroy_process_group()
torch.distributed.init_process_group("gloo", rank=int(os.environ["RANK"]), world_size=int(os.environ["WORLD_SIZE"]))
initialize_model_parallel(2, 3, model_parallel_backend="nccl", pipeline_backend="mpi")
init_rpc()
if get_pipeline_parallel_group().rank() != 0:
rpc.shutdown()
torch.distributed.barrier()
return
check_pipe_against_reference([4, 4, 2], make_model_simple, "always")
check_pipe_against_reference([4, 2, 2], make_model_with_reuse)
rpc.shutdown()
torch.distributed.barrier()
@torch_spawn([3])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="mpi required")
def rpc_reuse_in_final_stage():
# 'reused' and 'reused2' are located on stage 2, so the backward pass for
# the final stage will need to first send gradients to stage 2, then receive
# gradients from stage 2. This tests custom logic to handle reuse of layers
# in the final stage of the pipeline.
reused = nn.Linear(10, 10)
reused2 = nn.Linear(10, 10)
model = [
nn.Linear(10, 10),
nn.ReLU(),
nn.Linear(10, 10),
reused2,
nn.ReLU(),
reused,
nn.ReLU(),
reused,
reused2,
nn.ReLU(),
reused,
nn.ReLU(),
]
balance = [2, 3, 4]
init_rpc()
if torch.distributed.get_rank() != 0:
rpc.shutdown()
torch.distributed.barrier()
return
pipe = PipeRPCWrapper(model, balance, worker_map=get_worker_map())
inputs = torch.rand(10).cuda()
target = torch.rand(10).cuda()
output = pipe(inputs)
nn.MSELoss()(output, target).backward()
output = pipe(inputs)
nn.MSELoss()(output, target).backward()
rpc.shutdown()
torch.distributed.barrier()
@torch_spawn([3])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="mpi required")
def rpc_multiple_tensors():
class FuseTwo(nn.Module):
def forward(self, left, right):
return left + right
class SplitTwo(nn.Module):
def forward(self, inputs):
return (inputs, 2 * inputs)
@torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="no mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def construct_only_rank_zero():
model = [nn.Linear(10, 10), nn.ReLU()]
if torch.distributed.get_rank() == 0:
PipeRPCWrapper(model, [1, 1], worker_map=get_worker_map())
rpc.shutdown()
else:
# Must enter rpc loop to complte PipeRPCWrapper constructor above
rpc.shutdown()
with pytest.raises(AssertionError):
PipeRPCWrapper(model, [1, 1], worker_map=get_worker_map())
...@@ -27,7 +27,8 @@ from tests.nn.model_parallel.commons import get_worker_map, set_random_seed, tor ...@@ -27,7 +27,8 @@ from tests.nn.model_parallel.commons import get_worker_map, set_random_seed, tor
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def simple_linears(): @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def simple_linears(pipeline_style):
def sum_grad(parameters): def sum_grad(parameters):
return sum([p.grad.sum() for p in parameters if p.grad is not None]) return sum([p.grad.sum() for p in parameters if p.grad is not None])
...@@ -54,19 +55,19 @@ def simple_linears(): ...@@ -54,19 +55,19 @@ def simple_linears():
zero_grad(model.parameters()) zero_grad(model.parameters())
# With Pipe # With Pipe
model = Pipe(model, [2, 2], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=4) model = Pipe(model, [2, 2], style=pipeline_style, worker_map=get_worker_map(), chunks=4)
outputs = model(inputs) outputs = model(inputs)
if model.group.rank() == 1: if model.group.rank() == 1:
loss = outputs.mean() loss = outputs.mean()
loss.backward() loss.backward()
grad_with_pipe = sum_grad(model.pipeline.partitions[0].parameters()) grad_with_pipe = sum_grad(model.pipeline.mp_partitions[0].module.parameters())
# Both grads should be identical. # Both grads should be identical.
assert torch.allclose(grad_with_pipe, grad_without_pipe[1]) assert torch.allclose(grad_with_pipe, grad_without_pipe[1])
else: else:
model.back_helper(outputs) model.back_helper(outputs)
grad_with_pipe = sum_grad(model.pipeline.partitions[0].parameters()) grad_with_pipe = sum_grad(model.pipeline.mp_partitions[0].module.parameters())
# Both grads should be identical. # Both grads should be identical.
assert torch.allclose(grad_with_pipe, grad_without_pipe[0]) assert torch.allclose(grad_with_pipe, grad_without_pipe[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment