Unverified Commit a2408eb8 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] Add AdaScaleWrapper (#347)

* [feat] Add AdaScaleWrapper

- This enables a different API for wrapping an optimizer with AdaScale.
- This also enables AdaScale to be wrapped by OSS.
- However, OSS wrapping AdaScale results in different optimization,
  which future research will be needed to study its effects.

testing: add unit tests.

* addressed comment: typo
parent a265586b
......@@ -8,7 +8,7 @@
"""
import logging
from .adascale import AdaScale
from .adascale import AdaScale, AdaScaleWrapper
from .oss import OSS
try:
......
......@@ -32,13 +32,18 @@
# POSSIBILITY OF SUCH DAMAGE.
import functools
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
import numpy as np
import torch
from torch.autograd import Variable
import torch.distributed as dist
from torch.optim import Optimizer
from torch.optim import SGD, Optimizer
if TYPE_CHECKING: # pragma: no cover
from torch.optim.optimizer import _params_t
else:
_params_t = Any
class AdaScale(Optimizer):
......@@ -582,3 +587,47 @@ class AdaScale(Optimizer):
# When effective world size is large enough, smoothing is probably
# not needed, so the smoothing factor is 0.
self._smoothing = max(1 - self._world_size * self._num_grads_to_accum / 1000, 0)
class AdaScaleWrapper(AdaScale):
"""
A thin wrapper for AdaScale so that the constructor resembles a
standard optimizer. This allows it to work with other Optimizer
Wrappers, like `OSS`.
.. warn::
OSS(AdaScaleWrapper) (i.e. OSS wrapping AdaScale) resulting in each
rank's AdaScale operates on different set of parameters. They
will get different gain values and it is unclear how to adjust
effective step size in that case. We have not validated effectiveness
or benefit in this case.
OTOH, AdaScale(OSS) (i.e. AdaScale wrapping OSS) is recommended
and is numerically identical to AdaScale without OSS. Since
AdaScale doesn't incur per-parameter state, the memory benefit
of OSS is still the same.
Args:
params (list of tensors):
parameters to be optimized
optim (class subtyping torch.optim.Optimizer):
a optimizer class to be wrapped.
additional_optim_args (argument dict):
keyward arguments to the `optim` class above.
The rest params are in-sync with the `AdaScale` class above.
"""
def __init__(
self,
params: _params_t,
world_size: Optional[int] = None,
scale: Optional[float] = None,
smoothing: float = None,
num_gradients_to_accumulate: int = 1,
debias_ewma: bool = True,
optim_cls: Type[Optimizer] = SGD,
**additional_optim_args: Any,
):
optim_obj = optim_cls(params, **additional_optim_args)
super().__init__(optim_obj, world_size, scale, smoothing, num_gradients_to_accumulate, debias_ewma)
......@@ -22,7 +22,7 @@ from torch.nn import Linear, Sequential
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD
from fairscale.optim import OSS, AdaScale
from fairscale.optim import OSS, AdaScale, AdaScaleWrapper
from fairscale.utils.golden_testing_data import adascale_test_data
from fairscale.utils.testing import skip_if_single_gpu
......@@ -40,11 +40,16 @@ def _test_basic_func(rank, world_size, tempfile_name, test_case, oss, model=None
model = Linear(2, 2, bias=False)
model.to("cuda")
model = DDP(model, device_ids=[rank])
if oss:
# For now, we can only wrap AdaScale over OSS. If we do it the other way around,
# AdaScale needs to take different parameter types, i.e. the parameter list, etc.
assert oss in ["none", "ada-oss", "wrapper-oss", "oss-wrapper"]
if oss == "ada-oss":
optim = AdaScale(OSS(model.parameters(), SGD, lr=0.1))
elif oss == "wrapper-oss":
optim = AdaScaleWrapper(model.parameters(), optim_cls=OSS, optim=SGD, lr=0.1)
elif oss == "oss-wrapper":
optim = OSS(model.parameters(), AdaScaleWrapper, optim_cls=SGD, lr=0.1)
else:
assert oss == "none"
optim = AdaScale(SGD(model.parameters(), lr=0.1))
if "input" in test_case:
......@@ -59,6 +64,7 @@ def _test_basic_func(rank, world_size, tempfile_name, test_case, oss, model=None
optim.step()
optim.zero_grad()
if "expected_gain" in test_case:
assert np.allclose(optim.gain(), test_case["expected_gain"]), optim.gain()
if "expected_mean_weight" in test_case:
......@@ -75,11 +81,11 @@ def test_basic(test_case):
world_size = 2
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(_test_basic_func, args=(world_size, temp_file_name, test_case, True), nprocs=world_size, join=True)
mp.spawn(_test_basic_func, args=(world_size, temp_file_name, test_case, "ada-oss"), nprocs=world_size, join=True)
@skip_if_single_gpu
@pytest.mark.parametrize("oss", [True, False])
@pytest.mark.parametrize("oss", ["none", "ada-oss", "wrapper-oss", "oss-wrapper"])
def test_sequential(oss):
"""Test adascale with DDP + OSS with a sequential model"""
world_size = 2
......@@ -98,6 +104,14 @@ def test_sequential(oss):
"expected_mean_weight": 52.92657661437988,
}
if oss == "oss-wrapper":
# When OSS wraps AdaScale, the training is numerically different
# and it exists only to enable future research. So we don't check
# the gain (OSS doesn't have a gain() function, different rank's
# gains are different). We just ensure the mean_weight is expected.
del test_case["expected_gain"]
test_case["expected_mean_weight"] = 94.93386840820312
# The model.
model = Sequential(
Linear(2, 3, bias=False), Linear(3, 4, bias=False), Linear(4, 5, bias=False), Linear(5, 6, bias=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment