Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
1d4a95d4
Commit
1d4a95d4
authored
Mar 12, 2020
by
Thor Johnsen
Browse files
Modify fused_adam to take advantage of undo feature
parent
d48218a0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
4 deletions
+7
-4
apex/contrib/optimizers/fused_adam.py
apex/contrib/optimizers/fused_adam.py
+7
-4
No files found.
apex/contrib/optimizers/fused_adam.py
View file @
1d4a95d4
...
@@ -92,7 +92,7 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -92,7 +92,7 @@ class FusedAdam(torch.optim.Optimizer):
stride
,
stride
,
1
if
clear
else
0
)
1
if
clear
else
0
)
def
step
(
self
,
closure
=
None
,
grads
=
None
,
output_params
=
None
,
scale
=
1.
,
grad_norms
=
None
):
def
step
(
self
,
closure
=
None
,
grads
=
None
,
output_params
=
None
,
scale
=
1.
,
grad_norms
=
None
,
allow_undo
=
False
):
"""Performs a single optimization step.
"""Performs a single optimization step.
Arguments:
Arguments:
...
@@ -106,12 +106,15 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -106,12 +106,15 @@ class FusedAdam(torch.optim.Optimizer):
updated weights. Have to be of same type as gradients. (default: None)
updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional): factor to divide gradient tensor values
scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1)
by before applying to weights. (default: 1)
allow_undo (bool, optional): allow use of undo feature. Internal buffers
will be restored to pre-step state if overflow is detected in gradient.
"""
"""
loss
=
None
loss
=
None
if
closure
is
not
None
:
if
closure
is
not
None
:
loss
=
closure
()
loss
=
closure
()
self
.
_step
(
grads
,
output_params
,
scale
,
grad_norms
,
False
)
self
.
_step
(
grads
,
output_params
,
scale
,
grad_norms
,
False
)
if
allow_overflow
:
self
.
strided_check_finite
(
output_params
,
output_params
.
numel
(),
0
,
output_params
.
numel
())
self
.
strided_check_finite
(
output_params
,
output_params
.
numel
(),
0
,
output_params
.
numel
())
if
self
.
peek_overflow
:
if
self
.
peek_overflow
:
self
.
_step
(
grads
,
output_params
,
scale
,
grad_norms
,
True
)
self
.
_step
(
grads
,
output_params
,
scale
,
grad_norms
,
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment