Commit c20f4d48 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'checkpointing-fix' into 'master'

small fix for CheckpointFunction's backward() method when some args may be NoneType

See merge request ADLR/megatron-lm!92
parents 46a536cc 4ee0537a
...@@ -234,7 +234,9 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -234,7 +234,9 @@ class CheckpointFunction(torch.autograd.Function):
if isinstance(outputs, torch.Tensor): if isinstance(outputs, torch.Tensor):
outputs = (outputs,) outputs = (outputs,)
torch.autograd.backward(outputs, args) torch.autograd.backward(outputs, args)
return (None,) + tuple(inp.grad for inp in detached_inputs) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs)
return (None,) + grads
def checkpoint(function, *args): def checkpoint(function, *args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment