Unverified Commit a0042113 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[cleanup] mypy adascale (#149)

- close #143
parent 58e97aa6
...@@ -11,6 +11,6 @@ try: ...@@ -11,6 +11,6 @@ try:
from .adam import Adam, Precision from .adam import Adam, Precision
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
pass # pragma: no cover pass # pragma: no cover
from .adascale import AdaScale # type: ignore from .adascale import AdaScale
from .grad_scaler import GradScaler from .grad_scaler import GradScaler
from .oss import OSS from .oss import OSS
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2020 Petuum, Inc. All Rights Reserved. # Copyright 2020 Petuum, Inc. All Rights Reserved.
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
...@@ -26,9 +31,8 @@ ...@@ -26,9 +31,8 @@
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE. # POSSIBILITY OF SUCH DAMAGE.
# type: ignore
import functools import functools
from typing import Any, Dict, Optional
import numpy as np import numpy as np
from torch.autograd import Variable from torch.autograd import Variable
...@@ -67,11 +71,18 @@ class AdaScale(object): ...@@ -67,11 +71,18 @@ class AdaScale(object):
.. _AdaScale: https://proceedings.icml.cc/static/paper_files/icml/2020/4682-Supplemental.pdf .. _AdaScale: https://proceedings.icml.cc/static/paper_files/icml/2020/4682-Supplemental.pdf
""" """
def __init__(self, optimizer, world_size=None, scale=None, smoothing=0.999, patch_optimizer=False): def __init__(
self,
optimizer: torch.optim.Optimizer,
world_size: Optional[int] = None,
scale: Optional[float] = None,
smoothing: float = 0.999,
patch_optimizer: bool = False,
):
self._optimizer = optimizer self._optimizer = optimizer
self._optimizer_step = optimizer.step self._optimizer_step = optimizer.step
self._local_grad_sqr = None self._local_grad_sqr: Optional[torch.Tensor] = None
self._world_size = world_size if world_size is not None else torch.distributed.get_world_size() self._world_size: int = world_size if world_size is not None else torch.distributed.get_world_size()
if self._world_size <= 1: if self._world_size <= 1:
raise RuntimeError("AdaScale does not support a single worker.") raise RuntimeError("AdaScale does not support a single worker.")
...@@ -96,11 +107,11 @@ class AdaScale(object): ...@@ -96,11 +107,11 @@ class AdaScale(object):
self._smoothing = smoothing self._smoothing = smoothing
@property @property
def state(self): def state(self) -> Dict[str, np.ndarray]:
return self._optimizer.state["adascale"] return self._optimizer.state["adascale"]
@property @property
def scale(self): def scale(self) -> float:
""" """
The scaling factor of the current batch size, relative to the baseline The scaling factor of the current batch size, relative to the baseline
batch size when training with a single worker. For example, if the batch size when training with a single worker. For example, if the
...@@ -109,7 +120,7 @@ class AdaScale(object): ...@@ -109,7 +120,7 @@ class AdaScale(object):
""" """
return self._scale return self._scale
def set_scale(self, scale): def set_scale(self, scale: float) -> None:
""" """
Set the scaling factor of the current batch size. It is up to the Set the scaling factor of the current batch size. It is up to the
application to invoke this function to make sure that AdaScale's application to invoke this function to make sure that AdaScale's
...@@ -120,7 +131,7 @@ class AdaScale(object): ...@@ -120,7 +131,7 @@ class AdaScale(object):
""" """
self._scale = scale self._scale = scale
def grad_sqr_avg(self): def grad_sqr_avg(self) -> float:
""" """
Current estimate of the squared l2-norm of the true gradient (sigma Current estimate of the squared l2-norm of the true gradient (sigma
squared in the AdaScale paper). squared in the AdaScale paper).
...@@ -129,7 +140,7 @@ class AdaScale(object): ...@@ -129,7 +140,7 @@ class AdaScale(object):
""" """
return np.sum(self.state["grad_sqr_avg"]) return np.sum(self.state["grad_sqr_avg"])
def grad_var_avg(self): def grad_var_avg(self) -> float:
""" """
Current estimate of the trace of the covariance of the true gradient Current estimate of the trace of the covariance of the true gradient
(mu squared in the AdaScale paper). (mu squared in the AdaScale paper).
...@@ -138,7 +149,7 @@ class AdaScale(object): ...@@ -138,7 +149,7 @@ class AdaScale(object):
""" """
return np.sum(self.state["grad_var_avg"]) return np.sum(self.state["grad_var_avg"])
def gain(self, scale=None): def gain(self, scale: Optional[float] = None) -> float:
""" """
Current estimate of the AdaScale gain ratio (r_t). Current estimate of the AdaScale gain ratio (r_t).
...@@ -152,7 +163,7 @@ class AdaScale(object): ...@@ -152,7 +163,7 @@ class AdaScale(object):
sqr = self.grad_sqr_avg() sqr = self.grad_sqr_avg()
return (var + sqr) / (var / scale + sqr) return (var + sqr) / (var / scale + sqr)
def _update_avg(self, name, value, factor): def _update_avg(self, name: str, value: float, factor: float) -> None:
biased = self.state.get(name + "_biased", 0.0) biased = self.state.get(name + "_biased", 0.0)
unbias = self.state.get(name + "_unbias", 0.0) unbias = self.state.get(name + "_unbias", 0.0)
biased = factor * biased + (1.0 - factor) * value biased = factor * biased + (1.0 - factor) * value
...@@ -161,7 +172,7 @@ class AdaScale(object): ...@@ -161,7 +172,7 @@ class AdaScale(object):
self.state[name + "_unbias"] = unbias self.state[name + "_unbias"] = unbias
self.state[name] = biased / unbias self.state[name] = biased / unbias
def _backward_hook(self, idx, grad): def _backward_hook(self, idx: int, grad: torch.Tensor) -> None:
# This method should be invoked once for each parameter during the # This method should be invoked once for each parameter during the
# backward pass, before gradients are synchronized between world_size. # backward pass, before gradients are synchronized between world_size.
if self._local_grad_sqr is None: if self._local_grad_sqr is None:
...@@ -170,7 +181,7 @@ class AdaScale(object): ...@@ -170,7 +181,7 @@ class AdaScale(object):
self._final_callback_queued = False self._final_callback_queued = False
Variable._execution_engine.queue_callback(self._queue_callback) Variable._execution_engine.queue_callback(self._queue_callback)
def _queue_callback(self): def _queue_callback(self) -> None:
# This method should be invoked after the entire backward pass. We want # This method should be invoked after the entire backward pass. We want
# to make sure self._final_callback is invoked once, only after all # to make sure self._final_callback is invoked once, only after all
# gradients have been synchronized between each worker. However, the # gradients have been synchronized between each worker. However, the
...@@ -183,10 +194,11 @@ class AdaScale(object): ...@@ -183,10 +194,11 @@ class AdaScale(object):
self._final_callback_queued = True self._final_callback_queued = True
Variable._execution_engine.queue_callback(self._final_callback) Variable._execution_engine.queue_callback(self._final_callback)
def _final_callback(self): def _final_callback(self) -> None:
# This method should be invoked once for each backward pass, after # This method should be invoked once for each backward pass, after
# gradients have been synchronized between each worker. # gradients have been synchronized between each worker.
self._final_callback_queued = False self._final_callback_queued = False
assert isinstance(self._local_grad_sqr, torch.Tensor)
torch.distributed.all_reduce(self._local_grad_sqr / self._world_size) torch.distributed.all_reduce(self._local_grad_sqr / self._world_size)
local_grad_sqr = self._local_grad_sqr.cpu().numpy() local_grad_sqr = self._local_grad_sqr.cpu().numpy()
total_grad_sqr = np.array( total_grad_sqr = np.array(
...@@ -201,7 +213,7 @@ class AdaScale(object): ...@@ -201,7 +213,7 @@ class AdaScale(object):
self._update_avg("grad_var_avg", grad_var, theta) self._update_avg("grad_var_avg", grad_var, theta)
self._local_grad_sqr = None self._local_grad_sqr = None
def step(self, *args, **kwargs): def step(self, *args: Any, **kwargs: Any) -> Optional[float]:
""" """
Run one optimizer step using Adascale. Essentially just invokes Run one optimizer step using Adascale. Essentially just invokes
``optimizer.step(*args, **kwargs)`` with a scaled learning rate. ``optimizer.step(*args, **kwargs)`` with a scaled learning rate.
...@@ -216,17 +228,18 @@ class AdaScale(object): ...@@ -216,17 +228,18 @@ class AdaScale(object):
grad_var = float(self.state["grad_var_avg"][idx]) grad_var = float(self.state["grad_var_avg"][idx])
gain = (grad_var + grad_sqr) / (grad_var / self._scale + grad_sqr) gain = (grad_var + grad_sqr) / (grad_var / self._scale + grad_sqr)
param_group["lr"] = gain * param_group["lr"] param_group["lr"] = gain * param_group["lr"]
self._optimizer_step(*args, **kwargs) res = self._optimizer_step(*args, **kwargs)
for lr, param_group in zip(initial_lr, self._optimizer.param_groups): for lr, param_group in zip(initial_lr, self._optimizer.param_groups):
param_group["lr"] = lr param_group["lr"] = lr
return res
def patch_optimizer(self): def patch_optimizer(self) -> None:
""" """
Monkey-patch the optimizer's step function with :meth:`AdaScale.step`. Monkey-patch the optimizer's step function with :meth:`AdaScale.step`.
""" """
@functools.wraps(self._optimizer.step) @functools.wraps(self._optimizer.step)
def wrapper(*args, **kwargs): def wrapper(*args: Any, **kwargs: Any) -> Optional[float]:
return self.step(*args, **kwargs) return self.step(*args, **kwargs)
self._optimizer.step = wrapper setattr(self._optimizer, "step", wrapper)
...@@ -6,9 +6,13 @@ from .grad_mode import no_grad as no_grad, enable_grad as enable_grad, \ ...@@ -6,9 +6,13 @@ from .grad_mode import no_grad as no_grad, enable_grad as enable_grad, \
set_grad_enabled as set_grad_enabled set_grad_enabled as set_grad_enabled
from .profiler import record_function from .profiler import record_function
# This is defined in CPP in PyTorch source
class ImperativeEngine:
def queue_callback(self, callback: Callable[..., None]): ...
# TODO make Variable and Function more precise # TODO make Variable and Function more precise
class Variable: class Variable:
... _execution_engine: ImperativeEngine
class Function: class Function:
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment