Unverified Commit 587b707d authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[doc] add AdaScale API doc (#191)

- removed experimental warning as we have validated it on cifar and
imagenet, transformer is looking good so far too.
- fixed API doc formatting
- make it consistent with the other code in the repo
- tested by making the doc locally and inspect the results
parent ade312c4
...@@ -4,5 +4,6 @@ API Reference ...@@ -4,5 +4,6 @@ API Reference
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
optim/adascale
optim/oss optim/oss
nn/pipe nn/pipe
AdaScale SGD
============
.. autoclass:: fairscale.optim.AdaScale
:members:
:undoc-members:
.. fairscale documentation master file, created by .. fairscale documentation master file, created by
sphinx-quickstart on Tue Sep 8 16:19:17 2020. sphinx-quickstart on Tue Sep 8 16:19:17 2020.
You can adapt this file completely to your liking, but it should at least You can adapt this file completely to your liking,
contain the root `toctree` directive. but it should at least contain the root `toctree`
directive.
Welcome to fairscale's documentation! Welcome to fairscale's documentation!
===================================== =====================================
...@@ -14,7 +15,10 @@ Welcome to fairscale's documentation! ...@@ -14,7 +15,10 @@ Welcome to fairscale's documentation!
tutorials/index tutorials/index
api/index api/index
*fairscale* is a PyTorch extension library for high performance and large scale training for optimizing training on one or across multiple machines/nodes. This library extend basic pytorch capabilities while adding new experimental ones. *fairscale* is a PyTorch extension library for high performance and
large scale training for optimizing training on one or across multiple
machines/nodes. This library extend basic pytorch capabilities while
adding new experimental ones.
Components Components
...@@ -25,11 +29,14 @@ Components ...@@ -25,11 +29,14 @@ Components
* `tensor parallelism <../../build/html/api/nn/model_parallel.html>`_ * `tensor parallelism <../../build/html/api/nn/model_parallel.html>`_
* Optimization: * Optimization:
* `optimizer state sharding <../../build/html/api/optim/oss.html>`_ * `optimizer state sharding <../../build/html/api/optim/oss.html>`_
* `AdaScale SGD <../../build/html/api/optim/adascale.html>`_
.. warning:: .. warning::
This library is under active development. This library is under active development.
Please be mindful and create an `issue <https://github.com/facebookresearch/fairscale/issues>`_ if you have any trouble and/or suggestion. Please be mindful and create an
`issue <https://github.com/facebookresearch/fairscale/issues>`_
if you have any trouble and/or suggestion.
Reference Reference
......
AdaScale SGD AdaScale SGD Tutorial
============ =====================
Note, AdaScale is still experimental. It is being validated. APIs may change `AdaScale <https://arxiv.org/pdf/2007.05105.pdf>`_ adaptively scales the learning rate when
in the future. Use at your own risk. using larger batch sizes for data-parallel training. Let's suppose that your trainer looks
like the following.
`AdaScale <https://arxiv.org/pdf/2007.05105.pdf>`_ adaptively scales the learning rate when using larger batch sizes for data-parallel training. Let's suppose that your trainer looks like
.. code-block:: python .. code-block:: python
...@@ -46,7 +45,8 @@ in the future. Use at your own risk. ...@@ -46,7 +45,8 @@ in the future. Use at your own risk.
optimizer.step() optimizer.step()
Applying AdaScale is as simple as wrapping your SGD optimizer with fairscale.optim.AdaScale, as follows Applying AdaScale is as simple as wrapping your SGD optimizer with fairscale.optim.AdaScale,
as follows.
.. code-block:: python .. code-block:: python
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
""" """
:mod:`fairgc.optim` is a package implementing various torch optimization algorithms. :mod:`fairscale.optim` is a package implementing various torch optimization algorithms.
""" """
try: try:
......
...@@ -32,7 +32,6 @@ ...@@ -32,7 +32,6 @@
# POSSIBILITY OF SUCH DAMAGE. # POSSIBILITY OF SUCH DAMAGE.
import functools import functools
import logging
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import numpy as np import numpy as np
...@@ -46,6 +45,8 @@ class AdaScale(object): ...@@ -46,6 +45,8 @@ class AdaScale(object):
distributed and large batch size training. Can be used in combination with distributed and large batch size training. Can be used in combination with
``torch.nn.parallel.DistributedDataParallel`` and ``torch.optim.SGD``. ``torch.nn.parallel.DistributedDataParallel`` and ``torch.optim.SGD``.
.. _AdaScale: https://proceedings.icml.cc/static/paper_files/icml/2020/4682-Supplemental.pdf
.. code-block:: python .. code-block:: python
optim = torch.optim.SGD(model, lr=0.001) optim = torch.optim.SGD(model, lr=0.001)
...@@ -59,17 +60,19 @@ class AdaScale(object): ...@@ -59,17 +60,19 @@ class AdaScale(object):
loss.backward() loss.backward()
adascale.step() adascale.step()
Arguments: Args:
optimizer (torch.optim.Optimizer): Optimizer to apply AdaScale to. optimizer (torch.optim.Optimizer):
world_size (int): Number of world_size for distributed training. If Optimizer to apply AdaScale to.
world_size (int):
Number of world_size for distributed training. If
None, defaults to ``torch.distributed.get_world_size()``. None, defaults to ``torch.distributed.get_world_size()``.
scale (float): Scaling factor of the batch size, e.g. using a 10x scale (float):
Scaling factor of the batch size, e.g. using a 10x
larger batch size (summed across all world_size) means a scale of larger batch size (summed across all world_size) means a scale of
10. If None, defaults to ``world_size``. 10. If None, defaults to ``world_size``.
patch_optimizer (bool): If True, monkey-patches the ``step`` method of patch_optimizer (bool):
the optimizer with the AdaScale ``step`` method. If True, monkey-patches the ``step`` method of
the optimizer with the AdaScale's ``step`` method.
.. _AdaScale: https://proceedings.icml.cc/static/paper_files/icml/2020/4682-Supplemental.pdf
""" """
def __init__( def __init__(
...@@ -80,12 +83,10 @@ class AdaScale(object): ...@@ -80,12 +83,10 @@ class AdaScale(object):
smoothing: float = 0.999, smoothing: float = 0.999,
patch_optimizer: bool = False, patch_optimizer: bool = False,
): ):
logging.warn("AdaScale is experimental. APIs may change. Use at your own risk.")
self._optimizer = optimizer self._optimizer = optimizer
self._optimizer_step = optimizer.step self._optimizer_step = optimizer.step
self._local_grad_sqr: Optional[torch.Tensor] = None self._local_grad_sqr: Optional[torch.Tensor] = None
self._world_size: int = world_size if world_size is not None else torch.distributed.get_world_size() self._world_size: int = (world_size if world_size is not None else torch.distributed.get_world_size())
if self._world_size <= 1: if self._world_size <= 1:
raise RuntimeError("AdaScale does not support a single worker.") raise RuntimeError("AdaScale does not support a single worker.")
...@@ -129,8 +130,9 @@ class AdaScale(object): ...@@ -129,8 +130,9 @@ class AdaScale(object):
application to invoke this function to make sure that AdaScale's application to invoke this function to make sure that AdaScale's
scaling factor matches the actual batch size used during training. scaling factor matches the actual batch size used during training.
Arguments: Args:
scale (float): New scaling factor to be applied to AdaScale. scale (float):
New scaling factor to be applied to AdaScale.
""" """
self._scale = scale self._scale = scale
...@@ -139,7 +141,9 @@ class AdaScale(object): ...@@ -139,7 +141,9 @@ class AdaScale(object):
Current estimate of the squared l2-norm of the true gradient (sigma Current estimate of the squared l2-norm of the true gradient (sigma
squared in the AdaScale paper). squared in the AdaScale paper).
Returns (float): Estimate of squared l2-norm. Returns
(float):
Estimate of squared l2-norm.
""" """
return np.sum(self.state["grad_sqr_avg"]) return np.sum(self.state["grad_sqr_avg"])
...@@ -148,7 +152,9 @@ class AdaScale(object): ...@@ -148,7 +152,9 @@ class AdaScale(object):
Current estimate of the trace of the covariance of the true gradient Current estimate of the trace of the covariance of the true gradient
(mu squared in the AdaScale paper). (mu squared in the AdaScale paper).
Returns (float): Estimate of trace of the covariance. Returns
(float):
Estimate of trace of the covariance.
""" """
return np.sum(self.state["grad_var_avg"]) return np.sum(self.state["grad_var_avg"])
...@@ -156,10 +162,13 @@ class AdaScale(object): ...@@ -156,10 +162,13 @@ class AdaScale(object):
""" """
Current estimate of the AdaScale gain ratio (r_t). Current estimate of the AdaScale gain ratio (r_t).
Arguments: Args:
scale (float): The batch size scale to estimate the gain ratio for. scale (float):
The batch size scale to estimate the gain ratio for.
Returns (float): Estimate of gain ratio. Returns
:(float):
Estimate of gain ratio.
""" """
scale = self._scale if scale is None else scale scale = self._scale if scale is None else scale
var = self.grad_var_avg() var = self.grad_var_avg()
...@@ -225,9 +234,14 @@ class AdaScale(object): ...@@ -225,9 +234,14 @@ class AdaScale(object):
Run one optimizer step using Adascale. Essentially just invokes Run one optimizer step using Adascale. Essentially just invokes
``optimizer.step(*args, **kwargs)`` with a scaled learning rate. ``optimizer.step(*args, **kwargs)`` with a scaled learning rate.
Arguments: Args:
args: Positional arguments passed to ``optimizer.step``. args:
kwargs: Keyword arguments passed to ``optimizer.step``. Positional arguments passed to ``optimizer.step``.
kwargs:
Keyword arguments passed to ``optimizer.step``.
Returns:
(Tensor):
loss if a closure is passed to the optimizer to reevaluate the model.
""" """
initial_lr = [pg["lr"] for pg in self._optimizer.param_groups] initial_lr = [pg["lr"] for pg in self._optimizer.param_groups]
for idx, param_group in enumerate(self._optimizer.param_groups): for idx, param_group in enumerate(self._optimizer.param_groups):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment