Commit c619fe6e authored by Michael Carilli's avatar Michael Carilli
Browse files

Updating comments

parent 646fc0d0
...@@ -88,10 +88,15 @@ def cached_cast(cast_fn, x, cache): ...@@ -88,10 +88,15 @@ def cached_cast(cast_fn, x, cache):
if x in cache: if x in cache:
cached_x = cache[x] cached_x = cache[x]
if x.requires_grad and cached_x.requires_grad: if x.requires_grad and cached_x.requires_grad:
# Check to make sure x is actually cached_x's autograd parent. # Make sure x is actually cached_x's autograd parent.
if cached_x.grad_fn.next_functions[1][0].variable is not x: if cached_x.grad_fn.next_functions[1][0].variable is not x:
raise RuntimeError("x and cache[x] both require grad, but x is not " raise RuntimeError("x and cache[x] both require grad, but x is not "
"cache[x]'s parent. This is likely an error.") "cache[x]'s parent. This is likely an error.")
# During eval, it's possible to end up caching casted weights with
# requires_grad=False. On the next training iter, if cached_x is found
# and reused from the cache, it will not actually have x as its parent.
# Therefore, we choose to invalidate the cache (and force refreshing the cast)
# if x.requires_grad and cached_x.requires_grad do not match.
if x.requires_grad != cached_x.requires_grad: if x.requires_grad != cached_x.requires_grad:
del cache[x] del cache[x]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment