Unverified Commit 31d600cc authored by Rahul Iyer's avatar Rahul Iyer Committed by GitHub
Browse files

Fix pre-commit hook failures (#756)

Pre-commit hook fails when run on all files for three reasons:
(see trace below)

1. Trailing whitespace on multiple files
2. mypy fails to load numpy and then subsequently fails to load
LazyModule from pipe.py
3. isort sees issues with known_third_party packages

```
> pre-commit run --all-files

Trim Trailing Whitespace.................................................Failed
- hook id: trailing-whitespace
- exit code: 1
- files were modified by this hook

Fixing docs/source/conf.py
Fixing fairscale/experimental/nn/auto_shard.py
Fixing docs/source/deep_dive/activation_checkpointing.rst
Fixing docs/source/tutorials/pipe.rst
Fixing docs/source/installation_instructions.rst
Fixing docs/source/deep_dive/pipeline_parallelism.rst
Fixing docs/source/tutorials/activation_checkpointing.rst
Fixing docs/source/tutorials/offload_model.rst
Fixing docs/source/deep_dive/oss_sdp_fsdp.rst
Fixing docs/source/what_is_fairscale.rst
Fixing CHANGELOG.md
Fixing fairscale/experimental/nn/offload.py
Fixing docs/source/index.rst
Fixing docs/source/deep_dive/adascale.rst
Fixing README.md
Fixing docs/source/tutorials/oss.rst
Fixing docs/source/deep_dive/offload.rst

Check python ast.........................................................Passed
Check for merge conflicts................................................Passed
Don't commit to branch...................................................Passed
Check for added large files..............................................Passed
Fix End of Files.........................................................Failed
- hook id: end-of-file-fixer
- exit code: 1
- files were modified by this hook

Fixing requirements.txt
Fixing docs/source/getting_started.rst
Fixing docs/source/installation_instructions.rst
Fixing codecov.yml
Fixing docs/source/deep_dive/adascale.rst
Fixing docs/source/tutorials/oss.rst
Fixing docs/source/deep_dive/offload.rst

black....................................................................Passed
flake8...................................................................Passed
seed isort known_third_party.............................................Failed
- hook id: seed-isort-config
- exit code: 1
- files were modified by this hook
isort....................................................................Passed
mypy.....................................................................Failed
- hook id: mypy
- exit code: 2

setup.cfg:45: error: Error importing plugin 'numpy.typing.mypy_plugin': No module named 'numpy'
Found 1 error in 1 file (checked 197 source files)
```
parent 2dc2617c
...@@ -43,3 +43,4 @@ repos: ...@@ -43,3 +43,4 @@ repos:
rev: 'v0.790' rev: 'v0.790'
hooks: hooks:
- id: mypy - id: mypy
additional_dependencies: [numpy]
...@@ -6,24 +6,24 @@ ...@@ -6,24 +6,24 @@
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
## Description ## Description
FairScale is a PyTorch extension library for high performance and large scale training. FairScale is a PyTorch extension library for high performance and large scale training.
This library extends basic PyTorch capabilities while adding new SOTA scaling techniques. This library extends basic PyTorch capabilities while adding new SOTA scaling techniques.
FairScale makes available the latest distributed training techniques in the form of composable FairScale makes available the latest distributed training techniques in the form of composable
modules and easy to use APIs. These APIs are a fundamental part of a researcher's toolbox as modules and easy to use APIs. These APIs are a fundamental part of a researcher's toolbox as
they attempt to scale models with limited resources. they attempt to scale models with limited resources.
FairScale was designed with the following values in mind: FairScale was designed with the following values in mind:
* **Usability** - Users should be able to understand and use FairScale APIs with minimum cognitive overload. * **Usability** - Users should be able to understand and use FairScale APIs with minimum cognitive overload.
* **Modularity** - Users should be able to combine multiple FairScale APIs as part of their training loop seamlessly. * **Modularity** - Users should be able to combine multiple FairScale APIs as part of their training loop seamlessly.
* **Performance** - FairScale APIs provide the best performance in terms of scaling and efficiency. * **Performance** - FairScale APIs provide the best performance in terms of scaling and efficiency.
## Installation ## Installation
To install FairScale, please see the following [instructions](https://github.com/facebookresearch/fairscale/blob/master/docs/source/installation_instructions.rst). You should be able to install a pip package or To install FairScale, please see the following [instructions](https://github.com/facebookresearch/fairscale/blob/master/docs/source/installation_instructions.rst). You should be able to install a pip package or
build directly from source. build directly from source.
## Getting Started ## Getting Started
......
...@@ -16,4 +16,4 @@ parsers: ...@@ -16,4 +16,4 @@ parsers:
comment: comment:
layout: "reach,diff,flags,tree" layout: "reach,diff,flags,tree"
behavior: default behavior: default
require_changes: no require_changes: no
\ No newline at end of file
...@@ -129,7 +129,7 @@ def setup(app): ...@@ -129,7 +129,7 @@ def setup(app):
app.add_config_value( app.add_config_value(
"recommonmark_config", "recommonmark_config",
{ {
"url_resolver": lambda url: github_doc_root + url, "url_resolver": lambda url: github_doc_root + url,
"auto_toc_tree_section": "Contents", "auto_toc_tree_section": "Contents",
"enable_math": True, "enable_math": True,
"enable_inline_math": True, "enable_inline_math": True,
......
Enhanced Activation Checkpointing Enhanced Activation Checkpointing
================================= =================================
Activation checkpointing is a technique used to reduce GPU memory usage during training. This is Activation checkpointing is a technique used to reduce GPU memory usage during training. This is
done by avoiding the need to store intermediate activation tensors during the forward pass. Instead, done by avoiding the need to store intermediate activation tensors during the forward pass. Instead,
the forward pass is recomputed by keeping track of the original input during the backward pass. the forward pass is recomputed by keeping track of the original input during the backward pass.
There is a slight increase in computation cost (about 33%) but this reduces the need to store There is a slight increase in computation cost (about 33%) but this reduces the need to store
large activation tensors which allows us to increase the batch size and thereby the net throughput large activation tensors which allows us to increase the batch size and thereby the net throughput
of the model. of the model.
Activation checkpointing is implemented by overriding `torch.autograd.Function`. In the `forward` Activation checkpointing is implemented by overriding `torch.autograd.Function`. In the `forward`
function which handles the forward pass of the module, using `no_grad`, we can prevent the creation function which handles the forward pass of the module, using `no_grad`, we can prevent the creation
of the forward graph and materialization of intermediate activation tensors for a long period of of the forward graph and materialization of intermediate activation tensors for a long period of
time (i.e till the backward pass). Instead, during the backward pass, the forward pass is executed time (i.e till the backward pass). Instead, during the backward pass, the forward pass is executed
again followed by the backward pass. The inputs to the forward pass are saved using a context object again followed by the backward pass. The inputs to the forward pass are saved using a context object
that is then accessed in the backward pass to retrieve the original inputs. We also save the that is then accessed in the backward pass to retrieve the original inputs. We also save the
Random Number Generator(RNG) state for the forward and backward passes as required for Dropout layers. Random Number Generator(RNG) state for the forward and backward passes as required for Dropout layers.
The above functionality is already implemented as part of the `torch.utils.checkpoint.checkpoint_wrapper` The above functionality is already implemented as part of the `torch.utils.checkpoint.checkpoint_wrapper`
API whereby different modules in the forward pass can be wrapped. The wrapper in FairScale offers API whereby different modules in the forward pass can be wrapped. The wrapper in FairScale offers
functionality beyond that provided by the PyTorch API specifically you can use functionality beyond that provided by the PyTorch API specifically you can use
`fairscale.nn.checkpoint.checkpoint_wrapper` to wrap a `nn.Module`, handle kwargs in the forward `fairscale.nn.checkpoint.checkpoint_wrapper` to wrap a `nn.Module`, handle kwargs in the forward
pass, offload intermediate activations to the CPU and handle non-tensor outputs returned from the pass, offload intermediate activations to the CPU and handle non-tensor outputs returned from the
forward function. forward function.
Best practices for `fairscale.nn.checkpoint.checkpoint_wrapper` Best practices for `fairscale.nn.checkpoint.checkpoint_wrapper`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1. Memory savings depends entirely on the model and the segmentation of checkpoint wrapping. 1. Memory savings depends entirely on the model and the segmentation of checkpoint wrapping.
Each backprop consists of several mini-forward and backprop passes. The gain is entirely dependent Each backprop consists of several mini-forward and backprop passes. The gain is entirely dependent
on the memory footprint of the layer’s activations. on the memory footprint of the layer’s activations.
2. When using BatchNormalization you may need to freeze the calculation of statistics since we run 2. When using BatchNormalization you may need to freeze the calculation of statistics since we run
the forward pass twice. the forward pass twice.
3. Ensure that the input tensor’s `requires_grad` field is set to True. In order to trigger the 3. Ensure that the input tensor’s `requires_grad` field is set to True. In order to trigger the
backward function, the output needs to have this field set. By setting it on the input tensor we backward function, the output needs to have this field set. By setting it on the input tensor we
ensure that this is propagated to the output and the `backward` function is triggered. ensure that this is propagated to the output and the `backward` function is triggered.
Adascale Adascale
========= =========
`Adascale <https://arxiv.org/abs/2007.05105>`_ is a technique used to enable large batch training that allows you to increase batch size `Adascale <https://arxiv.org/abs/2007.05105>`_ is a technique used to enable large batch training that allows you to increase batch size
without loss of accuracy. When increasing batch size with the number of devices, the learning rate without loss of accuracy. When increasing batch size with the number of devices, the learning rate
is typically tuned based on the batch size. With Adascale, users no longer need to modify the is typically tuned based on the batch size. With Adascale, users no longer need to modify the
learning rate schedule and still achieve the desired accuracy. Adascale has been implemented as learning rate schedule and still achieve the desired accuracy. Adascale has been implemented as
the Adascale API in FairScale. This technique typically works well for SGD (with and without momentum) the Adascale API in FairScale. This technique typically works well for SGD (with and without momentum)
The assumption is that you already have a good learning rate schedule that works well for small The assumption is that you already have a good learning rate schedule that works well for small
batch sizes. (AdaScale has not been validated to work effectively with Adam, further research in batch sizes. (AdaScale has not been validated to work effectively with Adam, further research in
that direction is needed.) that direction is needed.)
AdaScale adapts the learning rate schedule and determines when to stop based on comparing statistics AdaScale adapts the learning rate schedule and determines when to stop based on comparing statistics
of large-batch gradients with those of small-batch gradients. Small batch gradients are gradients that of large-batch gradients with those of small-batch gradients. Small batch gradients are gradients that
have been computed on each GPU and large batch gradients are the average of gradients computed on N have been computed on each GPU and large batch gradients are the average of gradients computed on N
such GPUs. Adascale uses the concept of gain ratio which is intuitively a measure of how much the such GPUs. Adascale uses the concept of gain ratio which is intuitively a measure of how much the
variance has reduced by averaging N small batch gradients. It is a quantity between 1 and N. variance has reduced by averaging N small batch gradients. It is a quantity between 1 and N.
In practice, the implementation tracks estimates of the gradient variance and norm-squared which In practice, the implementation tracks estimates of the gradient variance and norm-squared which
are smoothed using an exponentially-weighted moving average. If T is the number of steps used to are smoothed using an exponentially-weighted moving average. If T is the number of steps used to
train the original small batch size before scaling, Adascale stops training once the accumulated train the original small batch size before scaling, Adascale stops training once the accumulated
gain ratio is greater than T. As you use more and more GPUs in the training the total steps needed gain ratio is greater than T. As you use more and more GPUs in the training the total steps needed
to train decreases, but due to the value of gain ratio between [1, N], the total steps does not to train decreases, but due to the value of gain ratio between [1, N], the total steps does not
linearly decrease as you increase the GPUs. Additional training steps are taken to maintain the linearly decrease as you increase the GPUs. Additional training steps are taken to maintain the
model accuracy, when compared with original_total_step/N (i.e. linear scaling). In other words, model accuracy, when compared with original_total_step/N (i.e. linear scaling). In other words,
whenever the gain ratio is less than N, we could not take as large a step as we may have hoped for, whenever the gain ratio is less than N, we could not take as large a step as we may have hoped for,
and so the total number of iterations ends up being larger than T / N. and so the total number of iterations ends up being larger than T / N.
The current implementation in FairScale supports gradient accumulation training, can be used The current implementation in FairScale supports gradient accumulation training, can be used
with Optimizer State Sharding (OSS), and works with PyTorch LR scheduler classes. with Optimizer State Sharding (OSS), and works with PyTorch LR scheduler classes.
The training process is as follows: The training process is as follows:
...@@ -42,4 +42,3 @@ The training process is as follows: ...@@ -42,4 +42,3 @@ The training process is as follows:
Best practices for `fairscale.optim.AdaScale` Best practices for `fairscale.optim.AdaScale`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Adascale only works for the SGD optimizer (with and without momentum) Adascale only works for the SGD optimizer (with and without momentum)
OffloadModel OffloadModel
============= =============
Heavily inspired by the `Layer-to-Layer <https://arxiv.org/abs/2002.05645>`_ algorithm and Heavily inspired by the `Layer-to-Layer <https://arxiv.org/abs/2002.05645>`_ algorithm and
`Zero-Offload <https://arxiv.org/abs/2101.06840>`_, OffloadModel uses the CPU to store `Zero-Offload <https://arxiv.org/abs/2101.06840>`_, OffloadModel uses the CPU to store
the entire model, optimizer state and gradients. OffloadModel then brings in a layer (or a number of the entire model, optimizer state and gradients. OffloadModel then brings in a layer (or a number of
layers) onto the GPU for training at a time during the forward and backward pass. The intermediate layers) onto the GPU for training at a time during the forward and backward pass. The intermediate
activations for the layer boundaries are also stored on the CPU and copied to the GPU as needed for activations for the layer boundaries are also stored on the CPU and copied to the GPU as needed for
the backward pass. Once the backward pass is completed all the parameters are updated with the the backward pass. Once the backward pass is completed all the parameters are updated with the
gradients present on the CPU. gradients present on the CPU.
.. image:: ../_static/img/offload.png .. image:: ../_static/img/offload.png
...@@ -15,25 +15,25 @@ gradients present on the CPU. ...@@ -15,25 +15,25 @@ gradients present on the CPU.
Offload uses the following techniques to enable large model training: Offload uses the following techniques to enable large model training:
1. The model is assumed to be nn.Sequential and sharded (almost) equally based on the number of 1. The model is assumed to be nn.Sequential and sharded (almost) equally based on the number of
parameters into a list of nn.Modules. Each nn.Module now contains a fraction of the whole model parameters into a list of nn.Modules. Each nn.Module now contains a fraction of the whole model
which we shall refer to as model shards. which we shall refer to as model shards.
2. At each iteration, each of the model shards are copied from the CPU -> GPU, FW pass is computed 2. At each iteration, each of the model shards are copied from the CPU -> GPU, FW pass is computed
using the minibatch of data and the model shard is copied back from GPU -> CPU. In the BW pass, the using the minibatch of data and the model shard is copied back from GPU -> CPU. In the BW pass, the
same process is repeated. same process is repeated.
3. The optimizer remains on the CPU and gradients and parameters are all moved onto the CPU before 3. The optimizer remains on the CPU and gradients and parameters are all moved onto the CPU before
running optimizer.step. This ensures that the CPU is responsible for updating the parameters and running optimizer.step. This ensures that the CPU is responsible for updating the parameters and
holding onto the optimizer state. holding onto the optimizer state.
4. If activation checkpointing is enabled, we use torch.autograd.Function to disable graph construction 4. If activation checkpointing is enabled, we use torch.autograd.Function to disable graph construction
in the FW pass and copy intermediate activations from GPU -> CPU after the FW pass of a given shard is in the FW pass and copy intermediate activations from GPU -> CPU after the FW pass of a given shard is
complete. The reverse copy is carried out in the BW pass. complete. The reverse copy is carried out in the BW pass.
5. Microbatches are used to enable larger throughput and offset the cost of moving model parameters 5. Microbatches are used to enable larger throughput and offset the cost of moving model parameters
and activations from CPU <-> GPU. Micro-batches allow you to specify large mini-batches which are and activations from CPU <-> GPU. Micro-batches allow you to specify large mini-batches which are
broken down into micro-batches and fed to the model shards at each iteration. In short it is a way broken down into micro-batches and fed to the model shards at each iteration. In short it is a way
to allow more computation at a given time on a model shard to offset the cost of copying from CPU <-> GPU. to allow more computation at a given time on a model shard to offset the cost of copying from CPU <-> GPU.
Best practices for using `fairscale.experimental.nn.OffloadModel` Best practices for using `fairscale.experimental.nn.OffloadModel`
...@@ -42,4 +42,3 @@ Best practices for using `fairscale.experimental.nn.OffloadModel` ...@@ -42,4 +42,3 @@ Best practices for using `fairscale.experimental.nn.OffloadModel`
1. Using OffloadModel to train large models can result in loss of throughput which can be overcome by using activation checkpointing and microbatches. 1. Using OffloadModel to train large models can result in loss of throughput which can be overcome by using activation checkpointing and microbatches.
2. OffloadModel currently only works for `nn.Sequential` models. 2. OffloadModel currently only works for `nn.Sequential` models.
Efficient Memory management Efficient Memory management
============================ ============================
FairScale provides implementations inspired by the `ZeRO <https://arxiv.org/pdf/1910.02054.pdf>`_ class of algorithms in the form of modular FairScale provides implementations inspired by the `ZeRO <https://arxiv.org/pdf/1910.02054.pdf>`_ class of algorithms in the form of modular
APIs that you can plug into your model training. Zero Redundancy Optimizer is a class of algorithms APIs that you can plug into your model training. Zero Redundancy Optimizer is a class of algorithms
that aim to tackle the tradeoff between using Data Parallel training and Model Parallel training. that aim to tackle the tradeoff between using Data Parallel training and Model Parallel training.
When using Data Parallel training, you tradeoff memory for computation/communication efficiency. When using Data Parallel training, you tradeoff memory for computation/communication efficiency.
On the other hand, when using Model Parallel training, you tradeoff computation/communication On the other hand, when using Model Parallel training, you tradeoff computation/communication
efficiency for memory. ZeRO attempts to solve this problem. Model training generally involves memory efficiency for memory. ZeRO attempts to solve this problem. Model training generally involves memory
footprints that falls into two categories: footprints that falls into two categories:
1. Model states - optimizer states, gradients, parameters 1. Model states - optimizer states, gradients, parameters
2. Residual states - activations, temp buffers, fragmented memory 2. Residual states - activations, temp buffers, fragmented memory
To reduce redundancy in model states, three different algorithms were proposed. These have been To reduce redundancy in model states, three different algorithms were proposed. These have been
implemented in FairScale as Optimizer State Sharding (OSS), Sharded Data Parallel (SDP) and finally implemented in FairScale as Optimizer State Sharding (OSS), Sharded Data Parallel (SDP) and finally
Fully Sharded Data Parallel (FSDP). Let’s dive deeper into the actual mechanics of each of these Fully Sharded Data Parallel (FSDP). Let’s dive deeper into the actual mechanics of each of these
algorithms and understand why they provide the memory savings that they do. algorithms and understand why they provide the memory savings that they do.
Optimizer State Sharding (OSS) Optimizer State Sharding (OSS)
------------------------------ ------------------------------
FairScale has implemented memory optimization related to optimizer memory (inspired by `ZeRO-1 <https://arxiv.org/pdf/1910.02054.pdf>`_) footprint FairScale has implemented memory optimization related to optimizer memory (inspired by `ZeRO-1 <https://arxiv.org/pdf/1910.02054.pdf>`_) footprint
using `fairscale.optim.OSS` API. Optimizers such as Adam usually require maintaining momentum, variance, using `fairscale.optim.OSS` API. Optimizers such as Adam usually require maintaining momentum, variance,
parameters and gradients all in FP32 precision even though training can be carried out with parameters parameters and gradients all in FP32 precision even though training can be carried out with parameters
and gradients in FP16 precision. When each of the ranks update the full model, this means that a sizable and gradients in FP16 precision. When each of the ranks update the full model, this means that a sizable
part of the memory is occupied by redundant representations of the optimizer state. part of the memory is occupied by redundant representations of the optimizer state.
To overcome this redundancy, optimizer state sharding entails partitioning the model optimization step in To overcome this redundancy, optimizer state sharding entails partitioning the model optimization step in
between the different ranks, so that each of them is only in charge of updating a unique shard of the between the different ranks, so that each of them is only in charge of updating a unique shard of the
model. This in turn makes sure that the optimizer state is a lot smaller on each rank, and that it contains model. This in turn makes sure that the optimizer state is a lot smaller on each rank, and that it contains
no redundant information across ranks. no redundant information across ranks.
.. image:: ../_static/img/oss.png .. image:: ../_static/img/oss.png
The training process can be modified from that carried out by DDP as follows: The training process can be modified from that carried out by DDP as follows:
1. The wrapped optimizer shards the optimizer state in a greedy fashion based on the parameter size but not 1. The wrapped optimizer shards the optimizer state in a greedy fashion based on the parameter size but not
the order in which it is used. This is to ensure that each rank has almost the same optimizer memory the order in which it is used. This is to ensure that each rank has almost the same optimizer memory
footprint. footprint.
2. The training process is similar to that used by PyTorch’s Distributed Data Parallel (DDP). The forward 2. The training process is similar to that used by PyTorch’s Distributed Data Parallel (DDP). The forward
pass completes on each of the ranks followed by the backward pass. During the backward pass, gradients pass completes on each of the ranks followed by the backward pass. During the backward pass, gradients
are synchronized using allreduce. are synchronized using allreduce.
3. Each rank updates the parameters for the shard of optimizer state that it is responsible for and then 3. Each rank updates the parameters for the shard of optimizer state that it is responsible for and then
discards the rest. discards the rest.
4. After update, a broadcast or allgather follows to ensure all ranks receive the latest updated parameter 4. After update, a broadcast or allgather follows to ensure all ranks receive the latest updated parameter
values. values.
OSS is very useful when you are using an optimizer such as Adam that has additional state. The wrapping OSS is very useful when you are using an optimizer such as Adam that has additional state. The wrapping
of the optimizer is a one-line non intrusive change that provides memory savings. of the optimizer is a one-line non intrusive change that provides memory savings.
If you are using SGD or any optimizer with a limited memory footprint, it is likely that you will see a If you are using SGD or any optimizer with a limited memory footprint, it is likely that you will see a
slowdown when using multiple nodes, due to the additional communication in step 4. There is also some slowdown when using multiple nodes, due to the additional communication in step 4. There is also some
wasteful memory used to store gradients during allreduce in step 2 that is then discarded, although this wasteful memory used to store gradients during allreduce in step 2 that is then discarded, although this
also happens with normal PyTorch (nothing extraneous here). also happens with normal PyTorch (nothing extraneous here).
Best practices for using `fairscale.optim.oss` Best practices for using `fairscale.optim.oss`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1. OSS exposes a broadcast_fp16 flag that you should probably use in multi-node jobs, unless this leads to 1. OSS exposes a broadcast_fp16 flag that you should probably use in multi-node jobs, unless this leads to
accuracy issues (which is very unlikely). This can be used with or without Torch AMP. This is usually not accuracy issues (which is very unlikely). This can be used with or without Torch AMP. This is usually not
needed in a single node experiment. needed in a single node experiment.
2. If your model is extremely unbalanced in terms of size (one giant tensor for instance), then this method 2. If your model is extremely unbalanced in terms of size (one giant tensor for instance), then this method
will not be very helpful, and tensor sharding options such as `fairscale.nn.FullyShardedDataParallel` will not be very helpful, and tensor sharding options such as `fairscale.nn.FullyShardedDataParallel`
would be preferable. would be preferable.
3. OSS should be a drop in solution in a DDP context, and stays compatible with most of the DDP features 3. OSS should be a drop in solution in a DDP context, and stays compatible with most of the DDP features
such as the `fp16 gradient compression hook <https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook>`_, such as the `fp16 gradient compression hook <https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook>`_,
gradient accumulation and PyTorch AMP. gradient accumulation and PyTorch AMP.
Performance tips for `fairscale.optim.oss` Performance tips for `fairscale.optim.oss`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1. On a single node, OSS should be always faster than vanilla PyTorch, memory savings will vary depending 1. On a single node, OSS should be always faster than vanilla PyTorch, memory savings will vary depending
on the optimizer being used on the optimizer being used
2. When using multiple nodes, OSS can alternatively be faster or slower than vanilla PyTorch, depending 2. When using multiple nodes, OSS can alternatively be faster or slower than vanilla PyTorch, depending
on the optimizer being used, and optional flags (E.g broadcast_fp16, gradient compression, gradient on the optimizer being used, and optional flags (E.g broadcast_fp16, gradient compression, gradient
accumulation as mentioned above.) accumulation as mentioned above.)
3. If applicable (if your experiment can do with a bigger batch size), it’s usually beneficial to reinvest 3. If applicable (if your experiment can do with a bigger batch size), it’s usually beneficial to reinvest
the saved memory in a larger batch size and reduce the number of ranks involved, or to use gradient the saved memory in a larger batch size and reduce the number of ranks involved, or to use gradient
accumulation since this diminishes the communication cost. accumulation since this diminishes the communication cost.
Optimizer + Gradient State Sharding Optimizer + Gradient State Sharding
----------------------------------- -----------------------------------
To overcome redundant gradient memory and to enable further memory savings, gradient sharding or To overcome redundant gradient memory and to enable further memory savings, gradient sharding or
`ZeRO-2 <https://arxiv.org/pdf/1910.02054.pdf>`_ was proposed. This has been implemented by the Sharded Data Parallel(SDP) API in FairScale. `ZeRO-2 <https://arxiv.org/pdf/1910.02054.pdf>`_ was proposed. This has been implemented by the Sharded Data Parallel(SDP) API in FairScale.
While OSS solved the redundancy problem in optimizers, the above data parallel training steps revealed While OSS solved the redundancy problem in optimizers, the above data parallel training steps revealed
a duplication of computation of gradient aggregation as well as additional memory being used for gradients a duplication of computation of gradient aggregation as well as additional memory being used for gradients
are discarded. are discarded.
To enable gradient sharding, each rank is assigned a set of parameters for which they are responsible To enable gradient sharding, each rank is assigned a set of parameters for which they are responsible
for managing optimizer state as well as gradient aggregation. By assigning a model shard to a given for managing optimizer state as well as gradient aggregation. By assigning a model shard to a given
rank we ensure that gradients are reduced to specific ranks that are in turn responsible for the update. rank we ensure that gradients are reduced to specific ranks that are in turn responsible for the update.
This reduces communication as well as memory usage. This reduces communication as well as memory usage.
.. image:: ../_static/img/sdp.png .. image:: ../_static/img/sdp.png
...@@ -116,45 +116,45 @@ The training process is as follows: ...@@ -116,45 +116,45 @@ The training process is as follows:
4. During the backward pass, gradients are reduced to the rank that they are assigned to as part of the sharding process in 1. Instead of an allreduce op, a reduce op is used which reduces the communication overhead. 4. During the backward pass, gradients are reduced to the rank that they are assigned to as part of the sharding process in 1. Instead of an allreduce op, a reduce op is used which reduces the communication overhead.
5. Each rank updates the parameters that they are responsible for. 5. Each rank updates the parameters that they are responsible for.
6. After the update, a broadcast or allgather follows to ensure all ranks receive the latest updated parameter values. 6. After the update, a broadcast or allgather follows to ensure all ranks receive the latest updated parameter values.
Both the OSS and SDP APIs allow you to reduce the memory used for gradients and optimizer states. Both the OSS and SDP APIs allow you to reduce the memory used for gradients and optimizer states.
Additional communication costs can be present in slow interconnects but are useful to try as first steps Additional communication costs can be present in slow interconnects but are useful to try as first steps
when running into Out Of Memory (OOM) issues. when running into Out Of Memory (OOM) issues.
Best practices for `fairscale.nn.ShardedDataParallel` Best practices for `fairscale.nn.ShardedDataParallel`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1. If using multiple nodes, make sure that SDP is using reduce buffers by specifying the 1. If using multiple nodes, make sure that SDP is using reduce buffers by specifying the
`reduce_buffer_size` arg. Changing their size can be an optimization target, the best configuration `reduce_buffer_size` arg. Changing their size can be an optimization target, the best configuration
could depend on the interconnect. could depend on the interconnect.
2. If on a single node, it’s usually best not to use `reduce_buffer_size` since there is a latency 2. If on a single node, it’s usually best not to use `reduce_buffer_size` since there is a latency
cost associated with it but no memory gain. Setting this value to 0 means that this feature is not cost associated with it but no memory gain. Setting this value to 0 means that this feature is not
used and this is the recommended single node setting. used and this is the recommended single node setting.
3. If applicable (if your experiment can do with a bigger batch size), it’s usually beneficial to 3. If applicable (if your experiment can do with a bigger batch size), it’s usually beneficial to
reinvest the saved memory in a larger batch size and reduce the number of ranks involved, or to use reinvest the saved memory in a larger batch size and reduce the number of ranks involved, or to use
gradient accumulation since this diminishes the communication cost. gradient accumulation since this diminishes the communication cost.
Optimizer + Gradient + Horizontal Model Sharding Optimizer + Gradient + Horizontal Model Sharding
------------------------------------------------ ------------------------------------------------
To further optimize training and achieve greater memory savings, we need to enable parameter sharding. To further optimize training and achieve greater memory savings, we need to enable parameter sharding.
With parameter sharding similar to gradient and optimizer states, data parallel ranks are responsible With parameter sharding similar to gradient and optimizer states, data parallel ranks are responsible
for a shard of the model parameters. FairScale implements parameter sharding by way of the Fully Sharded for a shard of the model parameters. FairScale implements parameter sharding by way of the Fully Sharded
Data Parallel (FSDP) API which is heavily inspired by `ZeRO-3 <https://arxiv.org/pdf/1910.02054.pdf>`_. Parameter sharding is possible because of Data Parallel (FSDP) API which is heavily inspired by `ZeRO-3 <https://arxiv.org/pdf/1910.02054.pdf>`_. Parameter sharding is possible because of
two key insights: two key insights:
1. The allreduce operation can be broken up into reduce and allgather similar to the previous sharding 1. The allreduce operation can be broken up into reduce and allgather similar to the previous sharding
technologies (optimizer state and gradient). technologies (optimizer state and gradient).
2. Individual layers can be wrapped with the FSDP API that allows us to bring in all the parameters 2. Individual layers can be wrapped with the FSDP API that allows us to bring in all the parameters
required for a single layer onto a given GPU at a given instance, compute the forward pass and then required for a single layer onto a given GPU at a given instance, compute the forward pass and then
discard the parameters not owned by that rank. Please see the tutorial section for how you can use discard the parameters not owned by that rank. Please see the tutorial section for how you can use
autowrap to enable wrapping individual layers of your model. autowrap to enable wrapping individual layers of your model.
The training process is as follows: The training process is as follows:
...@@ -174,18 +174,18 @@ The training process is as follows: ...@@ -174,18 +174,18 @@ The training process is as follows:
6. Let each rank update the parameters that have been assigned to it using the aggregated gradients. 6. Let each rank update the parameters that have been assigned to it using the aggregated gradients.
With FSDP there are small changes one needs to make when using APIs for checkpointing and saving optimizer With FSDP there are small changes one needs to make when using APIs for checkpointing and saving optimizer
state. Given the sharded nature of optimizer state and parameters, any API that aims to save the model state. Given the sharded nature of optimizer state and parameters, any API that aims to save the model
state for training or inference needs to account for saving weights from all workers. FSDP implements the state for training or inference needs to account for saving weights from all workers. FSDP implements the
required plumbing to save weights from all workers, save weights on individual workers and save optimizer required plumbing to save weights from all workers, save weights on individual workers and save optimizer
state from all workers. state from all workers.
FSDP also supports mixed precision training where both the computation and communication are carried out FSDP also supports mixed precision training where both the computation and communication are carried out
in FP16 precision. If you want to reduce operations to be carried out in FP32 which is the default in FP16 precision. If you want to reduce operations to be carried out in FP32 which is the default
behavior of DDP, then you must set `fp32_reduce_scatter=True`. behavior of DDP, then you must set `fp32_reduce_scatter=True`.
To enable further memory savings, FSDP supports offloading parameters and gradients that are currently To enable further memory savings, FSDP supports offloading parameters and gradients that are currently
not being used onto the CPU. This can be enabled by setting `move_params_to_cpu` and `move_grads_to_cpu` not being used onto the CPU. This can be enabled by setting `move_params_to_cpu` and `move_grads_to_cpu`
to be equal to True. to be equal to True.
Best practices for `fairscale.nn.FullyShardedDataParallel` Best practices for `fairscale.nn.FullyShardedDataParallel`
...@@ -194,21 +194,21 @@ Best practices for `fairscale.nn.FullyShardedDataParallel` ...@@ -194,21 +194,21 @@ Best practices for `fairscale.nn.FullyShardedDataParallel`
1. For FSDP, it is preferable to use `model.zero_grad(set_to_none=True)` since it saves a large amount of 1. For FSDP, it is preferable to use `model.zero_grad(set_to_none=True)` since it saves a large amount of
memory after stepping. memory after stepping.
2. `torch.cuda.amp.autocast for mixed precision` is fully compatible with FSDP. However you will need 2. `torch.cuda.amp.autocast for mixed precision` is fully compatible with FSDP. However you will need
to set the `mixed_precision` arg to be True. to set the `mixed_precision` arg to be True.
3. If combined with activation checkpointing, it is preferable to use FSDP(checkpoint_wrapper(module)) 3. If combined with activation checkpointing, it is preferable to use FSDP(checkpoint_wrapper(module))
over checkpoint_wrapper(FSDP(module)). The latter will result in more communication and will be slower. over checkpoint_wrapper(FSDP(module)). The latter will result in more communication and will be slower.
4. Results should be identical to DDP with pointwise Optimizers, e.g., Adam, AdamW, Adadelta, Adamax, 4. Results should be identical to DDP with pointwise Optimizers, e.g., Adam, AdamW, Adadelta, Adamax,
SGD, etc.. However, the sharding will result in slightly different results when using non-pointwise SGD, etc.. However, the sharding will result in slightly different results when using non-pointwise
Optimizers, e.g., Adagrad, Adafactor, LAMB, etc. Optimizers, e.g., Adagrad, Adafactor, LAMB, etc.
Performance tips for `fairscale.nn.FullyShardedDataParallel` Performance tips for `fairscale.nn.FullyShardedDataParallel`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1. For best memory efficiency use auto_wrap to wrap each layer in your network with FSDP and 1. For best memory efficiency use auto_wrap to wrap each layer in your network with FSDP and
set `reshard_after_forward` to be True set `reshard_after_forward` to be True
2. For best training speed set `reshard_after_forward` to be False (wrapping each layer is not 2. For best training speed set `reshard_after_forward` to be False (wrapping each layer is not
required, but will improve speed further) required, but will improve speed further)
Pipeline Parallelism Pipeline Parallelism
===================== =====================
Training large models can lead to out-of-memory when the size of the model is too large for a single GPU. Training large models can lead to out-of-memory when the size of the model is too large for a single GPU.
To train such a large model, layers can be pipelined across different GPU devices as described in GPipe. To train such a large model, layers can be pipelined across different GPU devices as described in GPipe.
The `fairscale.nn.Pipe` is an implementation of GPipe which has been adopted from torchgpipe. This API The `fairscale.nn.Pipe` is an implementation of GPipe which has been adopted from torchgpipe. This API
has also been upstreamed to PyTorch in the 1.8 release with the experimental tag. has also been upstreamed to PyTorch in the 1.8 release with the experimental tag.
.. image:: ../_static/img/pipe.png .. image:: ../_static/img/pipe.png
Gpipe first shards the model across different devices where each device hosts a shard of the model. Gpipe first shards the model across different devices where each device hosts a shard of the model.
A shard can be a single layer or a series of layers. However Gpipe splits a mini-batch of data into A shard can be a single layer or a series of layers. However Gpipe splits a mini-batch of data into
micro-batches and feeds it to the device hosting the first shard. The layers on each device process micro-batches and feeds it to the device hosting the first shard. The layers on each device process
the micro-batches and send the output to the following shard/device. In the meantime it is ready to the micro-batches and send the output to the following shard/device. In the meantime it is ready to
process the micro batch from the previous shard/device. By pipepling the input in this way, Gpipe is process the micro batch from the previous shard/device. By pipepling the input in this way, Gpipe is
able to reduce the idle time of devices. able to reduce the idle time of devices.
Best practices for using `fairscale.nn.Pipe` Best practices for using `fairscale.nn.Pipe`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
......
...@@ -3,4 +3,4 @@ User Workflow ...@@ -3,4 +3,4 @@ User Workflow
User workflow Diagram with explanation of various decision points User workflow Diagram with explanation of various decision points
.. image:: _static/img/flowchart.png .. image:: _static/img/flowchart.png
\ No newline at end of file
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
FairScale Documentation FairScale Documentation
======================= =======================
FairScale is a PyTorch extension library for high performance and large scale training. FairScale is a PyTorch extension library for high performance and large scale training.
FairScale makes available the latest distributed training techniques in the form of composable FairScale makes available the latest distributed training techniques in the form of composable
modules and easy to use APIs. modules and easy to use APIs.
.. toctree:: .. toctree::
......
...@@ -13,12 +13,12 @@ from source using the instructions below. ...@@ -13,12 +13,12 @@ from source using the instructions below.
.. code-block:: bash .. code-block:: bash
pip install fairscale pip install fairscale
### Installing from source ### Installing from source
.. code-block:: bash .. code-block:: bash
git clone https://github.com/facebookresearch/fairscale.git git clone https://github.com/facebookresearch/fairscale.git
cd fairscale cd fairscale
pip install -r requirements.txt pip install -r requirements.txt
...@@ -26,4 +26,4 @@ from source using the instructions below. ...@@ -26,4 +26,4 @@ from source using the instructions below.
pip install -e . pip install -e .
Note: If either of the above fails, add `--no-build-isolation` to the `pip install` command (this could be a problem with recent versions of pip). Note: If either of the above fails, add `--no-build-isolation` to the `pip install` command (this could be a problem with recent versions of pip).
\ No newline at end of file
...@@ -4,7 +4,7 @@ Efficient memory usage using Activation Checkpointing ...@@ -4,7 +4,7 @@ Efficient memory usage using Activation Checkpointing
Adaped from `torch.utils.checkpoint`, this is a friendlier wrapper for performing activation checkpointing. Adaped from `torch.utils.checkpoint`, this is a friendlier wrapper for performing activation checkpointing.
Compared to the PyTorch version, this version wraps a `nn.Module` and allows for all subsequent calls to be Compared to the PyTorch version, this version wraps a `nn.Module` and allows for all subsequent calls to be
checkpointed. checkpointed.
.. code-block:: python .. code-block:: python
...@@ -22,7 +22,7 @@ checkpointed. ...@@ -22,7 +22,7 @@ checkpointed.
nn.Dropout(p=0.5), nn.Dropout(p=0.5),
nn.Linear(128, 32), nn.Linear(128, 32),
) )
self.ffn_module = checkpoint_wrapper(self.ffn_module, **kwargs) self.ffn_module = checkpoint_wrapper(self.ffn_module, **kwargs)
self.last_linear = nn.Linear(32, 1) self.last_linear = nn.Linear(32, 1)
......
...@@ -2,9 +2,9 @@ Scale your model on a single GPU using OffloadModel ...@@ -2,9 +2,9 @@ Scale your model on a single GPU using OffloadModel
==================================================== ====================================================
`fairscale.experimental.nn.offload.OffloadModel` API democratizes large scale distributed training by enabling `fairscale.experimental.nn.offload.OffloadModel` API democratizes large scale distributed training by enabling
users to train large models on limited GPU resources that would have traditionally resulted in OOM errors. users to train large models on limited GPU resources that would have traditionally resulted in OOM errors.
`OffloadModel` API wraps the given model and shards it almost equally. Each shard of the model is copied `OffloadModel` API wraps the given model and shards it almost equally. Each shard of the model is copied
from the CPU to the GPU for the forward pass and then copied back. The same process is repeated in the reverse from the CPU to the GPU for the forward pass and then copied back. The same process is repeated in the reverse
order for the backward pass. `OffloadModel` supports mixed precision training, activation checkpointing for reducing order for the backward pass. `OffloadModel` supports mixed precision training, activation checkpointing for reducing
the memory footprint and using micro batches to reduce throughput. the memory footprint and using micro batches to reduce throughput.
...@@ -24,7 +24,7 @@ Consider a training loop as described below: ...@@ -24,7 +24,7 @@ Consider a training loop as described below:
num_inputs = 8 num_inputs = 8
num_outputs = 8 num_outputs = 8
num_hidden = 4 num_hidden = 4
num_layers = 2 num_layers = 2
batch_size = 8 batch_size = 8
...@@ -45,8 +45,8 @@ Consider a training loop as described below: ...@@ -45,8 +45,8 @@ Consider a training loop as described below:
) )
To use the `OffloadModel` API, we should wrap the model as shown below. You can specify the device that you want To use the `OffloadModel` API, we should wrap the model as shown below. You can specify the device that you want
to use for computing the forward and backward pass, the offload device on which the model will be stored and the number to use for computing the forward and backward pass, the offload device on which the model will be stored and the number
of slices that the model should be sharded into. By default activation checkpointing is turned off and number of microbatches is 1. of slices that the model should be sharded into. By default activation checkpointing is turned off and number of microbatches is 1.
.. code-block:: python .. code-block:: python
......
Optimizer, Gradient and Model Sharding Optimizer, Gradient and Model Sharding
======================================= =======================================
Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications in the case of OSS, Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications in the case of OSS,
but it is possible and makes OSS a drop in solution in your existing torch distributed code. but it is possible and makes OSS a drop in solution in your existing torch distributed code.
Let's suppose that your trainer looks like Let's suppose that your trainer looks like
...@@ -45,9 +45,9 @@ Let's suppose that your trainer looks like ...@@ -45,9 +45,9 @@ Let's suppose that your trainer looks like
optimizer.step() optimizer.step()
Then sharding the optimizer state is merely a matter of wrapping your optimizer in `fairscale.optim.OSS`, Then sharding the optimizer state is merely a matter of wrapping your optimizer in `fairscale.optim.OSS`,
as follows. as follows.
DDP can be used in place of ShardedDDP in the example below, but the memory savings will be reduced DDP can be used in place of ShardedDDP in the example below, but the memory savings will be reduced
(the gradients are not as efficiently sharded). (the gradients are not as efficiently sharded).
.. code-block:: python .. code-block:: python
...@@ -97,7 +97,7 @@ DDP can be used in place of ShardedDDP in the example below, but the memory savi ...@@ -97,7 +97,7 @@ DDP can be used in place of ShardedDDP in the example below, but the memory savi
optimizer.step() optimizer.step()
The above `train` function can then be run via a `multiprocessing.spawn` call. Note that any launcher The above `train` function can then be run via a `multiprocessing.spawn` call. Note that any launcher
can be used, the only assumption being that each of the ranks lives in its own python process. can be used, the only assumption being that each of the ranks lives in its own python process.
.. code-block:: python .. code-block:: python
...@@ -111,12 +111,12 @@ can be used, the only assumption being that each of the ranks lives in its own p ...@@ -111,12 +111,12 @@ can be used, the only assumption being that each of the ranks lives in its own p
) )
Using PyTorch Automatic Mixed Precision is possible, and its actual usage will depend on whether OSS Using PyTorch Automatic Mixed Precision is possible, and its actual usage will depend on whether OSS
is used with DDP or with ShardedDDP. is used with DDP or with ShardedDDP.
If OSS is used with DDP, then the normal PyTorch GradScaler can be used, nothing needs to be changed. If OSS is used with DDP, then the normal PyTorch GradScaler can be used, nothing needs to be changed.
If OSS is used with ShardedDDP (to If OSS is used with ShardedDDP (to
get the gradient sharding), then a very similar flow can be used, but it requires a shard-aware GradScaler, get the gradient sharding), then a very similar flow can be used, but it requires a shard-aware GradScaler,
which is available in `fairscale.optim.grad_scaler`. In both cases Autocast can be used as is, and the which is available in `fairscale.optim.grad_scaler`. In both cases Autocast can be used as is, and the
loss will be scaled and handled in the same way. loss will be scaled and handled in the same way.
See [the original documentation] (https://pytorch.org/docs/stable/notes/amp_examples.html?highlight=automatic%20mixed%20precision) See [the original documentation] (https://pytorch.org/docs/stable/notes/amp_examples.html?highlight=automatic%20mixed%20precision)
for more information. for more information.
...@@ -156,7 +156,7 @@ for more information. ...@@ -156,7 +156,7 @@ for more information.
scaler.update() scaler.update()
Parameters can be sharded using the FullyShardedDataParallel (FSDP) API. It involves wrapping your model similar to the Parameters can be sharded using the FullyShardedDataParallel (FSDP) API. It involves wrapping your model similar to the
SDP API above. SDP API above.
.. code-block:: python .. code-block:: python
...@@ -201,9 +201,9 @@ SDP API above. ...@@ -201,9 +201,9 @@ SDP API above.
optimizer.step() optimizer.step()
Auto wrapping sub-modules with FSDP is a convenient way to improve training speed by overlapping Auto wrapping sub-modules with FSDP is a convenient way to improve training speed by overlapping
the allgather step across the forward passes of different submodules. the allgather step across the forward passes of different submodules.
It also improves memory efficiency by freeing gathered parameters after each layer finishes executing. It also improves memory efficiency by freeing gathered parameters after each layer finishes executing.
For example: For example:
.. code-block:: python .. code-block:: python
...@@ -233,4 +233,4 @@ For example: ...@@ -233,4 +233,4 @@ For example:
assert l.mixed_precision and l.flatten_parameters assert l.mixed_precision and l.flatten_parameters
assert isinstance(l.linear1, FSDP) assert isinstance(l.linear1, FSDP)
assert isinstance(l.linear2, FSDP) assert isinstance(l.linear2, FSDP)
assert not isinstance(l.self_attn, FSDP) # self attention is not auto-wrapped assert not isinstance(l.self_attn, FSDP) # self attention is not auto-wrapped
\ No newline at end of file
Model sharding using Pipeline Parallel Model sharding using Pipeline Parallel
====================================== ======================================
Let us start with a toy model that contains two linear layers. Let us start with a toy model that contains two linear layers.
...@@ -57,12 +57,12 @@ You can then define any optimizer and loss function ...@@ -57,12 +57,12 @@ You can then define any optimizer and loss function
optimizer.zero_grad() optimizer.zero_grad()
target = torch.randint(0,2,size=(20,1)).squeeze() target = torch.randint(0,2,size=(20,1)).squeeze()
data = torch.randn(20, 10) data = torch.randn(20, 10)
Finally, to run the model and compute the loss function, make sure that outputs and target are on the same device. Finally, to run the model and compute the loss function, make sure that outputs and target are on the same device.
.. code-block:: default .. code-block:: default
device = model.devices[0] device = model.devices[0]
## outputs and target need to be on the same device ## outputs and target need to be on the same device
......
What is FairScale? What is FairScale?
==================== ====================
FairScale is a PyTorch extension library for high performance and large scale training. FairScale is a PyTorch extension library for high performance and large scale training.
This library extends basic PyTorch capabilities while adding new SOTA scaling techniques. This library extends basic PyTorch capabilities while adding new SOTA scaling techniques.
FairScale makes available the latest distributed training techniques in the form of composable FairScale makes available the latest distributed training techniques in the form of composable
modules and easy to use APIs. These APIs are a fundamental part of a researcher's toolbox as modules and easy to use APIs. These APIs are a fundamental part of a researcher's toolbox as
they attempt to scale models with limited resources. they attempt to scale models with limited resources.
.. image:: _static/img/global.png .. image:: _static/img/global.png
...@@ -16,19 +16,19 @@ FairScale was designed with the following values in mind: ...@@ -16,19 +16,19 @@ FairScale was designed with the following values in mind:
1. **Usability** - Users should be able to understand and use FairScale APIs with minimum cognitive overload. 1. **Usability** - Users should be able to understand and use FairScale APIs with minimum cognitive overload.
2. **Modularity** - Users should be able to combine multiple FairScale APIs as part of their training loop seamlessly. 2. **Modularity** - Users should be able to combine multiple FairScale APIs as part of their training loop seamlessly.
3. **Performance** - FairScale APIs provide the best performance in terms of scaling and efficiency. 3. **Performance** - FairScale APIs provide the best performance in terms of scaling and efficiency.
.. image:: _static/img/ddp.png .. image:: _static/img/ddp.png
ML training at scale traditionally means `data parallelism <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`_ ML training at scale traditionally means `data parallelism <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`_
which allows us to use multiple devices at the same which allows us to use multiple devices at the same
time to train a large batch size per step thereby achieving the goal accuracy in a shorter period of time time to train a large batch size per step thereby achieving the goal accuracy in a shorter period of time
as compared to training on a single device. With recent advances in ML research, the size of ML models as compared to training on a single device. With recent advances in ML research, the size of ML models
has only increased over the years and data parallelism no longer serves all “scaling” purposes. has only increased over the years and data parallelism no longer serves all “scaling” purposes.
There are multiple axes across which you can scale training and FairScale provides the following broad There are multiple axes across which you can scale training and FairScale provides the following broad
categories of solutions: categories of solutions:
1. **Parallelism** → These techniques allow scaling of models by layer parallelism and tensor parallelism. 1. **Parallelism** → These techniques allow scaling of models by layer parallelism and tensor parallelism.
......
...@@ -117,17 +117,17 @@ def shard_model(model: torch.nn.Module, shard_count: int = 3) -> List[torch.fx.G ...@@ -117,17 +117,17 @@ def shard_model(model: torch.nn.Module, shard_count: int = 3) -> List[torch.fx.G
This function traces the model twice in an attempt to identify the This function traces the model twice in an attempt to identify the
right cutpoints and then shard the model. In the first pass we calculate right cutpoints and then shard the model. In the first pass we calculate
the number of parameters as we are tracing the graph and mark nodes at the number of parameters as we are tracing the graph and mark nodes at
which we might want to create a new module. In the second pass we which we might want to create a new module. In the second pass we
modify the graph by inserting placeholders and output nodes to essentially modify the graph by inserting placeholders and output nodes to essentially
shard the graph. shard the graph.
We don't support skip connections between shards. This means that all We don't support skip connections between shards. This means that all
input and output is self contained within a given shard. A node from input and output is self contained within a given shard. A node from
shard 1 cannot be an input to a node from shard 3. We expect all inputs shard 1 cannot be an input to a node from shard 3. We expect all inputs
to a given shard to be coming from the last node in the previous shard. to a given shard to be coming from the last node in the previous shard.
This means that we may not be able to shard models by the specified This means that we may not be able to shard models by the specified
`shard_count` mentioned by the user. `shard_count` mentioned by the user.
Args: Args:
model (nn.Module): Model to be sharded as specified by the device count. model (nn.Module): Model to be sharded as specified by the device count.
......
...@@ -75,7 +75,7 @@ def _split(modules: nn.Sequential, number_splits: int) -> List[List[nn.Module]]: ...@@ -75,7 +75,7 @@ def _split(modules: nn.Sequential, number_splits: int) -> List[List[nn.Module]]:
class ModelShard(nn.Module): class ModelShard(nn.Module):
""" """
Wrap one shard of the model, make it possible to load parameters on the Wrap one shard of the model, make it possible to load parameters on the
fly for the FW and BW pass on the given device. fly for the FW and BW pass on the given device.
""" """
...@@ -146,8 +146,8 @@ class OffloadFunction(torch.autograd.Function): ...@@ -146,8 +146,8 @@ class OffloadFunction(torch.autograd.Function):
This enables us to offload intermediate activations present at the shard This enables us to offload intermediate activations present at the shard
boundaries. boundaries.
- In the BW pass, it does the reverse. We run the forward pass using the - In the BW pass, it does the reverse. We run the forward pass using the
saved intermediate activations and calculate gradients as needed. saved intermediate activations and calculate gradients as needed.
The trade-off is latency vs memory when using activation checkpointing. The trade-off is latency vs memory when using activation checkpointing.
- Follows heavily from https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint. - Follows heavily from https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint.
...@@ -382,8 +382,8 @@ class OffloadModel(nn.Module): ...@@ -382,8 +382,8 @@ class OffloadModel(nn.Module):
model = get_model() model = get_model()
offload_model = OffloadModel(model, device, offload_model = OffloadModel(model, device,
offload_device=torch.device(“cpu”), offload_device=torch.device(“cpu”),
num_slices=3, num_slices=3,
checkpoint_activation=True, checkpoint_activation=True,
num_microbatches=5) num_microbatches=5)
.. _L2L: https://arxiv.org/abs/2002.05645 .. _L2L: https://arxiv.org/abs/2002.05645
......
...@@ -22,5 +22,6 @@ from .async_pipe import AsyncPipe ...@@ -22,5 +22,6 @@ from .async_pipe import AsyncPipe
from .checkpoint import is_checkpointing, is_recomputing from .checkpoint import is_checkpointing, is_recomputing
from .pipe import Pipe from .pipe import Pipe
from .rpc import PipeRPCWrapper from .rpc import PipeRPCWrapper
from .types import LazyModule
__all__ = ["Pipe", "is_checkpointing", "is_recomputing", "LazyModule"] __all__ = ["Pipe", "is_checkpointing", "is_recomputing", "LazyModule"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment