Unverified Commit b5c3638f authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] pipe: move async-specific code out of MultiProcessPipeline (#345)

parent a8dd9254
......@@ -11,11 +11,11 @@ from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
import torch
from torch import Tensor, nn
from .async_pipeline import AsyncPipeline
from .async_schedule import Invocation, Location, ModuleWrapper
from .multiprocess_pipe import MultiProcessPipe, check_balance
from .multiprocess_pipeline import MultiProcessPipeline
from .skip.skippable import Skippable
from .types import LazyModule, PipelineStyle
from .types import LazyModule
if TYPE_CHECKING:
Module = nn.Module[TensorOrTensors]
......@@ -43,11 +43,10 @@ class AsyncPipe(MultiProcessPipe):
# The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
self.pipeline = MultiProcessPipeline(
self.pipeline = AsyncPipeline(
self.partitions,
self._skip_layout,
checkpoint_stop,
style=PipelineStyle.AsyncSchedule,
group=self.group,
worker_map=self.worker_map,
input_device=self.input_device,
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from threading import Event
from typing import List, Optional
import torch
from .async_schedule import AsyncEventLoop
from .microbatch import Batch
from .multiprocess_pipeline import MultiProcessPipeline
from .skip.tracker import SkipTrackerThroughPotals
class AsyncPipeline(MultiProcessPipeline):
def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> None:
"""Runs pipeline parallelism.
It modifies the given batches in place.
"""
self.training = training
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))]
rank = self.group.rank()
event_loop = AsyncEventLoop(self.partitions, self.group, self.transport, self.training, self.checkpoint_stop,)
if rank == 0 and not self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event head")
event_loop.event_loop_head(batches, skip_trackers, event)
logging.debug(f"{torch.distributed.get_rank()}: exited event head")
elif self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event tail")
event_loop.event_loop_tail(batches, skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event tail")
else:
logging.debug(f"{torch.distributed.get_rank()}: entered event loop")
event_loop.event_loop(len(batches), skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event loop")
def back_helper(self, output: List[Batch]) -> None:
pass
......@@ -37,7 +37,7 @@ from .multiprocess_pipeline import MultiProcessPipeline
from .phony import get_phony
from .skip.layout import SkipLayout, inspect_skip_layout
from .skip.skippable import Skippable, verify_skippables
from .types import LazyModule, PipelineStyle
from .types import LazyModule
__all__ = ["MultiProcessPipe", "LazyModule"]
......@@ -202,11 +202,8 @@ class MultiProcessPipe(Module):
list of number of layers in each partition
Keyword Args:
style (PipelineStyle):
whether to use a single process for all pipeline stages or to assign
one stage per process
group (ProcessGroup):
specific to `style=MultiProcess`, the process group that all
the process group that all
pipeline stages are a member of. Defaults to
`get_pipeline_parallel_group()`
worker_map (Dict[int, str]):
......@@ -374,7 +371,6 @@ class MultiProcessPipe(Module):
self.partitions,
self._skip_layout,
checkpoint_stop,
style=PipelineStyle.MultiProcess,
group=self.group,
worker_map=self.worker_map,
input_device=self.input_device,
......
......@@ -17,7 +17,6 @@
# limitations under the License.
"""The multiprocess pipeline parallelism of Pipe."""
import logging
import os
from queue import Empty as QueueEmpty
from queue import Queue
......@@ -31,22 +30,14 @@ from torch.autograd.profiler import record_function
from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from .async_schedule import AsyncEventLoop, ModuleWrapper
from .async_schedule import ModuleWrapper
from .checkpoint import Checkpointing
from .messages import MakeTransport, Transport
from .microbatch import Batch
from .skip import Namespace
from .skip.layout import SkipLayout
from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from .types import (
ACTIVATIONS_GRADS_QUEUE,
PORTAL_QUEUE,
SKIP_TENSOR_QUEUE,
PipelineStyle,
PipeMessage,
TensorOrTensors,
Tensors,
)
from .types import ACTIVATIONS_GRADS_QUEUE, PORTAL_QUEUE, SKIP_TENSOR_QUEUE, PipeMessage, TensorOrTensors, Tensors
from .worker import Task
__all__: List[str] = []
......@@ -174,8 +165,8 @@ class MultiProcessPipeline:
partitions: List[ModuleWrapper],
skip_layout: SkipLayout,
checkpoint_stop: int,
style: PipelineStyle,
group: torch.distributed.ProcessGroup,
*,
worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None,
final_stage: bool = False,
......@@ -183,7 +174,6 @@ class MultiProcessPipeline:
self.partitions = partitions
self.skip_layout = skip_layout
self.__checkpoint_stop = checkpoint_stop
self.style = style
self.group = group
self.training: bool
self.transport = MakeTransport(
......@@ -192,7 +182,6 @@ class MultiProcessPipeline:
input_device=input_device,
)
self.input_device = input_device
self.callcount = 0
self.final_stage = final_stage
@property
......@@ -214,30 +203,22 @@ class MultiProcessPipeline:
m = len(batches)
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))]
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(m)]
if self.style is PipelineStyle.MultiProcess:
schedule = [(i, self.group.rank()) for i in range(m)]
self.compute(batches, schedule, skip_trackers)
elif self.style is PipelineStyle.AsyncSchedule:
rank = self.group.rank()
event_loop = AsyncEventLoop(
self.partitions, self.group, self.transport, self.training, self.checkpoint_stop,
)
if rank == 0 and not self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event head")
event_loop.event_loop_head(batches, skip_trackers, event)
logging.debug(f"{torch.distributed.get_rank()}: exited event head")
elif self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event tail")
event_loop.event_loop_tail(batches, skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event tail")
for i, j in schedule:
assert len(self.partitions) == 1
partition = self.partitions[0]
if self.group.rank() != 0:
batch = self.get_batch_from_previous_stage(i, skip_trackers, batches)
else:
logging.debug(f"{torch.distributed.get_rank()}: entered event loop")
event_loop.event_loop(len(batches), skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event loop")
batch = batches[i]
self.callcount += 1
task = create_task(self.checkpoint_stop, i, j, batch, partition.module, skip_trackers)
batches[i] = self.execute_task(task, i, skip_trackers)
def get_batch_from_previous_stage(
self, i: int, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]
......@@ -299,7 +280,7 @@ class MultiProcessPipeline:
rank = self.group.rank()
if self.style is PipelineStyle.MultiProcess and not self.final_stage:
if not self.final_stage:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
......@@ -313,51 +294,6 @@ class MultiProcessPipeline:
return batch
def compute(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals]
) -> None:
"""Runs tasks with synchronization to copy streams."""
assert self.style is PipelineStyle.MultiProcess
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for i, j in schedule:
assert len(self.partitions) == 1
partition = self.partitions[0]
if self.group.rank() != 0:
batch = self.get_batch_from_previous_stage(i, skip_trackers, batches)
else:
batch = batches[i]
task = create_task(self.checkpoint_stop, i, j, batch, partition.module, skip_trackers)
batches[i] = self.execute_task(task, i, skip_trackers)
def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: TensorOrTensors) -> None:
dest, src = self.skip_layout.by_ns_name.get(ns_name, (-1, -1))
if dest == src:
......@@ -385,9 +321,6 @@ class MultiProcessPipeline:
return result
def back_helper(self, output: List[Batch]) -> None:
if self.style == PipelineStyle.AsyncSchedule:
return
tensors: Tensors
rank = torch.distributed.get_rank()
......
......@@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
......@@ -34,11 +33,6 @@ class LazyModule:
return self.function()
class PipelineStyle(Enum):
MultiProcess = auto()
AsyncSchedule = auto()
@dataclass(init=False)
class PipeMessage:
src: int
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment