Unverified Commit 2eb1b8ec authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[cleanup] multiprocess_pipe: dead-code removal and simplification (#335)

parent 65ca68a9
......@@ -18,7 +18,7 @@ from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from .messages import Transport
from .microbatch import Batch
from .skip.tracker import SkipTrackerThroughPotals
from .types import EVENT_LOOP_QUEUE, PipelineStyle, PipeMessage, Tensors
from .types import EVENT_LOOP_QUEUE, PipeMessage, Tensors
@dataclass(frozen=True)
......@@ -190,17 +190,13 @@ class AsyncEventLoop:
) -> Batch:
"""Actually run the forward pass for a given module, and send the result
to the next stage in the pipeline if needed."""
assert self.group
# We import here to avoid a cyclic dependency.
# TODO(msb) Break the cyclic dependency.
from .multiprocess_pipeline import create_task
task = create_task(
PipelineStyle.AsyncSchedule,
self.checkpoint_stop,
batch.index,
self.group.rank(),
batch,
partition.module,
skip_trackers,
self.checkpoint_stop, batch.index, self.group.rank(), batch, partition.module, skip_trackers,
)
result = task.compute()
task.finalize(result)
......@@ -316,8 +312,6 @@ class AsyncEventLoop:
calculated. This also handles the first/only stage for the special
case of a 1-stage pipeline."""
assert self.group
invocations, activations = self.get_invocations_and_activations()
expected_invocations = len(invocations) * len(batches)
actual_invocations = 0
......@@ -379,7 +373,6 @@ class AsyncEventLoop:
def event_loop(self, num_microbatch: int, skip_trackers: List[SkipTrackerThroughPotals]) -> None:
"""The event loop for the "middle", i.e. neither the head nor the tail"""
assert self.group
invocations, activations = self.get_invocations_and_activations()
......
......@@ -36,6 +36,7 @@ from . import microbatch
from .async_schedule import Invocation, Location, ModuleWrapper
from .batchnorm import DeferredBatchNorm
from .multiprocess_pipeline import MultiProcessPipeline
from .phony import get_phony
from .skip.layout import SkipLayout, inspect_skip_layout
from .skip.skippable import Skippable, verify_skippables
from .types import LazyModule, PipelineStyle
......@@ -43,9 +44,6 @@ from .types import LazyModule, PipelineStyle
__all__ = ["MultiProcessPipe", "LazyModule"]
Device = Union[torch.device, int, str]
Devices = Union[Iterable[Device], List[Device]]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
......@@ -579,10 +577,6 @@ class MultiProcessPipe(Module):
"""
microbatch.check(input)
if not self.group:
# Empty sequential module is not illegal.
return input
if not self.pipeline:
# No pipeline is not illegal, more ranks than partitions
return input
......@@ -594,19 +588,12 @@ class MultiProcessPipe(Module):
with self.lock:
self.pipeline.run(self.training, batches, event)
if not self.final_stage:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
return batches # type: ignore
else:
if self.final_stage:
# Merge the micro-batches into one mini-batch.
if self.pipelined_backward:
with torch.no_grad():
output = microbatch.gather(batches)
from .phony import get_phony
phony = get_phony(
torch.device(torch.cuda.current_device() if torch.cuda.is_available() else "cpu"),
requires_grad=True,
......@@ -614,6 +601,11 @@ class MultiProcessPipe(Module):
output = PipelinedBackwardPass.apply(output, batches, phony, True) # self.retain_graph)
else:
output = microbatch.gather(batches)
else:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
output = batches # type: ignore
return output
......@@ -622,7 +614,7 @@ class MultiProcessPipe(Module):
raise ValueError("back_helper should only be called on non-final stages")
if self.pipeline:
self.pipeline.back_helper(list(reversed(output)))
self.pipeline.back_helper(output)
class PipelinedBackwardPass(torch.autograd.Function):
......
......@@ -78,7 +78,7 @@ class RecvOperator(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx, dst_rank: int, tensor: Tensor, input_device, transport: Transport, index: int) -> Tensors:
def forward(ctx, dst_rank: int, tensor: Tensor, transport: Transport, index: int) -> Tensors:
assert dst_rank == torch.distributed.get_rank()
ctx.transport = transport
ctx.index = index
......@@ -120,7 +120,6 @@ else:
def create_task(
style: PipelineStyle,
checkpoint_stop: int,
i: int,
j: int,
......@@ -176,7 +175,7 @@ class MultiProcessPipeline:
skip_layout: SkipLayout,
checkpoint_stop: int,
style: PipelineStyle,
group: Optional[torch.distributed.ProcessGroup] = None,
group: torch.distributed.ProcessGroup,
worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None,
final_stage: bool = False,
......@@ -193,7 +192,6 @@ class MultiProcessPipeline:
input_device=input_device,
)
self.input_device = input_device
self.all_at_once = False
self.callcount = 0
self.final_stage = final_stage
......@@ -219,11 +217,9 @@ class MultiProcessPipeline:
skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))]
if self.style is PipelineStyle.MultiProcess:
assert self.group
schedule = [(i, self.group.rank()) for i in range(m)]
self.compute(batches, schedule, skip_trackers)
elif self.style is PipelineStyle.AsyncSchedule:
assert self.group
rank = self.group.rank()
event_loop = AsyncEventLoop(
self.partitions, self.group, self.transport, self.training, self.checkpoint_stop,
......@@ -248,7 +244,7 @@ class MultiProcessPipeline:
) -> Batch:
phony = torch.empty(0, device=self.input_device, requires_grad=True)
result = RecvOperator.apply(torch.distributed.get_rank(), phony, self.input_device, self.transport, i)
result = RecvOperator.apply(torch.distributed.get_rank(), phony, self.transport, i)
if len(result) == 1:
batch = Batch(result[0], i)
else:
......@@ -261,7 +257,6 @@ class MultiProcessPipeline:
def send_skip_tensors(
self, this_rank: int, ranks: List[int], batch: Batch, i: int, skip_trackers: List[SkipTrackerThroughPotals]
) -> None:
assert self.group
for next_j, ns, name in self.skip_layout.copy_policy_by_src(self.group.rank()):
life = skip_trackers[i].portals[(ns, name)].tensor_life
loaded = skip_trackers[i].load(batch, ns, name)
......@@ -302,7 +297,6 @@ class MultiProcessPipeline:
def execute_task(self, task: Task, i: int, skip_trackers: List[SkipTrackerThroughPotals]) -> Batch:
batch = task.compute()
assert self.group
rank = self.group.rank()
if self.style is PipelineStyle.MultiProcess and not self.final_stage:
......@@ -324,9 +318,7 @@ class MultiProcessPipeline:
) -> None:
"""Runs tasks with synchronization to copy streams."""
if self.style is PipelineStyle.MultiProcess:
assert self.group
n = self.group.size()
assert self.style is PipelineStyle.MultiProcess
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
......@@ -354,17 +346,15 @@ class MultiProcessPipeline:
# │ Copy │
# └─────┰──────┘
for i, j in schedule:
batch = batches[i]
if self.style is PipelineStyle.MultiProcess:
assert len(self.partitions) == 1
partition = self.partitions[0]
assert self.group
if self.group.rank() != 0:
batch = self.get_batch_from_previous_stage(i, skip_trackers, batches)
else:
batch = batches[i]
task = create_task(self.style, self.checkpoint_stop, i, j, batch, partition.module, skip_trackers)
task = create_task(self.checkpoint_stop, i, j, batch, partition.module, skip_trackers)
batches[i] = self.execute_task(task, i, skip_trackers)
......@@ -398,26 +388,10 @@ class MultiProcessPipeline:
if self.style == PipelineStyle.AsyncSchedule:
return
o = list(output)
tensors: Tensors
if self.all_at_once:
# FIXME(tom) allow specifying this branch when constructing Pipe(), add a test
grads = []
for i, batch in enumerate(o):
rank = torch.distributed.get_rank()
found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, i)
assert len(found) == 1
grads.append(found[0])
tensors = tuple(x.tensor_or_tensors for x in o) # type: ignore
try:
torch.autograd.backward(tensors, grad_tensors=grads, retain_graph=True)
except Exception as e:
raise RuntimeError("Autograd failed") from e
else:
rank = torch.distributed.get_rank()
for batch in o:
for batch in reversed(output):
found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, batch.index)
if batch.atomic:
tensors = tuple([batch.tensor])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment