Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
8be1b053
Commit
8be1b053
authored
May 29, 2018
by
Carl Case
Browse files
bugfix: keep cache up-to-date on parameter require_grad-ness
parent
9cc74429
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
0 deletions
+8
-0
apex/amp/utils.py
apex/amp/utils.py
+8
-0
No files found.
apex/amp/utils.py
View file @
8be1b053
...
...
@@ -85,7 +85,15 @@ def cached_cast(cast_fn, x, cache):
if
is_nested
(
x
):
return
type
(
x
)([
cached_cast
(
y
)
for
y
in
x
])
if
x
in
cache
:
cached_x
=
cache
[
x
]
# During eval, it's possible to end up caching casted weights
# with requires_grad == False. This is then a problem when they
# get reused on the next train iter. So we ensure that cached
# weights have same requires_grad flag of most recent request.
if
x
.
requires_grad
!=
cached_x
.
requires_grad
:
cached_x
.
requires_grad_
(
x
.
requires_grad
)
return
cache
[
x
]
casted_x
=
cast_fn
(
x
)
cache
[
x
]
=
casted_x
return
casted_x
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment