Unverified Commit 2d3d5a7b authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[feat] remove old MultiProcessPipe (#563)

parent e141a93e
...@@ -197,12 +197,6 @@ run_pipe_benchmark: &run_pipe_benchmark ...@@ -197,12 +197,6 @@ run_pipe_benchmark: &run_pipe_benchmark
command: | command: |
python benchmarks/pipe.py python benchmarks/pipe.py
run_mp_pipe_benchmark: &run_mp_pipe_benchmark
- run:
name: Run Multiprocess Pipe Benchmark
command: |
python benchmarks/pipe.py --multiprocess --lazy-construction
run_oss_benchmark: &run_oss_benchmark run_oss_benchmark: &run_oss_benchmark
- run: - run:
name: Run OSS Benchmark name: Run OSS Benchmark
...@@ -578,8 +572,6 @@ jobs: ...@@ -578,8 +572,6 @@ jobs:
- <<: *run_pipe_benchmark - <<: *run_pipe_benchmark
- <<: *run_mp_pipe_benchmark
- <<: *run_oss_amp - <<: *run_oss_amp
- <<: *run_oss_for_each - <<: *run_oss_for_each
......
...@@ -80,19 +80,12 @@ class Pipe: ...@@ -80,19 +80,12 @@ class Pipe:
"criterion": nn.CrossEntropyLoss(), "criterion": nn.CrossEntropyLoss(),
} }
def get_golden_real_stats(multiprocess=False): def get_golden_real_stats():
if not multiprocess: return {
return { "avg_wps": 703.778,
"avg_wps": 703.778, "std_dev_wps": 5.732,
"std_dev_wps": 5.732, "peak_mem_usage": [2320996352, 1396742144, 1396742144, 2340010496],
"peak_mem_usage": [2320996352, 1396742144, 1396742144, 2340010496], }
}
else:
return {
"avg_wps": 647.404,
"std_dev_wps": 14.51,
"peak_mem_usage": [3305007616, 2578692608, 3304524288, 2578692608],
}
def get_golden_synthetic_stats(): def get_golden_synthetic_stats():
# TODO(anj-s): Add support for synthetic regression benchmarks # TODO(anj-s): Add support for synthetic regression benchmarks
......
...@@ -16,16 +16,13 @@ import numpy as np ...@@ -16,16 +16,13 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import rpc from torch.distributed import rpc
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam from torch.optim import Adam
from benchmarks.golden_configs.lm_wikitext2 import Pipe as lm_wikitext2 from benchmarks.golden_configs.lm_wikitext2 import Pipe as lm_wikitext2
from fairscale.nn import Pipe from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group from fairscale.utils.testing import dist_init
from fairscale.nn.pipe import LazyModule, MultiProcessPipe
from fairscale.utils.testing import dist_init, get_worker_map
MPI_PORT = 29500 MPI_PORT = 29500
RPC_PORT = 29501 RPC_PORT = 29501
...@@ -211,7 +208,7 @@ def train(model_config, model, benchmark_config, model_specs, args): ...@@ -211,7 +208,7 @@ def train(model_config, model, benchmark_config, model_specs, args):
if i % log_interval == 0 and i > 0: if i % log_interval == 0 and i > 0:
cur_loss = total_loss / log_interval cur_loss = total_loss / log_interval
elapsed = time.time() - start_time elapsed = time.time() - start_time
if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1: if dist.get_rank() == dist.get_world_size() - 1:
logging.debug( logging.debug(
"| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format( "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss) i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
...@@ -227,7 +224,7 @@ def train(model_config, model, benchmark_config, model_specs, args): ...@@ -227,7 +224,7 @@ def train(model_config, model, benchmark_config, model_specs, args):
raise RuntimeError( raise RuntimeError(
"Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark." "Unable to benchmark on a single batch. Increase the size " " of the dataset and rerun the benchmark."
) )
if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1: if dist.get_rank() == dist.get_world_size() - 1:
return wps, loss.item() return wps, loss.item()
else: else:
return 0.0, 0.0 return 0.0, 0.0
...@@ -276,8 +273,7 @@ def verify_peak_memory(rank, golden_config, std_dev): ...@@ -276,8 +273,7 @@ def verify_peak_memory(rank, golden_config, std_dev):
def verify_lm_run(wps, golden_config, args): def verify_lm_run(wps, golden_config, args):
"""Verify that words per second for a given benchmark run matches the golden data.""" """Verify that words per second for a given benchmark run matches the golden data."""
# Verify wps only on the last rank in multiprocess pipe if dist.get_rank() == dist.get_world_size() - 1:
if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
# Assert that words per second is within 3 standard deviations of the average # Assert that words per second is within 3 standard deviations of the average
# of five golden runs # of five golden runs
logging.info("Throughput(wps) is {:.2f}.".format(wps)) logging.info("Throughput(wps) is {:.2f}.".format(wps))
...@@ -289,11 +285,8 @@ def verify_lm_run(wps, golden_config, args): ...@@ -289,11 +285,8 @@ def verify_lm_run(wps, golden_config, args):
) )
) )
if args.multiprocess: for i in range(4):
verify_peak_memory(dist.get_rank(), golden_config, 1.5) verify_peak_memory(i, golden_config, 1.1)
else:
for i in range(4):
verify_peak_memory(i, golden_config, 1.1)
def benchmark_language_model(model_config, model, benchmark_config, model_specs, args): def benchmark_language_model(model_config, model, benchmark_config, model_specs, args):
...@@ -400,7 +393,7 @@ def get_golden_config(model_name, args): ...@@ -400,7 +393,7 @@ def get_golden_config(model_name, args):
"""Return a dict with the golden data for throughput and memory usage.""" """Return a dict with the golden data for throughput and memory usage."""
if model_name == "lm": if model_name == "lm":
return lm_wikitext2.get_golden_real_stats(args.multiprocess) return lm_wikitext2.get_golden_real_stats()
else: else:
raise RuntimeError("Unrecognized args.model_mame " % args.model_name) raise RuntimeError("Unrecognized args.model_mame " % args.model_name)
...@@ -431,32 +424,6 @@ def benchmark_single_process(args): ...@@ -431,32 +424,6 @@ def benchmark_single_process(args):
benchmark_language_model(model_config, pipe_model, benchmark_config, model_specs, args) benchmark_language_model(model_config, pipe_model, benchmark_config, model_specs, args)
def run_mp_worker(args, available_workers):
benchmark_config = create_benchmark_config(args.model_name)
model_specs = get_model_specs(args.model_name)
model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
model = model_config["model"]
balance = generate_balance(get_pipeline_parallel_group().size(), len(model))
pipe_model = MultiProcessPipe(
model,
balance,
chunks=args.chunks,
worker_map=get_worker_map(),
input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
checkpoint=args.checkpoint,
# TODO(anj-s): Do we need to comment this out? loss_fn=benchmark_config["criterion"],
)
if torch.cuda.is_available():
pipe_model = pipe_model.cuda()
if args.dry_run:
train(model_config, pipe_model, benchmark_config, model_specs, args)
else:
benchmark_language_model(model_config, pipe_model, benchmark_config, model_specs, args)
def run_worker(rank, world_size, args): def run_worker(rank, world_size, args):
if args.world_size != 0: if args.world_size != 0:
world_size = args.world_size world_size = args.world_size
...@@ -469,35 +436,7 @@ def run_worker(rank, world_size, args): ...@@ -469,35 +436,7 @@ def run_worker(rank, world_size, args):
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
def benchmark_multiprocess(rank, world_size, args):
init_method_pgroup = "tcp://localhost:{}".format(MPI_PORT)
# TODO(anj-s): Add regression benchmarks for nccl as well.
torch.distributed.init_process_group(
backend="gloo", rank=rank, world_size=world_size, init_method=init_method_pgroup
)
torch.cuda.set_device(rank % torch.cuda.device_count())
# TODO(anj-s): Move to TensorPipeRpcBackendOptions.
rpc.init_rpc(
f"Test{rank}",
rank=rank,
world_size=world_size,
backend=rpc.BackendType.PROCESS_GROUP,
rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(
rpc_timeout=20, init_method="tcp://localhost:{}".format(RPC_PORT)
),
)
initialize_model_parallel(1, world_size)
init_random_seed(0)
run_mp_worker(args, world_size)
rpc.shutdown()
torch.distributed.destroy_process_group()
parser = argparse.ArgumentParser(description="benchmark") parser = argparse.ArgumentParser(description="benchmark")
parser.add_argument("--multiprocess", action="store_true", help="Runs single process benchmarks.")
parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname") parser.add_argument("--host", "-o", type=str, default="localhost", help="hostname")
parser.add_argument("--chunks", type=int, default=1, help="number of microbatches per batch") parser.add_argument("--chunks", type=int, default=1, help="number of microbatches per batch")
parser.add_argument("--batch-size", type=int, default=8, help="size of a batch") parser.add_argument("--batch-size", type=int, default=8, help="size of a batch")
...@@ -522,10 +461,5 @@ if __name__ == "__main__": ...@@ -522,10 +461,5 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG) logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
if not args.multiprocess: logging.info(f"Running single process benchmark with args: {args}")
logging.info(f"Running single process benchmark with args: {args}") benchmark_single_process(args)
benchmark_single_process(args)
else:
world_size = max(torch.cuda.device_count(), 1)
logging.info(f"Running multiprocess benchmark with args: {args}")
mp.spawn(benchmark_multiprocess, args=(world_size, args), nprocs=world_size, join=True)
import os
from helpers import dist_init, get_data, get_loss_fun, get_model
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.optim as optim
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.pipe import MultiProcessPipe
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def run(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "10638"
dist_init(rank, world_size)
os.environ["MASTER_PORT"] = "10639"
dist.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size)
initialize_model_parallel(1, world_size)
model = get_model()
data, target = get_data()[0]
loss_fn = get_loss_fun()
device = torch.device("cuda", rank) if DEVICE == "cuda" else torch.device("cpu")
model = MultiProcessPipe(
model,
balance=[2, 1],
worker_map={0: "worker0", 1: "worker1"}, # Needed to convert ranks to RPC worker names
input_device=device,
).to(device)
# define optimizer and loss function
optimizer = optim.SGD(model.parameters(), lr=0.001)
# zero the parameter gradients
optimizer.zero_grad()
# outputs and target need to be on the same device
# forward step
outputs = model(data.to(device))
# compute loss
if rank == 1:
loss = loss_fn(outputs.to(device), target.to(device))
# backward + optimize
loss.backward()
optimizer.step()
else:
model.back_helper(outputs)
print(f"Finished Training Step on {rank}")
dist.rpc.shutdown()
del model
if __name__ == "__main__":
world_size = 2
mp.spawn(run, args=(world_size,), nprocs=world_size, join=True)
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
"""A Pipe implementation in PyTorch.""" """A Pipe implementation in PyTorch."""
from .async_pipe import AsyncPipe from .async_pipe import AsyncPipe
from .checkpoint import is_checkpointing, is_recomputing from .checkpoint import is_checkpointing, is_recomputing
from .multiprocess_pipe import LazyModule, MultiProcessPipe
from .pipe import Pipe from .pipe import Pipe
from .rpc import PipeRPCWrapper from .rpc import PipeRPCWrapper
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The MultiProcessPipe interface."""
from collections import OrderedDict
import threading
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
import warnings
import torch
from torch import Tensor, nn
import torch.autograd
import torch.cuda
from fairscale.nn.model_parallel import get_model_parallel_world_size, get_pipeline_parallel_group
from . import microbatch
from .batchnorm import DeferredBatchNorm
from .multiprocess_pipeline import MultiProcessPipeline
from .phony import get_phony
from .skip.layout import SkipLayout
from .types import LazyModule
__all__ = ["MultiProcessPipe", "LazyModule"]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
if TYPE_CHECKING:
Module = nn.Module[TensorOrTensors]
NamedModules = OrderedDict[str, Module]
else:
Module = nn.Module
NamedModules = OrderedDict
def verify_module(module: Union[nn.Sequential, List[LazyModule]]) -> None:
if len(set(map(id, module))) != len(module):
raise ValueError("module with duplicate children is not supported")
def check_balance(module: Union[nn.Sequential, List[LazyModule]], balance: List[int]) -> None:
if len(module) != sum(balance):
raise ValueError(
f"module and sum of balance have different length (module: {len(module)}, sum of balance: {sum(balance)})"
)
if any(x <= 0 for x in balance):
raise ValueError(f"all balance numbers must be positive integer (balance: {balance})")
MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement")
class MultiProcessPipe(Module):
"""Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module
to train on Pipe_. If the module requires lots of memory, Pipe will be
very efficient.
::
model = nn.Sequential(a, b, c, d)
model = Pipe(model, balance=[1, 1, 1, 1], chunks=8)
output = model(input)
.. _Pipe: https://arxiv.org/abs/1811.06965
Pipe combines pipeline parallelism with checkpointing to reduce peak
memory required to train while minimizing device under-utilization.
You should determine the balance when defining a :class:`Pipe` module, as
balancing will not be done automatically. The module will be partitioned
into multiple devices according to the given balance. You may rely on
heuristics to find your own optimal configuration.
Args:
module (torch.nn.Sequential):
sequential module to be parallelized
balance (ints):
list of number of layers in each partition
Keyword Args:
group (ProcessGroup):
the process group that all
pipeline stages are a member of. Defaults to
`get_pipeline_parallel_group()`
worker_map (Dict[int, str]):
a map from worker name (the first argument to
`torch.distributed.rpc.init_rpc`) to global rank (i.e.
`torch.distributed.get_rank()`) needed in order for pipeline stages
to communicate with each other
input_device (device):
the device on which tensors should be located before being passed to
the first module in a given pipeline stage
chunks (int):
number of micro-batches (default: ``1``)
checkpoint (str):
when to enable checkpointing, one of ``'always'``,
``'except_last'``, or ``'never'`` (default: ``'except_last'``)
deferred_batch_norm (bool):
whether to use deferred BatchNorm moving statistics (default:
:data:`False`, see :class:`DeferredBatchNorm` for more
details)
Raises:
TypeError:
the module is not a :class:`nn.Sequential <torch.nn.Sequential>`.
ValueError:
invalid arguments, or wrong balance
IndexError:
the number of devices is fewer than the number of partitions.
"""
#: The number of layers in each partition.
balance: List[int] = []
# ^^
# The default value [] required for Sphinx's autoattribute.
#: The devices mapped to each partition.
#:
#: ``devices[-1]`` refers to the device of the last partition, which means
#: it is the output device. Probably, you need to use it to transfer the
#: target to calculate the loss without a device mismatch
#: :exc:`RuntimeError`. For example::
#:
#: out_device = pipe.devices[-1]
#:
#: for input, target in loader:
#: target = target.to(out_device, non_blocking=True)
#: output = pipe(input)
#: loss = F.cross_entropy(output, target)
#:
#: The number of micro-batches.
chunks: int = 1
#: The checkpoint mode to determine when to enable checkpointing. It is one
#: of ``'always'``, ``'except_last'``, or ``'never'``.
checkpoint: str = "except_last"
def __init__(
self,
module: Union[nn.Sequential, List[LazyModule]],
balance: Iterable[int],
*,
group: Optional[torch.distributed.ProcessGroup] = None,
worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None,
chunks: int = chunks,
checkpoint: str = checkpoint,
deferred_batch_norm: bool = False,
) -> None:
super().__init__()
if chunks <= 0:
raise ValueError("number of chunks must be positive integer")
if checkpoint not in ["always", "except_last", "never"]:
raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'")
if get_model_parallel_world_size() > 1:
self.pipelined_backward = True
else:
self.pipelined_backward = False
self.balance = list(balance)
verify_module(module)
check_balance(module, self.balance)
self.chunks = chunks
self.checkpoint = checkpoint
self.pipeline: Optional[MultiProcessPipeline]
self.lock = threading.Lock()
self.worker_map = worker_map
self.input_device = input_device
self.group: torch.distributed.ProcessGroup
if group is None:
self.group = get_pipeline_parallel_group()
else:
self.group = group
if self.group.size() < len(self.balance):
raise IndexError(
f"too few ranks to hold given partitions (ranks: {self.group.size()}, partitions:"
f" {len(self.balance)})"
)
self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom)
rank = self.group.rank()
self.final_stage = rank == len(self.balance) - 1
if rank >= len(self.balance):
warnings.warn("More ranks than partitions, some ranks unused")
self.partition = nn.Sequential()
self.pipeline = None
else:
self.partition = self.instantiate_partition(module, self.balance, self.group)
if deferred_batch_norm:
self.partitition = DeferredBatchNorm.convert_deferred_batch_norm(self.partition, chunks)
self.add_module(str(0), self.partition)
self.create_pipeline()
del module
def create_pipeline(self) -> None:
# The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
self.pipeline = MultiProcessPipeline(
self.partition,
self._skip_layout,
checkpoint_stop,
group=self.group,
worker_map=self.worker_map,
input_device=self.input_device,
final_stage=self.final_stage,
)
def instantiate_partition(
self, module: Union[nn.Sequential, List[LazyModule]], balance: List[int], group: torch.distributed.ProcessGroup,
) -> nn.Sequential:
rank = group.rank()
first_layer = sum(balance[:rank])
num_layers = balance[rank]
layers = module[first_layer : first_layer + num_layers]
instantiated_layers = [l if isinstance(l, nn.Module) else l() for l in layers]
return nn.Sequential(*instantiated_layers)
def __len__(self) -> int:
"""Counts the length of the underlying sequential module."""
return self.partition.__len__()
def __getitem__(self, index: int) -> nn.Module:
"""Gets a layer in the underlying sequential module."""
return self.partition.__getitem__(index)
def __iter__(self) -> Iterable[nn.Module]:
"""Iterates over children of the underlying sequential module."""
return self.partition.__iter__()
def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore
""":class:`MultiProcessPipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But
there's type restriction. Input and output have to be a
:class:`~torch.Tensor` or a tuple of tensors. This restriction is
applied at partition boundaries too.
Args:
input (torch.Tensor or tensors): input mini-batch
Returns:
tensor or tensors: output mini-batch
Raises:
TypeError: input is not a tensor or tensors.
"""
microbatch.check(input)
if not self.pipeline:
# No pipeline is not illegal, more ranks than partitions
return input
# Divide a mini-batch into micro-batches.
batches = microbatch.scatter(input, self.chunks)
# Run pipeline parallelism.
with self.lock:
self.pipeline.run(self.training, batches)
if self.final_stage:
# Merge the micro-batches into one mini-batch.
if self.pipelined_backward:
with torch.no_grad():
output = microbatch.gather(batches)
phony = get_phony(
torch.device(torch.cuda.current_device() if torch.cuda.is_available() else "cpu"),
requires_grad=True,
)
output = PipelinedBackwardPass.apply(output, batches, phony)
else:
output = microbatch.gather(batches)
else:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
output = batches # type: ignore
return output
def back_helper(self, output: List[microbatch.Batch]) -> None:
if self.final_stage:
raise ValueError("back_helper should only be called on non-final stages")
if self.pipeline:
self.pipeline.back_helper(output)
class PipelinedBackwardPass(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx, input: TensorOrTensors, batches, phony) -> TensorOrTensors:
ctx.batches = batches
return input
@staticmethod
# type: ignore
def backward(ctx, *grads) -> Tuple:
with torch.no_grad():
grad_batches = microbatch.scatter(grads, len(ctx.batches))
for grad, batch in reversed(list(zip(grad_batches, ctx.batches))):
for t in batch:
t.retain_grad()
torch.autograd.backward(batch.tensor_or_tensors, grad_tensors=(*grad,), retain_graph=True)
with torch.no_grad():
if ctx.batches[0].atomic:
tensors = tuple(b.tensor.grad for b in ctx.batches)
output: TensorOrTensors = torch.cat(tensors)
else:
rotated = [[t.grad for t in b.tensors] for b in ctx.batches]
output_buf = []
for tensors in zip(*rotated):
output_buf.append(torch.cat(tensors))
output = tuple(output_buf)
del ctx.batches
return (output, None, None, None)
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The multiprocess pipeline parallelism of Pipe."""
import os
from queue import Empty as QueueEmpty
from queue import Queue
from types import TracebackType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
import torch
from torch import Tensor, nn
from torch.autograd.profiler import record_function
from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from .checkpoint import Checkpointing
from .messages import MakeTransport, Transport
from .microbatch import Batch
from .skip import Namespace
from .skip.layout import SkipLayout
from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from .types import ACTIVATIONS_GRADS_QUEUE, PORTAL_QUEUE, SKIP_TENSOR_QUEUE, PipeMessage, TensorOrTensors, Tensors
from .worker import Task
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if TYPE_CHECKING:
InQueue = Queue[Optional[Task]]
OutQueue = Queue[Tuple[bool, Union[Tuple[Task, Batch], ExcInfo, None]]]
else:
InQueue = Queue
OutQueue = Queue
__all__: List[str] = []
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
class SendOperator(torch.autograd.Function):
"""Send activations to the next pipeline stage"""
@staticmethod
# type: ignore
def forward(ctx, transport: Transport, input: List[Tensor], index: int) -> Tensors:
ranks = get_pipeline_parallel_ranks()
src_rank = torch.distributed.get_rank()
dst_rank = ranks[ranks.index(src_rank) + 1]
transport.send_message(
PipeMessage(src_rank, dst_rank, queue_name=ACTIVATIONS_GRADS_QUEUE, args=index, tensors=tuple(input)),
)
return ()
@staticmethod
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tensors:
return tuple(grad)
class RecvOperator(torch.autograd.Function):
"""Receive activations to the previous pipeline stage"""
@staticmethod
# type: ignore
def forward(ctx, tensor: Tensor, transport: Transport, index: int) -> Tensors:
ctx.transport = transport
ctx.index = index
result = transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, index)
def maybe_requires_grad(t: Tensor) -> Tensor:
if t.dtype.is_floating_point:
return t.requires_grad_()
return t
return tuple(maybe_requires_grad(r) for r in result)
@staticmethod
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]:
ranks = get_pipeline_parallel_ranks()
src_rank = torch.distributed.get_rank()
dst_rank = ranks[ranks.index(src_rank) - 1]
ctx.transport.send_message(
PipeMessage(src_rank, dst_rank, queue_name=ACTIVATIONS_GRADS_QUEUE, args=ctx.index, tensors=tuple(grad),),
)
return (None, None, None, None)
class MultiProcessPipeline:
"""The multiprocess pipeline parallelism for Pipe."""
def __init__(
self,
partition: nn.Sequential,
skip_layout: SkipLayout,
checkpoint_stop: int,
group: torch.distributed.ProcessGroup,
*,
worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None,
final_stage: bool = False,
) -> None:
self.partition = partition
self.skip_layout = skip_layout
self.__checkpoint_stop = checkpoint_stop
self.group = group
self.training: bool
self.transport = MakeTransport(
use_rpc=("OMPI_COMM_WORLD_RANK" not in os.environ) or ("FORCE_RPC" in os.environ),
worker_map=worker_map,
input_device=input_device,
)
self.input_device = input_device
self.final_stage = final_stage
@property
def checkpoint_stop(self) -> int:
# Disable checkpointing if in eval mode.
training = self.partition.training
if not training:
return 0
return self.__checkpoint_stop
def run(self, training: bool, batches: List[Batch]) -> None:
"""Runs pipeline parallelism.
It modifies the given batches in place.
"""
self.training = training
m = len(batches)
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(m)]
rank = self.group.rank()
for i in range(m):
if rank != 0:
batch = self.get_batch_from_previous_stage(i, skip_trackers, batches)
else:
batch = batches[i]
with use_skip_tracker(skip_trackers[i]), record_function("chunk%d-part%d" % (i, rank)):
if i < self.checkpoint_stop:
chk = Checkpointing(self.partition, batch)
batch = chk.checkpoint()
else:
batch = batch.call(self.partition)
if not self.final_stage:
self.send_skip_tensors(batch, i, skip_trackers)
SendOperator.apply(self.transport, [*batch], i)
for portal in skip_trackers[i].portals.values():
portal.pipeline = self
if i < self.checkpoint_stop:
chk.recompute(batch)
batches[i] = batch
def get_batch_from_previous_stage(
self, i: int, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]
) -> Batch:
phony = torch.empty(0, device=self.input_device, requires_grad=True)
result = RecvOperator.apply(phony, self.transport, i)
if len(result) == 1:
batch = Batch(result[0], i)
else:
batch = Batch(result, i)
self.recv_skip_tensors(skip_trackers, batches)
return batch
def send_skip_tensors(self, batch: Batch, i: int, skip_trackers: List[SkipTrackerThroughPotals]) -> None:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
for next_j, ns, name in self.skip_layout.copy_policy_by_src(self.group.rank()):
life = skip_trackers[i].portals[(ns, name)].tensor_life
loaded = skip_trackers[i].load(batch, ns, name)
if loaded is not None:
tensors = tuple([loaded])
else:
tensors = tuple()
self.transport.send_message(
PipeMessage(
this_rank, ranks[next_j], queue_name=SKIP_TENSOR_QUEUE, args=(i, ns, name, life), tensors=tensors,
),
sync=True,
)
def recv_skip_tensors(self, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]) -> None:
while True:
try:
message = self.transport.recv_message(SKIP_TENSOR_QUEUE, nowait=True)
(si, ns, name, life) = message.args
value: Optional[TensorOrTensors] = message.tensors
assert isinstance(value, tuple)
if len(value) == 0:
value = None
else:
assert len(value) == 1
value = value[0]
skip_trackers[si].save(batches[si], ns, name, value)
old_life = skip_trackers[si].portals[(ns, name)].tensor_life
if life != 0:
skip_trackers[si].portals[(ns, name)].tensor_life = life
except QueueEmpty:
break
def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: TensorOrTensors) -> None:
dest, src = self.skip_layout.by_ns_name.get(ns_name, (-1, -1))
if dest == src:
return
ranks = get_pipeline_parallel_ranks()
dst_rank = ranks[dest]
if dst_rank == torch.distributed.get_rank():
return
if isinstance(grad, Tensor):
grad = tuple([grad])
self.transport.send_message(
PipeMessage(ranks[src], dst_rank, queue_name=PORTAL_QUEUE, args=(ns_name, index), tensors=grad), sync=True,
)
def recv_portal_grad(self, expected_ns_name: Tuple[Namespace, str], expected_index: int) -> Tensor:
message = self.transport.recv_message(PORTAL_QUEUE)
(ns_name, index) = message.args
grad = message.tensors
assert len(grad) == 1
result = grad[0]
assert index == expected_index and ns_name == expected_ns_name
return result
def back_helper(self, output: List[Batch]) -> None:
tensors: Tensors
rank = torch.distributed.get_rank()
for batch in reversed(output):
found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, batch.index)
if batch.atomic:
tensors = tuple([batch.tensor])
else:
tensors = batch.tensors
if len(found) != len(tensors):
raise RuntimeError("different number of tensors and gradients")
grads = []
final_tensors = []
for i, tensor in enumerate(tensors):
if tensor.requires_grad or getattr(tensor, "grad_fn", None) is not None:
grads.append(found[i])
final_tensors.append(tensor)
try:
torch.autograd.backward(final_tensors, grad_tensors=grads, retain_graph=True)
except Exception as e:
raise RuntimeError(f"Autograd failed on {torch.distributed.get_rank()}") from e
...@@ -20,19 +20,15 @@ ...@@ -20,19 +20,15 @@
# limitations under the License. # limitations under the License.
import os import os
import tempfile
import pytest import pytest
import torch import torch
from torch import nn
from torch.distributed import rpc
import torch.nn.init as init import torch.nn.init as init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from fairscale.nn.model_parallel import initialize as mpu from fairscale.nn.model_parallel import initialize as mpu
from fairscale.nn.model_parallel import layers from fairscale.nn.model_parallel import layers
from fairscale.nn.pipe import MultiProcessPipe from fairscale.utils.testing import dist_init, set_random_seed, spawn_for_all_world_sizes
from fairscale.utils.testing import dist_init, get_world_sizes, set_random_seed, spawn_for_all_world_sizes, torch_spawn
def run_test_parallel_embedding(rank, model_parallel_size, filename, filename_rpc): def run_test_parallel_embedding(rank, model_parallel_size, filename, filename_rpc):
...@@ -302,241 +298,6 @@ def run_test_row_parallel_linear(rank, model_parallel_size, filename, filename_r ...@@ -302,241 +298,6 @@ def run_test_row_parallel_linear(rank, model_parallel_size, filename, filename_r
print(" >> passed the test :-)") print(" >> passed the test :-)")
def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False):
pipe_world_size = 2
if world_size == 1:
return
if not skip_dist_init:
dist_init(rank, world_size, filename, filename_rpc)
else:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29502"
rpc.init_rpc(f"Test{rank}", rank=rank, world_size=world_size)
mpu.initialize_model_parallel(world_size / pipe_world_size, pipe_world_size)
model_parallel_size = mpu.get_model_parallel_world_size()
if torch.distributed.get_rank() == 0:
print(
"> testing Sequential + MultiProcessPipe with model parallel size: {}, pipe: {}".format(
model_parallel_size, pipe_world_size
)
)
chunk_size = 4
seed = 12345
set_random_seed(seed)
input_size_coeff = 3
input_size = input_size_coeff * model_parallel_size
output_size_coeff = 7
output_size = output_size_coeff * model_parallel_size
batch_size = 3 * chunk_size
target = torch.rand((batch_size, input_size), requires_grad=True).cuda()
print(f"target = {target}")
identity = IdentityLayer2D(batch_size, input_size).cuda()
pipeline_devices = mpu.get_pipeline_parallel_group()
set_random_seed(seed)
model = nn.Sequential(
layers.ColumnParallelLinear(input_size, output_size, keep_master_weight_for_test=True, bias=False).cuda(),
nn.ReLU(),
layers.RowParallelLinear(output_size, input_size, keep_master_weight_for_test=True, bias=False).cuda(),
)
set_random_seed(seed)
reference = [
nn.Linear(input_size, output_size, bias=False).cuda(),
nn.ReLU(),
nn.Linear(output_size, input_size, bias=False).cuda(),
]
print(f"setup {reference[0].weight.size()}, {model[0].weight.size()}, {(input_size, output_size)}")
print(f"setup {reference[2].weight.size()}, {(output_size, input_size)}")
reference[0].weight = Parameter(model[0].get_master_weight().clone()).cuda()
reference[2].weight = Parameter(model[2].get_master_weight().clone()).cuda()
reference = nn.Sequential(*reference)
def grad_graph(depth, grad):
result = depth * " " + str(grad)
if grad:
for x in grad.next_functions:
result += "\n" + grad_graph(depth + 1, x[0])
return result
def check_weights(x, y, key: str, index=None):
for i in [2, 0]:
if index is not None and i != index:
continue
left = x[i].get_master_weight()
right = y[i].weight.data
if not torch.allclose(left, right, atol=1.0e-6) or index is not None:
print(f"check_weights {key}-{i}: left = {left}, \nright = {right}")
if not torch.equal(left, right):
print(f"check_weights NOT_EQUAL {key}-{i}: left = {left}, \nright = {right}")
assert torch.allclose(left, right, atol=1.0e-6)
def dump_opt_params(opt):
for i, group in enumerate(opt.param_groups):
for j, p in enumerate(group["params"]):
print(f"{torch.distributed.get_rank()}:param {(i,j)} = {p}")
print(f"{torch.distributed.get_rank()}:param.grad {(i,j)} = {p.grad}")
def forward_model(model_, target, step=False):
optimizer = torch.optim.SGD(model_.parameters(), lr=0.01, momentum=0.9)
optimizer.zero_grad()
model_.zero_grad()
output = model_(identity())
loss = nn.MSELoss()
model_.zero_grad()
if step:
loss(output, target).backward()
saved_weight_0 = model_[0].weight.data.clone()
saved_weight_2 = model_[2].weight.data.clone()
dump_opt_params(optimizer)
optimizer.step()
assert not torch.allclose(saved_weight_0, model_[0].weight.data, atol=1.0e-6)
assert not torch.allclose(saved_weight_2, model_[2].weight.data, atol=1.0e-6)
return output
output = forward_model(model, target)
reference_output = forward_model(reference, target)
error = reference_output.sub(output).max()
torch.distributed.barrier()
assert error < 1.0e-6
output = forward_model(model, target)
error = reference_output.sub(output).max()
torch.distributed.barrier()
assert error < 1.0e-6
output = forward_model(model, target)
error = reference_output.sub(output).max()
torch.distributed.barrier()
assert error < 1.0e-6
check_weights(model, reference, "before")
saved_weight_0 = model[0].weight.data.clone()
saved_weight_2 = model[2].weight.data.clone()
output = forward_model(model, target, step=True)
error = reference_output.sub(output).max()
assert error < 1.0e-6
model[0].weight.data = saved_weight_0
model[2].weight.data = saved_weight_2
worker_map = {i: f"Test{i}" for i in range(torch.distributed.get_world_size())}
if pipe_world_size == 2:
print("actually doing pipe stuff now")
assert torch.equal(saved_weight_0, model[0].weight.data)
assert torch.equal(saved_weight_2, model[2].weight.data)
pipe_model = MultiProcessPipe(
model,
[2, 1],
group=pipeline_devices,
worker_map=worker_map,
input_device=torch.cuda.current_device(),
chunks=chunk_size,
).cuda()
torch.distributed.barrier()
pipe_rank = torch.distributed.get_rank(group=mpu.get_pipeline_parallel_group())
print(f"pipe rank is {pipe_rank}")
if pipe_rank == 0:
assert torch.equal(saved_weight_0, pipe_model[0].weight.data)
else:
if not torch.equal(saved_weight_2, pipe_model[0].weight.data):
print(f"ne {pipe_rank}: left\n{saved_weight_2}\nright:\n{pipe_model[0].weight.data}")
assert torch.equal(saved_weight_2, pipe_model[0].weight.data)
optimizer = torch.optim.SGD(pipe_model.parameters(), lr=0.01, momentum=0.9)
optimizer.zero_grad()
if pipe_rank == 0:
assert torch.equal(saved_weight_0, pipe_model[0].weight.data)
print(f"runner {rank}:\n{pipe_model[0].weight.data}")
else:
assert torch.equal(saved_weight_2, pipe_model[0].weight.data)
print(f"runner {rank}:\n{pipe_model[0].weight.data}")
if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1:
check_weights(model, reference, "pre-pipe", index=2)
else:
check_weights(model, reference, "pre-pipe", index=0)
pipe_output = pipe_model(identity())
print(f"exited pipe for {rank}")
forward_model(reference, target, step=True)
print(f"pipe_output {rank} = {pipe_output}")
print(f"reference_output {rank} = {reference_output}")
torch.distributed.barrier()
if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1:
error = reference_output.sub(pipe_output.cuda()).max()
if error >= 1.0e-6:
print(f"error bad {error}")
assert error < 1.0e-6
loss = nn.MSELoss()
failed = False
pipe_output.retain_grad()
with torch.autograd.profiler.profile() as prof:
try:
loss(pipe_output, target).backward()
except Exception as e:
failed = True
print(f"got {e} while doing backward, deadlock?")
if failed:
raise RuntimeError("failed somehow")
dump_opt_params(optimizer)
optimizer.step()
print("calling check_weights on master")
check_weights(model, reference, "pipe", index=2)
print(f"waiting for barrier on master, pid={os.getpid()}")
else:
print(f"calling backwards on slave, pid={os.getpid()}")
failed = False
with torch.autograd.profiler.profile() as prof:
try:
pipe_model.back_helper(pipe_output)
except Exception as e:
failed = True
print(f"got {e} while doing backward, deadlock?")
if failed:
raise RuntimeError("failed somehow")
dump_opt_params(optimizer)
print("calling step on slave")
optimizer.step()
print("calling check_weights on slave")
check_weights(model, reference, "pipe", index=0)
print("waiting for barrier on slave")
pipe_model.zero_grad()
torch.distributed.barrier()
pipe_model.eval()
pipe_output = pipe_model(identity())
updated_ref_output = forward_model(reference, target)
if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1:
error = updated_ref_output.sub(pipe_output.cuda()).max()
print(f"outputs are ref:\n{updated_ref_output}\npipe:\n{pipe_output}")
assert error < 1.0e-6
torch.distributed.barrier()
print(f"finished waiting for barrier on, pid={os.getpid()}")
print(f"really exited pipe for {rank}")
rpc.shutdown()
torch.distributed.destroy_process_group()
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
...@@ -556,35 +317,3 @@ def test_column_parallel(): ...@@ -556,35 +317,3 @@ def test_column_parallel():
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="only works on mpi") @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="only works on mpi")
def test_row_parallel(): def test_row_parallel():
spawn_for_all_world_sizes(run_test_row_parallel_linear) spawn_for_all_world_sizes(run_test_row_parallel_linear)
@torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="only works on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def mpi_pipe():
mpu.destroy_model_parallel()
_, tempfile_init = tempfile.mkstemp()
_, tempfile_rpc_init = tempfile.mkstemp()
run_test_pipe(
torch.distributed.get_rank(),
torch.distributed.get_world_size(),
tempfile_init,
tempfile_rpc_init,
skip_dist_init=True,
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_pipe_layer():
world_sizes = [x for x in get_world_sizes() if x <= torch.cuda.device_count() / 2]
spawn_for_all_world_sizes(run_test_pipe, args=[False])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.skip(reason="potential deadlock in nccl with multiple processes using the same gpu")
def test_eight_pipe_layer():
world_sizes = [x for x in get_world_sizes() if x <= torch.cuda.device_count() / 2]
spawn_for_all_world_sizes(run_test_pipe, [8])
...@@ -22,13 +22,13 @@ import torch ...@@ -22,13 +22,13 @@ import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from fairscale.nn.pipe import AsyncPipe, MultiProcessPipe from fairscale.nn.pipe import AsyncPipe
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def python_autograd_function(pipe_class): def python_autograd_function(pipe_class):
# FIXME deadlock with AsyncPipe? # FIXME deadlock with AsyncPipe?
# A Python autograd function might fail with this error: # A Python autograd function might fail with this error:
...@@ -71,7 +71,7 @@ def python_autograd_function(pipe_class): ...@@ -71,7 +71,7 @@ def python_autograd_function(pipe_class):
@torch_spawn([3]) @torch_spawn([3])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def exception_no_hang(pipe_class): def exception_no_hang(pipe_class):
# In v0.0.2, once a failed partition receives a normal message # In v0.0.2, once a failed partition receives a normal message
# (non-closing) for the next micro-batch, a hang occured. The reason was # (non-closing) for the next micro-batch, a hang occured. The reason was
...@@ -104,7 +104,7 @@ def exception_no_hang(pipe_class): ...@@ -104,7 +104,7 @@ def exception_no_hang(pipe_class):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required")
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def tuple_wait(cuda_sleep, pipe_class): def tuple_wait(cuda_sleep, pipe_class):
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch. # In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# Under this behavior, if checkpointing was disabled, there's a possibility # Under this behavior, if checkpointing was disabled, there's a possibility
...@@ -157,7 +157,7 @@ def tuple_wait(cuda_sleep, pipe_class): ...@@ -157,7 +157,7 @@ def tuple_wait(cuda_sleep, pipe_class):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def parallel_randoms(pipe_class): def parallel_randoms(pipe_class):
class Dropouts(nn.Module): class Dropouts(nn.Module):
def forward(self, x): def forward(self, x):
......
...@@ -21,13 +21,13 @@ import pytest ...@@ -21,13 +21,13 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from fairscale.nn.pipe import AsyncPipe, MultiProcessPipe from fairscale.nn.pipe import AsyncPipe
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def inplace_on_requires_grad(pipe_class): def inplace_on_requires_grad(pipe_class):
model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True))
model = pipe_class(model, [1, 1], worker_map=get_worker_map(), checkpoint="always") model = pipe_class(model, [1, 1], worker_map=get_worker_map(), checkpoint="always")
...@@ -50,7 +50,7 @@ def inplace_on_requires_grad(pipe_class): ...@@ -50,7 +50,7 @@ def inplace_on_requires_grad(pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.xfail(strict=True) @pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def inplace_on_not_requires_grad(pipe_class): def inplace_on_not_requires_grad(pipe_class):
# In-place operation on a tensor not requiring grad doesn't cause a # In-place operation on a tensor not requiring grad doesn't cause a
# RuntimeError. Currently, we cannot detect this case. # RuntimeError. Currently, we cannot detect this case.
...@@ -70,7 +70,7 @@ def inplace_on_not_requires_grad(pipe_class): ...@@ -70,7 +70,7 @@ def inplace_on_not_requires_grad(pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.xfail(strict=True) @pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def inplace_incorrect_grad(pipe_class): def inplace_incorrect_grad(pipe_class):
class M(nn.Module): class M(nn.Module):
def forward(self, foo_bar): def forward(self, foo_bar):
......
...@@ -26,17 +26,14 @@ import pytest ...@@ -26,17 +26,14 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from fairscale.nn.model_parallel.initialize import ( from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
destroy_model_parallel, from fairscale.nn.pipe import AsyncPipe
get_pipeline_parallel_group, from fairscale.nn.pipe.types import LazyModule
initialize_model_parallel,
)
from fairscale.nn.pipe import AsyncPipe, LazyModule, MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn, torch_version from fairscale.utils.testing import get_worker_map, torch_spawn, torch_version
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def parameters(pipe_class): def parameters(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
pipe = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=1) pipe = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=1)
...@@ -107,7 +104,7 @@ def mpi(): ...@@ -107,7 +104,7 @@ def mpi():
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def public_attrs(pipe_class): def public_attrs(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
...@@ -122,7 +119,7 @@ def public_attrs(pipe_class): ...@@ -122,7 +119,7 @@ def public_attrs(pipe_class):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("balance", [[2], [1, 1]]) @pytest.mark.parametrize("balance", [[2], [1, 1]])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def sequential_like(balance, pipe_class): def sequential_like(balance, pipe_class):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
...@@ -161,7 +158,7 @@ def sequential_like(balance, pipe_class): ...@@ -161,7 +158,7 @@ def sequential_like(balance, pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def balance_wrong_length(pipe_class): def balance_wrong_length(pipe_class):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
...@@ -176,7 +173,7 @@ def balance_wrong_length(pipe_class): ...@@ -176,7 +173,7 @@ def balance_wrong_length(pipe_class):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def balance_less_than_1(pipe_class): def balance_less_than_1(pipe_class):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
...@@ -191,7 +188,7 @@ def balance_less_than_1(pipe_class): ...@@ -191,7 +188,7 @@ def balance_less_than_1(pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def chunks_less_than_1(pipe_class): def chunks_less_than_1(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
...@@ -203,7 +200,7 @@ def chunks_less_than_1(pipe_class): ...@@ -203,7 +200,7 @@ def chunks_less_than_1(pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def too_few_devices(pipe_class): def too_few_devices(pipe_class):
model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1))
...@@ -213,7 +210,7 @@ def too_few_devices(pipe_class): ...@@ -213,7 +210,7 @@ def too_few_devices(pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def batch_size_indivisible(pipe_class): def batch_size_indivisible(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=4) model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=4)
...@@ -226,7 +223,7 @@ def batch_size_indivisible(pipe_class): ...@@ -226,7 +223,7 @@ def batch_size_indivisible(pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def batch_size_small(pipe_class): def batch_size_small(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=4) model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=4)
...@@ -239,7 +236,7 @@ def batch_size_small(pipe_class): ...@@ -239,7 +236,7 @@ def batch_size_small(pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def checkpoint_mode(pipe_class): def checkpoint_mode(pipe_class):
def count_grad_fn(grad_fn, name, visited=set()): def count_grad_fn(grad_fn, name, visited=set()):
if grad_fn in visited: if grad_fn in visited:
...@@ -273,7 +270,7 @@ def checkpoint_mode(pipe_class): ...@@ -273,7 +270,7 @@ def checkpoint_mode(pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def checkpoint_mode_invalid(pipe_class): def checkpoint_mode_invalid(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
...@@ -284,7 +281,7 @@ def checkpoint_mode_invalid(pipe_class): ...@@ -284,7 +281,7 @@ def checkpoint_mode_invalid(pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def checkpoint_mode_when_chunks_1(pipe_class): def checkpoint_mode_when_chunks_1(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
...@@ -297,7 +294,7 @@ def checkpoint_mode_when_chunks_1(pipe_class): ...@@ -297,7 +294,7 @@ def checkpoint_mode_when_chunks_1(pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def checkpoint_eval(pipe_class): def checkpoint_eval(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2,) model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2,)
...@@ -326,7 +323,7 @@ def checkpoint_eval(pipe_class): ...@@ -326,7 +323,7 @@ def checkpoint_eval(pipe_class):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.xfail(torch_version() < (1, 6, 0), reason="Doesn't work on torch < 1.6.0", strict=True) @pytest.mark.xfail(torch_version() < (1, 6, 0), reason="Doesn't work on torch < 1.6.0", strict=True)
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def checkpoint_non_float_input(pipe_class): def checkpoint_non_float_input(pipe_class):
class ForkNonFloat(nn.Module): class ForkNonFloat(nn.Module):
def forward(self, input): def forward(self, input):
...@@ -344,14 +341,12 @@ def checkpoint_non_float_input(pipe_class): ...@@ -344,14 +341,12 @@ def checkpoint_non_float_input(pipe_class):
if model.group.rank() == 1: if model.group.rank() == 1:
# with torch.autograd.detect_anomaly(): # with torch.autograd.detect_anomaly():
output.backward() output.backward()
elif pipe_class == MultiProcessPipe:
model.back_helper(output)
torch.distributed.barrier() torch.distributed.barrier()
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def no_grad(pipe_class): def no_grad(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2) model = pipe_class(model, balance=[1], worker_map=get_worker_map(), chunks=2)
...@@ -376,7 +371,7 @@ def no_grad(pipe_class): ...@@ -376,7 +371,7 @@ def no_grad(pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def exception(pipe_class): def exception(pipe_class):
class ExpectedException(Exception): class ExpectedException(Exception):
pass pass
...@@ -396,7 +391,7 @@ def exception(pipe_class): ...@@ -396,7 +391,7 @@ def exception(pipe_class):
@torch_spawn([4]) @torch_spawn([4])
@pytest.mark.skipif(torch.cuda.is_available() and torch.cuda.device_count() < 4, reason="Not enough GPUs") @pytest.mark.skipif(torch.cuda.is_available() and torch.cuda.device_count() < 4, reason="Not enough GPUs")
@pytest.mark.xfail(strict=True) @pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def exception_early_stop_asap(pipe_class): def exception_early_stop_asap(pipe_class):
"""Even the first partitions have finished to process, the partition before """Even the first partitions have finished to process, the partition before
the failed partition hould be killed as soon as possible. the failed partition hould be killed as soon as possible.
...@@ -435,7 +430,7 @@ def exception_early_stop_asap(pipe_class): ...@@ -435,7 +430,7 @@ def exception_early_stop_asap(pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def input_pair(pipe_class): def input_pair(pipe_class):
class Two(nn.Module): class Two(nn.Module):
def __init__(self): def __init__(self):
...@@ -462,7 +457,7 @@ def input_pair(pipe_class): ...@@ -462,7 +457,7 @@ def input_pair(pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def input_singleton(pipe_class): def input_singleton(pipe_class):
class One(nn.Module): class One(nn.Module):
def __init__(self): def __init__(self):
...@@ -487,7 +482,7 @@ def input_singleton(pipe_class): ...@@ -487,7 +482,7 @@ def input_singleton(pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def input_varargs(pipe_class): def input_varargs(pipe_class):
model = nn.Sequential(nn.Linear(1, 1)) model = nn.Sequential(nn.Linear(1, 1))
model = pipe_class(model, balance=[1], worker_map=get_worker_map()) model = pipe_class(model, balance=[1], worker_map=get_worker_map())
...@@ -501,7 +496,7 @@ def input_varargs(pipe_class): ...@@ -501,7 +496,7 @@ def input_varargs(pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def non_tensor(pipe_class): def non_tensor(pipe_class):
class NonTensor(nn.Module): class NonTensor(nn.Module):
def forward(self, _): def forward(self, _):
...@@ -521,7 +516,7 @@ def non_tensor(pipe_class): ...@@ -521,7 +516,7 @@ def non_tensor(pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def non_tensor_tuple(pipe_class): def non_tensor_tuple(pipe_class):
class NonTensorTuple(nn.Module): class NonTensorTuple(nn.Module):
def forward(self, x): def forward(self, x):
...@@ -543,7 +538,7 @@ def non_tensor_tuple(pipe_class): ...@@ -543,7 +538,7 @@ def non_tensor_tuple(pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
@pytest.mark.parametrize("lazy", [True, False]) @pytest.mark.parametrize("lazy", [True, False])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def deferred_batch_norm(checkpoint, lazy, pipe_class): def deferred_batch_norm(checkpoint, lazy, pipe_class):
bn = nn.BatchNorm2d(3) bn = nn.BatchNorm2d(3)
pipe_bn = deepcopy(bn) pipe_bn = deepcopy(bn)
...@@ -567,7 +562,7 @@ def deferred_batch_norm(checkpoint, lazy, pipe_class): ...@@ -567,7 +562,7 @@ def deferred_batch_norm(checkpoint, lazy, pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("checkpoint", ["never", "always"]) @pytest.mark.parametrize("checkpoint", ["never", "always"])
@pytest.mark.parametrize("lazy", [True, False]) @pytest.mark.parametrize("lazy", [True, False])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def deferred_batch_norm_params(checkpoint, lazy, pipe_class): def deferred_batch_norm_params(checkpoint, lazy, pipe_class):
bn = nn.BatchNorm2d(3) bn = nn.BatchNorm2d(3)
pipe_bn = deepcopy(bn) pipe_bn = deepcopy(bn)
...@@ -592,7 +587,7 @@ def deferred_batch_norm_params(checkpoint, lazy, pipe_class): ...@@ -592,7 +587,7 @@ def deferred_batch_norm_params(checkpoint, lazy, pipe_class):
@torch_spawn([4]) @torch_spawn([4])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def devices(pipe_class): def devices(pipe_class):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
...@@ -608,7 +603,7 @@ def devices(pipe_class): ...@@ -608,7 +603,7 @@ def devices(pipe_class):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def partitions(pipe_class): def partitions(pipe_class):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
...@@ -626,7 +621,7 @@ def partitions(pipe_class): ...@@ -626,7 +621,7 @@ def partitions(pipe_class):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def deny_moving(pipe_class): def deny_moving(pipe_class):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
b = nn.Linear(1, 1) b = nn.Linear(1, 1)
...@@ -650,7 +645,7 @@ def deny_moving(pipe_class): ...@@ -650,7 +645,7 @@ def deny_moving(pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def empty_module(pipe_class): def empty_module(pipe_class):
# Empty sequential module is not illegal. # Empty sequential module is not illegal.
model = nn.Sequential() model = nn.Sequential()
...@@ -666,7 +661,7 @@ def empty_module(pipe_class): ...@@ -666,7 +661,7 @@ def empty_module(pipe_class):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
@pytest.mark.skip(reason="TODO(msb) handle named_children") @pytest.mark.skip(reason="TODO(msb) handle named_children")
def named_children(pipe_class): def named_children(pipe_class):
a = nn.Linear(1, 1) a = nn.Linear(1, 1)
...@@ -688,7 +683,7 @@ def named_children(pipe_class): ...@@ -688,7 +683,7 @@ def named_children(pipe_class):
@torch_spawn([1]) @torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def recommend_auto_balance(pipe_class): def recommend_auto_balance(pipe_class):
with pytest.raises(ValueError): with pytest.raises(ValueError):
# module and sum of balance have differen length (module: 0, sum of balance: 1) # module and sum of balance have differen length (module: 0, sum of balance: 1)
...@@ -700,7 +695,7 @@ def recommend_auto_balance(pipe_class): ...@@ -700,7 +695,7 @@ def recommend_auto_balance(pipe_class):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def lazy_construction(pipe_class): def lazy_construction(pipe_class):
init_count = 0 init_count = 0
...@@ -730,7 +725,7 @@ def lazy_construction(pipe_class): ...@@ -730,7 +725,7 @@ def lazy_construction(pipe_class):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="doesn't apply to mpi") @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="doesn't apply to mpi")
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def missing_worker_map(pipe_class): def missing_worker_map(pipe_class):
model = nn.Sequential(nn.ReLU(), nn.ReLU()) model = nn.Sequential(nn.ReLU(), nn.ReLU())
...@@ -740,7 +735,7 @@ def missing_worker_map(pipe_class): ...@@ -740,7 +735,7 @@ def missing_worker_map(pipe_class):
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skip(reason="currently broken") @pytest.mark.skip(reason="currently broken")
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def verify_module_duplicate_parameters_on_distinct_partitions(pipe_class): def verify_module_duplicate_parameters_on_distinct_partitions(pipe_class):
class Surrogate(nn.Module): class Surrogate(nn.Module):
def __init__(self, module): def __init__(self, module):
...@@ -755,24 +750,6 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipe_class): ...@@ -755,24 +750,6 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipe_class):
pipe_class(model, [1, 1], worker_map=get_worker_map()) pipe_class(model, [1, 1], worker_map=get_worker_map())
@torch_spawn([4])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe])
def pipelined_backward(pipe_class):
model = nn.Sequential(nn.ReLU(), nn.ReLU())
destroy_model_parallel()
initialize_model_parallel(1, 4)
pipe = pipe_class(model, [1, 1], worker_map=get_worker_map())
assert pipe.pipelined_backward is False
destroy_model_parallel()
initialize_model_parallel(2, 2)
pipe = pipe_class(model, [1, 1], worker_map=get_worker_map())
assert pipe.pipelined_backward is True
@torch_spawn([4]) @torch_spawn([4])
def async_event_loop(): def async_event_loop():
......
...@@ -21,13 +21,13 @@ import pytest ...@@ -21,13 +21,13 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from fairscale.nn.pipe import AsyncPipe, MultiProcessPipe from fairscale.nn.pipe import AsyncPipe
from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) @pytest.mark.parametrize("pipe_class", [AsyncPipe])
def simple_linears(pipe_class): def simple_linears(pipe_class):
def sum_grad(parameters): def sum_grad(parameters):
return sum([p.grad.sum() for p in parameters if p.grad is not None]) return sum([p.grad.sum() for p in parameters if p.grad is not None])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment