Unverified Commit 8f77255b authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] multiprocess_pipe: avoid unnecessary use of create_task and other cleanup (#456)

parent d2924670
...@@ -11,15 +11,64 @@ from typing import Dict, Iterable, List, Optional, Tuple ...@@ -11,15 +11,64 @@ from typing import Dict, Iterable, List, Optional, Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from torch.autograd.profiler import record_function
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from fairscale.nn.model_parallel import get_pipeline_parallel_ranks from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from .checkpoint import Checkpointing
from .messages import Transport from .messages import Transport
from .microbatch import Batch from .microbatch import Batch
from .multiprocess_pipeline import create_task from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from .skip.tracker import SkipTrackerThroughPotals from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors, Tensors
from .types import EVENT_LOOP_QUEUE, PipeMessage, Tensors from .worker import Task
def create_task(
checkpoint_stop: int,
chunk_id: int,
part_id: int,
batch: Batch,
partition: nn.Sequential,
skip_trackers: List[SkipTrackerThroughPotals],
) -> Task:
# Determine whether checkpointing or not.
if chunk_id < checkpoint_stop:
def function(
input: TensorOrTensors,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[chunk_id],
chunk_id: int = chunk_id,
part_id: int = part_id,
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
ret = partition(input)
# We do a check here because the backtrace from the checkpoint backward code path
# is very hard to make sense. It would be much easier to check earlier at this point.
assert type(ret) is not list, "Only Tensor or Tuple of Tensor output is supported"
return ret
chk = Checkpointing(function, batch)
task = Task(None, compute=chk.checkpoint, finalize=chk.recompute)
del function, chk # TODO(tom) maybe remove
else:
def compute(
batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[chunk_id],
chunk_id: int = chunk_id,
part_id: int = part_id,
) -> Batch:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition)
task = Task(None, compute=compute, finalize=None)
del compute # TODO(tom) maybe remove
return task
@dataclass(frozen=True) @dataclass(frozen=True)
......
...@@ -256,7 +256,7 @@ class MultiProcessPipe(Module): ...@@ -256,7 +256,7 @@ class MultiProcessPipe(Module):
"""Iterates over children of the underlying sequential module.""" """Iterates over children of the underlying sequential module."""
return self.partition.__iter__() return self.partition.__iter__()
def forward(self, input: TensorOrTensors, *, event=None) -> TensorOrTensors: # type: ignore def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore
""":class:`MultiProcessPipe` is a fairly transparent module wrapper. It doesn't """:class:`MultiProcessPipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But modify the input and output signature of the underlying module. But
there's type restriction. Input and output have to be a there's type restriction. Input and output have to be a
...@@ -284,7 +284,7 @@ class MultiProcessPipe(Module): ...@@ -284,7 +284,7 @@ class MultiProcessPipe(Module):
# Run pipeline parallelism. # Run pipeline parallelism.
with self.lock: with self.lock:
self.pipeline.run(self.training, batches, event) self.pipeline.run(self.training, batches)
if self.final_stage: if self.final_stage:
# Merge the micro-batches into one mini-batch. # Merge the micro-batches into one mini-batch.
......
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
import os import os
from queue import Empty as QueueEmpty from queue import Empty as QueueEmpty
from queue import Queue from queue import Queue
from threading import Event
from types import TracebackType from types import TracebackType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
...@@ -39,6 +38,16 @@ from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker ...@@ -39,6 +38,16 @@ from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from .types import ACTIVATIONS_GRADS_QUEUE, PORTAL_QUEUE, SKIP_TENSOR_QUEUE, PipeMessage, TensorOrTensors, Tensors from .types import ACTIVATIONS_GRADS_QUEUE, PORTAL_QUEUE, SKIP_TENSOR_QUEUE, PipeMessage, TensorOrTensors, Tensors
from .worker import Task from .worker import Task
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if TYPE_CHECKING:
InQueue = Queue[Optional[Task]]
OutQueue = Queue[Tuple[bool, Union[Tuple[Task, Batch], ExcInfo, None]]]
else:
InQueue = Queue
OutQueue = Queue
__all__: List[str] = [] __all__: List[str] = []
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
...@@ -49,8 +58,10 @@ class SendOperator(torch.autograd.Function): ...@@ -49,8 +58,10 @@ class SendOperator(torch.autograd.Function):
@staticmethod @staticmethod
# type: ignore # type: ignore
def forward(ctx, src_rank, dst_rank, transport: Transport, input: List[Tensor], index: int) -> Tensors: def forward(ctx, transport: Transport, input: List[Tensor], index: int) -> Tensors:
assert src_rank == torch.distributed.get_rank() ranks = get_pipeline_parallel_ranks()
src_rank = torch.distributed.get_rank()
dst_rank = ranks[ranks.index(src_rank) + 1]
transport.send_message( transport.send_message(
PipeMessage(src_rank, dst_rank, queue_name=ACTIVATIONS_GRADS_QUEUE, args=index, tensors=tuple(input)), PipeMessage(src_rank, dst_rank, queue_name=ACTIVATIONS_GRADS_QUEUE, args=index, tensors=tuple(input)),
...@@ -68,8 +79,7 @@ class RecvOperator(torch.autograd.Function): ...@@ -68,8 +79,7 @@ class RecvOperator(torch.autograd.Function):
@staticmethod @staticmethod
# type: ignore # type: ignore
def forward(ctx, dst_rank: int, tensor: Tensor, transport: Transport, index: int) -> Tensors: def forward(ctx, tensor: Tensor, transport: Transport, index: int) -> Tensors:
assert dst_rank == torch.distributed.get_rank()
ctx.transport = transport ctx.transport = transport
ctx.index = index ctx.index = index
...@@ -86,74 +96,12 @@ class RecvOperator(torch.autograd.Function): ...@@ -86,74 +96,12 @@ class RecvOperator(torch.autograd.Function):
# type: ignore # type: ignore
def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]: def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]:
ranks = get_pipeline_parallel_ranks() ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank() src_rank = torch.distributed.get_rank()
dst_rank = ranks[ranks.index(src_rank) - 1]
ctx.transport.send_message( ctx.transport.send_message(
PipeMessage( PipeMessage(src_rank, dst_rank, queue_name=ACTIVATIONS_GRADS_QUEUE, args=ctx.index, tensors=tuple(grad),),
this_rank,
ranks[ranks.index(this_rank) - 1],
queue_name=ACTIVATIONS_GRADS_QUEUE,
args=ctx.index,
tensors=tuple(grad),
),
) )
return (None, None, None, None, None) return (None, None, None, None)
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if TYPE_CHECKING:
InQueue = Queue[Optional["Task"]]
OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]]
else:
InQueue = Queue
OutQueue = Queue
def create_task(
checkpoint_stop: int,
i: int,
j: int,
batch: Batch,
partition: nn.Sequential,
skip_trackers: List[SkipTrackerThroughPotals],
) -> Task:
# Determine whether checkpointing or not.
if i < checkpoint_stop:
def function(
input: TensorOrTensors,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
ret = partition(input)
# We do a check here because the backtrace from the checkpoint backward code path
# is very hard to make sense. It would be much easier to check earlier at this point.
assert type(ret) is not list, "Only Tensor or Tuple of Tensor output is supported"
return ret
chk = Checkpointing(function, batch)
task = Task(None, compute=chk.checkpoint, finalize=chk.recompute)
del function, chk # TODO(tom) maybe remove
else:
def compute(
batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> Batch:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition)
task = Task(None, compute=compute, finalize=None)
del compute # TODO(tom) maybe remove
return task
class MultiProcessPipeline: class MultiProcessPipeline:
...@@ -191,7 +139,7 @@ class MultiProcessPipeline: ...@@ -191,7 +139,7 @@ class MultiProcessPipeline:
return 0 return 0
return self.__checkpoint_stop return self.__checkpoint_stop
def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> None: def run(self, training: bool, batches: List[Batch]) -> None:
"""Runs pipeline parallelism. """Runs pipeline parallelism.
...@@ -204,24 +152,39 @@ class MultiProcessPipeline: ...@@ -204,24 +152,39 @@ class MultiProcessPipeline:
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(m)] skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(m)]
schedule = [(i, self.group.rank()) for i in range(m)] rank = self.group.rank()
for i, j in schedule: for i in range(m):
if self.group.rank() != 0: if rank != 0:
batch = self.get_batch_from_previous_stage(i, skip_trackers, batches) batch = self.get_batch_from_previous_stage(i, skip_trackers, batches)
else: else:
batch = batches[i] batch = batches[i]
task = create_task(self.checkpoint_stop, i, j, batch, self.partition, skip_trackers) with use_skip_tracker(skip_trackers[i]), record_function("chunk%d-part%d" % (i, rank)):
if i < self.checkpoint_stop:
chk = Checkpointing(self.partition, batch)
batch = chk.checkpoint()
else:
batch = batch.call(self.partition)
if not self.final_stage:
self.send_skip_tensors(batch, i, skip_trackers)
SendOperator.apply(self.transport, [*batch], i)
for portal in skip_trackers[i].portals.values():
portal.pipeline = self
batches[i] = self.execute_task(task, i, skip_trackers) if i < self.checkpoint_stop:
chk.recompute(batch)
batches[i] = batch
def get_batch_from_previous_stage( def get_batch_from_previous_stage(
self, i: int, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch] self, i: int, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]
) -> Batch: ) -> Batch:
phony = torch.empty(0, device=self.input_device, requires_grad=True) phony = torch.empty(0, device=self.input_device, requires_grad=True)
result = RecvOperator.apply(torch.distributed.get_rank(), phony, self.transport, i) result = RecvOperator.apply(phony, self.transport, i)
if len(result) == 1: if len(result) == 1:
batch = Batch(result[0], i) batch = Batch(result[0], i)
else: else:
...@@ -231,9 +194,10 @@ class MultiProcessPipeline: ...@@ -231,9 +194,10 @@ class MultiProcessPipeline:
return batch return batch
def send_skip_tensors( def send_skip_tensors(self, batch: Batch, i: int, skip_trackers: List[SkipTrackerThroughPotals]) -> None:
self, this_rank: int, ranks: List[int], batch: Batch, i: int, skip_trackers: List[SkipTrackerThroughPotals] ranks = get_pipeline_parallel_ranks()
) -> None: this_rank = torch.distributed.get_rank()
for next_j, ns, name in self.skip_layout.copy_policy_by_src(self.group.rank()): for next_j, ns, name in self.skip_layout.copy_policy_by_src(self.group.rank()):
life = skip_trackers[i].portals[(ns, name)].tensor_life life = skip_trackers[i].portals[(ns, name)].tensor_life
loaded = skip_trackers[i].load(batch, ns, name) loaded = skip_trackers[i].load(batch, ns, name)
...@@ -271,25 +235,6 @@ class MultiProcessPipeline: ...@@ -271,25 +235,6 @@ class MultiProcessPipeline:
except QueueEmpty: except QueueEmpty:
break break
def execute_task(self, task: Task, i: int, skip_trackers: List[SkipTrackerThroughPotals]) -> Batch:
batch = task.compute()
rank = self.group.rank()
if not self.final_stage:
ranks = get_pipeline_parallel_ranks()
this_rank = torch.distributed.get_rank()
self.send_skip_tensors(this_rank, ranks, batch, i, skip_trackers)
SendOperator.apply(this_rank, ranks[ranks.index(this_rank) + 1], self.transport, [*batch], i)
for portal in skip_trackers[i].portals.values():
portal.pipeline = self
task.finalize(batch)
return batch
def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: TensorOrTensors) -> None: def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: TensorOrTensors) -> None:
dest, src = self.skip_layout.by_ns_name.get(ns_name, (-1, -1)) dest, src = self.skip_layout.by_ns_name.get(ns_name, (-1, -1))
if dest == src: if dest == src:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment