Commit 4a9c2a53 authored by Michael Carilli's avatar Michael Carilli
Browse files

Adding delay_overflow_check=False ninja control point

parent ae7f0def
...@@ -17,7 +17,8 @@ def scale_loss(loss, ...@@ -17,7 +17,8 @@ def scale_loss(loss,
optimizers, optimizers,
loss_id=0, loss_id=0,
model=None, model=None,
delay_unscale=False): delay_unscale=False,
delay_overflow_check=False):
""" """
On context manager entrance, creates ``scaled_loss = (loss.float())*current loss scale``. On context manager entrance, creates ``scaled_loss = (loss.float())*current loss scale``.
``scaled_loss`` is yielded so that the user can call ``scaled_loss.backward()``:: ``scaled_loss`` is yielded so that the user can call ``scaled_loss.backward()``::
...@@ -127,7 +128,7 @@ def scale_loss(loss, ...@@ -127,7 +128,7 @@ def scale_loss(loss,
optimizer._amp_stash.params_have_scaled_gradients = False optimizer._amp_stash.params_have_scaled_gradients = False
# For future fused optimizers that enable sync-free dynamic loss scaling, # For future fused optimizers that enable sync-free dynamic loss scaling,
# should_skip will always be False. # should_skip will always be False.
should_skip = loss_scaler.update_scale() should_skip = False if delay_overflow_check else loss_scaler.update_scale()
if should_skip: if should_skip:
for optimizer in optimizers: for optimizer in optimizers:
if not optimizer._amp_stash.already_patched: if not optimizer._amp_stash.already_patched:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment