Unverified Commit dcfb7a99 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[docs] Revamp FairScale documentation (#698)

* add tutorials

* add new context, modify and delete existing docs

* remove duplicate labels

* modify layout and more nits

* address comments

* fix merge conflicts
parent 29aae007
Enhanced Activation Checkpointing
=================================
Activation checkpointing is a technique used to reduce GPU memory usage during training. This is
done by avoiding the need to store intermediate activation tensors during the forward pass. Instead,
the forward pass is recomputed by keeping track of the original input during the backward pass.
There is a slight increase in computation cost (about 33%) but this reduces the need to store
large activation tensors which allows us to increase the batch size and thereby the net throughput
of the model.
Activation checkpointing is implemented by overriding `torch.autograd.Function`. In the `forward`
function which handles the forward pass of the module, using `no_grad`, we can prevent the creation
of the forward graph and materialization of intermediate activation tensors for a long period of
time (i.e till the backward pass). Instead, during the backward pass, the forward pass is executed
again followed by the backward pass. The inputs to the forward pass are saved using a context object
that is then accessed in the backward pass to retrieve the original inputs. We also save the
Random Number Generator(RNG) state for the forward and backward passes as required for Dropout layers.
The above functionality is already implemented as part of the `torch.utils.checkpoint.checkpoint_wrapper`
API whereby different modules in the forward pass can be wrapped. The wrapper in FairScale offers
functionality beyond that provided by the PyTorch API specifically you can use
`fairscale.nn.checkpoint.checkpoint_wrapper` to wrap a `nn.Module`, handle kwargs in the forward
pass, offload intermediate activations to the CPU and handle non-tensor outputs returned from the
forward function.
Best practices for `fairscale.nn.checkpoint.checkpoint_wrapper`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1. Memory savings depends entirely on the model and the segmentation of checkpoint wrapping.
Each backprop consists of several mini-forward and backprop passes. The gain is entirely dependent
on the memory footprint of the layer’s activations.
2. When using BatchNormalization you may need to freeze the calculation of statistics since we run
the forward pass twice.
3. Ensure that the input tensor’s `requires_grad` field is set to True. In order to trigger the
backward function, the output needs to have this field set. By setting it on the input tensor we
ensure that this is propagated to the output and the `backward` function is triggered.
Adascale
=========
`Adascale <https://arxiv.org/abs/2007.05105>`_ is a technique used to enable large batch training that allows you to increase batch size
without loss of accuracy. When increasing batch size with the number of devices, the learning rate
is typically tuned based on the batch size. With Adascale, users no longer need to modify the
learning rate schedule and still achieve the desired accuracy. Adascale has been implemented as
the Adascale API in FairScale. This technique typically works well for SGD (with and without momentum)
The assumption is that you already have a good learning rate schedule that works well for small
batch sizes. (AdaScale has not been validated to work effectively with Adam, further research in
that direction is needed.)
AdaScale adapts the learning rate schedule and determines when to stop based on comparing statistics
of large-batch gradients with those of small-batch gradients. Small batch gradients are gradients that
have been computed on each GPU and large batch gradients are the average of gradients computed on N
such GPUs. Adascale uses the concept of gain ratio which is intuitively a measure of how much the
variance has reduced by averaging N small batch gradients. It is a quantity between 1 and N.
In practice, the implementation tracks estimates of the gradient variance and norm-squared which
are smoothed using an exponentially-weighted moving average. If T is the number of steps used to
train the original small batch size before scaling, Adascale stops training once the accumulated
gain ratio is greater than T. As you use more and more GPUs in the training the total steps needed
to train decreases, but due to the value of gain ratio between [1, N], the total steps does not
linearly decrease as you increase the GPUs. Additional training steps are taken to maintain the
model accuracy, when compared with original_total_step/N (i.e. linear scaling). In other words,
whenever the gain ratio is less than N, we could not take as large a step as we may have hoped for,
and so the total number of iterations ends up being larger than T / N.
The current implementation in FairScale supports gradient accumulation training, can be used
with Optimizer State Sharding (OSS), and works with PyTorch LR scheduler classes.
The training process is as follows:
1. Compute the forward pass
2. During the backward pass, hooks attached to each of the parameters fire before the allreduce operation. This is to enable us to calculate the accumulated squares of the local gradients.
3. A final backward hook fires after all the gradients have been reduced using the allreduce op. Using the global gradient square and the accumulated local gradient square, the gradient square average and gradient variance average is calculated.
4. These values are then used to calculate the gain ratio. During the `step` call of the optimizer, the learning rate is updated using this gain ratio value.
5. The training loop terminates once maximum number of steps has been reached
Best practices for `fairscale.optim.AdaScale`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Adascale only works for the SGD optimizer (with and without momentum)
OffloadModel
=============
Heavily inspired by the `Layer-to-Layer <https://arxiv.org/abs/2002.05645>`_ algorithm and
`Zero-Offload <https://arxiv.org/abs/2101.06840>`_, OffloadModel uses the CPU to store
the entire model, optimizer state and gradients. OffloadModel then brings in a layer (or a number of
layers) onto the GPU for training at a time during the forward and backward pass. The intermediate
activations for the layer boundaries are also stored on the CPU and copied to the GPU as needed for
the backward pass. Once the backward pass is completed all the parameters are updated with the
gradients present on the CPU.
.. image:: ../_static/img/offload.png
:height: 500px
:width: 500px
Offload uses the following techniques to enable large model training:
1. The model is assumed to be nn.Sequential and sharded (almost) equally based on the number of
parameters into a list of nn.Modules. Each nn.Module now contains a fraction of the whole model
which we shall refer to as model shards.
2. At each iteration, each of the model shards are copied from the CPU -> GPU, FW pass is computed
using the minibatch of data and the model shard is copied back from GPU -> CPU. In the BW pass, the
same process is repeated.
3. The optimizer remains on the CPU and gradients and parameters are all moved onto the CPU before
running optimizer.step. This ensures that the CPU is responsible for updating the parameters and
holding onto the optimizer state.
4. If activation checkpointing is enabled, we use torch.autograd.Function to disable graph construction
in the FW pass and copy intermediate activations from GPU -> CPU after the FW pass of a given shard is
complete. The reverse copy is carried out in the BW pass.
5. Microbatches are used to enable larger throughput and offset the cost of moving model parameters
and activations from CPU <-> GPU. Micro-batches allow you to specify large mini-batches which are
broken down into micro-batches and fed to the model shards at each iteration. In short it is a way
to allow more computation at a given time on a model shard to offset the cost of copying from CPU <-> GPU.
Best practices for using `fairscale.experimental.nn.OffloadModel`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1. Using OffloadModel to train large models can result in loss of throughput which can be overcome by using activation checkpointing and microbatches.
2. OffloadModel currently only works for `nn.Sequential` models.
Efficient Memory management
============================
FairScale provides implementations inspired by the `ZeRO <https://arxiv.org/pdf/1910.02054.pdf>`_ class of algorithms in the form of modular
APIs that you can plug into your model training. Zero Redundancy Optimizer is a class of algorithms
that aim to tackle the tradeoff between using Data Parallel training and Model Parallel training.
When using Data Parallel training, you tradeoff memory for computation/communication efficiency.
On the other hand, when using Model Parallel training, you tradeoff computation/communication
efficiency for memory. ZeRO attempts to solve this problem. Model training generally involves memory
footprints that falls into two categories:
1. Model states - optimizer states, gradients, parameters
2. Residual states - activations, temp buffers, fragmented memory
To reduce redundancy in model states, three different algorithms were proposed. These have been
implemented in FairScale as Optimizer State Sharding (OSS), Sharded Data Parallel (SDP) and finally
Fully Sharded Data Parallel (FSDP). Let’s dive deeper into the actual mechanics of each of these
algorithms and understand why they provide the memory savings that they do.
Optimizer State Sharding (OSS)
------------------------------
FairScale has implemented memory optimization related to optimizer memory (inspired by `ZeRO-1 <https://arxiv.org/pdf/1910.02054.pdf>`_) footprint
using `fairscale.optim.OSS` API. Optimizers such as Adam usually require maintaining momentum, variance,
parameters and gradients all in FP32 precision even though training can be carried out with parameters
and gradients in FP16 precision. When each of the ranks update the full model, this means that a sizable
part of the memory is occupied by redundant representations of the optimizer state.
To overcome this redundancy, optimizer state sharding entails partitioning the model optimization step in
between the different ranks, so that each of them is only in charge of updating a unique shard of the
model. This in turn makes sure that the optimizer state is a lot smaller on each rank, and that it contains
no redundant information across ranks.
.. image:: ../_static/img/oss.png
The training process can be modified from that carried out by DDP as follows:
1. The wrapped optimizer shards the optimizer state in a greedy fashion based on the parameter size but not
the order in which it is used. This is to ensure that each rank has almost the same optimizer memory
footprint.
2. The training process is similar to that used by PyTorch’s Distributed Data Parallel (DDP). The forward
pass completes on each of the ranks followed by the backward pass. During the backward pass, gradients
are synchronized using allreduce.
3. Each rank updates the parameters for the shard of optimizer state that it is responsible for and then
discards the rest.
4. After update, a broadcast or allgather follows to ensure all ranks receive the latest updated parameter
values.
OSS is very useful when you are using an optimizer such as Adam that has additional state. The wrapping
of the optimizer is a one-line non intrusive change that provides memory savings.
If you are using SGD or any optimizer with a limited memory footprint, it is likely that you will see a
slowdown when using multiple nodes, due to the additional communication in step 4. There is also some
wasteful memory used to store gradients during allreduce in step 2 that is then discarded, although this
also happens with normal PyTorch (nothing extraneous here).
Best practices for using `fairscale.optim.oss`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1. OSS exposes a broadcast_fp16 flag that you should probably use in multi-node jobs, unless this leads to
accuracy issues (which is very unlikely). This can be used with or without Torch AMP. This is usually not
needed in a single node experiment.
2. If your model is extremely unbalanced in terms of size (one giant tensor for instance), then this method
will not be very helpful, and tensor sharding options such as `fairscale.nn.FullyShardedDataParallel`
would be preferable.
3. OSS should be a drop in solution in a DDP context, and stays compatible with most of the DDP features
such as the `fp16 gradient compression hook <https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook>`_,
gradient accumulation and PyTorch AMP.
Performance tips for `fairscale.optim.oss`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1. On a single node, OSS should be always faster than vanilla PyTorch, memory savings will vary depending
on the optimizer being used
2. When using multiple nodes, OSS can alternatively be faster or slower than vanilla PyTorch, depending
on the optimizer being used, and optional flags (E.g broadcast_fp16, gradient compression, gradient
accumulation as mentioned above.)
3. If applicable (if your experiment can do with a bigger batch size), it’s usually beneficial to reinvest
the saved memory in a larger batch size and reduce the number of ranks involved, or to use gradient
accumulation since this diminishes the communication cost.
Optimizer + Gradient State Sharding
-----------------------------------
To overcome redundant gradient memory and to enable further memory savings, gradient sharding or
`ZeRO-2 <https://arxiv.org/pdf/1910.02054.pdf>`_ was proposed. This has been implemented by the Sharded Data Parallel(SDP) API in FairScale.
While OSS solved the redundancy problem in optimizers, the above data parallel training steps revealed
a duplication of computation of gradient aggregation as well as additional memory being used for gradients
are discarded.
To enable gradient sharding, each rank is assigned a set of parameters for which they are responsible
for managing optimizer state as well as gradient aggregation. By assigning a model shard to a given
rank we ensure that gradients are reduced to specific ranks that are in turn responsible for the update.
This reduces communication as well as memory usage.
.. image:: ../_static/img/sdp.png
The training process is as follows:
1. As before the wrapped optimizer shards parameters across the different ranks.
2. The model is now wrapped with a Sharded Data Parallel (SDP) wrapper that allows us to add the appropriate hooks and maintain state during the training process.
3. SDP focuses on trainable parameters and adds a backward hook for each of the them.
4. During the backward pass, gradients are reduced to the rank that they are assigned to as part of the sharding process in 1. Instead of an allreduce op, a reduce op is used which reduces the communication overhead.
5. Each rank updates the parameters that they are responsible for.
6. After the update, a broadcast or allgather follows to ensure all ranks receive the latest updated parameter values.
Both the OSS and SDP APIs allow you to reduce the memory used for gradients and optimizer states.
Additional communication costs can be present in slow interconnects but are useful to try as first steps
when running into Out Of Memory (OOM) issues.
Best practices for `fairscale.nn.ShardedDataParallel`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1. If using multiple nodes, make sure that SDP is using reduce buffers by specifying the
`reduce_buffer_size` arg. Changing their size can be an optimization target, the best configuration
could depend on the interconnect.
2. If on a single node, it’s usually best not to use `reduce_buffer_size` since there is a latency
cost associated with it but no memory gain. Setting this value to 0 means that this feature is not
used and this is the recommended single node setting.
3. If applicable (if your experiment can do with a bigger batch size), it’s usually beneficial to
reinvest the saved memory in a larger batch size and reduce the number of ranks involved, or to use
gradient accumulation since this diminishes the communication cost.
Optimizer + Gradient + Horizontal Model Sharding
------------------------------------------------
To further optimize training and achieve greater memory savings, we need to enable parameter sharding.
With parameter sharding similar to gradient and optimizer states, data parallel ranks are responsible
for a shard of the model parameters. FairScale implements parameter sharding by way of the Fully Sharded
Data Parallel (FSDP) API which is heavily inspired by `ZeRO-3 <https://arxiv.org/pdf/1910.02054.pdf>`_. Parameter sharding is possible because of
two key insights:
1. The allreduce operation can be broken up into reduce and allgather similar to the previous sharding
technologies (optimizer state and gradient).
2. Individual layers can be wrapped with the FSDP API that allows us to bring in all the parameters
required for a single layer onto a given GPU at a given instance, compute the forward pass and then
discard the parameters not owned by that rank. Please see the tutorial section for how you can use
autowrap to enable wrapping individual layers of your model.
The training process is as follows:
.. image:: ../_static/img/fsdp.png
1. `allgather` the parameters required for the forward pass of each of the layers of the model just before the compute of a specific layer commences.
2. Compute the forward pass.
3. `allgather` the parameters required for the backward pass of each of the layers of the model just before the backward pass of a specific layer commences.
4. Compute the backward pass.
5. `reduce` the gradients such that aggregated grads are accumulated on the ranks that are responsible for the corresponding parameters.
6. Let each rank update the parameters that have been assigned to it using the aggregated gradients.
With FSDP there are small changes one needs to make when using APIs for checkpointing and saving optimizer
state. Given the sharded nature of optimizer state and parameters, any API that aims to save the model
state for training or inference needs to account for saving weights from all workers. FSDP implements the
required plumbing to save weights from all workers, save weights on individual workers and save optimizer
state from all workers.
FSDP also supports mixed precision training where both the computation and communication are carried out
in FP16 precision. If you want to reduce operations to be carried out in FP32 which is the default
behavior of DDP, then you must set `fp32_reduce_scatter=True`.
To enable further memory savings, FSDP supports offloading parameters and gradients that are currently
not being used onto the CPU. This can be enabled by setting `move_params_to_cpu` and `move_grads_to_cpu`
to be equal to True.
Best practices for `fairscale.nn.FullyShardedDataParallel`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1. For FSDP, it is preferable to use `model.zero_grad(set_to_none=True)` since it saves a large amount of
memory after stepping.
2. `torch.cuda.amp.autocast for mixed precision` is fully compatible with FSDP. However you will need
to set the `mixed_precision` arg to be True.
3. If combined with activation checkpointing, it is preferable to use FSDP(checkpoint_wrapper(module))
over checkpoint_wrapper(FSDP(module)). The latter will result in more communication and will be slower.
4. Results should be identical to DDP with pointwise Optimizers, e.g., Adam, AdamW, Adadelta, Adamax,
SGD, etc.. However, the sharding will result in slightly different results when using non-pointwise
Optimizers, e.g., Adagrad, Adafactor, LAMB, etc.
Performance tips for `fairscale.nn.FullyShardedDataParallel`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1. For best memory efficiency use auto_wrap to wrap each layer in your network with FSDP and
set `reshard_after_forward` to be True
2. For best training speed set `reshard_after_forward` to be False (wrapping each layer is not
required, but will improve speed further)
Pipeline Parallelism
=====================
Training large models can lead to out-of-memory when the size of the model is too large for a single GPU.
To train such a large model, layers can be pipelined across different GPU devices as described in GPipe.
The `fairscale.nn.Pipe` is an implementation of GPipe which has been adopted from torchgpipe. This API
has also been upstreamed to PyTorch in the 1.8 release with the experimental tag.
.. image:: ../_static/img/pipe.png
Gpipe first shards the model across different devices where each device hosts a shard of the model.
A shard can be a single layer or a series of layers. However Gpipe splits a mini-batch of data into
micro-batches and feeds it to the device hosting the first shard. The layers on each device process
the micro-batches and send the output to the following shard/device. In the meantime it is ready to
process the micro batch from the previous shard/device. By pipepling the input in this way, Gpipe is
able to reduce the idle time of devices.
Best practices for using `fairscale.nn.Pipe`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1. Choice of size of micro-batches can affect GPU utilization. A smaller microbatch can reduce latency of shards waiting for previous shard outputs but a large microbatch better utilizes GPUs.
2. Sharding the model can also impact GPU utilization where layers with heavier computation can slow down the shards downstream.
Getting Involved
=================
We welcome contributions from everyone! Please see the `CONTRIBUTING <https://github.com/facebookresearch/fairscale/blob/master/CONTRIBUTING.md>`_
guide on GitHub for more details on how you can contribute to FairScale.
User Workflow
==============
User workflow Diagram with explanation of various decision points
.. image:: _static/img/flowchart.png
\ No newline at end of file
...@@ -4,57 +4,63 @@ ...@@ -4,57 +4,63 @@
but it should at least contain the root `toctree` but it should at least contain the root `toctree`
directive. directive.
Welcome to FairScale's documentation! FairScale Documentation
===================================== =======================
*FairScale* is a PyTorch extension library for high performance and FairScale is a PyTorch extension library for high performance and large scale training.
large scale training for optimizing training on one or across multiple FairScale makes available the latest distributed training techniques in the form of composable
machines/nodes. This library extend basic pytorch capabilities while modules and easy to use APIs.
adding new experimental ones.
.. toctree::
:maxdepth: 1
:caption: Index
Components what_is_fairscale
---------- getting_started
blogs_and_press
* Parallelism: getting_involved
* `Pipeline parallelism <../../en/latest/api/nn/pipe.html>`_ integrations
* Sharded training:
* `Optimizer state sharding <../../en/latest/api/optim/oss.html>`_
* `Sharded grad scaler - automatic mixed precision <../../en/latest/api/optim/grad_scaler.html>`_
* `Sharded distributed data parallel <../../en/latest/api/nn/sharded_ddp.html>`_
* `Fully Sharded Data Parallel FSDP <../../en/latest/api/nn/fsdp.html>`_
* `FSDP Tips <../../en/latest/api/nn/fsdp_tips.html>`_
* Mixture-of-Experts: |
* `MOE <../../en/latest/api/nn/moe.html>`_ |
* Optimization at scale: .. toctree::
* `AdaScale SGD <../../en/latest/api/optim/adascale.html>`_ :maxdepth: 1
:caption: Installation
* GPU memory optimization: installation_instructions
* `Activation checkpointing wrapper <../../en/latest/api/nn/misc/checkpoint_activations.html>`_
|
|
* `Tutorials <../../en/latest/tutorials/index.html>`_ .. toctree::
:maxdepth: 1
:caption: Deep Dive
deep_dive/oss_sdp_fsdp
deep_dive/offload
deep_dive/adascale
deep_dive/pipeline_parallelism
deep_dive/activation_checkpointing
.. warning:: |
This library is under active development. |
Please be mindful and create an
`issue <https://github.com/facebookresearch/fairscale/issues>`_
if you have any trouble and/or suggestions.
.. toctree:: .. toctree::
:maxdepth: 5 :maxdepth: 1
:caption: Contents: :caption: Tutorials
:hidden:
tutorials/index tutorials/oss
api/index tutorials/activation_checkpointing
tutorials/offload_model
tutorials/adascale
tutorials/pipe
|
|
Reference .. toctree::
========= :maxdepth: 1
:caption: API Documentation
:ref:`genindex` | :ref:`modindex` | :ref:`search` api/index
Installing FairScale
====================
To install the pip package:
.. code-block:: bash
pip install fairscale
To install the master branch:
.. code-block:: bash
cd fairscale
pip install -r requirements.txt
pip install -e .
Integrations
============
FairScale has integrated with the following frameworks:
1. Fairseq
2. VISSL
3. PyTorch Lightening
4. Hugging Face
Efficient memory usage using Activation Checkpointing
=====================================================
Adaped from `torch.utils.checkpoint`, this is a friendlier wrapper for performing activation checkpointing.
Compared to the PyTorch version, this version wraps a `nn.Module` and allows for all subsequent calls to be
checkpointed.
.. code-block:: python
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
class CheckpointModel(nn.Module):
def __init__(self, **kwargs):
super().__init__()
torch.manual_seed(0) # make sure weights are deterministic.
self.ffn_module = nn.Sequential(
nn.Linear(32, 128),
nn.Dropout(p=0.5),
nn.Linear(128, 32),
)
self.ffn_module = checkpoint_wrapper(self.ffn_module, **kwargs)
self.last_linear = nn.Linear(32, 1)
def forward(self, input):
output = self.ffn_module(input)
return self.last_linear(output)
AdaScale SGD Tutorial Scale without modifying learning rate using Adascale
===================== ====================================================
`AdaScale <https://arxiv.org/pdf/2007.05105.pdf>`_ adaptively scales the learning rate when `AdaScale <https://arxiv.org/pdf/2007.05105.pdf>`_ adaptively scales the learning rate when
using larger batch sizes for data-parallel training. Let's suppose that your trainer looks using larger batch sizes for data-parallel training. Let's suppose that your trainer looks
......
Tutorials
=========
.. toctree::
:maxdepth: 1
pipe
oss
adascale
offload_model
\ No newline at end of file
Training with `OffloadModel` Scale your model on a single GPU using OffloadModel
============================ ====================================================
`fairscale.experimental.nn.offload.OffloadModel` API democratizes large scale distributed training by enabling `fairscale.experimental.nn.offload.OffloadModel` API democratizes large scale distributed training by enabling
users to train large models on limited GPU resources that would have traditionally resulted in OOM errors. users to train large models on limited GPU resources that would have traditionally resulted in OOM errors.
...@@ -21,6 +21,7 @@ Consider a training loop as described below: ...@@ -21,6 +21,7 @@ Consider a training loop as described below:
from fairscale.experimental.nn.offload import OffloadModel from fairscale.experimental.nn.offload import OffloadModel
num_inputs = 8 num_inputs = 8
num_outputs = 8 num_outputs = 8
num_hidden = 4 num_hidden = 4
......
Optimizer state sharding Optimizer, Gradient and Model Sharding
======================== =======================================
Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications in the case of OSS, but it is possible and makes OSS a drop in solution in your existing torch distributed code. Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications in the case of OSS,
but it is possible and makes OSS a drop in solution in your existing torch distributed code.
Let's suppose that your trainer looks like Let's suppose that your trainer looks like
.. code-block:: python .. code-block:: python
...@@ -44,8 +45,10 @@ Let's suppose that your trainer looks like ...@@ -44,8 +45,10 @@ Let's suppose that your trainer looks like
optimizer.step() optimizer.step()
Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows. Then sharding the optimizer state is merely a matter of wrapping your optimizer in `fairscale.optim.OSS`,
DDP can be used in place of ShardedDDP in the example below, but the memory savings will be reduced (the gradients are not as efficiently sharded) as follows.
DDP can be used in place of ShardedDDP in the example below, but the memory savings will be reduced
(the gradients are not as efficiently sharded).
.. code-block:: python .. code-block:: python
...@@ -71,7 +74,7 @@ DDP can be used in place of ShardedDDP in the example below, but the memory savi ...@@ -71,7 +74,7 @@ DDP can be used in place of ShardedDDP in the example below, but the memory savi
# optimizer specific arguments e.g. LR, momentum, etc... # optimizer specific arguments e.g. LR, momentum, etc...
base_optimizer_arguments = { "lr": 1e-4} base_optimizer_arguments = { "lr": 1e-4}
# ** NEW ** Wrap a base optimizer into OSS # Wrap a base optimizer into OSS
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
optimizer = OSS( optimizer = OSS(
params=model.parameters(), params=model.parameters(),
...@@ -94,8 +97,8 @@ DDP can be used in place of ShardedDDP in the example below, but the memory savi ...@@ -94,8 +97,8 @@ DDP can be used in place of ShardedDDP in the example below, but the memory savi
optimizer.step() optimizer.step()
The above `train` function can then be run via a `multiprocessing.spawn` call. Note that any launcher can be used, The above `train` function can then be run via a `multiprocessing.spawn` call. Note that any launcher
the only assumption being that each of the ranks lives in its own python process. can be used, the only assumption being that each of the ranks lives in its own python process.
.. code-block:: python .. code-block:: python
...@@ -108,20 +111,18 @@ the only assumption being that each of the ranks lives in its own python process ...@@ -108,20 +111,18 @@ the only assumption being that each of the ranks lives in its own python process
) )
to see it in action, you can test it with the following script `here <../../../examples/tutorial_oss.py>`_. Using PyTorch Automatic Mixed Precision is possible, and its actual usage will depend on whether OSS
is used with DDP or with ShardedDDP.
If OSS is used with DDP, then the normal PyTorch GradScaler can be used, nothing needs to be changed.
Using PyTorch Automatic Mixed Precision is possible, and its actual usage will depend on whether OSS is used with DDP or with ShardedDDP. If OSS is used with ShardedDDP (to
If OSS is used with DDP, then the normal PyTorch GradScaler can be used, nothing needs to be changed. If OSS is used with ShardedDDP (to get the gradient sharding), then a very similar flow can be used, but it requires a shard-aware GradScaler,
get the gradient sharding), then a very similar flow can be used, but it requires a shard-aware GradScaler, which is available in which is available in `fairscale.optim.grad_scaler`. In both cases Autocast can be used as is, and the
`fairscale.optim.grad_scaler`. In both cases Autocast can be used as is, and the loss will be scaled and handled in the same way. loss will be scaled and handled in the same way.
See [the original documentation] (https://pytorch.org/docs/stable/notes/amp_examples.html?highlight=automatic%20mixed%20precision) See [the original documentation] (https://pytorch.org/docs/stable/notes/amp_examples.html?highlight=automatic%20mixed%20precision)
for more information. for more information.
.. code-block:: python .. code-block:: python
from fairscale.optim.grad_scaler import ShardedGradScaler from fairscale.optim.grad_scaler import ShardedGradScaler
...@@ -153,3 +154,83 @@ for more information. ...@@ -153,3 +154,83 @@ for more information.
# Updates the scale for next iteration. # Updates the scale for next iteration.
scaler.update() scaler.update()
Parameters can be sharded using the FullyShardedDataParallel (FSDP) API. It involves wrapping your model similar to the
SDP API above.
.. code-block:: python
import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
def train(
rank: int,
world_size: int,
epochs: int):
# process group init
dist_init(rank, world_size)
# Problem statement
model = myAwesomeModel().to(rank)
dataloader = mySuperFastDataloader()
loss_ln = myVeryRelevantLoss()
# optimizer specific arguments e.g. LR, momentum, etc...
base_optimizer_arguments = { "lr": 1e-4}
# Wrap a base optimizer into OSS
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
# Wrap the model into FSDP, which will reduce parameters to the proper ranks
model = FSDP(model)
# Any relevant training loop. For example:
model.train()
for e in range(epochs):
for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank)
# Train
model.zero_grad()
outputs = model(data)
loss = loss_fn(outputs, target)
loss.backward()
optimizer.step()
Auto wrapping sub-modules with FSDP is a convenient way to improve training speed by overlapping
the allgather step across the forward passes of different submodules.
It also improves memory efficiency by freeing gathered parameters after each layer finishes executing.
For example:
.. code-block:: python
import torch
from fairscale.nn.wrap import auto_wrap, enable_wrap, wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import DummyProcessGroup
tfmr = torch.nn.Transformer(num_encoder_layers=2, num_decoder_layers=2)
group = DummyProcessGroup(rank=0, size=1)
fsdp_params = dict(mixed_precision=True, flatten_parameters=True)
with enable_wrap(wrapper_cls=FSDP, process_group=group, **fsdp_params):
# Wraps layer in FSDP by default if within context
l1 = wrap(torch.nn.Linear(5, 5))
assert isinstance(l1, FSDP)
assert l1.mixed_precision and l1.flatten_parameters
# Separately Wraps children modules with more than 1e8 params
tfmr_auto_wrapped = auto_wrap(tfmr, min_num_params=1e6)
assert isinstance(l2, nn.Transformer)
for l in l2.encoder.layers:
assert isinstance(l, FSDP)
assert l.mixed_precision and l.flatten_parameters
assert isinstance(l.linear1, FSDP)
assert isinstance(l.linear2, FSDP)
assert not isinstance(l.self_attn, FSDP) # self attention is not auto-wrapped
\ No newline at end of file
Pipeline Parallel Model sharding using Pipeline Parallel
================= ======================================
Let us start with a toy model that contains two linear layers. Let us start with a toy model that contains two linear layers.
......
What is FairScale?
====================
FairScale is a PyTorch extension library for high performance and large scale training.
This library extends basic PyTorch capabilities while adding new SOTA scaling techniques.
FairScale makes available the latest distributed training techniques in the form of composable
modules and easy to use APIs. These APIs are a fundamental part of a researcher's toolbox as
they attempt to scale models with limited resources.
.. image:: _static/img/global.png
:width: 400px
:height: 400px
:align: center
FairScale was designed with the following values in mind:
1. **Usability** - Users should be able to understand and use FairScale APIs with minimum cognitive overload.
2. **Modularity** - Users should be able to combine multiple FairScale APIs as part of their training loop seamlessly.
3. **Performance** - FairScale APIs provide the best performance in terms of scaling and efficiency.
.. image:: _static/img/ddp.png
ML training at scale traditionally means `data parallelism <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`_
which allows us to use multiple devices at the same
time to train a large batch size per step thereby achieving the goal accuracy in a shorter period of time
as compared to training on a single device. With recent advances in ML research, the size of ML models
has only increased over the years and data parallelism no longer serves all “scaling” purposes.
There are multiple axes across which you can scale training and FairScale provides the following broad
categories of solutions:
1. **Parallelism** → These techniques allow scaling of models by layer parallelism and tensor parallelism.
2. **Sharding Methods** → Memory and computation are usually trade-offs and in this category we attempt to achieve both low memory utilization and efficient computation by sharding model layers or parameters, optimizer state and gradients.
3. **Optimization** → This bucket deals with optimizing memory usage irrespective of the scale of the model, training without hyperparameter tuning and all other techniques that attempt to optimize training performance in some way.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment