Commit 8be1b053 authored by Carl Case's avatar Carl Case
Browse files

bugfix: keep cache up-to-date on parameter require_grad-ness

parent 9cc74429
...@@ -85,7 +85,15 @@ def cached_cast(cast_fn, x, cache): ...@@ -85,7 +85,15 @@ def cached_cast(cast_fn, x, cache):
if is_nested(x): if is_nested(x):
return type(x)([cached_cast(y) for y in x]) return type(x)([cached_cast(y) for y in x])
if x in cache: if x in cache:
cached_x = cache[x]
# During eval, it's possible to end up caching casted weights
# with requires_grad == False. This is then a problem when they
# get reused on the next train iter. So we ensure that cached
# weights have same requires_grad flag of most recent request.
if x.requires_grad != cached_x.requires_grad:
cached_x.requires_grad_(x.requires_grad)
return cache[x] return cache[x]
casted_x = cast_fn(x) casted_x = cast_fn(x)
cache[x] = casted_x cache[x] = casted_x
return casted_x return casted_x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment