Unverified Commit cae9b638 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] pipe: separate out Single and MultiProcess pipe (#326)

parent eab1551a
......@@ -19,10 +19,9 @@ import torchtext
from torchtext.data.utils import get_tokenizer
from experimental.nn.ampnet_pipe import pipe
from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule
from fairscale.nn.pipe import LazyModule, MultiProcessPipe
from fairscale.optim import GradScaler
from fairscale.utils.testing import dist_init, get_worker_map
......@@ -421,7 +420,7 @@ def run_mp_worker(args, available_workers):
p = pipe.AMPnetPipe(
module=model,
balance=balance,
style=Pipe.AsyncSchedule,
style=MultiProcessPipe.AsyncSchedule,
chunks=args.chunks,
worker_map=get_worker_map(),
input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
......
......@@ -25,7 +25,7 @@ from torch.optim import Adam
from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule, pipe
from fairscale.nn.pipe import LazyModule, MultiProcessPipe
from fairscale.optim.oss import OSS
from fairscale.utils.testing import dist_init, get_worker_map
......@@ -157,7 +157,7 @@ def dump_cuda_tensors():
def log_number_of_parameters(model):
num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
if model.group:
if hasattr(model, "group"):
total = torch.Tensor([num_params])
if torch.cuda.is_available():
total = total.cuda()
......@@ -212,7 +212,7 @@ def train(model_config, model, benchmark_config, args):
optimizer = optimizer(model.parameters())
pipe_group = model.group
pipe_group = model.group if hasattr(model, "group") else None
if args.ddp_zero:
model = DDP(
......@@ -479,9 +479,7 @@ def benchmark_single_process(args):
model = model_config["model"]
balance = generate_balance(min(num_devices, 4), len(model))
pipe_model = pipe.Pipe(
model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint
)
pipe_model = Pipe(model, balance, chunks=args.chunks, checkpoint=args.checkpoint)
del model
del model_config["model"]
......@@ -498,10 +496,10 @@ def run_mp_worker(args, available_workers):
model = model_config["model"]
balance = generate_balance_weighted(get_pipeline_parallel_group().size(), len(model), 0.8)
pipe_model = pipe.Pipe(
pipe_model = MultiProcessPipe(
model,
balance,
style=Pipe.AsyncSchedule,
style=MultiProcessPipe.AsyncSchedule,
chunks=args.chunks,
worker_map=get_worker_map(),
input_device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
......
......@@ -6,8 +6,8 @@ import torch.distributed as dist
import torch.multiprocessing as mp
import torch.optim as optim
import fairscale
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.pipe import MultiProcessPipe
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RANK = 0 # example
......@@ -27,10 +27,10 @@ def run(rank, world_size):
device = torch.device("cuda", RANK) if DEVICE == "cuda" else torch.device("cpu")
model = fairscale.nn.Pipe(
model = MultiProcessPipe(
model,
balance=[2, 1],
style=fairscale.nn.Pipe.MultiProcess,
style=MultiProcessPipe.MultiProcess,
worker_map={0: "worker0", 1: "worker1"}, # Needed to convert ranks to RPC worker names
input_device=device,
).to(device)
......
......@@ -11,7 +11,7 @@ from torch import nn
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from fairscale.nn.pipe import Pipe
from fairscale.nn.pipe import MultiProcessPipe
from fairscale.nn.pipe.types import PipelineStyle
from .ampnet import AsyncAMPnetEventLoop
......@@ -19,9 +19,9 @@ from .ampnet import AsyncAMPnetEventLoop
__all__ = ["AMPnetPipe"]
class AMPnetPipe(Pipe):
class AMPnetPipe(MultiProcessPipe):
"""
AMPnetPipe is the asynchronous version of the Pipe implementation
AMPnetPipe is the asynchronous version of the MultiProcessPipe implementation
which avoids the bubble issue, by using stale weights and gradients.
The implementation closely follows the paper: https://arxiv.org/abs/1705.09786
"""
......@@ -39,7 +39,7 @@ class AMPnetPipe(Pipe):
weight_prediction: bool = False,
) -> None:
partitions = self.mp_partitions
partitions = self.partitions
n = len(partitions)
# AMPnet implementation doesn't handle skip_trackers!
......
......@@ -23,7 +23,7 @@ from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset
from experimental.nn.ampnet_pipe.pipe import AMPnetPipe
from fairscale.nn.pipe import Pipe
from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn
......@@ -87,7 +87,7 @@ def async_event_loop_interleave_simple():
pipe = AMPnetPipe(
module=model,
balance=[2, 2],
style=Pipe.AsyncSchedule,
style=MultiProcessPipe.AsyncSchedule,
worker_map=get_worker_map(),
chunks=10,
checkpoint="never",
......@@ -105,7 +105,7 @@ def async_event_loop_interleave_hard():
pipe = AMPnetPipe(
module=model,
balance=[1, 1, 1, 1],
style=Pipe.AsyncSchedule,
style=MultiProcessPipe.AsyncSchedule,
worker_map=get_worker_map(),
chunks=10,
checkpoint="never",
......
......@@ -6,7 +6,7 @@
from .data_parallel import ShardedDataParallel
from .misc import FlattenParamsWrapper
from .moe import MOELayer, Top2Gate
from .pipe import LazyModule, Pipe, PipeRPCWrapper
from .pipe import Pipe, PipeRPCWrapper
__all__ = [
"FlattenParamsWrapper",
......
......@@ -19,7 +19,8 @@
"""A Pipe implementation in PyTorch."""
from .checkpoint import is_checkpointing, is_recomputing
from .pipe import LazyModule, Pipe
from .multiprocess_pipe import LazyModule, MultiProcessPipe
from .pipe import Pipe
from .rpc import PipeRPCWrapper
__all__ = ["Pipe", "is_checkpointing", "is_recomputing", "LazyModule"]
......@@ -191,7 +191,7 @@ class AsyncEventLoop:
"""Actually run the forward pass for a given module, and send the result
to the next stage in the pipeline if needed."""
assert self.group
from .pipeline import create_task
from .multiprocess_pipeline import create_task
task = create_task(
PipelineStyle.AsyncSchedule,
......@@ -201,7 +201,6 @@ class AsyncEventLoop:
batch,
partition.module,
skip_trackers,
[],
)
result = task.compute()
task.finalize(result)
......
This diff is collapsed.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The multiprocess pipeline parallelism of Pipe."""
import logging
import os
from queue import Empty as QueueEmpty
from queue import Queue
from threading import Event
from types import TracebackType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union, cast
import torch
from torch import Tensor, nn
from torch.autograd.profiler import record_function
from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from .async_schedule import AsyncEventLoop, ModuleWrapper
from .checkpoint import Checkpointing
from .messages import MakeTransport, Transport
from .microbatch import Batch
from .skip import Namespace
from .skip.layout import SkipLayout
from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from .types import (
ACTIVATIONS_GRADS_QUEUE,
PORTAL_QUEUE,
SKIP_TENSOR_QUEUE,
PipelineStyle,
PipeMessage,
TensorOrTensors,
Tensors,
)
from .worker import Task
__all__: List[str] = []
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
class SendOperator(torch.autograd.Function):
"""Send activations to the next pipeline stage"""
@staticmethod
# type: ignore
def forward(ctx, src_rank, dst_rank, transport: Transport, input: List[Tensor], index: int) -> Tensors:
assert src_rank == torch.distributed.get_rank()
transport.send_message(
PipeMessage(src_rank, dst_rank, queue_name=ACTIVATIONS_GRADS_QUEUE, args=index, tensors=tuple(input)),
)
return ()
@staticmethod
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tensors:
return tuple(grad)
class RecvOperator(torch.autograd.Function):
"""Receive activations to the previous pipeline stage"""
@staticmethod
# type: ignore
def forward(ctx, dst_rank: int, tensor: Tensor, input_device, transport: Transport, index: int) -> Tensors:
assert dst_rank == torch.distributed.get_rank()
ctx.transport = transport
ctx.index = index
result = transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, index)
def maybe_requires_grad(t: Tensor) -> Tensor:
if t.dtype.is_floating_point:
return t.requires_grad_()
return t
return tuple(maybe_requires_grad(r) for r in result)
@staticmethod
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
ctx.transport.send_message(
PipeMessage(
this_rank,
ranks[ranks.index(this_rank) - 1],
queue_name=ACTIVATIONS_GRADS_QUEUE,
args=ctx.index,
tensors=tuple(grad),
),
)
return (None, None, None, None, None)
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if TYPE_CHECKING:
InQueue = Queue[Optional["Task"]]
OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]]
else:
InQueue = Queue
OutQueue = Queue
def create_task(
style: PipelineStyle,
checkpoint_stop: int,
i: int,
j: int,
batch: Batch,
partition: nn.Sequential,
skip_trackers: List[SkipTrackerThroughPotals],
) -> Task:
# Determine whether checkpointing or not.
if i < checkpoint_stop:
def function(
input: TensorOrTensors,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
ret = partition(input)
# We do a check here because the backtrace from the checkpoint backward code path
# is very hard to make sense. It would be much easier to check earlier at this point.
assert type(ret) is not list, "Only Tensor or Tuple of Tensor output is supported"
return ret
chk = Checkpointing(function, batch)
task = Task(None, compute=chk.checkpoint, finalize=chk.recompute)
del function, chk # TODO(tom) maybe remove
else:
def compute(
batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> Batch:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition)
task = Task(None, compute=compute, finalize=None)
del compute # TODO(tom) maybe remove
return task
class MultiProcessPipeline:
"""The multiprocess pipeline parallelism for Pipe."""
def __init__(
self,
partitions: List[nn.Sequential],
skip_layout: SkipLayout,
checkpoint_stop: int,
style: PipelineStyle,
group: Optional[torch.distributed.ProcessGroup] = None,
worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None,
final_stage: bool = False,
) -> None:
self.partitions: List[ModuleWrapper] = cast(List[ModuleWrapper], partitions)
self.skip_layout = skip_layout
self.__checkpoint_stop = checkpoint_stop
self.style = style
self.group = group
self.training: bool
self.transport = MakeTransport(
use_rpc=("OMPI_COMM_WORLD_RANK" not in os.environ) or ("FORCE_RPC" in os.environ),
worker_map=worker_map,
input_device=input_device,
)
self.input_device = input_device
self.all_at_once = False
self.callcount = 0
self.final_stage = final_stage
@property
def checkpoint_stop(self) -> int:
# Disable checkpointing if in eval mode.
training = self.partitions[0].module.training
if not training:
return 0
return self.__checkpoint_stop
def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> None:
"""Runs pipeline parallelism.
It modifies the given batches in place.
"""
self.training = training
m = len(batches)
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))]
if self.style is PipelineStyle.MultiProcess:
assert self.group
schedule = [(i, self.group.rank()) for i in range(m)]
self.compute(batches, schedule, skip_trackers)
elif self.style is PipelineStyle.AsyncSchedule:
assert self.group
rank = self.group.rank()
event_loop = AsyncEventLoop(
self.partitions, self.group, self.transport, self.training, self.checkpoint_stop,
)
if rank == 0 and not self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event head")
event_loop.event_loop_head(batches, skip_trackers, event)
logging.debug(f"{torch.distributed.get_rank()}: exited event head")
elif self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event tail")
event_loop.event_loop_tail(batches, skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event tail")
else:
logging.debug(f"{torch.distributed.get_rank()}: entered event loop")
event_loop.event_loop(len(batches), skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event loop")
self.callcount += 1
def get_batch_from_previous_stage(
self, i: int, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]
) -> Batch:
phony = torch.empty(0, device=self.input_device, requires_grad=True)
result = RecvOperator.apply(torch.distributed.get_rank(), phony, self.input_device, self.transport, i)
if len(result) == 1:
batch = Batch(result[0], i)
else:
batch = Batch(result, i)
self.recv_skip_tensors(skip_trackers, batches)
return batch
def send_skip_tensors(
self, this_rank: int, ranks: List[int], batch: Batch, i: int, skip_trackers: List[SkipTrackerThroughPotals]
) -> None:
assert self.group
for next_j, ns, name in self.skip_layout.copy_policy_by_src(self.group.rank()):
life = skip_trackers[i].portals[(ns, name)].tensor_life
loaded = skip_trackers[i].load(batch, ns, name)
if loaded is not None:
tensors = tuple([loaded])
else:
tensors = tuple()
self.transport.send_message(
PipeMessage(
this_rank, ranks[next_j], queue_name=SKIP_TENSOR_QUEUE, args=(i, ns, name, life), tensors=tensors,
),
sync=True,
)
def recv_skip_tensors(self, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]) -> None:
while True:
try:
message = self.transport.recv_message(SKIP_TENSOR_QUEUE, nowait=True)
(si, ns, name, life) = message.args
value: Optional[TensorOrTensors] = message.tensors
assert isinstance(value, tuple)
if len(value) == 0:
value = None
else:
assert len(value) == 1
value = value[0]
skip_trackers[si].save(batches[si], ns, name, value)
old_life = skip_trackers[si].portals[(ns, name)].tensor_life
if life != 0:
skip_trackers[si].portals[(ns, name)].tensor_life = life
except QueueEmpty:
break
def execute_task(self, task: Task, i: int, skip_trackers: List[SkipTrackerThroughPotals]) -> Batch:
batch = task.compute()
assert self.group
rank = self.group.rank()
if self.style is PipelineStyle.MultiProcess and not self.final_stage:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
self.send_skip_tensors(this_rank, ranks, batch, i, skip_trackers)
SendOperator.apply(this_rank, ranks[ranks.index(this_rank) + 1], self.transport, [*batch], i)
for portal in skip_trackers[i].portals.values():
portal.pipeline = self
task.finalize(batch)
return batch
def compute(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals]
) -> None:
"""Runs tasks with synchronization to copy streams."""
if self.style is PipelineStyle.MultiProcess:
assert self.group
n = self.group.size()
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for i, j in schedule:
batch = batches[i]
if self.style is PipelineStyle.MultiProcess:
assert len(self.partitions) == 1
partition = self.partitions[0]
assert self.group
if self.group.rank() != 0:
batch = self.get_batch_from_previous_stage(i, skip_trackers, batches)
task = create_task(self.style, self.checkpoint_stop, i, j, batch, partition.module, skip_trackers)
batches[i] = self.execute_task(task, i, skip_trackers)
def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: TensorOrTensors) -> None:
dest, src = self.skip_layout.by_ns_name.get(ns_name, (-1, -1))
if dest == src:
return
ranks = get_pipeline_parallel_ranks()
dst_rank = ranks[dest]
if dst_rank == torch.distributed.get_rank():
return
if isinstance(grad, Tensor):
grad = tuple([grad])
self.transport.send_message(
PipeMessage(ranks[src], dst_rank, queue_name=PORTAL_QUEUE, args=(ns_name, index), tensors=grad), sync=True,
)
def recv_portal_grad(self, expected_ns_name: Tuple[Namespace, str], expected_index: int) -> Tensor:
message = self.transport.recv_message(PORTAL_QUEUE)
(ns_name, index) = message.args
grad = message.tensors
assert len(grad) == 1
result = grad[0]
assert index == expected_index and ns_name == expected_ns_name
return result
def back_helper(self, output: List[Batch]) -> None:
if self.style == PipelineStyle.AsyncSchedule:
return
o = list(output)
tensors: Tensors
if self.all_at_once:
# FIXME(tom) allow specifying this branch when constructing Pipe(), add a test
grads = []
for i, batch in enumerate(o):
rank = torch.distributed.get_rank()
found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, i)
assert len(found) == 1
grads.append(found[0])
tensors = tuple(x.tensor_or_tensors for x in o) # type: ignore
try:
torch.autograd.backward(tensors, grad_tensors=grads, retain_graph=True)
except Exception as e:
raise RuntimeError("Autograd failed") from e
else:
rank = torch.distributed.get_rank()
for batch in o:
found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, batch.index)
if batch.atomic:
tensors = tuple([batch.tensor])
else:
tensors = batch.tensors
if len(found) != len(tensors):
raise RuntimeError("different number of tensors and gradients")
grads = []
final_tensors = []
for i, tensor in enumerate(tensors):
if tensor.requires_grad or getattr(tensor, "grad_fn", None) is not None:
grads.append(found[i])
final_tensors.append(tensor)
try:
torch.autograd.backward(final_tensors, grad_tensors=grads, retain_graph=True)
except Exception as e:
raise RuntimeError(f"Autograd failed on {torch.distributed.get_rank()}") from e
This diff is collapsed.
This diff is collapsed.
......@@ -13,13 +13,13 @@ from torch.distributed.distributed_c10d import _get_global_rank
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from . import Pipe
from .multiprocess_pipe import MultiProcessPipe
from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
PipeModel: Pipe
PipeModel: MultiProcessPipe
PipeResult: TensorOrTensors
......@@ -71,7 +71,7 @@ class PipeBackRedirect(torch.autograd.Function):
return (None, None, None, None, None, None)
def callback_with_model(callback: Callable[[Any, Pipe], None], ctx: Any) -> None:
def callback_with_model(callback: Callable[[Any, MultiProcessPipe], None], ctx: Any) -> None:
try:
group = get_pipeline_parallel_group() # FIXME(tom) handle dynamic group
set_device_based_on_group(group)
......@@ -105,10 +105,10 @@ class PipeRPCWrapper(nn.Module):
else:
kwargs["group"] = self.group
kwargs["style"] = Pipe.AsyncSchedule
kwargs["style"] = MultiProcessPipe.AsyncSchedule
kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device())
self.model = Pipe(*args, **kwargs)
self.model = MultiProcessPipe(*args, **kwargs)
self.worker_map = kwargs["worker_map"]
self._foreach_worker(self._register_remote_model, args=(args, kwargs))
self.model.cuda()
......@@ -121,7 +121,7 @@ class PipeRPCWrapper(nn.Module):
futures = [f.wait() for f in futures]
def foreach_worker(
self, callback: Callable[[Any, Pipe], None], ctx: Any = None, *, include_self: bool = False
self, callback: Callable[[Any, MultiProcessPipe], None], ctx: Any = None, *, include_self: bool = False
) -> None:
"""Call `callback` on each worker with the `ctx` and model local to that
worker. e.g.
......@@ -196,7 +196,9 @@ class PipeRPCWrapper(nn.Module):
return self.model.final_stage
@staticmethod
def _recv_result(model: Pipe, shapes: SizeOrSizes, dtypes: DtypeOrDtypes, message: PipeMessage) -> TensorOrTensors:
def _recv_result(
model: MultiProcessPipe, shapes: SizeOrSizes, dtypes: DtypeOrDtypes, message: PipeMessage
) -> TensorOrTensors:
group = get_pipeline_parallel_group()
set_device_based_on_group(group)
......@@ -243,7 +245,7 @@ class PipeRPCWrapper(nn.Module):
set_device_based_on_group(group)
kwargs["group"] = group
kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device())
model = Pipe(*args, **kwargs)
model = MultiProcessPipe(*args, **kwargs)
model.cuda()
global PipeModel
PipeModel = model
......
......@@ -24,7 +24,6 @@ Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
InputDevice = Union[None, int, str, torch.device]
Schedule = List[Tuple[int, int]]
class LazyModule:
......
......@@ -5,7 +5,7 @@ from typing import Any, Callable, Optional, Tuple
from torch import Tensor
def spawn(
fn: Callable[[Any], Any],
fn: Callable[..., Any],
args: Tuple[Optional[Any], ...] = (),
nprocs: int = 1,
join: bool = True,
......
......@@ -31,7 +31,7 @@ from torch.nn.parameter import Parameter
from fairscale.nn.model_parallel import initialize as mpu
from fairscale.nn.model_parallel import layers
from fairscale.nn.pipe import Pipe
from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import dist_init, get_world_sizes, set_random_seed, spawn_for_all_world_sizes, torch_spawn
......@@ -319,7 +319,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
model_parallel_size = mpu.get_model_parallel_world_size()
if torch.distributed.get_rank() == 0:
print(
"> testing Sequential + Pipe with model parallel size: {}, pipe: {}".format(
"> testing Sequential + MultiProcessPipe with model parallel size: {}, pipe: {}".format(
model_parallel_size, pipe_world_size
)
)
......@@ -431,13 +431,13 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
model[2].weight.data = saved_weight_2
worker_map = {i: f"Test{i}" for i in range(torch.distributed.get_world_size())}
style = Pipe.MultiProcess # Pipe.AsyncSchedule
style = MultiProcessPipe.MultiProcess # MultiProcessPipe.AsyncSchedule
if pipe_world_size == 2:
print(f"actually doing pipe stuff now")
assert torch.equal(saved_weight_0, model[0].weight.data)
assert torch.equal(saved_weight_2, model[2].weight.data)
pipe_model = Pipe(
pipe_model = MultiProcessPipe(
model,
[2, 1],
style=style,
......@@ -507,7 +507,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
failed = False
with torch.autograd.profiler.profile() as prof:
try:
if style == Pipe.MultiProcess:
if style == MultiProcessPipe.MultiProcess:
pipe_model.back_helper(pipe_output)
except Exception as e:
failed = True
......
......@@ -23,7 +23,7 @@ import pytest
import torch
from torch import nn
from fairscale.nn.pipe import LazyModule, Pipe
from fairscale.nn.pipe import LazyModule, MultiProcessPipe
from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange
from fairscale.utils.testing import get_worker_map, torch_spawn
......@@ -33,12 +33,12 @@ from fairscale.utils.testing import get_worker_map, torch_spawn
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
def x1to3(balance, checkpoint, pipeline_style):
torch.manual_seed(0)
if pipeline_style == Pipe.AsyncSchedule and len(balance) > 1:
if pipeline_style == MultiProcessPipe.AsyncSchedule and len(balance) > 1:
print(f"skipping yarg")
pytest.skip("Skip tensors NYI for AsyncSchedule")
......@@ -74,7 +74,7 @@ def x1to3(balance, checkpoint, pipeline_style):
return output
model = nn.Sequential(Layer1(), Layer2(), Layer3())
model = Pipe(
model = MultiProcessPipe(
model,
balance,
chunks=3,
......@@ -106,9 +106,10 @@ def x1to3(balance, checkpoint, pipeline_style):
@torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
@pytest.mark.skip(reason="flaky test")
def none_skip(pipeline_style):
if pipeline_style == Pipe.AsyncSchedule:
if pipeline_style == MultiProcessPipe.AsyncSchedule:
pytest.skip("Skip tensors NYI for AsyncSchedule")
@skippable(stash=["none"])
......@@ -125,7 +126,7 @@ def none_skip(pipeline_style):
return input
model = nn.Sequential(Stash(), Pop())
model = Pipe(
model = MultiProcessPipe(
model,
[1, 1],
style=pipeline_style,
......@@ -160,7 +161,7 @@ def none_skip(pipeline_style):
@torch_spawn([2])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def lazy_skippable_error(pipeline_style):
"""Using skippable layers in combination with lazy construction is currently
not supported, check that it raises an Exception"""
......@@ -180,6 +181,6 @@ def lazy_skippable_error(pipeline_style):
]
with pytest.raises(ValueError, match="Can't use Skippable layers with multi-process pipe and lazy construction"):
Pipe(
MultiProcessPipe(
model, [2, 1], style=pipeline_style, worker_map=get_worker_map(),
)
......@@ -23,7 +23,7 @@ import pytest
import torch
from torch import nn
from fairscale.nn.pipe import Pipe, is_checkpointing, is_recomputing
from fairscale.nn.pipe import MultiProcessPipe, is_checkpointing, is_recomputing
from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.tracker import current_skip_tracker
from fairscale.utils.testing import get_worker_map, torch_spawn
......@@ -46,7 +46,7 @@ class Pop(nn.Module):
@torch_spawn([2])
@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"])
@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"])
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def delete_portal_tensor(train, checkpoint, pipeline_style):
......@@ -60,7 +60,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
# +----------+ +------------+ +------------+ +----------+
if pipeline_style == Pipe.AsyncSchedule:
if pipeline_style == MultiProcessPipe.AsyncSchedule:
pytest.skip("Skip tensors NYI for AsyncSchedule")
def portal_tensor_life_is(tensor_life, skip_tracker=None):
......@@ -114,7 +114,7 @@ def delete_portal_tensor(train, checkpoint, pipeline_style):
return self.F.apply(input)
model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_)
model = Pipe(
model = MultiProcessPipe(
model, balance=[2, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint,
)
......
......@@ -22,15 +22,15 @@ import torch
from torch import nn
import torch.nn.functional as F
from fairscale.nn.pipe import Pipe
from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn
@torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def python_autograd_function(pipeline_style):
# FIXME deadlock with Pipe.AsyncSchedule?
# FIXME deadlock with MultiProcessPipe.AsyncSchedule?
# A Python autograd function might fail with this error:
#
# RuntimeError: Returning Variables sharing storage with other Variables
......@@ -57,7 +57,9 @@ def python_autograd_function(pipeline_style):
return Identity.apply(input)
model = nn.Sequential(M(), M())
model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always").cuda()
model = MultiProcessPipe(
model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always"
).cuda()
model.eval()
x = torch.rand(42)
......@@ -71,7 +73,7 @@ def python_autograd_function(pipeline_style):
@torch_spawn([3])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def exception_no_hang(pipeline_style):
# In v0.0.2, once a failed partition receives a normal message
# (non-closing) for the next micro-batch, a hang occured. The reason was
......@@ -90,7 +92,7 @@ def exception_no_hang(pipeline_style):
raise ExpectedException()
model = nn.Sequential(Pass(), Pass(), Raise())
model = Pipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3)
model = MultiProcessPipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3)
model.eval()
if model.group.rank() == 2:
......@@ -104,7 +106,7 @@ def exception_no_hang(pipeline_style):
@torch_spawn([2])
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def tuple_wait(cuda_sleep, pipeline_style):
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# Under this behavior, if checkpointing was disabled, there's a possibility
......@@ -133,7 +135,7 @@ def tuple_wait(cuda_sleep, pipeline_style):
return a + b + c
model = nn.Sequential(Layer1(), Layer2())
model = Pipe(
model = MultiProcessPipe(
model,
[1, 1],
style=pipeline_style,
......@@ -158,7 +160,7 @@ def tuple_wait(cuda_sleep, pipeline_style):
@torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def parallel_randoms(pipeline_style):
class Dropouts(nn.Module):
def forward(self, x):
......@@ -170,7 +172,7 @@ def parallel_randoms(pipeline_style):
x = torch.rand(10, 10, requires_grad=True).cuda()
x.retain_grad()
model = Pipe(
model = MultiProcessPipe(
model,
[1, 1],
style=pipeline_style,
......
......@@ -21,20 +21,20 @@ import pytest
import torch
from torch import nn
from fairscale.nn.pipe import Pipe
from fairscale.nn.pipe import MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn
@torch_spawn([2])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def inplace_on_requires_grad(pipeline_style):
model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True))
model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
model = MultiProcessPipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
x = torch.rand(1)
if pipeline_style == Pipe.AsyncSchedule and model.group.rank() == 0:
if pipeline_style == MultiProcessPipe.AsyncSchedule and model.group.rank() == 0:
# With AsyncSchedule, model will wait forever for gradients if not eval
model.eval()
......@@ -50,12 +50,12 @@ def inplace_on_requires_grad(pipeline_style):
@torch_spawn([1])
@pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def inplace_on_not_requires_grad(pipeline_style):
# In-place operation on a tensor not requiring grad doesn't cause a
# RuntimeError. Currently, we cannot detect this case.
model = nn.Sequential(nn.ReLU(inplace=True))
model = Pipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
model = MultiProcessPipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
x = torch.rand(1)
y = model(x)
......@@ -70,7 +70,7 @@ def inplace_on_not_requires_grad(pipeline_style):
@torch_spawn([1])
@pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
@pytest.mark.parametrize("pipeline_style", [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def inplace_incorrect_grad(pipeline_style):
class M(nn.Module):
def forward(self, foo_bar):
......@@ -88,7 +88,7 @@ def inplace_incorrect_grad(pipeline_style):
return foo * bar
model = nn.Sequential(M())
model = Pipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
model = MultiProcessPipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always")
foo = torch.tensor([1.0], requires_grad=True)
bar = torch.tensor([1.0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment