Unverified Commit 8f77255b authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] multiprocess_pipe: avoid unnecessary use of create_task and other cleanup (#456)

parent d2924670
......@@ -11,15 +11,64 @@ from typing import Dict, Iterable, List, Optional, Tuple
import torch
from torch import Tensor, nn
from torch.autograd.profiler import record_function
from torch.distributed import ProcessGroup
from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from .checkpoint import Checkpointing
from .messages import Transport
from .microbatch import Batch
from .multiprocess_pipeline import create_task
from .skip.tracker import SkipTrackerThroughPotals
from .types import EVENT_LOOP_QUEUE, PipeMessage, Tensors
from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors, Tensors
from .worker import Task
def create_task(
checkpoint_stop: int,
chunk_id: int,
part_id: int,
batch: Batch,
partition: nn.Sequential,
skip_trackers: List[SkipTrackerThroughPotals],
) -> Task:
# Determine whether checkpointing or not.
if chunk_id < checkpoint_stop:
def function(
input: TensorOrTensors,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[chunk_id],
chunk_id: int = chunk_id,
part_id: int = part_id,
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
ret = partition(input)
# We do a check here because the backtrace from the checkpoint backward code path
# is very hard to make sense. It would be much easier to check earlier at this point.
assert type(ret) is not list, "Only Tensor or Tuple of Tensor output is supported"
return ret
chk = Checkpointing(function, batch)
task = Task(None, compute=chk.checkpoint, finalize=chk.recompute)
del function, chk # TODO(tom) maybe remove
else:
def compute(
batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[chunk_id],
chunk_id: int = chunk_id,
part_id: int = part_id,
) -> Batch:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition)
task = Task(None, compute=compute, finalize=None)
del compute # TODO(tom) maybe remove
return task
@dataclass(frozen=True)
......
......@@ -256,7 +256,7 @@ class MultiProcessPipe(Module):
"""Iterates over children of the underlying sequential module."""
return self.partition.__iter__()
def forward(self, input: TensorOrTensors, *, event=None) -> TensorOrTensors: # type: ignore
def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore
""":class:`MultiProcessPipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But
there's type restriction. Input and output have to be a
......@@ -284,7 +284,7 @@ class MultiProcessPipe(Module):
# Run pipeline parallelism.
with self.lock:
self.pipeline.run(self.training, batches, event)
self.pipeline.run(self.training, batches)
if self.final_stage:
# Merge the micro-batches into one mini-batch.
......
......@@ -20,7 +20,6 @@
import os
from queue import Empty as QueueEmpty
from queue import Queue
from threading import Event
from types import TracebackType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
......@@ -39,6 +38,16 @@ from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from .types import ACTIVATIONS_GRADS_QUEUE, PORTAL_QUEUE, SKIP_TENSOR_QUEUE, PipeMessage, TensorOrTensors, Tensors
from .worker import Task
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if TYPE_CHECKING:
InQueue = Queue[Optional[Task]]
OutQueue = Queue[Tuple[bool, Union[Tuple[Task, Batch], ExcInfo, None]]]
else:
InQueue = Queue
OutQueue = Queue
__all__: List[str] = []
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
......@@ -49,8 +58,10 @@ class SendOperator(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx, src_rank, dst_rank, transport: Transport, input: List[Tensor], index: int) -> Tensors:
assert src_rank == torch.distributed.get_rank()
def forward(ctx, transport: Transport, input: List[Tensor], index: int) -> Tensors:
ranks = get_pipeline_parallel_ranks()
src_rank = torch.distributed.get_rank()
dst_rank = ranks[ranks.index(src_rank) + 1]
transport.send_message(
PipeMessage(src_rank, dst_rank, queue_name=ACTIVATIONS_GRADS_QUEUE, args=index, tensors=tuple(input)),
......@@ -68,8 +79,7 @@ class RecvOperator(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx, dst_rank: int, tensor: Tensor, transport: Transport, index: int) -> Tensors:
assert dst_rank == torch.distributed.get_rank()
def forward(ctx, tensor: Tensor, transport: Transport, index: int) -> Tensors:
ctx.transport = transport
ctx.index = index
......@@ -86,74 +96,12 @@ class RecvOperator(torch.autograd.Function):
# type: ignore
def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
src_rank = torch.distributed.get_rank()
dst_rank = ranks[ranks.index(src_rank) - 1]
ctx.transport.send_message(
PipeMessage(
this_rank,
ranks[ranks.index(this_rank) - 1],
queue_name=ACTIVATIONS_GRADS_QUEUE,
args=ctx.index,
tensors=tuple(grad),
),
PipeMessage(src_rank, dst_rank, queue_name=ACTIVATIONS_GRADS_QUEUE, args=ctx.index, tensors=tuple(grad),),
)
return (None, None, None, None, None)
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if TYPE_CHECKING:
InQueue = Queue[Optional["Task"]]
OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]]
else:
InQueue = Queue
OutQueue = Queue
def create_task(
checkpoint_stop: int,
i: int,
j: int,
batch: Batch,
partition: nn.Sequential,
skip_trackers: List[SkipTrackerThroughPotals],
) -> Task:
# Determine whether checkpointing or not.
if i < checkpoint_stop:
def function(
input: TensorOrTensors,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
ret = partition(input)
# We do a check here because the backtrace from the checkpoint backward code path
# is very hard to make sense. It would be much easier to check earlier at this point.
assert type(ret) is not list, "Only Tensor or Tuple of Tensor output is supported"
return ret
chk = Checkpointing(function, batch)
task = Task(None, compute=chk.checkpoint, finalize=chk.recompute)
del function, chk # TODO(tom) maybe remove
else:
def compute(
batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> Batch:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition)
task = Task(None, compute=compute, finalize=None)
del compute # TODO(tom) maybe remove
return task
return (None, None, None, None)
class MultiProcessPipeline:
......@@ -191,7 +139,7 @@ class MultiProcessPipeline:
return 0
return self.__checkpoint_stop
def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> None:
def run(self, training: bool, batches: List[Batch]) -> None:
"""Runs pipeline parallelism.
......@@ -204,24 +152,39 @@ class MultiProcessPipeline:
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(m)]
schedule = [(i, self.group.rank()) for i in range(m)]
rank = self.group.rank()
for i, j in schedule:
if self.group.rank() != 0:
for i in range(m):
if rank != 0:
batch = self.get_batch_from_previous_stage(i, skip_trackers, batches)
else:
batch = batches[i]
task = create_task(self.checkpoint_stop, i, j, batch, self.partition, skip_trackers)
with use_skip_tracker(skip_trackers[i]), record_function("chunk%d-part%d" % (i, rank)):
if i < self.checkpoint_stop:
chk = Checkpointing(self.partition, batch)
batch = chk.checkpoint()
else:
batch = batch.call(self.partition)
if not self.final_stage:
self.send_skip_tensors(batch, i, skip_trackers)
SendOperator.apply(self.transport, [*batch], i)
for portal in skip_trackers[i].portals.values():
portal.pipeline = self
batches[i] = self.execute_task(task, i, skip_trackers)
if i < self.checkpoint_stop:
chk.recompute(batch)
batches[i] = batch
def get_batch_from_previous_stage(
self, i: int, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]
) -> Batch:
phony = torch.empty(0, device=self.input_device, requires_grad=True)
result = RecvOperator.apply(torch.distributed.get_rank(), phony, self.transport, i)
result = RecvOperator.apply(phony, self.transport, i)
if len(result) == 1:
batch = Batch(result[0], i)
else:
......@@ -231,9 +194,10 @@ class MultiProcessPipeline:
return batch
def send_skip_tensors(
self, this_rank: int, ranks: List[int], batch: Batch, i: int, skip_trackers: List[SkipTrackerThroughPotals]
) -> None:
def send_skip_tensors(self, batch: Batch, i: int, skip_trackers: List[SkipTrackerThroughPotals]) -> None:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
for next_j, ns, name in self.skip_layout.copy_policy_by_src(self.group.rank()):
life = skip_trackers[i].portals[(ns, name)].tensor_life
loaded = skip_trackers[i].load(batch, ns, name)
......@@ -271,25 +235,6 @@ class MultiProcessPipeline:
except QueueEmpty:
break
def execute_task(self, task: Task, i: int, skip_trackers: List[SkipTrackerThroughPotals]) -> Batch:
batch = task.compute()
rank = self.group.rank()
if not self.final_stage:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
self.send_skip_tensors(this_rank, ranks, batch, i, skip_trackers)
SendOperator.apply(this_rank, ranks[ranks.index(this_rank) + 1], self.transport, [*batch], i)
for portal in skip_trackers[i].portals.values():
portal.pipeline = self
task.finalize(batch)
return batch
def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: TensorOrTensors) -> None:
dest, src = self.skip_layout.by_ns_name.get(ns_name, (-1, -1))
if dest == src:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment