Unverified Commit b1b9e0f8 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] remove multiprocess dependency on async (#373)

parent 08c10993
...@@ -192,6 +192,8 @@ class AsyncPipe(Module): ...@@ -192,6 +192,8 @@ class AsyncPipe(Module):
warnings.warn("More ranks than partitions, some ranks unused") warnings.warn("More ranks than partitions, some ranks unused")
self.partitions: List[ModuleWrapper] = [] self.partitions: List[ModuleWrapper] = []
self.pipeline = None self.pipeline = None
# TODO(msb) remove this hack
self.partition = None
else: else:
self.partitions = self.instantiate_partition(module, self.balance, self.group) self.partitions = self.instantiate_partition(module, self.balance, self.group)
if deferred_batch_norm: if deferred_batch_norm:
...@@ -200,6 +202,8 @@ class AsyncPipe(Module): ...@@ -200,6 +202,8 @@ class AsyncPipe(Module):
for name, part in enumerate(self.partitions): for name, part in enumerate(self.partitions):
self.add_module(str(name), part.module) self.add_module(str(name), part.module)
self.create_pipeline() self.create_pipeline()
# TODO(msb) remove this hack
self.partition = self.partitions[0].module
del module del module
......
...@@ -17,6 +17,7 @@ from fairscale.nn.model_parallel import get_pipeline_parallel_ranks ...@@ -17,6 +17,7 @@ from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from .messages import Transport from .messages import Transport
from .microbatch import Batch from .microbatch import Batch
from .multiprocess_pipeline import create_task
from .skip.tracker import SkipTrackerThroughPotals from .skip.tracker import SkipTrackerThroughPotals
from .types import EVENT_LOOP_QUEUE, PipeMessage, Tensors from .types import EVENT_LOOP_QUEUE, PipeMessage, Tensors
...@@ -191,10 +192,6 @@ class AsyncEventLoop: ...@@ -191,10 +192,6 @@ class AsyncEventLoop:
"""Actually run the forward pass for a given module, and send the result """Actually run the forward pass for a given module, and send the result
to the next stage in the pipeline if needed.""" to the next stage in the pipeline if needed."""
# We import here to avoid a cyclic dependency.
# TODO(msb) Break the cyclic dependency.
from .multiprocess_pipeline import create_task
task = create_task( task = create_task(
self.checkpoint_stop, batch.index, self.group.rank(), batch, partition.module, skip_trackers, self.checkpoint_stop, batch.index, self.group.rank(), batch, partition.module, skip_trackers,
) )
......
...@@ -31,7 +31,6 @@ import torch.cuda ...@@ -31,7 +31,6 @@ import torch.cuda
from fairscale.nn.model_parallel import get_model_parallel_world_size, get_pipeline_parallel_group from fairscale.nn.model_parallel import get_model_parallel_world_size, get_pipeline_parallel_group
from . import microbatch from . import microbatch
from .async_schedule import Location, ModuleWrapper
from .batchnorm import DeferredBatchNorm from .batchnorm import DeferredBatchNorm
from .multiprocess_pipeline import MultiProcessPipeline from .multiprocess_pipeline import MultiProcessPipeline
from .phony import get_phony from .phony import get_phony
...@@ -219,9 +218,6 @@ class MultiProcessPipe(Module): ...@@ -219,9 +218,6 @@ class MultiProcessPipe(Module):
self.add_module(str(0), self.partition) self.add_module(str(0), self.partition)
self.create_pipeline() self.create_pipeline()
# TODO(msb) Remove this hack at some point.
self.partitions = [ModuleWrapper(self.partition, Location(self.group.rank(), 0))]
del module del module
def create_pipeline(self) -> None: def create_pipeline(self) -> None:
...@@ -229,7 +225,7 @@ class MultiProcessPipe(Module): ...@@ -229,7 +225,7 @@ class MultiProcessPipe(Module):
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
self.pipeline = MultiProcessPipeline( self.pipeline = MultiProcessPipeline(
[ModuleWrapper(self.partition, Location(self.group.rank(), 0))], self.partition,
self._skip_layout, self._skip_layout,
checkpoint_stop, checkpoint_stop,
group=self.group, group=self.group,
......
...@@ -30,7 +30,6 @@ from torch.autograd.profiler import record_function ...@@ -30,7 +30,6 @@ from torch.autograd.profiler import record_function
from fairscale.nn.model_parallel import get_pipeline_parallel_ranks from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from .async_schedule import ModuleWrapper
from .checkpoint import Checkpointing from .checkpoint import Checkpointing
from .messages import MakeTransport, Transport from .messages import MakeTransport, Transport
from .microbatch import Batch from .microbatch import Batch
...@@ -162,7 +161,7 @@ class MultiProcessPipeline: ...@@ -162,7 +161,7 @@ class MultiProcessPipeline:
def __init__( def __init__(
self, self,
partitions: List[ModuleWrapper], partition: nn.Sequential,
skip_layout: SkipLayout, skip_layout: SkipLayout,
checkpoint_stop: int, checkpoint_stop: int,
group: torch.distributed.ProcessGroup, group: torch.distributed.ProcessGroup,
...@@ -171,7 +170,7 @@ class MultiProcessPipeline: ...@@ -171,7 +170,7 @@ class MultiProcessPipeline:
input_device: Union[None, int, str, torch.device] = None, input_device: Union[None, int, str, torch.device] = None,
final_stage: bool = False, final_stage: bool = False,
) -> None: ) -> None:
self.partitions = partitions self.partition = partition
self.skip_layout = skip_layout self.skip_layout = skip_layout
self.__checkpoint_stop = checkpoint_stop self.__checkpoint_stop = checkpoint_stop
self.group = group self.group = group
...@@ -187,7 +186,7 @@ class MultiProcessPipeline: ...@@ -187,7 +186,7 @@ class MultiProcessPipeline:
@property @property
def checkpoint_stop(self) -> int: def checkpoint_stop(self) -> int:
# Disable checkpointing if in eval mode. # Disable checkpointing if in eval mode.
training = self.partitions[0].module.training training = self.partition.training
if not training: if not training:
return 0 return 0
return self.__checkpoint_stop return self.__checkpoint_stop
...@@ -208,15 +207,12 @@ class MultiProcessPipeline: ...@@ -208,15 +207,12 @@ class MultiProcessPipeline:
schedule = [(i, self.group.rank()) for i in range(m)] schedule = [(i, self.group.rank()) for i in range(m)]
for i, j in schedule: for i, j in schedule:
assert len(self.partitions) == 1
partition = self.partitions[0]
if self.group.rank() != 0: if self.group.rank() != 0:
batch = self.get_batch_from_previous_stage(i, skip_trackers, batches) batch = self.get_batch_from_previous_stage(i, skip_trackers, batches)
else: else:
batch = batches[i] batch = batches[i]
task = create_task(self.checkpoint_stop, i, j, batch, partition.module, skip_trackers) task = create_task(self.checkpoint_stop, i, j, batch, self.partition, skip_trackers)
batches[i] = self.execute_task(task, i, skip_trackers) batches[i] = self.execute_task(task, i, skip_trackers)
......
...@@ -366,8 +366,8 @@ def no_grad(pipe_class): ...@@ -366,8 +366,8 @@ def no_grad(pipe_class):
nonlocal latent nonlocal latent
latent = output latent = output
partition = model.partitions[0] partition = model.partition
partition.module.register_forward_hook(hook) partition.register_forward_hook(hook)
with torch.no_grad(): with torch.no_grad():
model(input) model(input)
...@@ -616,9 +616,7 @@ def partitions(pipe_class): ...@@ -616,9 +616,7 @@ def partitions(pipe_class):
model = nn.Sequential(a, b) model = nn.Sequential(a, b)
model = pipe_class(model, [1, 1], worker_map=get_worker_map()) model = pipe_class(model, [1, 1], worker_map=get_worker_map())
assert isinstance(model.partitions, list) assert isinstance(model.partition, nn.Sequential)
assert len(model) == 1
assert isinstance(model.partitions[0].module, nn.Sequential)
if model.group.rank() == 0: if model.group.rank() == 0:
assert model[0].weight == a.weight assert model[0].weight == a.weight
......
...@@ -60,13 +60,13 @@ def simple_linears(pipe_class): ...@@ -60,13 +60,13 @@ def simple_linears(pipe_class):
if model.group.rank() == 1: if model.group.rank() == 1:
loss = outputs.mean() loss = outputs.mean()
loss.backward() loss.backward()
grad_with_pipe = sum_grad(model.pipeline.partitions[0].module.parameters()) grad_with_pipe = sum_grad(model.partition.parameters())
# Both grads should be identical. # Both grads should be identical.
assert torch.allclose(grad_with_pipe, grad_without_pipe[1]) assert torch.allclose(grad_with_pipe, grad_without_pipe[1])
else: else:
model.back_helper(outputs) model.back_helper(outputs)
grad_with_pipe = sum_grad(model.pipeline.partitions[0].module.parameters()) grad_with_pipe = sum_grad(model.partition.parameters())
# Both grads should be identical. # Both grads should be identical.
assert torch.allclose(grad_with_pipe, grad_without_pipe[0]) assert torch.allclose(grad_with_pipe, grad_without_pipe[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment