Unverified Commit b1b9e0f8 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] remove multiprocess dependency on async (#373)

parent 08c10993
......@@ -192,6 +192,8 @@ class AsyncPipe(Module):
warnings.warn("More ranks than partitions, some ranks unused")
self.partitions: List[ModuleWrapper] = []
self.pipeline = None
# TODO(msb) remove this hack
self.partition = None
else:
self.partitions = self.instantiate_partition(module, self.balance, self.group)
if deferred_batch_norm:
......@@ -200,6 +202,8 @@ class AsyncPipe(Module):
for name, part in enumerate(self.partitions):
self.add_module(str(name), part.module)
self.create_pipeline()
# TODO(msb) remove this hack
self.partition = self.partitions[0].module
del module
......
......@@ -17,6 +17,7 @@ from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from .messages import Transport
from .microbatch import Batch
from .multiprocess_pipeline import create_task
from .skip.tracker import SkipTrackerThroughPotals
from .types import EVENT_LOOP_QUEUE, PipeMessage, Tensors
......@@ -191,10 +192,6 @@ class AsyncEventLoop:
"""Actually run the forward pass for a given module, and send the result
to the next stage in the pipeline if needed."""
# We import here to avoid a cyclic dependency.
# TODO(msb) Break the cyclic dependency.
from .multiprocess_pipeline import create_task
task = create_task(
self.checkpoint_stop, batch.index, self.group.rank(), batch, partition.module, skip_trackers,
)
......
......@@ -31,7 +31,6 @@ import torch.cuda
from fairscale.nn.model_parallel import get_model_parallel_world_size, get_pipeline_parallel_group
from . import microbatch
from .async_schedule import Location, ModuleWrapper
from .batchnorm import DeferredBatchNorm
from .multiprocess_pipeline import MultiProcessPipeline
from .phony import get_phony
......@@ -219,9 +218,6 @@ class MultiProcessPipe(Module):
self.add_module(str(0), self.partition)
self.create_pipeline()
# TODO(msb) Remove this hack at some point.
self.partitions = [ModuleWrapper(self.partition, Location(self.group.rank(), 0))]
del module
def create_pipeline(self) -> None:
......@@ -229,7 +225,7 @@ class MultiProcessPipe(Module):
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
self.pipeline = MultiProcessPipeline(
[ModuleWrapper(self.partition, Location(self.group.rank(), 0))],
self.partition,
self._skip_layout,
checkpoint_stop,
group=self.group,
......
......@@ -30,7 +30,6 @@ from torch.autograd.profiler import record_function
from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from .async_schedule import ModuleWrapper
from .checkpoint import Checkpointing
from .messages import MakeTransport, Transport
from .microbatch import Batch
......@@ -162,7 +161,7 @@ class MultiProcessPipeline:
def __init__(
self,
partitions: List[ModuleWrapper],
partition: nn.Sequential,
skip_layout: SkipLayout,
checkpoint_stop: int,
group: torch.distributed.ProcessGroup,
......@@ -171,7 +170,7 @@ class MultiProcessPipeline:
input_device: Union[None, int, str, torch.device] = None,
final_stage: bool = False,
) -> None:
self.partitions = partitions
self.partition = partition
self.skip_layout = skip_layout
self.__checkpoint_stop = checkpoint_stop
self.group = group
......@@ -187,7 +186,7 @@ class MultiProcessPipeline:
@property
def checkpoint_stop(self) -> int:
# Disable checkpointing if in eval mode.
training = self.partitions[0].module.training
training = self.partition.training
if not training:
return 0
return self.__checkpoint_stop
......@@ -208,15 +207,12 @@ class MultiProcessPipeline:
schedule = [(i, self.group.rank()) for i in range(m)]
for i, j in schedule:
assert len(self.partitions) == 1
partition = self.partitions[0]
if self.group.rank() != 0:
batch = self.get_batch_from_previous_stage(i, skip_trackers, batches)
else:
batch = batches[i]
task = create_task(self.checkpoint_stop, i, j, batch, partition.module, skip_trackers)
task = create_task(self.checkpoint_stop, i, j, batch, self.partition, skip_trackers)
batches[i] = self.execute_task(task, i, skip_trackers)
......
......@@ -366,8 +366,8 @@ def no_grad(pipe_class):
nonlocal latent
latent = output
partition = model.partitions[0]
partition.module.register_forward_hook(hook)
partition = model.partition
partition.register_forward_hook(hook)
with torch.no_grad():
model(input)
......@@ -616,9 +616,7 @@ def partitions(pipe_class):
model = nn.Sequential(a, b)
model = pipe_class(model, [1, 1], worker_map=get_worker_map())
assert isinstance(model.partitions, list)
assert len(model) == 1
assert isinstance(model.partitions[0].module, nn.Sequential)
assert isinstance(model.partition, nn.Sequential)
if model.group.rank() == 0:
assert model[0].weight == a.weight
......
......@@ -60,13 +60,13 @@ def simple_linears(pipe_class):
if model.group.rank() == 1:
loss = outputs.mean()
loss.backward()
grad_with_pipe = sum_grad(model.pipeline.partitions[0].module.parameters())
grad_with_pipe = sum_grad(model.partition.parameters())
# Both grads should be identical.
assert torch.allclose(grad_with_pipe, grad_without_pipe[1])
else:
model.back_helper(outputs)
grad_with_pipe = sum_grad(model.pipeline.partitions[0].module.parameters())
grad_with_pipe = sum_grad(model.partition.parameters())
# Both grads should be identical.
assert torch.allclose(grad_with_pipe, grad_without_pipe[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment