"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3b2830618ddff967a1f3a1307a15e24a75c7ae6e"
Unverified Commit a88c09cf authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Update explanation of is_grad_enabled() use

parent dfd40f9a
...@@ -96,6 +96,16 @@ def cached_cast(cast_fn, x, cache): ...@@ -96,6 +96,16 @@ def cached_cast(cast_fn, x, cache):
# and reused from the cache, it will not actually have x as its parent. # and reused from the cache, it will not actually have x as its parent.
# Therefore, we choose to invalidate the cache (and force refreshing the cast) # Therefore, we choose to invalidate the cache (and force refreshing the cast)
# if x.requires_grad and cached_x.requires_grad do not match. # if x.requires_grad and cached_x.requires_grad do not match.
#
# During eval (i.e. running under with torch.no_grad()) the invalidation
# check would cause the cached value to be dropped every time, because
# cached_x would always be created with requires_grad=False, while x would
# still have requires_grad=True. This would render the cache effectively
# useless during eval. Therefore, if we are running under the no_grad()
# context manager (torch.is_grad_enabled=False) we elide the invalidation
# check, and use the cached value even though its requires_grad flag doesn't
# match. During eval, we don't care that there's no autograd-graph
# connection between x and cached_x.
if torch.is_grad_enabled() and x.requires_grad != cached_x.requires_grad: if torch.is_grad_enabled() and x.requires_grad != cached_x.requires_grad:
del cache[x] del cache[x]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment