Unverified Commit 2eb1b8ec authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[cleanup] multiprocess_pipe: dead-code removal and simplification (#335)

parent 65ca68a9
...@@ -18,7 +18,7 @@ from fairscale.nn.model_parallel import get_pipeline_parallel_ranks ...@@ -18,7 +18,7 @@ from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from .messages import Transport from .messages import Transport
from .microbatch import Batch from .microbatch import Batch
from .skip.tracker import SkipTrackerThroughPotals from .skip.tracker import SkipTrackerThroughPotals
from .types import EVENT_LOOP_QUEUE, PipelineStyle, PipeMessage, Tensors from .types import EVENT_LOOP_QUEUE, PipeMessage, Tensors
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -190,17 +190,13 @@ class AsyncEventLoop: ...@@ -190,17 +190,13 @@ class AsyncEventLoop:
) -> Batch: ) -> Batch:
"""Actually run the forward pass for a given module, and send the result """Actually run the forward pass for a given module, and send the result
to the next stage in the pipeline if needed.""" to the next stage in the pipeline if needed."""
assert self.group
# We import here to avoid a cyclic dependency.
# TODO(msb) Break the cyclic dependency.
from .multiprocess_pipeline import create_task from .multiprocess_pipeline import create_task
task = create_task( task = create_task(
PipelineStyle.AsyncSchedule, self.checkpoint_stop, batch.index, self.group.rank(), batch, partition.module, skip_trackers,
self.checkpoint_stop,
batch.index,
self.group.rank(),
batch,
partition.module,
skip_trackers,
) )
result = task.compute() result = task.compute()
task.finalize(result) task.finalize(result)
...@@ -316,8 +312,6 @@ class AsyncEventLoop: ...@@ -316,8 +312,6 @@ class AsyncEventLoop:
calculated. This also handles the first/only stage for the special calculated. This also handles the first/only stage for the special
case of a 1-stage pipeline.""" case of a 1-stage pipeline."""
assert self.group
invocations, activations = self.get_invocations_and_activations() invocations, activations = self.get_invocations_and_activations()
expected_invocations = len(invocations) * len(batches) expected_invocations = len(invocations) * len(batches)
actual_invocations = 0 actual_invocations = 0
...@@ -379,7 +373,6 @@ class AsyncEventLoop: ...@@ -379,7 +373,6 @@ class AsyncEventLoop:
def event_loop(self, num_microbatch: int, skip_trackers: List[SkipTrackerThroughPotals]) -> None: def event_loop(self, num_microbatch: int, skip_trackers: List[SkipTrackerThroughPotals]) -> None:
"""The event loop for the "middle", i.e. neither the head nor the tail""" """The event loop for the "middle", i.e. neither the head nor the tail"""
assert self.group
invocations, activations = self.get_invocations_and_activations() invocations, activations = self.get_invocations_and_activations()
......
...@@ -36,6 +36,7 @@ from . import microbatch ...@@ -36,6 +36,7 @@ from . import microbatch
from .async_schedule import Invocation, Location, ModuleWrapper from .async_schedule import Invocation, Location, ModuleWrapper
from .batchnorm import DeferredBatchNorm from .batchnorm import DeferredBatchNorm
from .multiprocess_pipeline import MultiProcessPipeline from .multiprocess_pipeline import MultiProcessPipeline
from .phony import get_phony
from .skip.layout import SkipLayout, inspect_skip_layout from .skip.layout import SkipLayout, inspect_skip_layout
from .skip.skippable import Skippable, verify_skippables from .skip.skippable import Skippable, verify_skippables
from .types import LazyModule, PipelineStyle from .types import LazyModule, PipelineStyle
...@@ -43,9 +44,6 @@ from .types import LazyModule, PipelineStyle ...@@ -43,9 +44,6 @@ from .types import LazyModule, PipelineStyle
__all__ = ["MultiProcessPipe", "LazyModule"] __all__ = ["MultiProcessPipe", "LazyModule"]
Device = Union[torch.device, int, str]
Devices = Union[Iterable[Device], List[Device]]
Tensors = Tuple[Tensor, ...] Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors] TensorOrTensors = Union[Tensor, Tensors]
...@@ -579,10 +577,6 @@ class MultiProcessPipe(Module): ...@@ -579,10 +577,6 @@ class MultiProcessPipe(Module):
""" """
microbatch.check(input) microbatch.check(input)
if not self.group:
# Empty sequential module is not illegal.
return input
if not self.pipeline: if not self.pipeline:
# No pipeline is not illegal, more ranks than partitions # No pipeline is not illegal, more ranks than partitions
return input return input
...@@ -594,19 +588,12 @@ class MultiProcessPipe(Module): ...@@ -594,19 +588,12 @@ class MultiProcessPipe(Module):
with self.lock: with self.lock:
self.pipeline.run(self.training, batches, event) self.pipeline.run(self.training, batches, event)
if not self.final_stage: if self.final_stage:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
return batches # type: ignore
else:
# Merge the micro-batches into one mini-batch. # Merge the micro-batches into one mini-batch.
if self.pipelined_backward: if self.pipelined_backward:
with torch.no_grad(): with torch.no_grad():
output = microbatch.gather(batches) output = microbatch.gather(batches)
from .phony import get_phony
phony = get_phony( phony = get_phony(
torch.device(torch.cuda.current_device() if torch.cuda.is_available() else "cpu"), torch.device(torch.cuda.current_device() if torch.cuda.is_available() else "cpu"),
requires_grad=True, requires_grad=True,
...@@ -614,6 +601,11 @@ class MultiProcessPipe(Module): ...@@ -614,6 +601,11 @@ class MultiProcessPipe(Module):
output = PipelinedBackwardPass.apply(output, batches, phony, True) # self.retain_graph) output = PipelinedBackwardPass.apply(output, batches, phony, True) # self.retain_graph)
else: else:
output = microbatch.gather(batches) output = microbatch.gather(batches)
else:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
output = batches # type: ignore
return output return output
...@@ -622,7 +614,7 @@ class MultiProcessPipe(Module): ...@@ -622,7 +614,7 @@ class MultiProcessPipe(Module):
raise ValueError("back_helper should only be called on non-final stages") raise ValueError("back_helper should only be called on non-final stages")
if self.pipeline: if self.pipeline:
self.pipeline.back_helper(list(reversed(output))) self.pipeline.back_helper(output)
class PipelinedBackwardPass(torch.autograd.Function): class PipelinedBackwardPass(torch.autograd.Function):
......
...@@ -78,7 +78,7 @@ class RecvOperator(torch.autograd.Function): ...@@ -78,7 +78,7 @@ class RecvOperator(torch.autograd.Function):
@staticmethod @staticmethod
# type: ignore # type: ignore
def forward(ctx, dst_rank: int, tensor: Tensor, input_device, transport: Transport, index: int) -> Tensors: def forward(ctx, dst_rank: int, tensor: Tensor, transport: Transport, index: int) -> Tensors:
assert dst_rank == torch.distributed.get_rank() assert dst_rank == torch.distributed.get_rank()
ctx.transport = transport ctx.transport = transport
ctx.index = index ctx.index = index
...@@ -120,7 +120,6 @@ else: ...@@ -120,7 +120,6 @@ else:
def create_task( def create_task(
style: PipelineStyle,
checkpoint_stop: int, checkpoint_stop: int,
i: int, i: int,
j: int, j: int,
...@@ -176,7 +175,7 @@ class MultiProcessPipeline: ...@@ -176,7 +175,7 @@ class MultiProcessPipeline:
skip_layout: SkipLayout, skip_layout: SkipLayout,
checkpoint_stop: int, checkpoint_stop: int,
style: PipelineStyle, style: PipelineStyle,
group: Optional[torch.distributed.ProcessGroup] = None, group: torch.distributed.ProcessGroup,
worker_map: Optional[Dict[int, str]] = None, worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None, input_device: Union[None, int, str, torch.device] = None,
final_stage: bool = False, final_stage: bool = False,
...@@ -193,7 +192,6 @@ class MultiProcessPipeline: ...@@ -193,7 +192,6 @@ class MultiProcessPipeline:
input_device=input_device, input_device=input_device,
) )
self.input_device = input_device self.input_device = input_device
self.all_at_once = False
self.callcount = 0 self.callcount = 0
self.final_stage = final_stage self.final_stage = final_stage
...@@ -219,11 +217,9 @@ class MultiProcessPipeline: ...@@ -219,11 +217,9 @@ class MultiProcessPipeline:
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))] skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))]
if self.style is PipelineStyle.MultiProcess: if self.style is PipelineStyle.MultiProcess:
assert self.group
schedule = [(i, self.group.rank()) for i in range(m)] schedule = [(i, self.group.rank()) for i in range(m)]
self.compute(batches, schedule, skip_trackers) self.compute(batches, schedule, skip_trackers)
elif self.style is PipelineStyle.AsyncSchedule: elif self.style is PipelineStyle.AsyncSchedule:
assert self.group
rank = self.group.rank() rank = self.group.rank()
event_loop = AsyncEventLoop( event_loop = AsyncEventLoop(
self.partitions, self.group, self.transport, self.training, self.checkpoint_stop, self.partitions, self.group, self.transport, self.training, self.checkpoint_stop,
...@@ -248,7 +244,7 @@ class MultiProcessPipeline: ...@@ -248,7 +244,7 @@ class MultiProcessPipeline:
) -> Batch: ) -> Batch:
phony = torch.empty(0, device=self.input_device, requires_grad=True) phony = torch.empty(0, device=self.input_device, requires_grad=True)
result = RecvOperator.apply(torch.distributed.get_rank(), phony, self.input_device, self.transport, i) result = RecvOperator.apply(torch.distributed.get_rank(), phony, self.transport, i)
if len(result) == 1: if len(result) == 1:
batch = Batch(result[0], i) batch = Batch(result[0], i)
else: else:
...@@ -261,7 +257,6 @@ class MultiProcessPipeline: ...@@ -261,7 +257,6 @@ class MultiProcessPipeline:
def send_skip_tensors( def send_skip_tensors(
self, this_rank: int, ranks: List[int], batch: Batch, i: int, skip_trackers: List[SkipTrackerThroughPotals] self, this_rank: int, ranks: List[int], batch: Batch, i: int, skip_trackers: List[SkipTrackerThroughPotals]
) -> None: ) -> None:
assert self.group
for next_j, ns, name in self.skip_layout.copy_policy_by_src(self.group.rank()): for next_j, ns, name in self.skip_layout.copy_policy_by_src(self.group.rank()):
life = skip_trackers[i].portals[(ns, name)].tensor_life life = skip_trackers[i].portals[(ns, name)].tensor_life
loaded = skip_trackers[i].load(batch, ns, name) loaded = skip_trackers[i].load(batch, ns, name)
...@@ -302,7 +297,6 @@ class MultiProcessPipeline: ...@@ -302,7 +297,6 @@ class MultiProcessPipeline:
def execute_task(self, task: Task, i: int, skip_trackers: List[SkipTrackerThroughPotals]) -> Batch: def execute_task(self, task: Task, i: int, skip_trackers: List[SkipTrackerThroughPotals]) -> Batch:
batch = task.compute() batch = task.compute()
assert self.group
rank = self.group.rank() rank = self.group.rank()
if self.style is PipelineStyle.MultiProcess and not self.final_stage: if self.style is PipelineStyle.MultiProcess and not self.final_stage:
...@@ -324,9 +318,7 @@ class MultiProcessPipeline: ...@@ -324,9 +318,7 @@ class MultiProcessPipeline:
) -> None: ) -> None:
"""Runs tasks with synchronization to copy streams.""" """Runs tasks with synchronization to copy streams."""
if self.style is PipelineStyle.MultiProcess: assert self.style is PipelineStyle.MultiProcess
assert self.group
n = self.group.size()
# With checkpointing, the autograd graph looks like this diagram: # With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐ # ┌─────┸──────┐
...@@ -354,19 +346,17 @@ class MultiProcessPipeline: ...@@ -354,19 +346,17 @@ class MultiProcessPipeline:
# │ Copy │ # │ Copy │
# └─────┰──────┘ # └─────┰──────┘
for i, j in schedule: for i, j in schedule:
batch = batches[i] assert len(self.partitions) == 1
partition = self.partitions[0]
if self.style is PipelineStyle.MultiProcess:
assert len(self.partitions) == 1
partition = self.partitions[0]
assert self.group if self.group.rank() != 0:
if self.group.rank() != 0: batch = self.get_batch_from_previous_stage(i, skip_trackers, batches)
batch = self.get_batch_from_previous_stage(i, skip_trackers, batches) else:
batch = batches[i]
task = create_task(self.style, self.checkpoint_stop, i, j, batch, partition.module, skip_trackers) task = create_task(self.checkpoint_stop, i, j, batch, partition.module, skip_trackers)
batches[i] = self.execute_task(task, i, skip_trackers) batches[i] = self.execute_task(task, i, skip_trackers)
def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: TensorOrTensors) -> None: def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: TensorOrTensors) -> None:
dest, src = self.skip_layout.by_ns_name.get(ns_name, (-1, -1)) dest, src = self.skip_layout.by_ns_name.get(ns_name, (-1, -1))
...@@ -398,43 +388,27 @@ class MultiProcessPipeline: ...@@ -398,43 +388,27 @@ class MultiProcessPipeline:
if self.style == PipelineStyle.AsyncSchedule: if self.style == PipelineStyle.AsyncSchedule:
return return
o = list(output)
tensors: Tensors tensors: Tensors
if self.all_at_once: rank = torch.distributed.get_rank()
# FIXME(tom) allow specifying this branch when constructing Pipe(), add a test for batch in reversed(output):
found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, batch.index)
if batch.atomic:
tensors = tuple([batch.tensor])
else:
tensors = batch.tensors
if len(found) != len(tensors):
raise RuntimeError("different number of tensors and gradients")
grads = [] grads = []
for i, batch in enumerate(o): final_tensors = []
rank = torch.distributed.get_rank() for i, tensor in enumerate(tensors):
found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, i) if tensor.requires_grad or getattr(tensor, "grad_fn", None) is not None:
assert len(found) == 1 grads.append(found[i])
grads.append(found[0]) final_tensors.append(tensor)
tensors = tuple(x.tensor_or_tensors for x in o) # type: ignore
try: try:
torch.autograd.backward(tensors, grad_tensors=grads, retain_graph=True) torch.autograd.backward(final_tensors, grad_tensors=grads, retain_graph=True)
except Exception as e: except Exception as e:
raise RuntimeError("Autograd failed") from e raise RuntimeError(f"Autograd failed on {torch.distributed.get_rank()}") from e
else:
rank = torch.distributed.get_rank()
for batch in o:
found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, batch.index)
if batch.atomic:
tensors = tuple([batch.tensor])
else:
tensors = batch.tensors
if len(found) != len(tensors):
raise RuntimeError("different number of tensors and gradients")
grads = []
final_tensors = []
for i, tensor in enumerate(tensors):
if tensor.requires_grad or getattr(tensor, "grad_fn", None) is not None:
grads.append(found[i])
final_tensors.append(tensor)
try:
torch.autograd.backward(final_tensors, grad_tensors=grads, retain_graph=True)
except Exception as e:
raise RuntimeError(f"Autograd failed on {torch.distributed.get_rank()}") from e
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment