Unverified Commit 7a3ed9e2 authored by Zhang Haitao's avatar Zhang Haitao Committed by GitHub
Browse files

[PyTorch] Support non-tensor inputs/outputs for te CheckpointFunction (#581)



* support non-tensor inputs/outputs for checkpoint
Signed-off-by: default avatarskydoorkai <htsantaclara@163.com>

* better format
Signed-off-by: default avatarskydoorkai <htsantaclara@163.com>

* modify to avoid python loops
Signed-off-by: default avatarskydoorkai <htsantaclara@163.com>

---------
Signed-off-by: default avatarskydoorkai <htsantaclara@163.com>
parent 94f54d71
...@@ -216,7 +216,10 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -216,7 +216,10 @@ class CheckpointFunction(torch.autograd.Function):
) )
# Store everything. # Store everything.
ctx.save_for_backward(*args) ctx.inputs = [arg if not torch.is_tensor(arg) else None for arg in args]
tensor_inputs = [arg if torch.is_tensor(arg) else None for arg in args]
ctx.save_for_backward(*tensor_inputs)
ctx.get_cuda_rng_tracker = get_cuda_rng_tracker ctx.get_cuda_rng_tracker = get_cuda_rng_tracker
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.kwargs = kwargs ctx.kwargs = kwargs
...@@ -233,7 +236,12 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -233,7 +236,12 @@ class CheckpointFunction(torch.autograd.Function):
"Checkpointing is not compatible with .grad(), " "Checkpointing is not compatible with .grad(), "
"please use .backward() if possible" "please use .backward() if possible"
) )
inputs = ctx.saved_tensors
inputs = tuple(
t if t is not None else arg
for (t, arg) in zip(ctx.saved_tensors, ctx.inputs)
)
get_cuda_rng_tracker = ctx.get_cuda_rng_tracker get_cuda_rng_tracker = ctx.get_cuda_rng_tracker
if ctx.distribute_saved_activations: if ctx.distribute_saved_activations:
...@@ -269,9 +277,22 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -269,9 +277,22 @@ class CheckpointFunction(torch.autograd.Function):
if isinstance(outputs, torch.Tensor): if isinstance(outputs, torch.Tensor):
outputs = (outputs,) outputs = (outputs,)
torch.autograd.backward(outputs, args)
outputs_with_grad = []
args_with_grad = []
for i, output in enumerate(outputs):
if torch.is_tensor(output) and output.requires_grad:
outputs_with_grad.append(output)
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True,"
" this checkpoint() is not necessary"
)
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple( grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else inp inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs for inp in detached_inputs
) )
return (None, None, None, None, None) + grads return (None, None, None, None, None) + grads
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment