Commit a11c45a4 authored by Michael Carilli's avatar Michael Carilli
Browse files

Removing patching of loss.backward, which appears to cause memory leaks...

Removing patching of loss.backward, which appears to cause memory leaks (reference cycles?) in some models
parent 45537d34
......@@ -40,16 +40,8 @@ class AmpHandle(object):
'use `optimizer.scale_loss(loss)`.')
# TODO: this code block is duplicated here and `opt.py`. Unify.
loss_backward = loss.backward
def warning_wrapper():
warnings.warn("You called .backward() on the unscaled loss "
"inside a scale_loss block. This is almost "
"certainly an error.", stacklevel=2)
loss_backward()
loss.backward = warning_wrapper
loss_scale = self._default_scaler.loss_scale()
yield loss * loss_scale
loss.backward = loss_backward
should_skip = self._default_scaler.unscale_and_update(
optimizer.param_groups, loss_scale)
......
......@@ -21,14 +21,6 @@ class OptimWrapper(object):
yield loss
return
loss_backward = loss.backward
def warning_wrapper():
warnings.warn("You called .backward() on the unscaled loss "
"inside a scale_loss block. This is almost "
"certainly an error.", stacklevel=2)
loss_backward()
loss.backward = warning_wrapper
# When there are multiple losses per-optimizer, we need
# to save out current grad accumulation, since we won't be
# able to unscale this particulare loss once the grads are
......@@ -44,7 +36,6 @@ class OptimWrapper(object):
loss_scale = self._cur_loss_scaler().loss_scale()
yield loss * loss_scale
loss.backward = loss_backward
self._skip_next[self._loss_idx] = self._cur_loss_scaler().unscale_and_update(
self._optimizer.param_groups, loss_scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment