Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
0750a757
Commit
0750a757
authored
Apr 04, 2019
by
Michael Carilli
Browse files
delay_unscale is never necessary and generally discouraged, but should still work for some cases
parent
3f87614f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
6 deletions
+14
-6
apex/amp/_process_optimizer.py
apex/amp/_process_optimizer.py
+1
-0
apex/amp/handle.py
apex/amp/handle.py
+13
-6
No files found.
apex/amp/_process_optimizer.py
View file @
0750a757
...
@@ -261,6 +261,7 @@ def _process_optimizer(optimizer, properties):
...
@@ -261,6 +261,7 @@ def _process_optimizer(optimizer, properties):
optimizer
.
_amp_stash
.
lazy_init_called
=
False
optimizer
.
_amp_stash
.
lazy_init_called
=
False
optimizer
.
_amp_stash
.
already_patched
=
False
optimizer
.
_amp_stash
.
already_patched
=
False
optimizer
.
_amp_stash
.
params_have_scaled_gradients
=
False
for
name
in
(
"_lazy_init_maybe_master_weights"
,
for
name
in
(
"_lazy_init_maybe_master_weights"
,
"_master_params_to_model_params"
,
"_master_params_to_model_params"
,
...
...
apex/amp/handle.py
View file @
0750a757
...
@@ -57,8 +57,9 @@ def scale_loss(loss,
...
@@ -57,8 +57,9 @@ def scale_loss(loss,
will use the default global loss scaler for this backward pass.
will use the default global loss scaler for this backward pass.
model(torch.nn.Module, optional, default=None): Currently unused, reserved to enable future
model(torch.nn.Module, optional, default=None): Currently unused, reserved to enable future
optimizations.
optimizations.
delay_unscale(bool, optional, default=False): ``delay_unscale`` is a ninja option that only
delay_unscale(bool, optional, default=False): ``delay_unscale`` is never necessary.
serves as a minor performance optimization, so only use it if you know what you're doing.
It's a minor ninja performance optimization and can result in weird gotchas (especially
with multiple models/optimzers/losses), so only use it if you know what you're doing.
If ``True``, Amp will not unscale the gradients or perform model->master
If ``True``, Amp will not unscale the gradients or perform model->master
gradient copies on context manager exit.
gradient copies on context manager exit.
"Gradient accumulation across iterations" under `Advanced Amp Usage`_
"Gradient accumulation across iterations" under `Advanced Amp Usage`_
...
@@ -98,18 +99,24 @@ def scale_loss(loss,
...
@@ -98,18 +99,24 @@ def scale_loss(loss,
_amp_state
.
handle
.
_clear_cache
()
_amp_state
.
handle
.
_clear_cache
()
return
return
if
isinstance
(
optimizers
,
list
):
if
not
delay_unscale
:
for
optimizer
in
optimizers
:
if
isinstance
(
optimizers
,
list
):
optimizer
.
_prepare_amp_backward
()
for
optimizer
in
optimizers
:
if
not
optimizer
.
_amp_stash
.
params_have_scaled_gradients
:
optimizer
.
_prepare_amp_backward
()
yield
(
loss
.
float
())
*
loss_scale
yield
(
loss
.
float
())
*
loss_scale
if
not
delay_unscale
:
if
delay_unscale
:
for
optimizer
in
optimizers
:
optimizer
.
_amp_stash
.
params_have_scaled_gradients
=
True
else
:
# FusedAdam and FusedSGD will take care of unscaling as part of their step() methods.
# FusedAdam and FusedSGD will take care of unscaling as part of their step() methods.
if
not
isinstance
(
optimizers
,
FP16_Optimizer_for_fused
):
if
not
isinstance
(
optimizers
,
FP16_Optimizer_for_fused
):
loss_scaler
.
clear_overflow_state
()
loss_scaler
.
clear_overflow_state
()
for
optimizer
in
optimizers
:
for
optimizer
in
optimizers
:
optimizer
.
_post_amp_backward
(
loss_scaler
)
optimizer
.
_post_amp_backward
(
loss_scaler
)
optimizer
.
_amp_stash
.
params_have_scaled_gradients
=
False
# For future fused optimizers that enable sync-free dynamic loss scaling,
# For future fused optimizers that enable sync-free dynamic loss scaling,
# should_skip will always be False.
# should_skip will always be False.
should_skip
=
loss_scaler
.
update_scale
()
should_skip
=
loss_scaler
.
update_scale
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment