This document describes how ``FSDP`` works, including subtle behaviors that can change performance significantly.
See :doc:`this page <fsdp>` for python docstrings.
Overview
---------
Recent work by `Microsoft <https://arxiv.org/abs/1910.02054>`__ and
`Google <https://arxiv.org/abs/2004.13336>`__ has shown that data
parallel training can be made significantly more efficient by sharding
the model parameters and optimizer state across data parallel workers.
These ideas are encapsulated in the new ``FullyShardedDataParallel``_
(FSDP) wrapper, which is a drop-in replacement for the PyTorch
``DistributedDataParallel`` (DDP) wrapper.
Compared to PyTorch ``DistributedDataParallel``:
* FSDP shards parameters (FP16 + FP32) and optimizer state across data parallel GPUs
* FSDP with ``reshard_after_forward=False`` has the same communication cost as PyTorch DDP and is similar to ZeRO-2
* FSDP with ``reshard_after_forward=True`` increases total communication by 50% and is similar to ZeRO-3:
* all-gather parameters at start of forward pass and start of backward pass
* reduce-scatter grads at end of the backward pass
* In practice, FSDP is faster than DDP because the optimizer step is sharded, and the extra communication can be overlapped with the forward pass.
* FSDP enables training 13B parameter models on 8 GPUs and 175B parameter models on 128 GPUs. When using the ``cpu_offload=True`` option, it's possible to train 1T parameter models on 256 GPUs.
General usage notes
--------------------
- For best memory efficiency use ``auto_wrap`` to wrap each layer in your network with ``FSDP`` and set ``reshard_after_forward=True``
- For best training speed set ``reshard_after_forward=False`` (wrapping each layer is not required, but will improve speed further)
- If you're using ``torch.cuda.amp.autocast`` for mixed precision, that's fully compatible with the FSDP wrapper, just set ``mixed_precision=True``
- If combining with `activation checkpointing <https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/misc/checkpoint_activations.py>`__,
prefer ``FSDP(checkpoint_wrapper(module))`` over ``checkpoint_wrapper(FSDP(module))``. The latter will result in more communication and will be slower.
- Results should be identical to DDP with pointwise Optimizers, e.g.,
Adam, AdamW, Adadelta, Adamax, SGD, etc.. However, the sharding will
result in slightly different results when using non-pointwise
Optimizers, e.g., Adagrad, Adafactor, LAMB, etc.
- In `fairseq <https://github.com/pytorch/fairseq>`_, FSDP is activated by the command line option ``--ddp-backend=fully_sharded``.
How it works
------------
In standard distributed data parallel (DDP) training every worker processes a separate batch and the gradients are
summed across workers using an `all-reduce operation <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allreduce>`__.
While DDP has become very popular, it wastes GPU memory because the model weights and optimizer states are replicated across all DDP workers.
The key insight to unlock full parameter sharding is that we can decompose the
Then, we can rearrange the reduce-scatter + all-gather so that each DDP worker only needs to store a single shard of parameters and optimizer state. The figure below illustrates standard DDP training (left) and fully sharded training (right):
To maximize memory efficiency we can discard the full weights after each
layer's forward pass, saving memory for subsequent layers. This can be
implemented by applying the FSDP wrapper to every layer in your network
(with ``reshard_after_forward=True``). In pseudo-code:
::
FSDP forward pass:
for layer_i in layers:
all-gather full weights for layer_i
forward pass for layer_i
discard full weights for layer_i
FSDP backward pass:
for layer_i in layers:
all-gather full weights for layer_i
backward pass for layer_i
discard full weights for layer_i
reduce-scatter gradients for layer_i
Saving and Loading
------------------
There are two ways to load and save FSDP instances,
- ``state_dict()`` returns a dictionary containing all parameters, which can be loaded with ``load_local_state_dict()``
- ``local_state_dict()`` returns a dictionary containing a shard's parameters, which can be loaded with ``load_state_dict()``
Mixed Precision
---------------
When ``mixed_precision=True``:
- Sharded parameters are downcast to ``fp16`` before ``forward``, promoted to ``fp32`` after forward.
- buffers are kept in ``fp16``, unless ``buffer_dtype=torch.float32`` is passed. Buffers are not sharded regardless of arguments.
- By default, gradients will be computed and reduced in FP16. If FP32 reductions are important, set ``fp32_reduce_scatter=True``
- If ``torch.amp.autocast`` is enabled it will override the output dtypes of some operations, like ``BatchNorm2D``
Auto-wrap
~~~~~~~~~
Auto wrapping sub-modules with ``FSDP`` is a convenient way to improve training speed by overlapping the all-gather step across the forward passes of different submodules.
It also improves memory efficiency by freeing gathered parameters after each layer finishes executing.
.. code-block:: python
import torch
from fairscale.nn.wrap import auto_wrap, enable_wrap, wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import DummyProcessGroup