Commit 614b11ff authored by Carl Case's avatar Carl Case
Browse files

support multi-loss scaling per-optimizer correctly

parent 8be1b053
import contextlib import contextlib
import logging
import warnings import warnings
from .scaler import LossScaler from .scaler import LossScaler, iter_params
import numpy as np import numpy as np
...@@ -28,27 +29,33 @@ class OptimWrapper(object): ...@@ -28,27 +29,33 @@ class OptimWrapper(object):
loss_backward() loss_backward()
loss.backward = warning_wrapper loss.backward = warning_wrapper
# When there are multiple losses per-optimizer, we need
# if loss_idx > 0: # to save out current grad accumulation, since we won't be
# save out current grads to buffers # able to unscale this particulare loss once the grads are
# keep some group caches # all mixed together.
# .detach().clone() cached_grads = []
# zero grads if self._loss_idx > 0:
for p in iter_params(self._optimizer.param_groups):
if p.grad is not None:
cached_grads.append(p.grad.data.detach().clone())
else:
cached_grads.append(None)
self._optimizer.zero_grad()
loss_scale = self._cur_loss_scaler().loss_scale() loss_scale = self._cur_loss_scaler().loss_scale()
print('Loss scale (log): {}'.format(np.log2(loss_scale)))
yield loss * loss_scale yield loss * loss_scale
loss.backward = loss_backward loss.backward = loss_backward
self._skip_next[self._loss_idx] = self._cur_loss_scaler().unscale_and_update( self._skip_next[self._loss_idx] = self._cur_loss_scaler().unscale_and_update(
self._optimizer.param_groups, loss_scale) self._optimizer.param_groups, loss_scale)
print('GOT SKIP NEXT: {}'.format(self._skip_next[self._loss_idx]))
self._loss_idx += 1 self._loss_idx += 1
# if loss_idx > 0: if len(cached_grads) > 0:
# += saved state into grads for p, cached_grad in zip(iter_params(self._optimizer.param_groups),
cached_grads):
if cached_grad is not None:
p.grad.data.add_(cached_grad)
cached_grads = []
def _cur_loss_scaler(self): def _cur_loss_scaler(self):
assert 0 <= self._loss_idx < self._num_loss assert 0 <= self._loss_idx < self._num_loss
...@@ -69,8 +76,8 @@ class OptimWrapper(object): ...@@ -69,8 +76,8 @@ class OptimWrapper(object):
'The `closure` argument is unsupported by the amp ' + 'The `closure` argument is unsupported by the amp ' +
'optimizer wrapper.') 'optimizer wrapper.')
if any(self._skip_next): if any(self._skip_next):
logging.info('Gradient overflow, skipping update')
self._skip_next = [False] * self._num_loss self._skip_next = [False] * self._num_loss
print('SKIP')
else: else:
return self._optimizer.step(closure=closure) return self._optimizer.step(closure=closure)
......
...@@ -15,12 +15,11 @@ class LossScaler(object): ...@@ -15,12 +15,11 @@ class LossScaler(object):
def unscale_and_update(self, param_groups, scale): def unscale_and_update(self, param_groups, scale):
self._overflow_buf.zero_() self._overflow_buf.zero_()
for group in param_groups: for p in iter_params(param_groups):
for p in group['params']: if p.grad is not None:
if p.grad is not None: scale_lib.scale_check_overflow(p.grad.data,
scale_lib.scale_check_overflow(p.grad.data, 1. / scale,
1. / scale, self._overflow_buf)
self._overflow_buf)
if self._overflow_buf.any(): if self._overflow_buf.any():
should_skip = True should_skip = True
...@@ -35,3 +34,8 @@ class LossScaler(object): ...@@ -35,3 +34,8 @@ class LossScaler(object):
self._unskipped = 0 self._unskipped = 0
return should_skip return should_skip
def iter_params(param_groups):
for group in param_groups:
for p in group['params']:
yield p
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment