Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
887a50bd
Commit
887a50bd
authored
Apr 16, 2019
by
Michael Carilli
Browse files
Better way to expose scale adjustment
parent
9efb2809
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
3 deletions
+6
-3
apex/amp/handle.py
apex/amp/handle.py
+1
-1
apex/optimizers/fused_adam.py
apex/optimizers/fused_adam.py
+5
-2
No files found.
apex/amp/handle.py
View file @
887a50bd
...
...
@@ -122,7 +122,7 @@ def scale_loss(loss,
# necessary because amp.scale_loss is already creating a temporary scope.
def
patch_step
(
opt
,
loss_scaler
,
loss_id
):
opt_step
=
opt
.
step
def
skip_step
(
scale
=
None
):
def
skip_step
():
maybe_print
((
"Gradient overflow. Skipping step, loss scaler "
+
"{} reducing loss scale to {}"
).
format
(
loss_id
,
loss_scaler
.
loss_scale
()))
...
...
apex/optimizers/fused_adam.py
View file @
887a50bd
...
...
@@ -39,7 +39,8 @@ class FusedAdam(torch.optim.Optimizer):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
eps_inside_sqrt
=
False
,
weight_decay
=
0.
,
max_grad_norm
=
0.
,
amsgrad
=
False
,
use_mt
=
False
):
weight_decay
=
0.
,
max_grad_norm
=
0.
,
amsgrad
=
False
,
use_mt
=
False
,
amp_scale_adjustment
=
1.0
):
global
fused_adam_cuda
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
...
...
@@ -51,6 +52,8 @@ class FusedAdam(torch.optim.Optimizer):
self
.
_use_multi_tensor
=
True
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_amp_scale_adjustment
=
amp_scale_adjustment
if
amsgrad
:
raise
RuntimeError
(
'FusedAdam does not support the AMSGrad variant.'
)
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
...
...
@@ -81,7 +84,7 @@ class FusedAdam(torch.optim.Optimizer):
if
hasattr
(
self
,
"_amp_stash"
):
grads
=
self
.
_amp_stash
.
grads
output_params
=
self
.
_amp_stash
.
output_params
scale
=
self
.
_amp_stash
.
scale
*
s
cale
scale
=
self
.
_amp_stash
.
scale
*
s
elf
.
_amp_scale_adjustment
grad_norms
=
self
.
_amp_stash
.
grad_norms
if
grads
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment