Unverified Commit d65cd838 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat]: AdaScale work with lr_scheduler and tests, examples (#229)

* [doc]: AdaScale example and notes

* formatted notes correctly as suggested by Benjamin

* added feature and unit test to make sure lr_scheduler works

* update the example with lr_scheduler

* fixed doc with "make html"

* addressed Mike's suggestions
parent 4402c410
......@@ -117,6 +117,32 @@ AdaScale can be used to wrap a SGD optimizer and to be used in DDP (Distributed
training or non-DDP with gradient accumulation. The benefit is to re-use the same LR
schedule from a baseline batch size when effective batch size is bigger.
```python
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR # or your scheduler
from fairscale.optim import AdaScale
...
optim = AdaScale(SGD(model.parameters(), lr=0.1))
scheduler = LambdaLR(optim, ...)
...
# Note: the train loop should be with DDP or with gradient accumulation.
last_epoch = 0
step = 0
done = False
while not done:
for sample in dataset:
...
step += optim.gain()
optim.step()
epoch = step // len(dataset)
if last_epoch != epoch:
scheduler.step()
last_epoch = epoch
if epoch > max_epoch:
done = True
```
Primary goal is to allow scaling to bigger batch sizes without losing model accuracy.
At a high level, we want ML researchers to:
......
......@@ -3,4 +3,4 @@ sphinx==3.2.1
sphinx_rtd_theme==0.4.3
sphinxcontrib-programoutput==0.16
git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
torch==1.6.0
torch>=1.6.0
......@@ -31,6 +31,8 @@ like the following.
optimizer = torch.optim.SGD(
params=model.parameters(),
**base_optimizer_arguments)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
lr_lambda = lambda x: 1/10**x)
# Any relevant training loop. For example:
model.train()
......@@ -43,10 +45,12 @@ like the following.
loss = loss_fn(outputs, target)
loss.backward()
optimizer.step()
scheduler.step()
Applying AdaScale is as simple as wrapping your SGD optimizer with fairscale.optim.AdaScale,
as follows.
Applying AdaScale is as simple as wrapping your SGD optimizer with
`fairscale.optim.AdaScale`, as follows and uses its gain() to update
the effective step and compute learning rate schedule accordingly.
.. code-block:: python
......@@ -75,13 +79,18 @@ as follows.
optimizer = torch.optim.SGD(
params=model.parameters(),
**base_optimizer_arguments)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
lr_lambda = lambda x: 1/10**x)
# Wrap optimizer with AdaScale
optimizer = AdaScale(optimizer)
# Any relevant training loop. For example:
model.train()
for e in range(epochs):
last_epoch = 0
step = 0
done = False
while not done:
for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank)
# Train
......@@ -89,4 +98,11 @@ as follows.
outputs = model(data)
loss = loss_fn(outputs, target)
loss.backward()
step += optimizer.gain()
optimizer.step()
epoch = step // len(dataloader)
if last_epoch != epoch:
scheduler.step()
last_epoch = epoch
if epoch >= epochs:
done = True
......@@ -38,28 +38,67 @@ import numpy as np
import torch
from torch.autograd import Variable
import torch.distributed as dist
from torch.optim import Optimizer
class AdaScale(object):
class AdaScale(Optimizer):
"""
Implements the AdaScale_ algorithm for scaling the learning rate for
distributed and large batch size training. Can be used in combination with
``torch.nn.parallel.DistributedDataParallel`` and ``torch.optim.SGD``.
Subclass `Optimizer` so that `torch.optim.lr_scheduler` can work. In other words,
AdaScale is intended to be a complete wrapper of an torch Optimizer.
.. _AdaScale: https://proceedings.icml.cc/static/paper_files/icml/2020/4682-Supplemental.pdf
There are several ways to integrate AdaScale with your training loop.
We show two examples below.
Example 1: using PyTorch's `lr_scheduler` classes.
.. code-block:: python
optim = AdaScale(SGD(model.parameters(), lr=0.001))
model = DistributedDataParallel(model)
scheduler = LambdaLR(optim, lr_lambda=...)
last_epoch = 0
done = False
step = 0
while True:
for batch in dataset:
optim.zero_grad()
logits = model()
loss = criterion(logits, ...)
loss.backward()
step += optim.gain()
optim.step()
epoch = step // len(dataset)
if epoch > last_epoch:
scheduler.step()
last_epoch = epoch
if epoch >= max_epochs:
done = True
Example 2: using a custom `update_lr()` function that update the learning
rate based on the current step count.
.. code-block:: python
optim = torch.optim.SGD(model.parameters(), lr=0.001)
optim = AdaScale(SGD(model.parameters(), lr=0.001))
model = DistributedDataParallel(model)
adascale = AdaScale(optim)
for epoch in ...:
step = 0
while step < max_steps:
for batch in ...:
optim.zero_grad()
loss = ...
logits = model()
loss = criterion()
loss.backward()
adascale.step()
step += optim.gain()
optim.step()
update_lr(step)
Args:
optimizer (torch.optim.Optimizer):
......@@ -68,9 +107,10 @@ class AdaScale(object):
Number of world_size for distributed training. If
None, defaults to ``dist.get_world_size()``.
scale (float):
Scaling factor of the batch size, e.g. using a 10x
larger batch size (summed across all world_size) means a scale of
10. If None, defaults to ``world_size``.
Scaling factor of the batch size from scale equals 1, e.g. using a 10x
larger batch size (summed across all ranks with gradient accumulation)
means a scale of 10. If None, defaults to
``world_size * num_gradients_to_accumulate``.
smoothing (float):
Smoothing factor for moving average. If None, it defaults to
max(1 - (world_size * num_gradients_to_accumulate)/1000, 0).
......@@ -95,6 +135,9 @@ class AdaScale(object):
self._num_backward_calls = 0
self._num_grads_to_accum = num_gradients_to_accumulate
# Proxy the param_groups so that `torch.optim.lr_scheduler` can work.
self.param_groups = self._optimizer.param_groups
if self._world_size * self._num_grads_to_accum <= 1:
# gain will be NaN since we will be dividing by zero in paper's B.3 where (S-1) == 0.
raise RuntimeError("AdaScale does not support a single worker without grad accumulation.")
......@@ -123,7 +166,7 @@ class AdaScale(object):
param.register_hook(functools.partial(self._backward_hook, idx))
@property
def state(self) -> Dict[str, np.ndarray]:
def _state(self) -> Dict[str, np.ndarray]:
"""
Return the states of AdaScale.
"""
......@@ -169,9 +212,9 @@ class AdaScale(object):
Estimate of squared l2-norm.
"""
if pg_idx is not None:
return self.state["grad_sqr_avg"][pg_idx]
return self._state["grad_sqr_avg"][pg_idx]
else:
return np.sum(self.state["grad_sqr_avg"])
return np.sum(self._state["grad_sqr_avg"])
def grad_var_avg(self, pg_idx: Optional[int] = None) -> float:
"""
......@@ -187,9 +230,9 @@ class AdaScale(object):
Estimate of trace of the covariance.
"""
if pg_idx is not None:
return self.state["grad_var_avg"][pg_idx]
return self._state["grad_var_avg"][pg_idx]
else:
return np.sum(self.state["grad_var_avg"])
return np.sum(self._state["grad_var_avg"])
def gain(self, scale: Optional[float] = None, pg_idx: Optional[int] = None) -> float:
"""
......@@ -213,13 +256,13 @@ class AdaScale(object):
def _update_avg(self, name: str, value: torch.Tensor, factor: float) -> None:
# This function computes and stores the moving average of a vector
# using a smoothing factor.
biased = self.state.get(name + "_biased", 0.0)
unbias = self.state.get(name + "_unbias", 0.0)
biased = self._state.get(name + "_biased", 0.0)
unbias = self._state.get(name + "_unbias", 0.0)
biased = factor * biased + (1.0 - factor) * value
unbias = factor * unbias + (1.0 - factor)
self.state[name + "_biased"] = biased
self.state[name + "_unbias"] = unbias
self.state[name] = biased / unbias
self._state[name + "_biased"] = biased
self._state[name + "_unbias"] = unbias
self._state[name] = biased / unbias
def _backward_hook(self, pg_idx: int, grad: torch.Tensor) -> None:
# This method should be invoked once for each parameter during the
......@@ -258,6 +301,8 @@ class AdaScale(object):
assert isinstance(self._local_grad_sqr, torch.Tensor)
# Keep track of number of backward calls for gradient accumulation.
# TODO (min): this may not work with activation checkpointing when
# multiple backward calls happen in a big backward.
self._num_backward_calls += 1
# TODO (min, mike): We need to have a way to check that training loop & DDP
......@@ -310,6 +355,13 @@ class AdaScale(object):
Run one optimizer step using Adascale. Essentially just invokes
``optimizer.step(*args, **kwargs)`` with a scaled learning rate.
.. note::
It is possible that this function becames a performance
bottleneck if you have frequent updates. To avoid that,
making bigger steps and reducing update frequency is generally
better for performance.
Args:
args (Any):
Positional arguments passed to ``optimizer.step``.
......@@ -340,9 +392,21 @@ class AdaScale(object):
return self._optimizer.zero_grad()
def state_dict(self) -> Dict:
"""Proxy function to optimizer, checkpointing needs this."""
""" Proxy function to optimizer, checkpointing needs this.
.. note::
Do NOT checkpoint in the middle of gradient accumulation since
associated AdaScale internal states are not saved in the checkpoint.
"""
return self._optimizer.state_dict()
def load_state_dict(self, data: Dict) -> None:
"""Proxy function to optimizer, checkpointing needs this."""
""" Proxy function to optimizer, checkpointing needs this.
.. note::
Do NOT checkpoint in the middle of gradient accumulation since
associated AdaScale internal states are not saved in the checkpoint.
"""
return self._optimizer.load_state_dict(data)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Any, List, Iterable, Union, Callable, Optional
from typing import Any, List, Dict, Iterable, Union, Callable, Optional
from .. import Tensor
_params_t = Union[Iterable[Tensor], Iterable[dict]]
_params_t = Union[Iterable[Tensor], Iterable[Dict]]
class Optimizer(object):
param_groups: List[dict]
state: dict
def __init__(self, params: _params_t, defaults: dict) -> None: ...
def state_dict(self) -> dict: ...
def load_state_dict(self, state_dict: dict) -> None: ...
param_groups: List[Dict]
state: Dict
def __init__(self, params: _params_t, defaults: Dict) -> None: ...
def state_dict(self) -> Dict: ...
def load_state_dict(self, state_dict: Dict) -> None: ...
def zero_grad(self) -> None: ...
def step(self, closure: Optional[Callable[[], float]]=...) -> Optional[float]: ...
def add_param_group(self, param_group: dict) -> None: ...
def add_param_group(self, param_group: Dict) -> None: ...
......@@ -17,6 +17,7 @@ import torch
from torch import Tensor
from torch.nn import Linear
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from fairscale.optim import AdaScale
......@@ -116,6 +117,9 @@ def test_grad_accum(test_case, cpu):
def test_state_checkpointing():
""" Test state checkpointing on GPU since that's the common case.
Note, we don't support checkpointing in the middle of gradient accumulation
step. Therefore, it is not tested here.
AdaScale doesn't have distributed state. Otherwise, it will need
a unit test for checkpointing with DDP.
"""
......@@ -181,3 +185,24 @@ def test_state_checkpointing():
# Assert the results.
assert np.allclose(out.sum().item(), expected_out), out.sum().item()
assert np.allclose(optim.gain(), expected_gain), optim.gain()
def test_lr_scheduler():
"""Test AdaScale working with torch.optim.lr_scheduler """
model = Linear(2, 2, bias=False)
optim = AdaScale(SGD(model.parameters(), lr=0.1), num_gradients_to_accumulate=3)
# We use 1, not 0.1 here since scheduler.step() is called here first.
scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 / 10 ** epoch)
for epoch in range(3):
for data_idx in range(10):
for accumulation in range(3):
in_data = torch.rand(2)
loss = model(in_data).sum()
loss.backward()
assert optim.gain() <= 3, optim.gain()
optim.step()
# asserting LR is right
assert np.allclose(optim.param_groups[0]["lr"], 0.1 / 10 ** epoch), optim.param_groups[0]["lr"]
scheduler.step()
# asserting LR is right
assert np.allclose(optim.param_groups[0]["lr"], 0.1 / 10 ** (epoch + 1)), optim.param_groups[0]["lr"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment