Commit 614b11ff authored by Carl Case's avatar Carl Case
Browse files

support multi-loss scaling per-optimizer correctly

parent 8be1b053
import contextlib
import logging
import warnings
from .scaler import LossScaler
from .scaler import LossScaler, iter_params
import numpy as np
......@@ -28,27 +29,33 @@ class OptimWrapper(object):
loss_backward()
loss.backward = warning_wrapper
# if loss_idx > 0:
# save out current grads to buffers
# keep some group caches
# .detach().clone()
# zero grads
# When there are multiple losses per-optimizer, we need
# to save out current grad accumulation, since we won't be
# able to unscale this particulare loss once the grads are
# all mixed together.
cached_grads = []
if self._loss_idx > 0:
for p in iter_params(self._optimizer.param_groups):
if p.grad is not None:
cached_grads.append(p.grad.data.detach().clone())
else:
cached_grads.append(None)
self._optimizer.zero_grad()
loss_scale = self._cur_loss_scaler().loss_scale()
print('Loss scale (log): {}'.format(np.log2(loss_scale)))
yield loss * loss_scale
loss.backward = loss_backward
self._skip_next[self._loss_idx] = self._cur_loss_scaler().unscale_and_update(
self._optimizer.param_groups, loss_scale)
print('GOT SKIP NEXT: {}'.format(self._skip_next[self._loss_idx]))
self._loss_idx += 1
# if loss_idx > 0:
# += saved state into grads
if len(cached_grads) > 0:
for p, cached_grad in zip(iter_params(self._optimizer.param_groups),
cached_grads):
if cached_grad is not None:
p.grad.data.add_(cached_grad)
cached_grads = []
def _cur_loss_scaler(self):
assert 0 <= self._loss_idx < self._num_loss
......@@ -69,8 +76,8 @@ class OptimWrapper(object):
'The `closure` argument is unsupported by the amp ' +
'optimizer wrapper.')
if any(self._skip_next):
logging.info('Gradient overflow, skipping update')
self._skip_next = [False] * self._num_loss
print('SKIP')
else:
return self._optimizer.step(closure=closure)
......
......@@ -15,8 +15,7 @@ class LossScaler(object):
def unscale_and_update(self, param_groups, scale):
self._overflow_buf.zero_()
for group in param_groups:
for p in group['params']:
for p in iter_params(param_groups):
if p.grad is not None:
scale_lib.scale_check_overflow(p.grad.data,
1. / scale,
......@@ -35,3 +34,8 @@ class LossScaler(object):
self._unskipped = 0
return should_skip
def iter_params(param_groups):
for group in param_groups:
for p in group['params']:
yield p
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment