Unverified Commit 2eef71b9 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[doc] updating the pipe balance doc a bit (#243)

* [doc] updating the pipe balance doc a bit

- Also added a warning to pipeline.py when the partition output is not
supported.

* addressed Mandeep's comment
parent 138b2033
......@@ -23,13 +23,20 @@ Usage::
import torch
from fairscale.nn import Pipe
from fairscale.nn.balance import balance_by_time
from fairscale.nn.pipe.balance import balance_by_time
sample = torch.empty(128, 3, 224, 224)
balance = balance_by_time(torch.cuda.device_count(), model, sample)
pipe = Pipe(model, balance, chunks=8)
.. note::
balance_by_time does not work with inplace ReLU because we exhausetively search
every partition boundary, which could hit an inplace ReLU.
.. note::
If the model is larger than a single CUDA device memory, use "cpu"
in the balance_by_time function.
"""
from typing import List, Tuple, Union
......@@ -49,7 +56,7 @@ Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
def balance_cost(cost: List[int], partitions: int) -> List[int]:
def _balance_cost(cost: List[int], partitions: int) -> List[int]:
partitioned = blockpartition.solve(cost, partitions)
return [len(p) for p in partitioned]
......@@ -87,14 +94,14 @@ def balance_by_time(
Returns:
A list of number of layers in each partition. Use it for the `balance`
parameter of :class:`~torchpipe.Pipe`.
parameter of :class:`~fairscale.nn.Pipe`.
.. note::
`module` and `sample` must be placed on the same device.
"""
times = profile_times(module, sample, timeout, torch.device(device))
return balance_cost(times, partitions)
return _balance_cost(times, partitions)
def balance_by_size(
......@@ -167,11 +174,11 @@ def balance_by_size(
Returns:
A list of number of layers in each partition. Use it for the `balance`
parameter of :class:`~torchpipe.Pipe`.
parameter of :class:`~fairscale.nn.Pipe`.
.. note::
`module` and `input` must be placed on the same CUDA device.
"""
sizes = profile_sizes(module, input, chunks, param_scale, torch.device(device))
return balance_cost(sizes, partitions)
return _balance_cost(sizes, partitions)
......@@ -180,7 +180,11 @@ def create_task(
part_id: int = j,
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return partition(input)
ret = partition(input)
# We do a check here because the backtrace from the checkpoint backward code path
# is very hard to make sense. It would be much easier to check earlier at this point.
assert type(ret) is not list, "Only Tensor or Tuple of Tensor output is supported"
return ret
chk = Checkpointing(function, batch)
if style is PipelineStyle.SingleProcess:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment