Unverified Commit 7a3ed9e2 authored by Zhang Haitao's avatar Zhang Haitao Committed by GitHub
Browse files

[PyTorch] Support non-tensor inputs/outputs for te CheckpointFunction (#581)



* support non-tensor inputs/outputs for checkpoint
Signed-off-by: default avatarskydoorkai <htsantaclara@163.com>

* better format
Signed-off-by: default avatarskydoorkai <htsantaclara@163.com>

* modify to avoid python loops
Signed-off-by: default avatarskydoorkai <htsantaclara@163.com>

---------
Signed-off-by: default avatarskydoorkai <htsantaclara@163.com>
parent 94f54d71
......@@ -216,7 +216,10 @@ class CheckpointFunction(torch.autograd.Function):
)
# Store everything.
ctx.save_for_backward(*args)
ctx.inputs = [arg if not torch.is_tensor(arg) else None for arg in args]
tensor_inputs = [arg if torch.is_tensor(arg) else None for arg in args]
ctx.save_for_backward(*tensor_inputs)
ctx.get_cuda_rng_tracker = get_cuda_rng_tracker
ctx.tp_group = tp_group
ctx.kwargs = kwargs
......@@ -233,7 +236,12 @@ class CheckpointFunction(torch.autograd.Function):
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
inputs = ctx.saved_tensors
inputs = tuple(
t if t is not None else arg
for (t, arg) in zip(ctx.saved_tensors, ctx.inputs)
)
get_cuda_rng_tracker = ctx.get_cuda_rng_tracker
if ctx.distribute_saved_activations:
......@@ -269,9 +277,22 @@ class CheckpointFunction(torch.autograd.Function):
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
torch.autograd.backward(outputs, args)
outputs_with_grad = []
args_with_grad = []
for i, output in enumerate(outputs):
if torch.is_tensor(output) and output.requires_grad:
outputs_with_grad.append(output)
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True,"
" this checkpoint() is not necessary"
)
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else inp
inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs
)
return (None, None, None, None, None) + grads
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment