Commit 0750a757 authored by Michael Carilli's avatar Michael Carilli
Browse files

delay_unscale is never necessary and generally discouraged, but should still work for some cases

parent 3f87614f
...@@ -261,6 +261,7 @@ def _process_optimizer(optimizer, properties): ...@@ -261,6 +261,7 @@ def _process_optimizer(optimizer, properties):
optimizer._amp_stash.lazy_init_called = False optimizer._amp_stash.lazy_init_called = False
optimizer._amp_stash.already_patched = False optimizer._amp_stash.already_patched = False
optimizer._amp_stash.params_have_scaled_gradients = False
for name in ("_lazy_init_maybe_master_weights", for name in ("_lazy_init_maybe_master_weights",
"_master_params_to_model_params", "_master_params_to_model_params",
......
...@@ -57,8 +57,9 @@ def scale_loss(loss, ...@@ -57,8 +57,9 @@ def scale_loss(loss,
will use the default global loss scaler for this backward pass. will use the default global loss scaler for this backward pass.
model(torch.nn.Module, optional, default=None): Currently unused, reserved to enable future model(torch.nn.Module, optional, default=None): Currently unused, reserved to enable future
optimizations. optimizations.
delay_unscale(bool, optional, default=False): ``delay_unscale`` is a ninja option that only delay_unscale(bool, optional, default=False): ``delay_unscale`` is never necessary.
serves as a minor performance optimization, so only use it if you know what you're doing. It's a minor ninja performance optimization and can result in weird gotchas (especially
with multiple models/optimzers/losses), so only use it if you know what you're doing.
If ``True``, Amp will not unscale the gradients or perform model->master If ``True``, Amp will not unscale the gradients or perform model->master
gradient copies on context manager exit. gradient copies on context manager exit.
"Gradient accumulation across iterations" under `Advanced Amp Usage`_ "Gradient accumulation across iterations" under `Advanced Amp Usage`_
...@@ -98,18 +99,24 @@ def scale_loss(loss, ...@@ -98,18 +99,24 @@ def scale_loss(loss,
_amp_state.handle._clear_cache() _amp_state.handle._clear_cache()
return return
if isinstance(optimizers, list): if not delay_unscale:
for optimizer in optimizers: if isinstance(optimizers, list):
optimizer._prepare_amp_backward() for optimizer in optimizers:
if not optimizer._amp_stash.params_have_scaled_gradients:
optimizer._prepare_amp_backward()
yield (loss.float())*loss_scale yield (loss.float())*loss_scale
if not delay_unscale: if delay_unscale:
for optimizer in optimizers:
optimizer._amp_stash.params_have_scaled_gradients = True
else:
# FusedAdam and FusedSGD will take care of unscaling as part of their step() methods. # FusedAdam and FusedSGD will take care of unscaling as part of their step() methods.
if not isinstance(optimizers, FP16_Optimizer_for_fused): if not isinstance(optimizers, FP16_Optimizer_for_fused):
loss_scaler.clear_overflow_state() loss_scaler.clear_overflow_state()
for optimizer in optimizers: for optimizer in optimizers:
optimizer._post_amp_backward(loss_scaler) optimizer._post_amp_backward(loss_scaler)
optimizer._amp_stash.params_have_scaled_gradients = False
# For future fused optimizers that enable sync-free dynamic loss scaling, # For future fused optimizers that enable sync-free dynamic loss scaling,
# should_skip will always be False. # should_skip will always be False.
should_skip = loss_scaler.update_scale() should_skip = loss_scaler.update_scale()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment