You need to sign in or sign up before continuing.
Commit 4ee0537a authored by Devendra Singh Sachan's avatar Devendra Singh Sachan
Browse files

small fix for CheckpointFunction's backward() method when sone args may be NoneType

parent 46a536cc
...@@ -234,7 +234,9 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -234,7 +234,9 @@ class CheckpointFunction(torch.autograd.Function):
if isinstance(outputs, torch.Tensor): if isinstance(outputs, torch.Tensor):
outputs = (outputs,) outputs = (outputs,)
torch.autograd.backward(outputs, args) torch.autograd.backward(outputs, args)
return (None,) + tuple(inp.grad for inp in detached_inputs) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs)
return (None,) + grads
def checkpoint(function, *args): def checkpoint(function, *args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment