Commit 4a9c2a53 authored by Michael Carilli's avatar Michael Carilli
Browse files

Adding delay_overflow_check=False ninja control point

parent ae7f0def
......@@ -17,7 +17,8 @@ def scale_loss(loss,
optimizers,
loss_id=0,
model=None,
delay_unscale=False):
delay_unscale=False,
delay_overflow_check=False):
"""
On context manager entrance, creates ``scaled_loss = (loss.float())*current loss scale``.
``scaled_loss`` is yielded so that the user can call ``scaled_loss.backward()``::
......@@ -127,7 +128,7 @@ def scale_loss(loss,
optimizer._amp_stash.params_have_scaled_gradients = False
# For future fused optimizers that enable sync-free dynamic loss scaling,
# should_skip will always be False.
should_skip = loss_scaler.update_scale()
should_skip = False if delay_overflow_check else loss_scaler.update_scale()
if should_skip:
for optimizer in optimizers:
if not optimizer._amp_stash.already_patched:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment