Unverified Commit 2412429d authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[util] fixed activation checkpointing on torch 1.9 (#719)

parent 04ff5ea5
...@@ -68,7 +68,10 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -68,7 +68,10 @@ class CheckpointFunction(torch.autograd.Function):
else: else:
ctx.inputs.append(arg) ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs) if activation_offload:
ctx.tensor_inputs = tensor_inputs
else:
ctx.save_for_backward(*tensor_inputs)
return outputs return outputs
@staticmethod @staticmethod
...@@ -79,7 +82,11 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -79,7 +82,11 @@ class CheckpointFunction(torch.autograd.Function):
# Copy the list to avoid modifying original list. # Copy the list to avoid modifying original list.
inputs = list(ctx.inputs) inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors
if ctx.activation_offload:
tensors = ctx.tensor_inputs
else:
tensors = ctx.saved_tensors
# store the current states # store the current states
bwd_cpu_rng_state = torch.get_rng_state() bwd_cpu_rng_state = torch.get_rng_state()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment