Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
a88c09cf
Unverified
Commit
a88c09cf
authored
Jan 25, 2019
by
mcarilli
Committed by
GitHub
Jan 25, 2019
Browse files
Update explanation of is_grad_enabled() use
parent
dfd40f9a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
0 deletions
+10
-0
apex/amp/utils.py
apex/amp/utils.py
+10
-0
No files found.
apex/amp/utils.py
View file @
a88c09cf
...
...
@@ -96,6 +96,16 @@ def cached_cast(cast_fn, x, cache):
# and reused from the cache, it will not actually have x as its parent.
# Therefore, we choose to invalidate the cache (and force refreshing the cast)
# if x.requires_grad and cached_x.requires_grad do not match.
#
# During eval (i.e. running under with torch.no_grad()) the invalidation
# check would cause the cached value to be dropped every time, because
# cached_x would always be created with requires_grad=False, while x would
# still have requires_grad=True. This would render the cache effectively
# useless during eval. Therefore, if we are running under the no_grad()
# context manager (torch.is_grad_enabled=False) we elide the invalidation
# check, and use the cached value even though its requires_grad flag doesn't
# match. During eval, we don't care that there's no autograd-graph
# connection between x and cached_x.
if
torch
.
is_grad_enabled
()
and
x
.
requires_grad
!=
cached_x
.
requires_grad
:
del
cache
[
x
]
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment