Unverified Commit 3932a1f6 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] sync adascale from internal repo, support add_param_group (#266)

* [feat] sync adascale from internal repo

- tbd

testing: tbd

* Update argument document of __init__

* update documentation around set_num_gradients_to_accumulate

* added checking code for proper API calling places

* rename internal APIs to make them internal

* updated changelog

* added support for add_param_group and its unit test

* added unit test for set_num_gradients_to_accumulate

* added debias_ewma unit test

* fixed test_set_num_gradients_to_accumulate (need zero_grad() call)

* added missing zero_grad() to test_lr_scheduler

* fixed test_add_param_group with respect to optim.zero_grad()

* added test_gradient_value

* added test_scale_not_equal_default for scale != world_size * grad_accum

* added test_unhook()

* removed print statements

* fixed a typo

* addressed Ben's comment
parent 84a3bdbe
......@@ -7,8 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [next rel] - TBD
### Added
- AdaScale: Added gradient accumulation feature (#202)
- AdaScale: Added support of torch.lr_scheduler (#229)
- AdaScale:
. Added gradient accumulation feature (#202)
. Added support of torch.lr_scheduler (#229)
. Added support for add_param_groups (#266)
. Added support for scale != world_size (#266)
### Fixed
- AdaScale: smoothing factor value fixed when using gradient accumulation (#235)
......
......@@ -32,7 +32,7 @@
# POSSIBILITY OF SUCH DAMAGE.
import functools
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
import numpy as np
import torch
......@@ -82,7 +82,7 @@ class AdaScale(Optimizer):
done = True
Example 2: using a custom `update_lr()` function that update the learning
rate based on the current step count.
rate based on the current step count per epoch.
.. code-block:: python
......@@ -104,19 +104,27 @@ class AdaScale(Optimizer):
optimizer (torch.optim.Optimizer):
Optimizer to apply AdaScale to.
world_size (int):
Number of world_size for distributed training. If
None, defaults to ``dist.get_world_size()``.
Number of world_size for distributed training.
If None, defaults to ``dist.get_world_size()``.
scale (float):
Scaling factor of the batch size from scale equals 1, e.g. using a 10x
larger batch size (summed across all ranks with gradient accumulation)
means a scale of 10. If None, defaults to
``world_size * num_gradients_to_accumulate``.
means a scale of 10.
If None, defaults to ``world_size * num_gradients_to_accumulate``.
smoothing (float):
Smoothing factor for moving average. If None, it defaults to
max(1 - (world_size * num_gradients_to_accumulate)/1000, 0).
Smoothing factor for moving average.
If None, it defaults to ``max(1 - (world_size * num_gradients_to_accumulate)/1000, 0)``.
num_gradients_to_accumulate (int):
Number of passes that we accumulate gradients locally.
Number of passes that we accumulate gradients locally
between each optimizer step. This can be changed during
training as long as the train loop changes gradient accumulation
accordingly.
Default to 1, which does not accumulate gradients.
debias_ewma (bool):
(experimental) Use debias exponential moving average
for smoothing and mu and sigma variables. False will
use the method in the paper's Appendix B.3.
Default: True, which is what have been validated so far.
"""
def __init__(
......@@ -126,6 +134,7 @@ class AdaScale(Optimizer):
scale: Optional[float] = None,
smoothing: float = None,
num_gradients_to_accumulate: int = 1,
debias_ewma: bool = True,
):
self._optimizer = optimizer
self._local_grad_sqr: Optional[torch.Tensor] = None
......@@ -133,11 +142,15 @@ class AdaScale(Optimizer):
world_size if world_size is not None else dist.get_world_size() if dist.is_initialized() else 1
)
self._num_backward_calls = 0
self._last_final_backward_call = 0
self._num_grads_to_accum = num_gradients_to_accumulate
self._debias_ewma = debias_ewma
# Proxy the param_groups so that `torch.optim.lr_scheduler` can work.
self.param_groups = self._optimizer.param_groups
self.set_num_gradients_to_accumulate(num_gradients_to_accumulate, update_smoothing=True)
if self._world_size * self._num_grads_to_accum <= 1:
# gain will be NaN since we will be dividing by zero in paper's B.3 where (S-1) == 0.
raise RuntimeError("AdaScale does not support a single worker without grad accumulation.")
......@@ -151,19 +164,47 @@ class AdaScale(Optimizer):
},
)
self._scale = 1.0 # Assign to inform mypy about the typing of this variable.
self.set_scale(self._world_size * self._num_grads_to_accum if scale is None else scale)
# Set smoothing based on effective world_size rather than scale here, since world_size
# determines the number of samples being averaged over at every update
self._smoothing = (
max(1 - (self._world_size * self._num_grads_to_accum) / 1000, 0) if smoothing is None else smoothing
)
self._hook_handles: List[Any] = []
self._hook()
def _hook(self) -> None:
""" Internal function to register the gradient hooks.
# Register the gradient hooks. Note, don't assume every param will generate
# a gradient (i.e. triggering the hook) in every backward pass.
Note, don't assume every parameter will generate a gradient (i.e. triggering the hook)
in every backward pass, which is the reason that we have ``find_unused_params`` flag
in the DDP class in ``torch.nn.parallel``.
"""
assert self._hook_handles == [], "Must run unhook first"
for idx, param_group in enumerate(self._optimizer.param_groups):
for param in param_group["params"]:
param.register_hook(functools.partial(self._backward_hook, idx))
h = param.register_hook(functools.partial(self._backward_hook, idx))
self._hook_handles.append(h)
def __del__(self) -> None:
""" Unhook in case caller forgets to call unhook.
This however may not "work" since there would be circular reference
between the hook objects and this objects. In that case, neither will
get GC'ed. Calling unhook explicitly if you really want to delete
AdaScale from memory.
"""
self.unhook()
def unhook(self) -> None:
""" Unregister hook handles.
This is public because caller may need to call this to ensure all GPU
memory are released. Otherwise, the hook may prevent parameters from being
released from the GPU memory pool.
Internally, we use this to support ``add_param_group()`` API.
"""
for h in self._hook_handles:
h.remove()
self._hook_handles = []
@property
def _state(self) -> Dict[str, np.ndarray]:
......@@ -176,9 +217,12 @@ class AdaScale(Optimizer):
def scale(self) -> float:
"""
The scaling factor of the current batch size, relative to the baseline
batch size when training with a single worker. For example, if the
baseline batch size is 32, but using a scaled-up batch size of 80, then
then the scaling factor is 2.5.
batch size, which could be a DDP training. For example, if the
baseline batch size is 32 on 2 GPUs, but using a scaled-up batch size
of 80 on 4 GPUs, then then the scaling factor is 80 * 4 / 32 / 2 = 5.
This is exposed API mainly for logging purpose. Note, this is different
from ``self.gain()``.
Returns:
(float):
......@@ -186,7 +230,22 @@ class AdaScale(Optimizer):
"""
return self._scale
def set_scale(self, scale: float) -> None:
@property
def smoothing(self) -> float:
"""
The smoothing constant used in exponentially-weighted moving average
tracking the gradient norm mean and variance within AdaScale.
This is exposed API since the value is computed and caller may
want to obtain this value and log it.
Returns:
(float):
The current smoothing value.
"""
return self._smoothing
def set_scale(self, scale: float, update_estimate: bool = True) -> None:
"""
Set the scaling factor of the current batch size. It is up to the
application to invoke this function to make sure that AdaScale's
......@@ -195,10 +254,23 @@ class AdaScale(Optimizer):
Args:
scale (float):
New scaling factor to be applied to AdaScale.
update_estimate (bool):
Whether to update the scale-depenent estimate of gradient
variance; this is highly recommended. (default: True)
"""
assert self._local_grad_sqr is None, "Don't change scale in backward phase"
assert scale >= 1, "Scale must be at least 1"
if update_estimate and hasattr(self, "_scale"):
assert self._scale >= 1, "bug: old scale isn't valid"
# Rescale grad_var_avg to account for the change in scale
if self._debias_ewma and "grad_var_avg_biased" in self._state:
self._state["grad_var_avg_biased"] *= self._scale / scale
elif "grad_var_avg_total" in self._state: # _debias_ewma==False
self._state["grad_var_avg_total"] *= self._scale / scale
self._state["grad_var_avg"] *= self._scale / scale
self._scale = scale
def grad_sqr_avg(self, pg_idx: Optional[int] = None) -> float:
def _grad_sqr_avg(self, pg_idx: Optional[int] = None) -> float:
"""
Current estimate of the squared l2-norm of the true gradient
(sigma squared in the AdaScale paper).
......@@ -216,7 +288,7 @@ class AdaScale(Optimizer):
else:
return np.sum(self._state["grad_sqr_avg"])
def grad_var_avg(self, pg_idx: Optional[int] = None) -> float:
def _grad_var_avg(self, pg_idx: Optional[int] = None) -> float:
"""
Current estimate of the trace of the covariance of the true gradient
(mu squared in the AdaScale paper).
......@@ -234,43 +306,71 @@ class AdaScale(Optimizer):
else:
return np.sum(self._state["grad_var_avg"])
def gain(self, scale: Optional[float] = None, pg_idx: Optional[int] = None) -> float:
def gain(self, pg_idx: Optional[int] = None) -> float:
"""
Current estimate of the AdaScale gain ratio (r_t in the paper).
Args:
scale (float):
Optional batch size scale to estimate the gain ratio for.
pg_idx (int):
Optional index of a parameter group.
Default None: returns "averaged" gain for all groups.
Returns:
(float):
Estimate of gain ratio.
"""
scale = self._scale if scale is None else scale
var = self.grad_var_avg(pg_idx)
sqr = self.grad_sqr_avg(pg_idx)
return (var + sqr) / (var / scale + sqr)
def _update_avg(self, name: str, value: torch.Tensor, factor: float) -> None:
# This function computes and stores the moving average of a vector
# using a smoothing factor.
biased = self._state.get(name + "_biased", 0.0)
unbias = self._state.get(name + "_unbias", 0.0)
biased = factor * biased + (1.0 - factor) * value
unbias = factor * unbias + (1.0 - factor)
self._state[name + "_biased"] = biased
self._state[name + "_unbias"] = unbias
self._state[name] = biased / unbias
var = self._grad_var_avg(pg_idx)
sqr = self._grad_sqr_avg(pg_idx)
gain = (var + sqr) / (var / self.scale + sqr)
return gain
def _update_avg(self, name: str, value: np.ndarray, factor: float) -> None:
if self._debias_ewma:
# This function computes and stores the moving average of a vector
# using a smoothing factor.
biased = self._state.get(name + "_biased", np.zeros(value.shape[0]))
unbias = self._state.get(name + "_unbias", np.zeros(value.shape[0]))
biased = factor * biased + (1.0 - factor) * value
unbias = factor * unbias + (1.0 - factor)
self._state[name + "_biased"] = biased
self._state[name + "_unbias"] = unbias
self._state[name] = biased / unbias
else:
# Moving average procedure described in Appendix B.3
# For iterations t < 1 / (1 - smoothing) define grad_var_avg
# and grad_sqr_avg as mean of the past samples. After that
# start using running average.
#
# Note: we only keep a single _count for all parameter groups.
# Ideally, it should be a vector and in case a PG is added
# after some iterations are done. But, then the if condition
# below will need to be a np.where. I leave this corner
# case to a future exercise.
count = self._state.get(name + "_count", 0)
count += 1
self._state[name + "_count"] = count
if count < 1 / (1 - self._smoothing):
total = self._state.get(name + "_total", None)
if total is None:
total = value
else:
total += value
self._state[name + "_total"] = total
self._state[name] = total / count
else:
self._state[name] = factor * self._state[name] + (1.0 - factor) * value
def _backward_hook(self, pg_idx: int, grad: torch.Tensor) -> None:
# This method should be invoked once for each parameter during the
# backward pass, before gradients are synchronized between world_size.
# Store the local gradient square sums in a vector.
# This vector is also used for error checking. Whenever it is not None,
# it means that we are in backward pass.
if self._local_grad_sqr is None:
self._local_grad_sqr = torch.zeros(len(self._optimizer.param_groups), device=grad.device)
self._local_grad_sqr = torch.zeros(
len(self._optimizer.param_groups), device=grad.device, requires_grad=False,
)
self._local_grad_sqr[pg_idx] += grad.pow(2).sum()
# Now, ensure we queue a callback at the end of the callback queue.
......@@ -310,7 +410,13 @@ class AdaScale(Optimizer):
# in this backward pass.
# Longer term, we may compute the gain and then inform
# the training loop when it is a good time to step().
if self._num_backward_calls % self._num_grads_to_accum != 0:
assert (
self._num_backward_calls - self._last_final_backward_call
) <= self._num_grads_to_accum, (
f"bug: {self._num_backward_calls} - {self._last_final_backward_call} should <= {self._num_grads_to_accum}"
)
if (self._num_backward_calls - self._last_final_backward_call) % self._num_grads_to_accum != 0:
assert self._local_grad_sqr is not None, "We should still be in backward phase"
return
# Since self._local_grad_sqr is FP32, sum shouldn't overflow.
......@@ -322,33 +428,37 @@ class AdaScale(Optimizer):
# Compute the sums of squares for reduced gradients.
# Divide by _num_grads_to_accum since the gradients are accumulated.
#
# Note: we are mutating the gradients here!!!
total_grad_sqr = np.array(
[
sum(param.grad.div_(self._num_grads_to_accum).pow(2).sum().item() for param in group["params"])
for group in self._optimizer.param_groups
]
[sum(param.grad.pow(2).sum().item() for param in group["params"]) for group in self._optimizer.param_groups]
)
# Divide by (_num_grads_to_accum ** 2) to account for gradient
# accumulation.
if self._num_grads_to_accum > 1:
# np array doesn't support /=.
total_grad_sqr = total_grad_sqr / (self._num_grads_to_accum ** 2)
# Wait for all_reduce to be done and move it to cpu & np.
if work:
work.wait()
local_grad_sqr = self._local_grad_sqr.cpu().numpy()
self._local_grad_sqr = None
# See appendix B.3 of the paper.
# Modified to handle cases where scale != world_size
#
# local_grad_sqr is \sigma_{i=1}^{S}\norm{g_t_i}
# total_grad_sqr is \norm{g_t}
# local_grad_sqr is \sum_{i=1}^{c N} \norm{g_t_i}^2
# where N is world size and c is num_grads_to_accum
# total_grad_sqr is \norm{\bar{g}_t}^2
S = self._scale
grad_var = local_grad_sqr / (S - 1) - total_grad_sqr * S / (S - 1)
cN = self._world_size * self._num_grads_to_accum
grad_var = local_grad_sqr * (S / cN) / (cN - 1) - total_grad_sqr * S / (cN - 1)
grad_sqr = total_grad_sqr - grad_var / S
grad_var = np.maximum(grad_var, 1e-6)
grad_sqr = np.maximum(grad_sqr, 0.0)
theta = self._smoothing
self._update_avg("grad_sqr_avg", grad_sqr, theta)
self._update_avg("grad_var_avg", grad_var, theta)
self._update_avg("grad_sqr_avg", grad_sqr, self.smoothing)
self._update_avg("grad_var_avg", grad_var, self.smoothing)
self._last_final_backward_call = self._num_backward_calls
# Indicating backward is done.
self._local_grad_sqr = None
def step(self, *args: Any, **kwargs: Any) -> Optional[float]:
"""
......@@ -372,6 +482,7 @@ class AdaScale(Optimizer):
(Tensor):
The loss tensor if a closure if used to re-evaluate the model.
"""
assert self._local_grad_sqr is None, "Don't step without finishing backward phase"
# Set original LR and set new LR.
original_lr = []
for idx, param_group in enumerate(self._optimizer.param_groups):
......@@ -387,8 +498,30 @@ class AdaScale(Optimizer):
return res
def add_param_group(self, pg: Dict) -> None:
""" Support adding parameter groups
We need to re-size some of the state and re-register the backward hooks.
"""
assert self._local_grad_sqr is None, "Can't add parameter group during backward"
self._optimizer.add_param_group(pg)
# Update the hooks.
self.unhook()
self._hook()
# Extend the states.
for name in self._state.keys():
assert name.startswith("grad_sqr_avg") or name.startswith("grad_var_avg"), name
if isinstance(self._state[name], int):
# This is the "_count" variable.
continue
# must be a np array, extend it with the right value and check the shape.
val = 1 if name == "grad_sqr_avg" else 0
self._state[name] = np.append(self._state[name], val)
assert self._state[name].shape == (len(self._optimizer.param_groups),)
def zero_grad(self) -> None:
"""Proxy function to optimizer, because some training loops need this."""
assert self._local_grad_sqr is None, "Don't zero_grad in backward"
return self._optimizer.zero_grad()
def state_dict(self) -> Dict:
......@@ -399,6 +532,7 @@ class AdaScale(Optimizer):
Do NOT checkpoint in the middle of gradient accumulation since
associated AdaScale internal states are not saved in the checkpoint.
"""
assert self._local_grad_sqr is None, "Don't checkpoint in backward"
return self._optimizer.state_dict()
def load_state_dict(self, data: Dict) -> None:
......@@ -409,4 +543,37 @@ class AdaScale(Optimizer):
Do NOT checkpoint in the middle of gradient accumulation since
associated AdaScale internal states are not saved in the checkpoint.
"""
assert self._local_grad_sqr is None, "Don't load checkpoint in backward"
return self._optimizer.load_state_dict(data)
def set_num_gradients_to_accumulate(self, num_gradients_to_accumulate: int, update_smoothing: bool = True,) -> None:
"""Set the number of gradients to accumulate to a new value.
This is experimental. This could be called while training so that
we can gradually increasing the steps between updates. Almost always,
`set_scale` needs to be called to update the scale as well.
TODO (min): need a way of determine how much to increase the step size?
TODO (min): have both `set_scale` and `set_num_gradients_to_accumulate`
is hard to use and easy to make mistake. I think it is better
to specific a specify a `base_scale`. But more discussion is
needed here.
Args:
num_gradients_to_accumulate (int):
Number of gradients to accumulate (calls to backward) between
each optimizer step
update_smoothing (bool):
Whether to update smoothing factor or not. Default: True.
"""
assert self._local_grad_sqr is None, "Don't change num_grad_to_accum in backward"
assert num_gradients_to_accumulate >= 1, f"Invalid value {num_gradients_to_accumulate}"
self._num_grads_to_accum = num_gradients_to_accumulate
if update_smoothing:
# Set smoothing based on effective world_size rather than scale here,
# since world_size determines the number of samples being averaged over
# at every update.
#
# When effective world size is large enough, smoothing is probably
# not needed, so the smoothing factor is 0.
self._smoothing = max(1 - self._world_size * self._num_grads_to_accum / 1000, 0)
......@@ -9,13 +9,14 @@
""" Test AdaScale with a single node (1 CPU or 1 GPU). """
import gc
import tempfile
import numpy as np
import pytest
import torch
from torch import Tensor
from torch.nn import Linear
from torch.nn import Linear, Sequential
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
......@@ -41,7 +42,7 @@ def test_loss_accum_cpu():
"""
model = Linear(2, 2, bias=False)
# num_gradients_to_accumulate value doesn't matter in this negative test.
optim = AdaScale(SGD(model.parameters(), lr=0.1), num_gradients_to_accumulate=123)
optim = AdaScale(SGD(model.parameters(), lr=0.1), num_gradients_to_accumulate=3)
# data 1
in_data = Tensor([0.0, 1.0])
loss = model(in_data).sum()
......@@ -53,9 +54,9 @@ def test_loss_accum_cpu():
loss += model(in_data).sum()
# backward, but gradient is only produced once by the autograd engine.
loss.backward()
# therefore, the gain will always be 1, which renders adascale as noop.
optim.step()
# The gain will always be 1, which renders adascale as noop.
assert np.allclose(optim.gain(), 1.0), optim.gain()
# We don't call optim.step(), since it will detect that backward is not yet done.
# IMPORTANT: make sure these test_cases values are sync'ed with the DDP
......@@ -138,7 +139,6 @@ def test_state_checkpointing():
# Run a bit.
def run_a_bit(replay_data=None):
print("running")
data = []
replay_data_idx = 0
for _ in range(6): # run some steps
......@@ -151,8 +151,6 @@ def test_state_checkpointing():
replay_data_idx += 1
out = model(in_data)
out.sum().backward()
# print(out.sum().item())
print(model.weight.grad)
if i == accum_steps - 1:
optim.step()
optim.zero_grad()
......@@ -188,7 +186,7 @@ def test_state_checkpointing():
def test_lr_scheduler():
"""Test AdaScale working with torch.optim.lr_scheduler """
"""Test AdaScale working with torch.optim.lr_scheduler."""
model = Linear(2, 2, bias=False)
optim = AdaScale(SGD(model.parameters(), lr=0.1), num_gradients_to_accumulate=3)
# We use 1, not 0.1 here since scheduler.step() is called here first.
......@@ -201,8 +199,211 @@ def test_lr_scheduler():
loss.backward()
assert optim.gain() <= 3, optim.gain()
optim.step()
optim.zero_grad()
# asserting LR is right
assert np.allclose(optim.param_groups[0]["lr"], 0.1 / 10 ** epoch), optim.param_groups[0]["lr"]
scheduler.step()
# asserting LR is right
assert np.allclose(optim.param_groups[0]["lr"], 0.1 / 10 ** (epoch + 1)), optim.param_groups[0]["lr"]
@skip_if_no_gpu
@pytest.mark.parametrize("debias_ewma", [True, False])
def test_add_param_group(debias_ewma):
"""Test AdaScale supports add_param_group() API."""
model1 = Linear(2, 2, bias=True)
with torch.no_grad():
# make weights and bias deterministic, which is needed for
# multi-layer models. For them, adascale gain is affected by
# parameters from other layers.
model1.weight.copy_(Tensor([1.0, 2.0, 3.0, 4.0]).reshape(2, 2))
model1.bias.fill_(0.1)
optim = AdaScale(SGD(model1.parameters(), lr=0.1), num_gradients_to_accumulate=2, debias_ewma=debias_ewma)
assert len(optim._hook_handles) == 2
model2 = Linear(2, 3, bias=True)
with torch.no_grad():
# make weights and bias deterministic
model2.weight.copy_(Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(3, 2))
model2.bias.fill_(0.2)
optim.add_param_group({"params": model2.parameters()})
assert len(optim._hook_handles) == 4
# make sure we can run the model.
model = Sequential(model1, model2).cuda()
in_data_0 = Tensor([1.0, 2.0]).cuda()
out = model(in_data_0)
out.sum().backward()
in_data_1 = Tensor([3.0, 4.0]).cuda()
out = model(in_data_1)
out.sum().backward()
# make sure the gains are right and we can step.
# since this is the first step, debias_ewma doesn't affect the value.
assert np.allclose(optim.gain(), 1.1440223454935758), optim.gain()
assert np.allclose(optim.gain(0), 1.1428571428571428), optim.gain(0)
assert np.allclose(optim.gain(1), 1.1471258476157762), optim.gain(1)
optim.step()
optim.zero_grad()
# make sure we can add a PG again after stepping.
model3 = Linear(3, 4, bias=True)
with torch.no_grad():
# make weights and bias deterministic
model3.weight.copy_(Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0] * 2).reshape(4, 3))
model3.bias.fill_(0.2)
optim.add_param_group({"params": model3.parameters()})
assert len(optim._hook_handles) == 6
# make sure we can run the model.
model = Sequential(model1, model2, model3).cuda()
in_data_0 = Tensor([1.0, 2.0]).cuda()
out = model(in_data_0)
out.sum().backward()
in_data_1 = Tensor([3.0, 4.0]).cuda()
out = model(in_data_1)
out.sum().backward()
# make sure gains are right and we can step.
# the last PG's gain is not affected by debias_ewma since it is the first step for that PG.
assert np.allclose(optim.gain(), 1.1191193589460822 if debias_ewma else 1.1192783954732368), optim.gain()
assert np.allclose(optim.gain(0), 1.1428571880897151 if debias_ewma else 1.142857188085096), optim.gain(0)
assert np.allclose(optim.gain(1), 1.1167103578364508 if debias_ewma else 1.1167104954034948), optim.gain(1)
assert np.allclose(optim.gain(2), 1.117381091722702), optim.gain(2)
optim.step()
optim.zero_grad()
@pytest.mark.parametrize(
"test_case",
[
{"new_accum": 3, "exp_gain": 1.2573902104603087},
{"new_accum": 6, "exp_gain": 1.0903738977361481},
{"new_accum": 9, "exp_gain": 1.0432658660558123},
],
)
def test_set_num_gradients_to_accumulate(test_case):
"""Test set_num_gradients_to_accumulate experimental feature."""
new_accum = test_case["new_accum"]
exp_gain = test_case["exp_gain"]
model = Linear(2, 2, bias=False)
optim = AdaScale(SGD(model.parameters(), lr=0.1), num_gradients_to_accumulate=2)
out = model(Tensor([0.0, 1.0]))
out.sum().backward()
out = model(Tensor([1.0, 0.0]))
out.sum().backward()
assert np.allclose(optim.gain(), 2.0)
optim.step()
optim.zero_grad()
optim.set_scale(float(new_accum))
optim.set_num_gradients_to_accumulate(new_accum)
for _ in range(new_accum):
out = model(Tensor([0.0, 1.0]))
out.sum().backward()
assert np.allclose(optim.gain(), exp_gain), optim.gain()
optim.step()
optim.zero_grad()
def test_debias_ewma():
"""Test debias_ewma experimental feature"""
model = Linear(2, 2, bias=False)
optim = AdaScale(SGD(model.parameters(), lr=0.1), num_gradients_to_accumulate=2, debias_ewma=True)
for _ in range(4):
out = model(Tensor([0.0, 1.0]))
out.sum().backward()
out = model(Tensor([1.0, 0.0]))
out.sum().backward()
assert np.allclose(optim.gain(), 2.0), optim.gain()
optim.step()
optim.zero_grad()
def test_gradient_value():
"""Test that we don't mutate the gradients during backward"""
model = Linear(2, 2, bias=False)
optim = AdaScale(SGD(model.parameters(), lr=0.1), num_gradients_to_accumulate=2)
# fwd 1
out = model(Tensor([0.0, 1.0]))
out.sum().backward()
assert np.allclose(model.weight.grad.numpy(), [[0.0, 1.0], [0.0, 1.0]]), model.weight.grad
# fwd 2, grad is accumulated
out = model(Tensor([0.0, 1.0]))
out.sum().backward()
assert np.allclose(model.weight.grad.numpy(), [[0.0, 2.0], [0.0, 2.0]]), model.weight.grad
# assert gain and grad value before/after step/zero_grad
assert np.allclose(optim.gain(), 1.0000002499999376), optim.gain()
optim.step()
assert np.allclose(model.weight.grad.numpy(), [[0.0, 2.0], [0.0, 2.0]]), model.weight.grad
optim.zero_grad()
assert np.allclose(model.weight.grad.numpy(), [[0.0, 0.0], [0.0, 0.0]]), model.weight.grad
@pytest.mark.parametrize(
"test_case",
[
{"scale": None, "exp_gain": 4.0}, # default, baseline is single batch
{"scale": 4.0 / 3, "exp_gain": 4.0 / 3}, # baseline is grad_accum = 3
{"scale": 4.0 / 2, "exp_gain": 2.0}, # baseline is grad_accum = 2
{"scale": 4.0 / 1, "exp_gain": 4.0}, # baseline is single batch
],
)
def test_scale_not_equal_default(test_case):
"""Test gain value when scale doesn't equal world size * grad_accum"""
scale = test_case["scale"]
exp_gain = test_case["exp_gain"]
model = Linear(4, 2, bias=False)
optim = AdaScale(SGD(model.parameters(), lr=0.1), num_gradients_to_accumulate=4, scale=scale)
data = [
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
]
for i in range(4):
out = model(Tensor(data[i]))
out.sum().backward()
# Since the inputs are perfect orthogonal, the gain should be at the scale.
assert np.allclose(optim.gain(), exp_gain), optim.gain()
@skip_if_no_gpu
def test_unhook():
"""Test unhook that frees the tensor from CUDA memory."""
model = Linear(123, 456, bias=False).cuda() # unique shape so that it can be found
optim = AdaScale(SGD(model.parameters(), lr=0.1), num_gradients_to_accumulate=2)
def find_tensor():
""" Find the weight tensor from the heap
Return True if found.
"""
for obj in gc.get_objects():
try:
# Only need to check parameter type objects
if "torch.nn.parameter.Parameter" not in str(type(obj)):
continue
if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)):
if obj.shape == (456, 123):
return True
except Exception as e:
pass
return False
torch.cuda.empty_cache()
assert find_tensor(), "something wrong with gc-based method to find the tensor"
optim.unhook()
del model
del optim
torch.cuda.empty_cache()
assert not find_tensor(), "tensor should have been released"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment