Unverified Commit a8dd9254 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] pipe: move async-specific code out of MultiProcessPipe (#344)

parent e348806b
......@@ -3,10 +3,153 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from .multiprocess_pipe import MultiProcessPipe
from .types import PipelineStyle
from collections import OrderedDict
from dataclasses import dataclass, field
import itertools
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
import torch
from torch import Tensor, nn
from .async_schedule import Invocation, Location, ModuleWrapper
from .multiprocess_pipe import MultiProcessPipe, check_balance
from .multiprocess_pipeline import MultiProcessPipeline
from .skip.skippable import Skippable
from .types import LazyModule, PipelineStyle
if TYPE_CHECKING:
Module = nn.Module[TensorOrTensors]
NamedModules = OrderedDict[str, Module]
else:
Module = nn.Module
NamedModules = OrderedDict
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
@dataclass
class PartitionInfo:
location: Location
modules: "OrderedDict[str, nn.Module]"
invocations: List[Invocation] = field(default_factory=list)
def __len__(self) -> int:
return len(self.modules)
class AsyncPipe(MultiProcessPipe):
def __init__(self, *args, **kwargs) -> None: # type: ignore
super().__init__(*args, style=PipelineStyle.AsyncSchedule, **kwargs)
def create_pipeline(self) -> None:
# The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
self.pipeline = MultiProcessPipeline(
self.partitions,
self._skip_layout,
checkpoint_stop,
style=PipelineStyle.AsyncSchedule,
group=self.group,
worker_map=self.worker_map,
input_device=self.input_device,
final_stage=self.final_stage,
)
def instantiate_partition(
self,
module: Union[nn.Sequential, List[LazyModule]],
balance: Iterable[int],
group: torch.distributed.ProcessGroup,
) -> List[ModuleWrapper]:
balance = list(balance)
check_balance(module, balance, True)
layers: NamedModules = OrderedDict()
def maybe_realize(layer: Any) -> nn.Module:
if isinstance(layer, nn.Module):
return layer
elif callable(layer):
return layer()
else:
raise TypeError(f"layer must be nn.Module or callable, is {type(layer)}")
def iterate_module(module: Union[nn.Sequential, list]) -> Iterable[Tuple[Any, nn.Module]]:
if isinstance(module, nn.Sequential):
yield from module.named_children()
else:
yield from ((str(k), v) for k, v in enumerate(module))
module_ids = list(map(id, module))
index_of_first_use = [module_ids.index(x) for x in module_ids]
locations: List[Location] = []
module_iter = enumerate(iterate_module(module))
partitions: List[List[PartitionInfo]] = []
for bi, b in enumerate(balance):
modules_for_rank: List[PartitionInfo] = []
current_module: OrderedDict[str, nn.Module] = OrderedDict()
def current_location() -> Location:
return Location(bi, len(modules_for_rank))
def append_module(mod: "OrderedDict[str, nn.Module]") -> None:
modules_for_rank.append(PartitionInfo(current_location(), mod))
while sum(map(len, modules_for_rank)) + len(current_module) < b:
module_index, (name, layer) = next(module_iter)
if index_of_first_use[module_index] != module_index:
# Subsequent reuse of a module
locations.append(locations[index_of_first_use[module_index]])
continue
is_reused = index_of_first_use.count(index_of_first_use[module_index]) > 1
if is_reused and len(current_module) > 0:
append_module(current_module)
current_module = OrderedDict()
current_module[str(name)] = layer
locations.append(current_location())
if is_reused:
append_module(current_module)
current_module = OrderedDict()
if len(current_module) > 0:
append_module(current_module)
partitions.append(modules_for_rank)
filtered_locations: List[Optional[Location]] = [loc for loc, _ in itertools.groupby(locations)]
filtered_locations.append(None)
for i in range(len(filtered_locations) - 1):
loc = filtered_locations[i]
assert loc
if i == 0:
inv = Invocation(i, loc, None, filtered_locations[i + 1])
else:
inv = Invocation(i, loc, filtered_locations[i - 1], filtered_locations[i + 1])
partitions[loc.stage][loc.index].invocations.append(inv)
invocations = enumerate(iterate_module(module))
partition = partitions[group.rank()]
result: List[ModuleWrapper] = []
for partition_info in partition:
wrapper = ModuleWrapper(
nn.Sequential(OrderedDict((k, maybe_realize(m)) for k, m in partition_info.modules.items())),
partition_info.location,
partition_info.invocations,
)
if not isinstance(module, nn.Sequential):
for layer in wrapper.module:
if isinstance(layer, Skippable):
raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction")
result.append(wrapper)
return result
......@@ -19,10 +19,8 @@
"""The MultiProcessPipe interface."""
from collections import OrderedDict
from dataclasses import dataclass, field
import itertools
import threading
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
import warnings
import torch
......@@ -33,7 +31,7 @@ import torch.cuda
from fairscale.nn.model_parallel import get_model_parallel_world_size, get_pipeline_parallel_group
from . import microbatch
from .async_schedule import Invocation, Location, ModuleWrapper
from .async_schedule import Location, ModuleWrapper
from .batchnorm import DeferredBatchNorm
from .multiprocess_pipeline import MultiProcessPipeline
from .phony import get_phony
......@@ -47,8 +45,6 @@ __all__ = ["MultiProcessPipe", "LazyModule"]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
ListOfLazyModules = List[LazyModule]
if TYPE_CHECKING:
Module = nn.Module[TensorOrTensors]
NamedModules = OrderedDict[str, Module]
......@@ -87,7 +83,7 @@ def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None:
raise TypeError(f"layer {type(layer)} must be nn.Module or LazyModule to be partitioned")
def verify_module(module: Union[nn.Sequential, ListOfLazyModules]) -> None:
def verify_module(module: Union[nn.Sequential, List[LazyModule]]) -> None:
if isinstance(module, Iterable) and not isinstance(module, nn.Sequential):
verify_list_of_callable(module)
else:
......@@ -135,145 +131,11 @@ def check_balance(module: Any, balance: Iterable[int], filter_unique: bool = Fal
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
@dataclass
class PartitionInfo:
location: Location
modules: "OrderedDict[str, nn.Module]"
invocations: List[Invocation] = field(default_factory=list)
def __len__(self) -> int:
return len(self.modules)
def instantiate_partition(
module: Union[nn.Sequential, ListOfLazyModules],
balance: Iterable[int],
group: torch.distributed.ProcessGroup,
style: PipelineStyle,
) -> List[ModuleWrapper]:
balance = list(balance)
check_balance(module, balance, True)
layers: NamedModules = OrderedDict()
def maybe_realize(layer: Any) -> nn.Module:
if isinstance(layer, nn.Module):
return layer
elif callable(layer):
return layer()
else:
raise TypeError(f"layer must be nn.Module or callable, is {type(layer)}")
def iterate_module(module: Union[nn.Sequential, list]) -> Iterable[Tuple[Any, nn.Module]]:
if isinstance(module, nn.Sequential):
yield from module.named_children()
else:
yield from ((str(k), v) for k, v in enumerate(module))
if style == PipelineStyle.AsyncSchedule:
module_ids = list(map(id, module))
index_of_first_use = [module_ids.index(x) for x in module_ids]
locations: List[Location] = []
module_iter = enumerate(iterate_module(module))
partitions: List[List[PartitionInfo]] = []
for bi, b in enumerate(balance):
modules_for_rank: List[PartitionInfo] = []
current_module: OrderedDict[str, nn.Module] = OrderedDict()
def current_location() -> Location:
return Location(bi, len(modules_for_rank))
def append_module(mod: "OrderedDict[str, nn.Module]") -> None:
modules_for_rank.append(PartitionInfo(current_location(), mod))
while sum(map(len, modules_for_rank)) + len(current_module) < b:
module_index, (name, layer) = next(module_iter)
if index_of_first_use[module_index] != module_index:
# Subsequent reuse of a module
locations.append(locations[index_of_first_use[module_index]])
continue
is_reused = index_of_first_use.count(index_of_first_use[module_index]) > 1
if is_reused and len(current_module) > 0:
append_module(current_module)
current_module = OrderedDict()
current_module[str(name)] = layer
locations.append(current_location())
if is_reused:
append_module(current_module)
current_module = OrderedDict()
if len(current_module) > 0:
append_module(current_module)
partitions.append(modules_for_rank)
filtered_locations: List[Optional[Location]] = [loc for loc, _ in itertools.groupby(locations)]
filtered_locations.append(None)
for i in range(len(filtered_locations) - 1):
loc = filtered_locations[i]
assert loc
if i == 0:
inv = Invocation(i, loc, None, filtered_locations[i + 1])
else:
inv = Invocation(i, loc, filtered_locations[i - 1], filtered_locations[i + 1])
partitions[loc.stage][loc.index].invocations.append(inv)
invocations = enumerate(iterate_module(module))
partition = partitions[group.rank()]
result: List[ModuleWrapper] = []
for partition_info in partition:
wrapper = ModuleWrapper(
nn.Sequential(OrderedDict((k, maybe_realize(m)) for k, m in partition_info.modules.items())),
partition_info.location,
partition_info.invocations,
)
if not isinstance(module, nn.Sequential):
for layer in wrapper.module:
if isinstance(layer, Skippable):
raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction")
result.append(wrapper)
return result
j = 0
for name, layer in iterate_module(module):
layers[name] = layer
if len(layers) == balance[j]:
if j == group.rank():
for key in layers:
layers[key] = maybe_realize(layers[key])
if not isinstance(module, nn.Sequential):
for layer in layers.values():
if isinstance(layer, Skippable):
raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction")
return [ModuleWrapper(nn.Sequential(layers), Location(j, 0))]
# Prepare for the next partition.
layers.clear()
j += 1
raise ValueError("Souldn't get here, more ranks than partitions")
def split_module(module: nn.Sequential, balance: Iterable[int],) -> Tuple[List[nn.Sequential], List[int]]:
def split_module(module: nn.Sequential, balance: Iterable[int],) -> List[nn.Sequential]:
"""Splits a module into multiple partitions.
Returns:
A tuple of (partitions, balance).
partitions
Partitions are represented as a :class:`~torch.nn.ModuleList` whose
item is a partition. All layers in a partition are placed in the
......@@ -307,8 +169,7 @@ def split_module(module: nn.Sequential, balance: Iterable[int],) -> Tuple[List[n
layers.clear()
j += 1
partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))
return partitions, balance
return partitions
MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement")
......@@ -415,10 +276,9 @@ class MultiProcessPipe(Module):
def __init__(
self,
module: Union[nn.Sequential, ListOfLazyModules],
module: Union[nn.Sequential, List[LazyModule]],
balance: Optional[Iterable[int]] = None,
*,
style: PipelineStyle = PipelineStyle.MultiProcess,
group: Optional[torch.distributed.ProcessGroup] = None,
worker_map: Optional[Dict[int, str]] = None,
input_device: Union[None, int, str, torch.device] = None,
......@@ -427,7 +287,6 @@ class MultiProcessPipe(Module):
deferred_batch_norm: bool = False,
pipelined_backward: bool = None,
retain_graph: bool = False,
loss_fn: Optional[nn.Module] = None,
) -> None:
super().__init__()
......@@ -453,19 +312,16 @@ class MultiProcessPipe(Module):
self.pipelined_backward = pipelined_backward
self.retain_graph = retain_graph
self.pipeline: Optional[MultiProcessPipeline]
self.loss_fn = loss_fn
self.lock = threading.Lock()
self.group = group
self.worker_map = worker_map
self.input_device = input_device
# The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
if self.group is None:
self.group: torch.distributed.ProcessGroup
if group is None:
self.group = get_pipeline_parallel_group()
assert self.group
else:
self.group = group
self.balance = list(balance)
......@@ -480,14 +336,14 @@ class MultiProcessPipe(Module):
warnings.warn("More ranks than partitions, some ranks unused")
self.partitions: List[ModuleWrapper] = []
else:
self.partitions = instantiate_partition(module, balance, self.group, style)
self.partitions = self.instantiate_partition(module, balance, self.group)
if deferred_batch_norm:
for part in self.partitions:
part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks)
for name, part in enumerate(self.partitions):
self.add_module(str(name), part.module)
if isinstance(module, nn.Sequential):
local_partitions, _ = split_module(module, balance)
local_partitions = split_module(module, balance)
self._skip_layout = inspect_skip_layout(local_partitions)
else:
self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom)
......@@ -501,18 +357,8 @@ class MultiProcessPipe(Module):
self.final_stage = False
else:
self.final_stage = rank == len(self.balance) - 1
assert loss_fn is None or self.final_stage
self.pipeline = MultiProcessPipeline(
cast(List[nn.Sequential], self.partitions),
self._skip_layout,
checkpoint_stop,
style=style,
group=self.group,
worker_map=self.worker_map,
input_device=self.input_device,
final_stage=self.final_stage,
)
self.create_pipeline()
del module
if self.pipelined_backward is None:
if get_model_parallel_world_size() > 1:
......@@ -520,6 +366,70 @@ class MultiProcessPipe(Module):
else:
self.pipelined_backward = False
def create_pipeline(self) -> None:
# The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
self.pipeline = MultiProcessPipeline(
self.partitions,
self._skip_layout,
checkpoint_stop,
style=PipelineStyle.MultiProcess,
group=self.group,
worker_map=self.worker_map,
input_device=self.input_device,
final_stage=self.final_stage,
)
def instantiate_partition(
self,
module: Union[nn.Sequential, List[LazyModule]],
balance: Iterable[int],
group: torch.distributed.ProcessGroup,
) -> List[ModuleWrapper]:
balance = list(balance)
check_balance(module, balance, True)
layers: NamedModules = OrderedDict()
def maybe_realize(layer: Any) -> nn.Module:
if isinstance(layer, nn.Module):
return layer
elif callable(layer):
return layer()
else:
raise TypeError(f"layer must be nn.Module or callable, is {type(layer)}")
def iterate_module(module: Union[nn.Sequential, list]) -> Iterable[Tuple[Any, nn.Module]]:
if isinstance(module, nn.Sequential):
yield from module.named_children()
else:
yield from ((str(k), v) for k, v in enumerate(module))
j = 0
for name, layer in iterate_module(module):
layers[name] = layer
if len(layers) == balance[j]:
if j == group.rank():
for key in layers:
layers[key] = maybe_realize(layers[key])
if not isinstance(module, nn.Sequential):
for layer in layers.values():
if isinstance(layer, Skippable):
raise ValueError(
"Can't use Skippable layers with multi-process pipe and lazy construction"
)
return [ModuleWrapper(nn.Sequential(layers), Location(j, 0))]
# Prepare for the next partition.
layers.clear()
j += 1
raise ValueError("Souldn't get here, more ranks than partitions")
def __len__(self) -> int:
"""Counts the length of the underlying sequential module."""
return sum(len(p) for p in self.partitions)
......
......@@ -23,7 +23,7 @@ from queue import Empty as QueueEmpty
from queue import Queue
from threading import Event
from types import TracebackType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union, cast
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
import torch
from torch import Tensor, nn
......@@ -171,7 +171,7 @@ class MultiProcessPipeline:
def __init__(
self,
partitions: List[nn.Sequential],
partitions: List[ModuleWrapper],
skip_layout: SkipLayout,
checkpoint_stop: int,
style: PipelineStyle,
......@@ -180,7 +180,7 @@ class MultiProcessPipeline:
input_device: Union[None, int, str, torch.device] = None,
final_stage: bool = False,
) -> None:
self.partitions: List[ModuleWrapper] = cast(List[ModuleWrapper], partitions)
self.partitions = partitions
self.skip_layout = skip_layout
self.__checkpoint_stop = checkpoint_stop
self.style = style
......
......@@ -32,7 +32,6 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel,
)
from fairscale.nn.pipe import AsyncPipe, LazyModule, MultiProcessPipe
from fairscale.nn.pipe.types import PipelineStyle
from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn, torch_version
......@@ -874,9 +873,12 @@ def reuse_lazy():
assert torch.equal(model_out, pipe_out)
def test_instantiate_partition():
@torch_spawn([1])
def instantiate_partition():
from fairscale.nn.pipe.async_schedule import Location
from fairscale.nn.pipe.multiprocess_pipe import instantiate_partition
model = nn.Sequential(nn.Linear(1, 1))
pipe = AsyncPipe(model, balance=[1], worker_map=get_worker_map(), chunks=1)
class FakeGroup:
def __init__(self, rank, size):
......@@ -904,9 +906,7 @@ def test_instantiate_partition():
# Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from
# instantiated model
for rank in range(len(balance)):
instantiated = instantiate_partition(
model, balance, FakeGroup(rank, len(balance)), PipelineStyle.AsyncSchedule
)
instantiated = pipe.instantiate_partition(model, balance, FakeGroup(rank, len(balance)))
for part in instantiated:
assert isinstance(part.module, nn.Sequential)
for inv in part.invocations:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment