Unverified Commit a65fc83e authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] fixing circleCI for AdaScale (#142)

* [fix] fixing circleCI for AdaScale

- ran black, isort, flake8, mypy

* more fix
parent 64d1e312
...@@ -58,7 +58,7 @@ flake8 ...@@ -58,7 +58,7 @@ flake8
### Static analysis ### Static analysis
``` ```
mypy . mypy --ignore-missing-imports --scripts-are-modules --pretty .
``` ```
### Unit tests ### Unit tests
......
...@@ -11,6 +11,6 @@ try: ...@@ -11,6 +11,6 @@ try:
from .adam import Adam, Precision from .adam import Adam, Precision
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
pass # pragma: no cover pass # pragma: no cover
from .adascale import AdaScale from .adascale import AdaScale # type: ignore
from .grad_scaler import GradScaler from .grad_scaler import GradScaler
from .oss import OSS from .oss import OSS
...@@ -26,12 +26,13 @@ ...@@ -26,12 +26,13 @@
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE. # POSSIBILITY OF SUCH DAMAGE.
# type: ignore
import functools import functools
import numpy as np import numpy as np
import torch.distributed
from torch.autograd import Variable from torch.autograd import Variable
import torch.distributed
class AdaScale(object): class AdaScale(object):
...@@ -65,28 +66,29 @@ class AdaScale(object): ...@@ -65,28 +66,29 @@ class AdaScale(object):
.. _AdaScale: https://proceedings.icml.cc/static/paper_files/icml/2020/4682-Supplemental.pdf .. _AdaScale: https://proceedings.icml.cc/static/paper_files/icml/2020/4682-Supplemental.pdf
""" """
def __init__(self, optimizer, world_size=None, scale=None, smoothing=0.999,
patch_optimizer=False): def __init__(self, optimizer, world_size=None, scale=None, smoothing=0.999, patch_optimizer=False):
self._optimizer = optimizer self._optimizer = optimizer
self._optimizer_step = optimizer.step self._optimizer_step = optimizer.step
self._local_grad_sqr = None self._local_grad_sqr = None
self._world_size = (world_size if world_size is not None self._world_size = world_size if world_size is not None else torch.distributed.get_world_size()
else torch.distributed.get_world_size())
if self._world_size <= 1: if self._world_size <= 1:
raise RuntimeError("AdaScale does not support a single worker.") raise RuntimeError("AdaScale does not support a single worker.")
self._optimizer.state.setdefault("adascale", { self._optimizer.state.setdefault(
"adascale",
{
"grad_sqr_avg": np.ones(len(optimizer.param_groups)), "grad_sqr_avg": np.ones(len(optimizer.param_groups)),
"grad_var_avg": np.zeros(len(optimizer.param_groups)), "grad_var_avg": np.zeros(len(optimizer.param_groups)),
}) },
)
self.set_scale(self._world_size if scale is None else scale) self.set_scale(self._world_size if scale is None else scale)
for idx, param_group in enumerate(self._optimizer.param_groups): for idx, param_group in enumerate(self._optimizer.param_groups):
for param in param_group["params"]: for param in param_group["params"]:
param.register_hook( param.register_hook(functools.partial(self._backward_hook, idx))
functools.partial(self._backward_hook, idx))
if patch_optimizer: if patch_optimizer:
self.patch_optimizer() self.patch_optimizer()
...@@ -163,8 +165,7 @@ class AdaScale(object): ...@@ -163,8 +165,7 @@ class AdaScale(object):
# This method should be invoked once for each parameter during the # This method should be invoked once for each parameter during the
# backward pass, before gradients are synchronized between world_size. # backward pass, before gradients are synchronized between world_size.
if self._local_grad_sqr is None: if self._local_grad_sqr is None:
self._local_grad_sqr = torch.zeros( self._local_grad_sqr = torch.zeros(len(self._optimizer.param_groups), device=grad.device)
len(self._optimizer.param_groups), device=grad.device)
self._local_grad_sqr[idx] += grad.pow(2).sum() self._local_grad_sqr[idx] += grad.pow(2).sum()
self._final_callback_queued = False self._final_callback_queued = False
Variable._execution_engine.queue_callback(self._queue_callback) Variable._execution_engine.queue_callback(self._queue_callback)
...@@ -188,18 +189,16 @@ class AdaScale(object): ...@@ -188,18 +189,16 @@ class AdaScale(object):
self._final_callback_queued = False self._final_callback_queued = False
torch.distributed.all_reduce(self._local_grad_sqr / self._world_size) torch.distributed.all_reduce(self._local_grad_sqr / self._world_size)
local_grad_sqr = self._local_grad_sqr.cpu().numpy() local_grad_sqr = self._local_grad_sqr.cpu().numpy()
total_grad_sqr = np.array([sum(param.grad.pow(2).sum().item() total_grad_sqr = np.array(
for param in group["params"]) [sum(param.grad.pow(2).sum().item() for param in group["params"]) for group in self._optimizer.param_groups]
for group in self._optimizer.param_groups]) )
grad_sqr = ((self._world_size * total_grad_sqr - local_grad_sqr) grad_sqr = (self._world_size * total_grad_sqr - local_grad_sqr) / (self._world_size - 1)
/ (self._world_size - 1)) grad_var = (local_grad_sqr - total_grad_sqr) * self._scale / (self._world_size - 1)
grad_var = ((local_grad_sqr - total_grad_sqr) * self._scale
/ (self._world_size - 1))
grad_sqr = np.maximum(grad_sqr, 0.0) grad_sqr = np.maximum(grad_sqr, 0.0)
grad_var = np.maximum(grad_var, 1e-6) grad_var = np.maximum(grad_var, 1e-6)
theta = self._smoothing ** self._scale theta = self._smoothing ** self._scale
self._update_avg('grad_sqr_avg', grad_sqr, theta) self._update_avg("grad_sqr_avg", grad_sqr, theta)
self._update_avg('grad_var_avg', grad_var, theta) self._update_avg("grad_var_avg", grad_var, theta)
self._local_grad_sqr = None self._local_grad_sqr = None
def step(self, *args, **kwargs): def step(self, *args, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment