Unverified Commit df7db85c authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] using dummy tensor to ensure checkpoint backward pass is called in corner cases (#701)



* [do not merge] testing a corner case

* workaround

* using dummy tensor to fix

* lint

* changelog

* update a comment
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 1bcab8dd
...@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD ## NEXT - TBD
### Fixed ### Fixed
- checkpointing: use dummy tensor to ensure backward pass is called [#701]
### Added ### Added
......
...@@ -165,7 +165,13 @@ def _checkpointed_forward( ...@@ -165,7 +165,13 @@ def _checkpointed_forward(
parent_ctx_dict: Dict[str, Any] = { parent_ctx_dict: Dict[str, Any] = {
"offload": offload_to_cpu, "offload": offload_to_cpu,
} }
output = CheckpointFunction.apply(original_forward, parent_ctx_dict, kwarg_keys, *flat_args) # Dummy tensor with grad is used to ensure the backward pass is called. This is needed
# when original_forward's input are non-tensor (i.e. a tuple). Using this dummy tensor
# avoids requiring users to set their input tensors's requires_grad flag. In the case
# of tuple type inputs, setting the flag won't even trigger the backward pass.
output = CheckpointFunction.apply(
torch.tensor([], requires_grad=True), original_forward, parent_ctx_dict, kwarg_keys, *flat_args
)
if not isinstance(output, torch.Tensor): if not isinstance(output, torch.Tensor):
packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"] packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"]
if packed_non_tensor_outputs: if packed_non_tensor_outputs:
...@@ -214,6 +220,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -214,6 +220,7 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward( # type: ignore def forward( # type: ignore
ctx: Any, ctx: Any,
dummy_tensor_requires_grad: torch.Tensor,
run_function: Any, run_function: Any,
parent_ctx_dict: Dict[str, Any], parent_ctx_dict: Dict[str, Any],
kwarg_keys: Tuple[str, ...], kwarg_keys: Tuple[str, ...],
...@@ -295,4 +302,4 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -295,4 +302,4 @@ class CheckpointFunction(torch.autograd.Function):
torch.autograd.backward(outputs_with_grad, args_with_grad) torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs)
return (None, None, None) + grads return (None, None, None, None) + grads
...@@ -52,7 +52,7 @@ def get_loss_and_gnorm(model, input): ...@@ -52,7 +52,7 @@ def get_loss_and_gnorm(model, input):
class BasicModel(nn.Module): class BasicModel(nn.Module):
"""Basic model with a single FFN being checkpointed. """Basic model with a single FFN being checkpointed.
Used for extensive checkings: equivalency with non-checkpoint, torch-checkpoint, etc. Used for extensive checkings: equivalency with non-checkpoint, torch-checkpoint, etc.
""" """
def __init__(self, use_pytorch_checkpoint=False, use_fairscale_checkpoint=False, **kwargs): def __init__(self, use_pytorch_checkpoint=False, use_fairscale_checkpoint=False, **kwargs):
...@@ -267,3 +267,38 @@ def test_deprecated_path(): ...@@ -267,3 +267,38 @@ def test_deprecated_path():
# Check if direct import works as before. # Check if direct import works as before.
ffn = nn.Sequential(nn.Linear(32, 128), nn.Dropout(p=0.5), nn.Linear(128, 32),) ffn = nn.Sequential(nn.Linear(32, 128), nn.Dropout(p=0.5), nn.Linear(128, 32),)
ffn = deprecated_checkpoint_wrapper(ffn, {}) ffn = deprecated_checkpoint_wrapper(ffn, {})
@skip_if_no_cuda
def test_list_input():
""" Test checkpointing with input argument type being a list.
Note: Testing shows that PyTorch's torch.utils.checkpoint function does not pass this test.
"""
count = 0
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Linear(2, 2)
def forward(self, x):
nonlocal count
count += 1
y = []
for i in x:
y.append(self.conv(i))
return y
model = nn.Sequential(checkpoint_wrapper(Model()), Model()).cuda()
in_data1 = torch.rand(4, 2).cuda()
in_data2 = torch.rand(4, 2).cuda()
# Forward. Count should be 2 for 2 modules.
out = model([in_data1, in_data2])
loss = sum(x.sum() for x in out)
assert count == 2, f"Incorrect count {count}"
# Backward. Adds 1 more forward call due to checkpoint.
loss.backward()
assert count == 3, f"Incorrect count {count}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment