Commit 4ee0537a authored by Devendra Singh Sachan's avatar Devendra Singh Sachan
Browse files

small fix for CheckpointFunction's backward() method when sone args may be NoneType

parent 46a536cc
......@@ -234,7 +234,9 @@ class CheckpointFunction(torch.autograd.Function):
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
torch.autograd.backward(outputs, args)
return (None,) + tuple(inp.grad for inp in detached_inputs)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs)
return (None,) + grads
def checkpoint(function, *args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment