"official/modeling/fast_training/progressive/utils.py" did not exist on "99a30fa2527eb115d883ec7056fcf7a0e9182dac"
Commit 4ee0537a authored by Devendra Singh Sachan's avatar Devendra Singh Sachan
Browse files

small fix for CheckpointFunction's backward() method when sone args may be NoneType

parent 46a536cc
......@@ -234,7 +234,9 @@ class CheckpointFunction(torch.autograd.Function):
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
torch.autograd.backward(outputs, args)
return (None,) + tuple(inp.grad for inp in detached_inputs)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs)
return (None,) + grads
def checkpoint(function, *args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment