Unverified Commit a65fc83e authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] fixing circleCI for AdaScale (#142)

* [fix] fixing circleCI for AdaScale

- ran black, isort, flake8, mypy

* more fix
parent 64d1e312
......@@ -58,7 +58,7 @@ flake8
### Static analysis
```
mypy .
mypy --ignore-missing-imports --scripts-are-modules --pretty .
```
### Unit tests
......
......@@ -11,6 +11,6 @@ try:
from .adam import Adam, Precision
except ImportError: # pragma: no cover
pass # pragma: no cover
from .adascale import AdaScale
from .adascale import AdaScale # type: ignore
from .grad_scaler import GradScaler
from .oss import OSS
# Copyright 2020 Petuum, Inc. All Rights Reserved.
#
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
#
# 3. Neither the name of Petuum, Inc. nor the names of its contributors may be
# used to endorse or promote products derived from this software without
# specific prior written permission.
#
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
......@@ -26,12 +26,13 @@
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
# type: ignore
import functools
import numpy as np
import torch.distributed
from torch.autograd import Variable
import torch.distributed
class AdaScale(object):
......@@ -65,28 +66,29 @@ class AdaScale(object):
.. _AdaScale: https://proceedings.icml.cc/static/paper_files/icml/2020/4682-Supplemental.pdf
"""
def __init__(self, optimizer, world_size=None, scale=None, smoothing=0.999,
patch_optimizer=False):
def __init__(self, optimizer, world_size=None, scale=None, smoothing=0.999, patch_optimizer=False):
self._optimizer = optimizer
self._optimizer_step = optimizer.step
self._local_grad_sqr = None
self._world_size = (world_size if world_size is not None
else torch.distributed.get_world_size())
self._world_size = world_size if world_size is not None else torch.distributed.get_world_size()
if self._world_size <= 1:
raise RuntimeError("AdaScale does not support a single worker.")
self._optimizer.state.setdefault("adascale", {
"grad_sqr_avg": np.ones(len(optimizer.param_groups)),
"grad_var_avg": np.zeros(len(optimizer.param_groups)),
})
self._optimizer.state.setdefault(
"adascale",
{
"grad_sqr_avg": np.ones(len(optimizer.param_groups)),
"grad_var_avg": np.zeros(len(optimizer.param_groups)),
},
)
self.set_scale(self._world_size if scale is None else scale)
for idx, param_group in enumerate(self._optimizer.param_groups):
for param in param_group["params"]:
param.register_hook(
functools.partial(self._backward_hook, idx))
param.register_hook(functools.partial(self._backward_hook, idx))
if patch_optimizer:
self.patch_optimizer()
......@@ -163,8 +165,7 @@ class AdaScale(object):
# This method should be invoked once for each parameter during the
# backward pass, before gradients are synchronized between world_size.
if self._local_grad_sqr is None:
self._local_grad_sqr = torch.zeros(
len(self._optimizer.param_groups), device=grad.device)
self._local_grad_sqr = torch.zeros(len(self._optimizer.param_groups), device=grad.device)
self._local_grad_sqr[idx] += grad.pow(2).sum()
self._final_callback_queued = False
Variable._execution_engine.queue_callback(self._queue_callback)
......@@ -188,18 +189,16 @@ class AdaScale(object):
self._final_callback_queued = False
torch.distributed.all_reduce(self._local_grad_sqr / self._world_size)
local_grad_sqr = self._local_grad_sqr.cpu().numpy()
total_grad_sqr = np.array([sum(param.grad.pow(2).sum().item()
for param in group["params"])
for group in self._optimizer.param_groups])
grad_sqr = ((self._world_size * total_grad_sqr - local_grad_sqr)
/ (self._world_size - 1))
grad_var = ((local_grad_sqr - total_grad_sqr) * self._scale
/ (self._world_size - 1))
total_grad_sqr = np.array(
[sum(param.grad.pow(2).sum().item() for param in group["params"]) for group in self._optimizer.param_groups]
)
grad_sqr = (self._world_size * total_grad_sqr - local_grad_sqr) / (self._world_size - 1)
grad_var = (local_grad_sqr - total_grad_sqr) * self._scale / (self._world_size - 1)
grad_sqr = np.maximum(grad_sqr, 0.0)
grad_var = np.maximum(grad_var, 1e-6)
theta = self._smoothing ** self._scale
self._update_avg('grad_sqr_avg', grad_sqr, theta)
self._update_avg('grad_var_avg', grad_var, theta)
self._update_avg("grad_sqr_avg", grad_sqr, theta)
self._update_avg("grad_var_avg", grad_var, theta)
self._local_grad_sqr = None
def step(self, *args, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment