Unverified Commit f21b5ffc authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] pipe: simplify balance and module checks (#346)

parent cd186441
......@@ -13,7 +13,7 @@ from torch import Tensor, nn
from .async_pipeline import AsyncPipeline
from .async_schedule import Invocation, Location, ModuleWrapper
from .multiprocess_pipe import MultiProcessPipe, check_balance
from .multiprocess_pipe import MultiProcessPipe
from .skip.skippable import Skippable
from .types import LazyModule
......@@ -54,14 +54,8 @@ class AsyncPipe(MultiProcessPipe):
)
def instantiate_partition(
self,
module: Union[nn.Sequential, List[LazyModule]],
balance: Iterable[int],
group: torch.distributed.ProcessGroup,
self, module: Union[nn.Sequential, List[LazyModule]], balance: List[int], group: torch.distributed.ProcessGroup,
) -> List[ModuleWrapper]:
balance = list(balance)
check_balance(module, balance, True)
layers: NamedModules = OrderedDict()
def maybe_realize(layer: Any) -> nn.Module:
......
......@@ -53,85 +53,22 @@ else:
NamedModules = OrderedDict
def recommend_auto_balance(message: str) -> str:
"""Expands a message with recommendation to :mod:`torchpipe.balance`."""
return f"""{message}
If your model is still under development, its optimal balance would change
frequently. In this case, we highly recommend 'fairscale.nn.pipe.balance' for
naive automatic balancing:
from fairscale.nn import Pipe
from fairscale.nn.pipe.balance import balance_by_time
partitions = torch.cuda.device_count()
sample = torch.empty(...)
balance = balance_by_time(partitions, model, sample)
model = MultiProcessPipe(model, balance, ...)
"""
# FIXME(tom) make this a valid way to call
def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None:
for layer in module:
if isinstance(layer, nn.Module):
pass
elif isinstance(layer, LazyModule):
pass
else:
raise TypeError(f"layer {type(layer)} must be nn.Module or LazyModule to be partitioned")
def verify_module(module: Union[nn.Sequential, List[LazyModule]]) -> None:
if isinstance(module, Iterable) and not isinstance(module, nn.Sequential):
verify_list_of_callable(module)
else:
if not isinstance(module, nn.Sequential):
raise TypeError("module must be nn.Sequential to be partitioned")
if len(set(map(id, module))) != len(module):
raise ValueError("module with duplicate children is not supported")
named_children = list(module.named_children())
if len(named_children) != len(module):
raise ValueError("module with duplicate children is not supported")
def verify_splitting(module: nn.Sequential, partitions: List[nn.Sequential], balance: Iterable[int],) -> None:
num_parameters = len(list(module.parameters()))
num_child_parameters = sum(len(list(child.parameters())) for child in module.children())
if num_parameters == num_child_parameters:
return
for i in range(len(partitions)):
for j in range(i + 1, len(partitions)):
parti = partitions[i]
partj = partitions[j]
for p in parti.parameters():
for q in partj.parameters():
if p is q:
raise ValueError("module with duplicate parameters on distinct devices is not supported")
class BalanceError(ValueError):
pass
def check_balance(module: Any, balance: Iterable[int], filter_unique: bool = False) -> None:
if filter_unique:
module_len = len(set(map(id, module)))
else:
module_len = len(module)
if module_len != sum(balance):
raise BalanceError(
def check_balance(module: Union[nn.Sequential, List[LazyModule]], balance: List[int]) -> None:
if len(module) != sum(balance):
raise ValueError(
f"module and sum of balance have different length (module: {len(module)}, sum of balance: {sum(balance)})"
)
if any(x <= 0 for x in balance):
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
raise ValueError(f"all balance numbers must be positive integer (balance: {balance})")
def split_module(module: nn.Sequential, balance: Iterable[int],) -> List[nn.Sequential]:
def split_module(module: nn.Sequential, balance: List[int]) -> List[nn.Sequential]:
"""Splits a module into multiple partitions.
Returns:
......@@ -148,10 +85,6 @@ def split_module(module: nn.Sequential, balance: Iterable[int],) -> List[nn.Sequ
the number of devices is fewer than the number of partitions.
"""
balance = list(balance)
check_balance(module, balance)
j = 0
partitions = []
layers: NamedModules = OrderedDict()
......@@ -274,7 +207,7 @@ class MultiProcessPipe(Module):
def __init__(
self,
module: Union[nn.Sequential, List[LazyModule]],
balance: Optional[Iterable[int]] = None,
balance: Iterable[int],
*,
group: Optional[torch.distributed.ProcessGroup] = None,
worker_map: Optional[Dict[int, str]] = None,
......@@ -290,14 +223,14 @@ class MultiProcessPipe(Module):
chunks = int(chunks)
checkpoint = str(checkpoint)
if balance is None:
raise ValueError(recommend_auto_balance("balance is required"))
if chunks <= 0:
raise ValueError("number of chunks must be positive integer")
if checkpoint not in ["always", "except_last", "never"]:
raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'")
self.balance = list(balance)
verify_module(module)
check_balance(module, self.balance)
# Verify if the underlying skippable modules satisfy integrity. The
# integrity can be verified before forward() because it is static.
......@@ -320,33 +253,28 @@ class MultiProcessPipe(Module):
else:
self.group = group
self.balance = list(balance)
if self.group.size() < len(self.balance):
raise IndexError(
f"too few ranks to hold given partitions (ranks: {self.group.size()}, partitions:"
f" {len(self.balance)})"
)
try:
rank = self.group.rank()
if rank >= len(self.balance):
warnings.warn("More ranks than partitions, some ranks unused")
self.partitions: List[ModuleWrapper] = []
else:
self.partitions = self.instantiate_partition(module, balance, self.group)
if deferred_batch_norm:
for part in self.partitions:
part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks)
for name, part in enumerate(self.partitions):
self.add_module(str(name), part.module)
if isinstance(module, nn.Sequential):
local_partitions = split_module(module, balance)
self._skip_layout = inspect_skip_layout(local_partitions)
else:
self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom)
except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc)))
rank = self.group.rank()
if rank >= len(self.balance):
warnings.warn("More ranks than partitions, some ranks unused")
self.partitions: List[ModuleWrapper] = []
else:
self.partitions = self.instantiate_partition(module, self.balance, self.group)
if deferred_batch_norm:
for part in self.partitions:
part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks)
for name, part in enumerate(self.partitions):
self.add_module(str(name), part.module)
if isinstance(module, nn.Sequential):
local_partitions = split_module(module, self.balance)
self._skip_layout = inspect_skip_layout(local_partitions)
else:
self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom)
rank = self.group.rank()
if rank >= len(self.balance):
......@@ -378,14 +306,8 @@ class MultiProcessPipe(Module):
)
def instantiate_partition(
self,
module: Union[nn.Sequential, List[LazyModule]],
balance: Iterable[int],
group: torch.distributed.ProcessGroup,
self, module: Union[nn.Sequential, List[LazyModule]], balance: List[int], group: torch.distributed.ProcessGroup,
) -> List[ModuleWrapper]:
balance = list(balance)
check_balance(module, balance, True)
layers: NamedModules = OrderedDict()
def maybe_realize(layer: Any) -> nn.Module:
......
......@@ -32,7 +32,7 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel,
)
from fairscale.nn.pipe import AsyncPipe, LazyModule, MultiProcessPipe
from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn, torch_version
from fairscale.utils.testing import get_worker_map, torch_spawn, torch_version
@torch_spawn([2])
......@@ -706,15 +706,11 @@ def named_children(pipe_class):
@torch_spawn([1])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def recommend_auto_balance(pipe_class):
with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"):
# balance is required
pipe_class(nn.Sequential())
with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"):
with pytest.raises(ValueError):
# module and sum of balance have differen length (module: 0, sum of balance: 1)
pipe_class(nn.Sequential(), [1])
with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"):
with pytest.raises(ValueError):
# module and sum of balance have different length (module: 2, sum of balance: 1)
pipe_class(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1])
......@@ -805,174 +801,3 @@ def async_event_loop():
if pipe.final_stage:
loss = output.mean()
loss.backward()
@torch_spawn([4])
def reuse_lazy():
if False: # speed
reused = LazyModule(lambda: nn.Linear(10, 10))
model = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()]
# model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()]
pipe = AsyncPipe(model, [3, 1, 1], worker_map=get_worker_map())
pipe.eval()
output = pipe(torch.rand(10))
print(f"output on {pipe.group.rank()}, {output}")
torch.distributed.barrier()
set_random_seed(1234)
# test both foward
reused = nn.Linear(10, 10)
layers = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()]
model = nn.Sequential(*layers)
model.eval()
set_random_seed(1234)
# ensure identical weights but no sharing between model and pipe
reused = nn.Linear(10, 10)
layers = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()]
pipe = AsyncPipe(layers, [3, 1, 1], worker_map=get_worker_map())
pipe.eval()
model_optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
pipe_optimizer = torch.optim.SGD(pipe.parameters(), lr=0.01, momentum=0.9) if len(list(pipe.parameters())) else None
inputs = torch.rand(10)
if False: # speed
model_out = model(inputs)
pipe_out = pipe(inputs)
torch.distributed.barrier()
if pipe.final_stage:
assert torch.equal(model_out, pipe_out)
model.train()
pipe.train()
model_out = model(inputs)
pipe_out = pipe(inputs)
if pipe.final_stage:
pipe_loss = pipe_out.mean()
pipe_loss.backward()
model_loss = model_out.mean()
model_loss.backward()
model_optimizer.step()
if pipe_optimizer:
pipe_optimizer.step()
model.eval()
pipe.eval()
model_out = model(inputs)
pipe_out = pipe(inputs)
print(f"before barrier on {torch.distributed.get_rank()}")
torch.distributed.barrier()
print(f"after barrier on {torch.distributed.get_rank()}")
if pipe.final_stage:
assert torch.equal(model_out, pipe_out)
@torch_spawn([1])
def instantiate_partition():
from fairscale.nn.pipe.async_schedule import Location
model = nn.Sequential(nn.Linear(1, 1))
pipe = AsyncPipe(model, balance=[1], worker_map=get_worker_map(), chunks=1)
class FakeGroup:
def __init__(self, rank, size):
self._rank = rank
self._size = size
def rank(self):
return self._rank
def size(self):
return self._size
def check_partitions(model, balance, expected_order, expected_ranks):
"""Check the instantiated model matches expectation of order and rank
model: a list of modules or an nn.Sequential
balance: the balance argument to MultiProcessPipe
expected_order: the index of modules in `model` in the order they will
be executed, grouped by nn.Sequential
expected_rank: the rank that each module will be executed on
"""
invocations = []
invocation_wrapper = dict()
# Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from
# instantiated model
for rank in range(len(balance)):
instantiated = pipe.instantiate_partition(model, balance, FakeGroup(rank, len(balance)))
for part in instantiated:
assert isinstance(part.module, nn.Sequential)
for inv in part.invocations:
invocations.append(inv)
invocation_wrapper[inv] = part
modules = []
prev = None
current = Location(0, 0)
ranks = []
for order, inv in enumerate(sorted(invocations, key=lambda x: x.order)):
# Check integrity of Location chain
assert inv.order == order
assert inv.source == prev
assert inv.this == current
prev = inv.this
current = inv.dest
modules.append(list(invocation_wrapper[inv].module.children()))
ranks.append(inv.this.stage)
# assert len(modules) == len(expected_order)
for left, right in zip(modules, expected_order):
assert len(left) == len(right), f"{right}"
assert list(map(id, left)) == list(map(id, (model[e] for e in right))), f"{right}"
assert ranks == expected_ranks
reused = nn.Linear(20, 20)
model = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()]
balance = [3, 1, 1]
check_partitions(
model, balance, expected_order=[[0], [1, 2], [0], [4], [0], [6]], expected_ranks=[0, 0, 0, 1, 0, 2]
)
reused2 = nn.Linear(5, 5)
model = [reused, reused2, nn.Linear(10, 10), nn.ReLU(), reused, reused2, nn.ReLU(), reused, reused2, nn.ReLU()]
balance = [4, 1, 1]
check_partitions(
model,
balance,
expected_order=[[0], [1], [2, 3], [0], [1], [6], [0], [1], [9]],
expected_ranks=[0, 0, 0, 0, 0, 1, 0, 0, 2],
)
reused2 = nn.Linear(5, 5)
model = [
nn.Linear(10, 10),
reused,
nn.Linear(10, 10),
nn.ReLU(),
reused,
reused2,
nn.ReLU(),
reused,
reused2,
nn.ReLU(),
]
# 0 1 2 3 1 5 6 1 5 9
balance = [4, 2, 1]
check_partitions(
model,
balance,
expected_order=[[0], [1], [2, 3], [1], [5], [6], [1], [5], [9]],
expected_ranks=[0, 0, 0, 0, 1, 1, 0, 1, 2],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment