Unverified Commit b5c3638f authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] pipe: move async-specific code out of MultiProcessPipeline (#345)

parent a8dd9254
...@@ -11,11 +11,11 @@ from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union ...@@ -11,11 +11,11 @@ from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from .async_pipeline import AsyncPipeline
from .async_schedule import Invocation, Location, ModuleWrapper from .async_schedule import Invocation, Location, ModuleWrapper
from .multiprocess_pipe import MultiProcessPipe, check_balance from .multiprocess_pipe import MultiProcessPipe, check_balance
from .multiprocess_pipeline import MultiProcessPipeline
from .skip.skippable import Skippable from .skip.skippable import Skippable
from .types import LazyModule, PipelineStyle from .types import LazyModule
if TYPE_CHECKING: if TYPE_CHECKING:
Module = nn.Module[TensorOrTensors] Module = nn.Module[TensorOrTensors]
...@@ -43,11 +43,10 @@ class AsyncPipe(MultiProcessPipe): ...@@ -43,11 +43,10 @@ class AsyncPipe(MultiProcessPipe):
# The micro-batch index where the checkpointing stops. # The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
self.pipeline = MultiProcessPipeline( self.pipeline = AsyncPipeline(
self.partitions, self.partitions,
self._skip_layout, self._skip_layout,
checkpoint_stop, checkpoint_stop,
style=PipelineStyle.AsyncSchedule,
group=self.group, group=self.group,
worker_map=self.worker_map, worker_map=self.worker_map,
input_device=self.input_device, input_device=self.input_device,
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from threading import Event
from typing import List, Optional
import torch
from .async_schedule import AsyncEventLoop
from .microbatch import Batch
from .multiprocess_pipeline import MultiProcessPipeline
from .skip.tracker import SkipTrackerThroughPotals
class AsyncPipeline(MultiProcessPipeline):
def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> None:
"""Runs pipeline parallelism.
It modifies the given batches in place.
"""
self.training = training
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))]
rank = self.group.rank()
event_loop = AsyncEventLoop(self.partitions, self.group, self.transport, self.training, self.checkpoint_stop,)
if rank == 0 and not self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event head")
event_loop.event_loop_head(batches, skip_trackers, event)
logging.debug(f"{torch.distributed.get_rank()}: exited event head")
elif self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event tail")
event_loop.event_loop_tail(batches, skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event tail")
else:
logging.debug(f"{torch.distributed.get_rank()}: entered event loop")
event_loop.event_loop(len(batches), skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event loop")
def back_helper(self, output: List[Batch]) -> None:
pass
...@@ -37,7 +37,7 @@ from .multiprocess_pipeline import MultiProcessPipeline ...@@ -37,7 +37,7 @@ from .multiprocess_pipeline import MultiProcessPipeline
from .phony import get_phony from .phony import get_phony
from .skip.layout import SkipLayout, inspect_skip_layout from .skip.layout import SkipLayout, inspect_skip_layout
from .skip.skippable import Skippable, verify_skippables from .skip.skippable import Skippable, verify_skippables
from .types import LazyModule, PipelineStyle from .types import LazyModule
__all__ = ["MultiProcessPipe", "LazyModule"] __all__ = ["MultiProcessPipe", "LazyModule"]
...@@ -202,11 +202,8 @@ class MultiProcessPipe(Module): ...@@ -202,11 +202,8 @@ class MultiProcessPipe(Module):
list of number of layers in each partition list of number of layers in each partition
Keyword Args: Keyword Args:
style (PipelineStyle):
whether to use a single process for all pipeline stages or to assign
one stage per process
group (ProcessGroup): group (ProcessGroup):
specific to `style=MultiProcess`, the process group that all the process group that all
pipeline stages are a member of. Defaults to pipeline stages are a member of. Defaults to
`get_pipeline_parallel_group()` `get_pipeline_parallel_group()`
worker_map (Dict[int, str]): worker_map (Dict[int, str]):
...@@ -374,7 +371,6 @@ class MultiProcessPipe(Module): ...@@ -374,7 +371,6 @@ class MultiProcessPipe(Module):
self.partitions, self.partitions,
self._skip_layout, self._skip_layout,
checkpoint_stop, checkpoint_stop,
style=PipelineStyle.MultiProcess,
group=self.group, group=self.group,
worker_map=self.worker_map, worker_map=self.worker_map,
input_device=self.input_device, input_device=self.input_device,
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
# limitations under the License. # limitations under the License.
"""The multiprocess pipeline parallelism of Pipe.""" """The multiprocess pipeline parallelism of Pipe."""
import logging
import os import os
from queue import Empty as QueueEmpty from queue import Empty as QueueEmpty
from queue import Queue from queue import Queue
...@@ -31,22 +30,14 @@ from torch.autograd.profiler import record_function ...@@ -31,22 +30,14 @@ from torch.autograd.profiler import record_function
from fairscale.nn.model_parallel import get_pipeline_parallel_ranks from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from .async_schedule import AsyncEventLoop, ModuleWrapper from .async_schedule import ModuleWrapper
from .checkpoint import Checkpointing from .checkpoint import Checkpointing
from .messages import MakeTransport, Transport from .messages import MakeTransport, Transport
from .microbatch import Batch from .microbatch import Batch
from .skip import Namespace from .skip import Namespace
from .skip.layout import SkipLayout from .skip.layout import SkipLayout
from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from .types import ( from .types import ACTIVATIONS_GRADS_QUEUE, PORTAL_QUEUE, SKIP_TENSOR_QUEUE, PipeMessage, TensorOrTensors, Tensors
ACTIVATIONS_GRADS_QUEUE,
PORTAL_QUEUE,
SKIP_TENSOR_QUEUE,
PipelineStyle,
PipeMessage,
TensorOrTensors,
Tensors,
)
from .worker import Task from .worker import Task
__all__: List[str] = [] __all__: List[str] = []
...@@ -174,8 +165,8 @@ class MultiProcessPipeline: ...@@ -174,8 +165,8 @@ class MultiProcessPipeline:
partitions: List[ModuleWrapper], partitions: List[ModuleWrapper],
skip_layout: SkipLayout, skip_layout: SkipLayout,
checkpoint_stop: int, checkpoint_stop: int,
style: PipelineStyle,
group: torch.distributed.ProcessGroup, group: torch.distributed.ProcessGroup,
*,
worker_map: Optional[Dict[int, str]] = None, worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None, input_device: Union[None, int, str, torch.device] = None,
final_stage: bool = False, final_stage: bool = False,
...@@ -183,7 +174,6 @@ class MultiProcessPipeline: ...@@ -183,7 +174,6 @@ class MultiProcessPipeline:
self.partitions = partitions self.partitions = partitions
self.skip_layout = skip_layout self.skip_layout = skip_layout
self.__checkpoint_stop = checkpoint_stop self.__checkpoint_stop = checkpoint_stop
self.style = style
self.group = group self.group = group
self.training: bool self.training: bool
self.transport = MakeTransport( self.transport = MakeTransport(
...@@ -192,7 +182,6 @@ class MultiProcessPipeline: ...@@ -192,7 +182,6 @@ class MultiProcessPipeline:
input_device=input_device, input_device=input_device,
) )
self.input_device = input_device self.input_device = input_device
self.callcount = 0
self.final_stage = final_stage self.final_stage = final_stage
@property @property
...@@ -214,30 +203,22 @@ class MultiProcessPipeline: ...@@ -214,30 +203,22 @@ class MultiProcessPipeline:
m = len(batches) m = len(batches)
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))] skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(m)]
if self.style is PipelineStyle.MultiProcess: schedule = [(i, self.group.rank()) for i in range(m)]
schedule = [(i, self.group.rank()) for i in range(m)]
self.compute(batches, schedule, skip_trackers) for i, j in schedule:
elif self.style is PipelineStyle.AsyncSchedule: assert len(self.partitions) == 1
rank = self.group.rank() partition = self.partitions[0]
event_loop = AsyncEventLoop(
self.partitions, self.group, self.transport, self.training, self.checkpoint_stop, if self.group.rank() != 0:
) batch = self.get_batch_from_previous_stage(i, skip_trackers, batches)
if rank == 0 and not self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event head")
event_loop.event_loop_head(batches, skip_trackers, event)
logging.debug(f"{torch.distributed.get_rank()}: exited event head")
elif self.final_stage:
logging.debug(f"{torch.distributed.get_rank()}: entered event tail")
event_loop.event_loop_tail(batches, skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event tail")
else: else:
logging.debug(f"{torch.distributed.get_rank()}: entered event loop") batch = batches[i]
event_loop.event_loop(len(batches), skip_trackers)
logging.debug(f"{torch.distributed.get_rank()}: exited event loop")
self.callcount += 1 task = create_task(self.checkpoint_stop, i, j, batch, partition.module, skip_trackers)
batches[i] = self.execute_task(task, i, skip_trackers)
def get_batch_from_previous_stage( def get_batch_from_previous_stage(
self, i: int, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch] self, i: int, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]
...@@ -299,7 +280,7 @@ class MultiProcessPipeline: ...@@ -299,7 +280,7 @@ class MultiProcessPipeline:
rank = self.group.rank() rank = self.group.rank()
if self.style is PipelineStyle.MultiProcess and not self.final_stage: if not self.final_stage:
ranks = get_pipeline_parallel_ranks() ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank() this_rank = torch.distributed.get_rank()
...@@ -313,51 +294,6 @@ class MultiProcessPipeline: ...@@ -313,51 +294,6 @@ class MultiProcessPipeline:
return batch return batch
def compute(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals]
) -> None:
"""Runs tasks with synchronization to copy streams."""
assert self.style is PipelineStyle.MultiProcess
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for i, j in schedule:
assert len(self.partitions) == 1
partition = self.partitions[0]
if self.group.rank() != 0:
batch = self.get_batch_from_previous_stage(i, skip_trackers, batches)
else:
batch = batches[i]
task = create_task(self.checkpoint_stop, i, j, batch, partition.module, skip_trackers)
batches[i] = self.execute_task(task, i, skip_trackers)
def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: TensorOrTensors) -> None: def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: TensorOrTensors) -> None:
dest, src = self.skip_layout.by_ns_name.get(ns_name, (-1, -1)) dest, src = self.skip_layout.by_ns_name.get(ns_name, (-1, -1))
if dest == src: if dest == src:
...@@ -385,9 +321,6 @@ class MultiProcessPipeline: ...@@ -385,9 +321,6 @@ class MultiProcessPipeline:
return result return result
def back_helper(self, output: List[Batch]) -> None: def back_helper(self, output: List[Batch]) -> None:
if self.style == PipelineStyle.AsyncSchedule:
return
tensors: Tensors tensors: Tensors
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Union
import torch import torch
...@@ -34,11 +33,6 @@ class LazyModule: ...@@ -34,11 +33,6 @@ class LazyModule:
return self.function() return self.function()
class PipelineStyle(Enum):
MultiProcess = auto()
AsyncSchedule = auto()
@dataclass(init=False) @dataclass(init=False)
class PipeMessage: class PipeMessage:
src: int src: int
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment