Unverified Commit a0042113 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[cleanup] mypy adascale (#149)

- close #143
parent 58e97aa6
......@@ -11,6 +11,6 @@ try:
from .adam import Adam, Precision
except ImportError: # pragma: no cover
pass # pragma: no cover
from .adascale import AdaScale # type: ignore
from .adascale import AdaScale
from .grad_scaler import GradScaler
from .oss import OSS
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2020 Petuum, Inc. All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
......@@ -26,9 +31,8 @@
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
# type: ignore
import functools
from typing import Any, Dict, Optional
import numpy as np
from torch.autograd import Variable
......@@ -67,11 +71,18 @@ class AdaScale(object):
.. _AdaScale: https://proceedings.icml.cc/static/paper_files/icml/2020/4682-Supplemental.pdf
"""
def __init__(self, optimizer, world_size=None, scale=None, smoothing=0.999, patch_optimizer=False):
def __init__(
self,
optimizer: torch.optim.Optimizer,
world_size: Optional[int] = None,
scale: Optional[float] = None,
smoothing: float = 0.999,
patch_optimizer: bool = False,
):
self._optimizer = optimizer
self._optimizer_step = optimizer.step
self._local_grad_sqr = None
self._world_size = world_size if world_size is not None else torch.distributed.get_world_size()
self._local_grad_sqr: Optional[torch.Tensor] = None
self._world_size: int = world_size if world_size is not None else torch.distributed.get_world_size()
if self._world_size <= 1:
raise RuntimeError("AdaScale does not support a single worker.")
......@@ -96,11 +107,11 @@ class AdaScale(object):
self._smoothing = smoothing
@property
def state(self):
def state(self) -> Dict[str, np.ndarray]:
return self._optimizer.state["adascale"]
@property
def scale(self):
def scale(self) -> float:
"""
The scaling factor of the current batch size, relative to the baseline
batch size when training with a single worker. For example, if the
......@@ -109,7 +120,7 @@ class AdaScale(object):
"""
return self._scale
def set_scale(self, scale):
def set_scale(self, scale: float) -> None:
"""
Set the scaling factor of the current batch size. It is up to the
application to invoke this function to make sure that AdaScale's
......@@ -120,7 +131,7 @@ class AdaScale(object):
"""
self._scale = scale
def grad_sqr_avg(self):
def grad_sqr_avg(self) -> float:
"""
Current estimate of the squared l2-norm of the true gradient (sigma
squared in the AdaScale paper).
......@@ -129,7 +140,7 @@ class AdaScale(object):
"""
return np.sum(self.state["grad_sqr_avg"])
def grad_var_avg(self):
def grad_var_avg(self) -> float:
"""
Current estimate of the trace of the covariance of the true gradient
(mu squared in the AdaScale paper).
......@@ -138,7 +149,7 @@ class AdaScale(object):
"""
return np.sum(self.state["grad_var_avg"])
def gain(self, scale=None):
def gain(self, scale: Optional[float] = None) -> float:
"""
Current estimate of the AdaScale gain ratio (r_t).
......@@ -152,7 +163,7 @@ class AdaScale(object):
sqr = self.grad_sqr_avg()
return (var + sqr) / (var / scale + sqr)
def _update_avg(self, name, value, factor):
def _update_avg(self, name: str, value: float, factor: float) -> None:
biased = self.state.get(name + "_biased", 0.0)
unbias = self.state.get(name + "_unbias", 0.0)
biased = factor * biased + (1.0 - factor) * value
......@@ -161,7 +172,7 @@ class AdaScale(object):
self.state[name + "_unbias"] = unbias
self.state[name] = biased / unbias
def _backward_hook(self, idx, grad):
def _backward_hook(self, idx: int, grad: torch.Tensor) -> None:
# This method should be invoked once for each parameter during the
# backward pass, before gradients are synchronized between world_size.
if self._local_grad_sqr is None:
......@@ -170,7 +181,7 @@ class AdaScale(object):
self._final_callback_queued = False
Variable._execution_engine.queue_callback(self._queue_callback)
def _queue_callback(self):
def _queue_callback(self) -> None:
# This method should be invoked after the entire backward pass. We want
# to make sure self._final_callback is invoked once, only after all
# gradients have been synchronized between each worker. However, the
......@@ -183,10 +194,11 @@ class AdaScale(object):
self._final_callback_queued = True
Variable._execution_engine.queue_callback(self._final_callback)
def _final_callback(self):
def _final_callback(self) -> None:
# This method should be invoked once for each backward pass, after
# gradients have been synchronized between each worker.
self._final_callback_queued = False
assert isinstance(self._local_grad_sqr, torch.Tensor)
torch.distributed.all_reduce(self._local_grad_sqr / self._world_size)
local_grad_sqr = self._local_grad_sqr.cpu().numpy()
total_grad_sqr = np.array(
......@@ -201,7 +213,7 @@ class AdaScale(object):
self._update_avg("grad_var_avg", grad_var, theta)
self._local_grad_sqr = None
def step(self, *args, **kwargs):
def step(self, *args: Any, **kwargs: Any) -> Optional[float]:
"""
Run one optimizer step using Adascale. Essentially just invokes
``optimizer.step(*args, **kwargs)`` with a scaled learning rate.
......@@ -216,17 +228,18 @@ class AdaScale(object):
grad_var = float(self.state["grad_var_avg"][idx])
gain = (grad_var + grad_sqr) / (grad_var / self._scale + grad_sqr)
param_group["lr"] = gain * param_group["lr"]
self._optimizer_step(*args, **kwargs)
res = self._optimizer_step(*args, **kwargs)
for lr, param_group in zip(initial_lr, self._optimizer.param_groups):
param_group["lr"] = lr
return res
def patch_optimizer(self):
def patch_optimizer(self) -> None:
"""
Monkey-patch the optimizer's step function with :meth:`AdaScale.step`.
"""
@functools.wraps(self._optimizer.step)
def wrapper(*args, **kwargs):
def wrapper(*args: Any, **kwargs: Any) -> Optional[float]:
return self.step(*args, **kwargs)
self._optimizer.step = wrapper
setattr(self._optimizer, "step", wrapper)
......@@ -6,9 +6,13 @@ from .grad_mode import no_grad as no_grad, enable_grad as enable_grad, \
set_grad_enabled as set_grad_enabled
from .profiler import record_function
# This is defined in CPP in PyTorch source
class ImperativeEngine:
def queue_callback(self, callback: Callable[..., None]): ...
# TODO make Variable and Function more precise
class Variable:
...
_execution_engine: ImperativeEngine
class Function:
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment