Unverified Commit 56506951 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[offload] Add API, tutorial and smaller doc string changes. (#576)



* modify doc string

* add offload docs

* add tutorial

* remove print

* remove print statement

* modify import

* modify constants

* modify README and add Offload symbol

* fix lint

* smaller mods

* lint errors

* Update README.md

added the references at the bottom of the readme

* address comments

* doc changes

* add blank line
Co-authored-by: default avatarAnjali Sridhar <anj@devfair0443.h2.fair>
Co-authored-by: default avatarVittorio Caggiano <caggiano@gmail.com>
parent 8f7ee69f
......@@ -18,6 +18,7 @@ FairScale supports:
* Optimizer state sharding (`fairscale.optim.OSS`)
* Sharded Data Parallel (SDP) (`fairscale.nn.ShardedDataParallel`)
* Fully Sharded Data Parallel (FSDP) (`fairscale.nn.FullyShardedDataParallel`) (PyTorch >= 1.6)
* OffloadModel (`fairscale.experimental.nn.OffloadModel`)
* Optimization at scale:
* AdaScale SGD (`fairscale.optim.AdaScale`)
* GPU memory optimization:
......@@ -192,3 +193,5 @@ Here is a list of all authors on relevant research papers this work is based on:
* AdaScale SGD: Tyler B. Johnson, Pulkit Agrawal, Haijie Gu, Carlos Guestrin. [[Paper](https://proceedings.icml.cc/static/paper_files/icml/2020/4682-Paper.pdf)]
* GShard: Dmitry Lepikhin, HyoukJoong Lee, Yuanzhong Xu, Dehao Chen, Orhan Firat, Yanping Huang, Maxim Krikun, Noam Shazeer, Zhifeng Chen [[Paper]](https://arxiv.org/abs/2006.16668)
* AMPNet:Alexander L. Gaunt, Matthew A. Johnson, Maik Riechert, Daniel Tarlow, Ryota Tomioka, Dimitrios Vytiniotis, Sam Webster [[Paper]](https://arxiv.org/abs/1705.09786)
* L2L: Training large Neural networks with constant Memory using a new execution Algorithm, 2020, [[Paper](https://arxiv.org/abs/2002.05645)]
* ZeRO-Offload: Democratizing Billion-Scale Model Training. 2021, [[Paper](https://arxiv.org/abs/2101.06840)]
OffloadModel
============
.. autoclass:: fairscale.experimental.nn.OffloadModel
:members:
:undoc-members:
......@@ -12,3 +12,4 @@ API Reference
nn/fsdp
nn/fsdp_tips
nn/misc/checkpoint_activations
experimental/nn/offload_model
......@@ -9,3 +9,5 @@ Tutorials
oss
adascale
offload_model
\ No newline at end of file
Training with `OffloadModel`
============================
`fairscale.experimental.nn.offload.OffloadModel` API democratizes large scale distributed training by enabling
users to train large models on limited GPU resources that would have traditionally resulted in OOM errors.
`OffloadModel` API wraps the given model and shards it almost equally. Each shard of the model is copied
from the CPU to the GPU for the forward pass and then copied back. The same process is repeated in the reverse
order for the backward pass. `OffloadModel` supports mixed precision training, activation checkpointing for reducing
the memory footprint and using micro batches to reduce throughput.
Note: We currently require the model to be a `nn.Sequential` model.
Consider a training loop as described below:
.. code-block:: python
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor
from fairscale.experimental.nn.offload import OffloadModel
num_inputs = 8
num_outputs = 8
num_hidden = 4
num_layers = 2
batch_size = 8
transform = ToTensor()
dataloader = DataLoader(
FakeData(
image_size=(1, num_inputs, num_inputs),
num_classes=num_outputs,
transform=transform,
),
batch_size=batch_size,
)
model = torch.nn.Sequential(
torch.nn.Linear(num_inputs * num_inputs, num_hidden),
*([torch.nn.Linear(num_hidden, num_hidden) for _ in range(num_layers)]),
torch.nn.Linear(num_hidden, num_outputs),
)
To use the `OffloadModel` API, we should wrap the model as shown below. You can specify the device that you want
to use for computing the forward and backward pass, the offload device on which the model will be stored and the number
of slices that the model should be sharded into. By default activation checkpointing is turned off and number of microbatches is 1.
.. code-block:: python
offload_model = OffloadModel(
model=model,
device=torch.device("cuda"),
offload_device=torch.device("cpu"),
num_slices=3,
checkpoint_activation=True,
num_microbatches=1,
)
torch.cuda.set_device(0)
device = torch.device("cuda")
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001)
# To train 1 epoch.
offload_model.train()
for batch_inputs, batch_outputs in dataloader:
batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda")
start = time.time_ns()
optimizer.zero_grad()
inputs = batch_inputs.reshape(-1, num_inputs * num_inputs)
with torch.cuda.amp.autocast():
output = model(inputs)
loss = criterion(output, target=batch_outputs)
loss.backward()
optimizer.step()
......@@ -5,4 +5,6 @@
from typing import List
from .offload import OffloadModel
__all__: List[str] = []
......@@ -130,7 +130,7 @@ class ModelShard(nn.Module):
self.model_shard.to(self.offload_device, non_blocking=non_blocking)
class ActivationCheckpointing(torch.autograd.Function):
class OffloadFunction(torch.autograd.Function):
"""
This Function enables checkpointing of intermediate activations at
shard boundaries by overriding the forward and backward pass of the nn.Module.
......@@ -293,13 +293,32 @@ class ActivationCheckpointing(torch.autograd.Function):
class OffloadModel(nn.Module):
"""Wrapper used offload parts of a model to the CPU.
The model is sharded into chunks and at each iteration, a
single chunk is copied from CPU->GPU, FW pass is computed and
the chunk is copied back to CPU. This process is repeated for
all the chunks. In the BW pass, the same process happens in
reverse.
"""Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module
to train by offloading majority of the model parameters to the CPU.
`OffloadModel` is heavily inspired by the _L2L algorithm and _Zero-Offload.
::
model = get_model()
offload_model = OffloadModel(model, device,
offload_device=torch.device(“cpu”),
num_slices=3,
checkpoint_activation=True,
num_microbatches=5)
.. _L2L: https://arxiv.org/abs/2002.05645
.. _Zero-Offload: https://arxiv.org/abs/2101.06840
At each step, a layer(or series of layers) are loaded
onto the GPU for the forward and backward pass with intermediate
activations being copied onto the GPU as required. Once the forward
or backward pass is completed for a given shard, it is moved back to
the CPU again.
`OffloadModel` supports activation checkpointing which reduces
the memory footprint. You can also increase the number of
microbatches which translates to more computation cycles for
every shard load. This helps offset the cost of moving the shard
from the CPU to GPU and vice versa.
Note: OffloadModel currently only supports nn.Sequential models.
......@@ -380,10 +399,10 @@ class OffloadModel(nn.Module):
self._num_microbatches = num_microbatches
def forward(self, *inputs: Any, **_: Any) -> Any:
# `apply` calls the `forward` function of the `ActivationCheckpointing` class
# `apply` calls the `forward` function of the `OffloadFunction` class
# and the `forward` function calls `inputs` on the first model shard.
# Please see https://pytorch.org/docs/stable/autograd.html#function for more details.
# We need the second param to be a dummy input to enable the
# backward pass to be triggered for integer inputs.
return ActivationCheckpointing.apply(*inputs, torch.tensor([], requires_grad=True), self)
return OffloadFunction.apply(*inputs, torch.tensor([], requires_grad=True), self)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment