Unverified Commit a8dd9254 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] pipe: move async-specific code out of MultiProcessPipe (#344)

parent e348806b
...@@ -3,10 +3,153 @@ ...@@ -3,10 +3,153 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .multiprocess_pipe import MultiProcessPipe from collections import OrderedDict
from .types import PipelineStyle from dataclasses import dataclass, field
import itertools
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
import torch
from torch import Tensor, nn
from .async_schedule import Invocation, Location, ModuleWrapper
from .multiprocess_pipe import MultiProcessPipe, check_balance
from .multiprocess_pipeline import MultiProcessPipeline
from .skip.skippable import Skippable
from .types import LazyModule, PipelineStyle
if TYPE_CHECKING:
Module = nn.Module[TensorOrTensors]
NamedModules = OrderedDict[str, Module]
else:
Module = nn.Module
NamedModules = OrderedDict
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
@dataclass
class PartitionInfo:
location: Location
modules: "OrderedDict[str, nn.Module]"
invocations: List[Invocation] = field(default_factory=list)
def __len__(self) -> int:
return len(self.modules)
class AsyncPipe(MultiProcessPipe): class AsyncPipe(MultiProcessPipe):
def __init__(self, *args, **kwargs) -> None: # type: ignore def create_pipeline(self) -> None:
super().__init__(*args, style=PipelineStyle.AsyncSchedule, **kwargs) # The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
self.pipeline = MultiProcessPipeline(
self.partitions,
self._skip_layout,
checkpoint_stop,
style=PipelineStyle.AsyncSchedule,
group=self.group,
worker_map=self.worker_map,
input_device=self.input_device,
final_stage=self.final_stage,
)
def instantiate_partition(
self,
module: Union[nn.Sequential, List[LazyModule]],
balance: Iterable[int],
group: torch.distributed.ProcessGroup,
) -> List[ModuleWrapper]:
balance = list(balance)
check_balance(module, balance, True)
layers: NamedModules = OrderedDict()
def maybe_realize(layer: Any) -> nn.Module:
if isinstance(layer, nn.Module):
return layer
elif callable(layer):
return layer()
else:
raise TypeError(f"layer must be nn.Module or callable, is {type(layer)}")
def iterate_module(module: Union[nn.Sequential, list]) -> Iterable[Tuple[Any, nn.Module]]:
if isinstance(module, nn.Sequential):
yield from module.named_children()
else:
yield from ((str(k), v) for k, v in enumerate(module))
module_ids = list(map(id, module))
index_of_first_use = [module_ids.index(x) for x in module_ids]
locations: List[Location] = []
module_iter = enumerate(iterate_module(module))
partitions: List[List[PartitionInfo]] = []
for bi, b in enumerate(balance):
modules_for_rank: List[PartitionInfo] = []
current_module: OrderedDict[str, nn.Module] = OrderedDict()
def current_location() -> Location:
return Location(bi, len(modules_for_rank))
def append_module(mod: "OrderedDict[str, nn.Module]") -> None:
modules_for_rank.append(PartitionInfo(current_location(), mod))
while sum(map(len, modules_for_rank)) + len(current_module) < b:
module_index, (name, layer) = next(module_iter)
if index_of_first_use[module_index] != module_index:
# Subsequent reuse of a module
locations.append(locations[index_of_first_use[module_index]])
continue
is_reused = index_of_first_use.count(index_of_first_use[module_index]) > 1
if is_reused and len(current_module) > 0:
append_module(current_module)
current_module = OrderedDict()
current_module[str(name)] = layer
locations.append(current_location())
if is_reused:
append_module(current_module)
current_module = OrderedDict()
if len(current_module) > 0:
append_module(current_module)
partitions.append(modules_for_rank)
filtered_locations: List[Optional[Location]] = [loc for loc, _ in itertools.groupby(locations)]
filtered_locations.append(None)
for i in range(len(filtered_locations) - 1):
loc = filtered_locations[i]
assert loc
if i == 0:
inv = Invocation(i, loc, None, filtered_locations[i + 1])
else:
inv = Invocation(i, loc, filtered_locations[i - 1], filtered_locations[i + 1])
partitions[loc.stage][loc.index].invocations.append(inv)
invocations = enumerate(iterate_module(module))
partition = partitions[group.rank()]
result: List[ModuleWrapper] = []
for partition_info in partition:
wrapper = ModuleWrapper(
nn.Sequential(OrderedDict((k, maybe_realize(m)) for k, m in partition_info.modules.items())),
partition_info.location,
partition_info.invocations,
)
if not isinstance(module, nn.Sequential):
for layer in wrapper.module:
if isinstance(layer, Skippable):
raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction")
result.append(wrapper)
return result
...@@ -19,10 +19,8 @@ ...@@ -19,10 +19,8 @@
"""The MultiProcessPipe interface.""" """The MultiProcessPipe interface."""
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass, field
import itertools
import threading import threading
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union, cast from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
import warnings import warnings
import torch import torch
...@@ -33,7 +31,7 @@ import torch.cuda ...@@ -33,7 +31,7 @@ import torch.cuda
from fairscale.nn.model_parallel import get_model_parallel_world_size, get_pipeline_parallel_group from fairscale.nn.model_parallel import get_model_parallel_world_size, get_pipeline_parallel_group
from . import microbatch from . import microbatch
from .async_schedule import Invocation, Location, ModuleWrapper from .async_schedule import Location, ModuleWrapper
from .batchnorm import DeferredBatchNorm from .batchnorm import DeferredBatchNorm
from .multiprocess_pipeline import MultiProcessPipeline from .multiprocess_pipeline import MultiProcessPipeline
from .phony import get_phony from .phony import get_phony
...@@ -47,8 +45,6 @@ __all__ = ["MultiProcessPipe", "LazyModule"] ...@@ -47,8 +45,6 @@ __all__ = ["MultiProcessPipe", "LazyModule"]
Tensors = Tuple[Tensor, ...] Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors] TensorOrTensors = Union[Tensor, Tensors]
ListOfLazyModules = List[LazyModule]
if TYPE_CHECKING: if TYPE_CHECKING:
Module = nn.Module[TensorOrTensors] Module = nn.Module[TensorOrTensors]
NamedModules = OrderedDict[str, Module] NamedModules = OrderedDict[str, Module]
...@@ -87,7 +83,7 @@ def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None: ...@@ -87,7 +83,7 @@ def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None:
raise TypeError(f"layer {type(layer)} must be nn.Module or LazyModule to be partitioned") raise TypeError(f"layer {type(layer)} must be nn.Module or LazyModule to be partitioned")
def verify_module(module: Union[nn.Sequential, ListOfLazyModules]) -> None: def verify_module(module: Union[nn.Sequential, List[LazyModule]]) -> None:
if isinstance(module, Iterable) and not isinstance(module, nn.Sequential): if isinstance(module, Iterable) and not isinstance(module, nn.Sequential):
verify_list_of_callable(module) verify_list_of_callable(module)
else: else:
...@@ -135,145 +131,11 @@ def check_balance(module: Any, balance: Iterable[int], filter_unique: bool = Fal ...@@ -135,145 +131,11 @@ def check_balance(module: Any, balance: Iterable[int], filter_unique: bool = Fal
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})") raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
@dataclass def split_module(module: nn.Sequential, balance: Iterable[int],) -> List[nn.Sequential]:
class PartitionInfo:
location: Location
modules: "OrderedDict[str, nn.Module]"
invocations: List[Invocation] = field(default_factory=list)
def __len__(self) -> int:
return len(self.modules)
def instantiate_partition(
module: Union[nn.Sequential, ListOfLazyModules],
balance: Iterable[int],
group: torch.distributed.ProcessGroup,
style: PipelineStyle,
) -> List[ModuleWrapper]:
balance = list(balance)
check_balance(module, balance, True)
layers: NamedModules = OrderedDict()
def maybe_realize(layer: Any) -> nn.Module:
if isinstance(layer, nn.Module):
return layer
elif callable(layer):
return layer()
else:
raise TypeError(f"layer must be nn.Module or callable, is {type(layer)}")
def iterate_module(module: Union[nn.Sequential, list]) -> Iterable[Tuple[Any, nn.Module]]:
if isinstance(module, nn.Sequential):
yield from module.named_children()
else:
yield from ((str(k), v) for k, v in enumerate(module))
if style == PipelineStyle.AsyncSchedule:
module_ids = list(map(id, module))
index_of_first_use = [module_ids.index(x) for x in module_ids]
locations: List[Location] = []
module_iter = enumerate(iterate_module(module))
partitions: List[List[PartitionInfo]] = []
for bi, b in enumerate(balance):
modules_for_rank: List[PartitionInfo] = []
current_module: OrderedDict[str, nn.Module] = OrderedDict()
def current_location() -> Location:
return Location(bi, len(modules_for_rank))
def append_module(mod: "OrderedDict[str, nn.Module]") -> None:
modules_for_rank.append(PartitionInfo(current_location(), mod))
while sum(map(len, modules_for_rank)) + len(current_module) < b:
module_index, (name, layer) = next(module_iter)
if index_of_first_use[module_index] != module_index:
# Subsequent reuse of a module
locations.append(locations[index_of_first_use[module_index]])
continue
is_reused = index_of_first_use.count(index_of_first_use[module_index]) > 1
if is_reused and len(current_module) > 0:
append_module(current_module)
current_module = OrderedDict()
current_module[str(name)] = layer
locations.append(current_location())
if is_reused:
append_module(current_module)
current_module = OrderedDict()
if len(current_module) > 0:
append_module(current_module)
partitions.append(modules_for_rank)
filtered_locations: List[Optional[Location]] = [loc for loc, _ in itertools.groupby(locations)]
filtered_locations.append(None)
for i in range(len(filtered_locations) - 1):
loc = filtered_locations[i]
assert loc
if i == 0:
inv = Invocation(i, loc, None, filtered_locations[i + 1])
else:
inv = Invocation(i, loc, filtered_locations[i - 1], filtered_locations[i + 1])
partitions[loc.stage][loc.index].invocations.append(inv)
invocations = enumerate(iterate_module(module))
partition = partitions[group.rank()]
result: List[ModuleWrapper] = []
for partition_info in partition:
wrapper = ModuleWrapper(
nn.Sequential(OrderedDict((k, maybe_realize(m)) for k, m in partition_info.modules.items())),
partition_info.location,
partition_info.invocations,
)
if not isinstance(module, nn.Sequential):
for layer in wrapper.module:
if isinstance(layer, Skippable):
raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction")
result.append(wrapper)
return result
j = 0
for name, layer in iterate_module(module):
layers[name] = layer
if len(layers) == balance[j]:
if j == group.rank():
for key in layers:
layers[key] = maybe_realize(layers[key])
if not isinstance(module, nn.Sequential):
for layer in layers.values():
if isinstance(layer, Skippable):
raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction")
return [ModuleWrapper(nn.Sequential(layers), Location(j, 0))]
# Prepare for the next partition.
layers.clear()
j += 1
raise ValueError("Souldn't get here, more ranks than partitions")
def split_module(module: nn.Sequential, balance: Iterable[int],) -> Tuple[List[nn.Sequential], List[int]]:
"""Splits a module into multiple partitions. """Splits a module into multiple partitions.
Returns: Returns:
A tuple of (partitions, balance). partitions
Partitions are represented as a :class:`~torch.nn.ModuleList` whose Partitions are represented as a :class:`~torch.nn.ModuleList` whose
item is a partition. All layers in a partition are placed in the item is a partition. All layers in a partition are placed in the
...@@ -307,8 +169,7 @@ def split_module(module: nn.Sequential, balance: Iterable[int],) -> Tuple[List[n ...@@ -307,8 +169,7 @@ def split_module(module: nn.Sequential, balance: Iterable[int],) -> Tuple[List[n
layers.clear() layers.clear()
j += 1 j += 1
partitions = cast(List[nn.Sequential], nn.ModuleList(partitions)) return partitions
return partitions, balance
MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement") MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement")
...@@ -415,10 +276,9 @@ class MultiProcessPipe(Module): ...@@ -415,10 +276,9 @@ class MultiProcessPipe(Module):
def __init__( def __init__(
self, self,
module: Union[nn.Sequential, ListOfLazyModules], module: Union[nn.Sequential, List[LazyModule]],
balance: Optional[Iterable[int]] = None, balance: Optional[Iterable[int]] = None,
*, *,
style: PipelineStyle = PipelineStyle.MultiProcess,
group: Optional[torch.distributed.ProcessGroup] = None, group: Optional[torch.distributed.ProcessGroup] = None,
worker_map: Optional[Dict[int, str]] = None, worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None, input_device: Union[None, int, str, torch.device] = None,
...@@ -427,7 +287,6 @@ class MultiProcessPipe(Module): ...@@ -427,7 +287,6 @@ class MultiProcessPipe(Module):
deferred_batch_norm: bool = False, deferred_batch_norm: bool = False,
pipelined_backward: bool = None, pipelined_backward: bool = None,
retain_graph: bool = False, retain_graph: bool = False,
loss_fn: Optional[nn.Module] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -453,19 +312,16 @@ class MultiProcessPipe(Module): ...@@ -453,19 +312,16 @@ class MultiProcessPipe(Module):
self.pipelined_backward = pipelined_backward self.pipelined_backward = pipelined_backward
self.retain_graph = retain_graph self.retain_graph = retain_graph
self.pipeline: Optional[MultiProcessPipeline] self.pipeline: Optional[MultiProcessPipeline]
self.loss_fn = loss_fn
self.lock = threading.Lock() self.lock = threading.Lock()
self.group = group
self.worker_map = worker_map self.worker_map = worker_map
self.input_device = input_device self.input_device = input_device
# The micro-batch index where the checkpointing stops. self.group: torch.distributed.ProcessGroup
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] if group is None:
if self.group is None:
self.group = get_pipeline_parallel_group() self.group = get_pipeline_parallel_group()
assert self.group else:
self.group = group
self.balance = list(balance) self.balance = list(balance)
...@@ -480,14 +336,14 @@ class MultiProcessPipe(Module): ...@@ -480,14 +336,14 @@ class MultiProcessPipe(Module):
warnings.warn("More ranks than partitions, some ranks unused") warnings.warn("More ranks than partitions, some ranks unused")
self.partitions: List[ModuleWrapper] = [] self.partitions: List[ModuleWrapper] = []
else: else:
self.partitions = instantiate_partition(module, balance, self.group, style) self.partitions = self.instantiate_partition(module, balance, self.group)
if deferred_batch_norm: if deferred_batch_norm:
for part in self.partitions: for part in self.partitions:
part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks) part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks)
for name, part in enumerate(self.partitions): for name, part in enumerate(self.partitions):
self.add_module(str(name), part.module) self.add_module(str(name), part.module)
if isinstance(module, nn.Sequential): if isinstance(module, nn.Sequential):
local_partitions, _ = split_module(module, balance) local_partitions = split_module(module, balance)
self._skip_layout = inspect_skip_layout(local_partitions) self._skip_layout = inspect_skip_layout(local_partitions)
else: else:
self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom) self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom)
...@@ -501,24 +357,78 @@ class MultiProcessPipe(Module): ...@@ -501,24 +357,78 @@ class MultiProcessPipe(Module):
self.final_stage = False self.final_stage = False
else: else:
self.final_stage = rank == len(self.balance) - 1 self.final_stage = rank == len(self.balance) - 1
assert loss_fn is None or self.final_stage
self.create_pipeline()
del module
if self.pipelined_backward is None:
if get_model_parallel_world_size() > 1:
self.pipelined_backward = True
else:
self.pipelined_backward = False
def create_pipeline(self) -> None:
# The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
self.pipeline = MultiProcessPipeline( self.pipeline = MultiProcessPipeline(
cast(List[nn.Sequential], self.partitions), self.partitions,
self._skip_layout, self._skip_layout,
checkpoint_stop, checkpoint_stop,
style=style, style=PipelineStyle.MultiProcess,
group=self.group, group=self.group,
worker_map=self.worker_map, worker_map=self.worker_map,
input_device=self.input_device, input_device=self.input_device,
final_stage=self.final_stage, final_stage=self.final_stage,
) )
del module
if self.pipelined_backward is None: def instantiate_partition(
if get_model_parallel_world_size() > 1: self,
self.pipelined_backward = True module: Union[nn.Sequential, List[LazyModule]],
balance: Iterable[int],
group: torch.distributed.ProcessGroup,
) -> List[ModuleWrapper]:
balance = list(balance)
check_balance(module, balance, True)
layers: NamedModules = OrderedDict()
def maybe_realize(layer: Any) -> nn.Module:
if isinstance(layer, nn.Module):
return layer
elif callable(layer):
return layer()
else: else:
self.pipelined_backward = False raise TypeError(f"layer must be nn.Module or callable, is {type(layer)}")
def iterate_module(module: Union[nn.Sequential, list]) -> Iterable[Tuple[Any, nn.Module]]:
if isinstance(module, nn.Sequential):
yield from module.named_children()
else:
yield from ((str(k), v) for k, v in enumerate(module))
j = 0
for name, layer in iterate_module(module):
layers[name] = layer
if len(layers) == balance[j]:
if j == group.rank():
for key in layers:
layers[key] = maybe_realize(layers[key])
if not isinstance(module, nn.Sequential):
for layer in layers.values():
if isinstance(layer, Skippable):
raise ValueError(
"Can't use Skippable layers with multi-process pipe and lazy construction"
)
return [ModuleWrapper(nn.Sequential(layers), Location(j, 0))]
# Prepare for the next partition.
layers.clear()
j += 1
raise ValueError("Souldn't get here, more ranks than partitions")
def __len__(self) -> int: def __len__(self) -> int:
"""Counts the length of the underlying sequential module.""" """Counts the length of the underlying sequential module."""
......
...@@ -23,7 +23,7 @@ from queue import Empty as QueueEmpty ...@@ -23,7 +23,7 @@ from queue import Empty as QueueEmpty
from queue import Queue from queue import Queue
from threading import Event from threading import Event
from types import TracebackType from types import TracebackType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union, cast from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
...@@ -171,7 +171,7 @@ class MultiProcessPipeline: ...@@ -171,7 +171,7 @@ class MultiProcessPipeline:
def __init__( def __init__(
self, self,
partitions: List[nn.Sequential], partitions: List[ModuleWrapper],
skip_layout: SkipLayout, skip_layout: SkipLayout,
checkpoint_stop: int, checkpoint_stop: int,
style: PipelineStyle, style: PipelineStyle,
...@@ -180,7 +180,7 @@ class MultiProcessPipeline: ...@@ -180,7 +180,7 @@ class MultiProcessPipeline:
input_device: Union[None, int, str, torch.device] = None, input_device: Union[None, int, str, torch.device] = None,
final_stage: bool = False, final_stage: bool = False,
) -> None: ) -> None:
self.partitions: List[ModuleWrapper] = cast(List[ModuleWrapper], partitions) self.partitions = partitions
self.skip_layout = skip_layout self.skip_layout = skip_layout
self.__checkpoint_stop = checkpoint_stop self.__checkpoint_stop = checkpoint_stop
self.style = style self.style = style
......
...@@ -32,7 +32,6 @@ from fairscale.nn.model_parallel.initialize import ( ...@@ -32,7 +32,6 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel, initialize_model_parallel,
) )
from fairscale.nn.pipe import AsyncPipe, LazyModule, MultiProcessPipe from fairscale.nn.pipe import AsyncPipe, LazyModule, MultiProcessPipe
from fairscale.nn.pipe.types import PipelineStyle
from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn, torch_version from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn, torch_version
...@@ -874,9 +873,12 @@ def reuse_lazy(): ...@@ -874,9 +873,12 @@ def reuse_lazy():
assert torch.equal(model_out, pipe_out) assert torch.equal(model_out, pipe_out)
def test_instantiate_partition(): @torch_spawn([1])
def instantiate_partition():
from fairscale.nn.pipe.async_schedule import Location from fairscale.nn.pipe.async_schedule import Location
from fairscale.nn.pipe.multiprocess_pipe import instantiate_partition
model = nn.Sequential(nn.Linear(1, 1))
pipe = AsyncPipe(model, balance=[1], worker_map=get_worker_map(), chunks=1)
class FakeGroup: class FakeGroup:
def __init__(self, rank, size): def __init__(self, rank, size):
...@@ -904,9 +906,7 @@ def test_instantiate_partition(): ...@@ -904,9 +906,7 @@ def test_instantiate_partition():
# Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from # Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from
# instantiated model # instantiated model
for rank in range(len(balance)): for rank in range(len(balance)):
instantiated = instantiate_partition( instantiated = pipe.instantiate_partition(model, balance, FakeGroup(rank, len(balance)))
model, balance, FakeGroup(rank, len(balance)), PipelineStyle.AsyncSchedule
)
for part in instantiated: for part in instantiated:
assert isinstance(part.module, nn.Sequential) assert isinstance(part.module, nn.Sequential)
for inv in part.invocations: for inv in part.invocations:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment