Unverified Commit 482944d9 authored by Alex Xiao's avatar Alex Xiao Committed by GitHub
Browse files

[feat] set requires_grad of output tensors of checkpointed modules properly (#787)



Before this commit, output tensors of checkpointed modules always
require grad, even if they shouldn't. This commit makes it so that
the outputs of checkpointed modules only require grad if either
the input requires grad or if the parameters require grad.

To achieve this, this commit also adds a new _unflattened_param_views
attribute to modules being flattened. This allows the checkpointing
to still access the parameters and check if gradients need to be
computed.
Co-authored-by: default avatarAlex Xiao <axiao@fb.com>
parent 3fb8aa2b
...@@ -191,13 +191,33 @@ def _checkpointed_forward( ...@@ -191,13 +191,33 @@ def _checkpointed_forward(
# when original_forward's input are non-tensor (i.e. a tuple). Using this dummy tensor # when original_forward's input are non-tensor (i.e. a tuple). Using this dummy tensor
# avoids requiring users to set their input tensors's requires_grad flag. In the case # avoids requiring users to set their input tensors's requires_grad flag. In the case
# of tuple type inputs, setting the flag won't even trigger the backward pass. # of tuple type inputs, setting the flag won't even trigger the backward pass.
#
# One implication of this is that since we always feed in a dummy tensor
# needing grad, then the output will always require grad, even if it originally
# wouldn't, such as if the module and original input both do not require grad.
# We get around this by saving the desired requires_grad value in output and
# detaching the output if needed.
output = CheckpointFunction.apply( output = CheckpointFunction.apply(
torch.tensor([], requires_grad=True), original_forward, parent_ctx_dict, kwarg_keys, *flat_args torch.tensor([], requires_grad=True), original_forward, parent_ctx_dict, kwarg_keys, *flat_args
) )
output_requires_grad = parent_ctx_dict["output_requires_grad"]
if not isinstance(output, torch.Tensor): if not isinstance(output, torch.Tensor):
# If output should not require grad, then detach it, since otherwise it will
# always have requires_grad = True due to our dummy tensor input above that
# requires_grad
output = [x.detach() if not output_requires_grad else x for x in output]
packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"] packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"]
if packed_non_tensor_outputs: if packed_non_tensor_outputs:
output = unpack_non_tensors(output, packed_non_tensor_outputs) output = unpack_non_tensors(output, packed_non_tensor_outputs)
else:
# If output should not require grad, then detach it, since otherwise it will
# always have requires_grad = True due to our dummy tensor input above that
# requires_grad
if not output_requires_grad:
output = output.detach()
return output return output
...@@ -273,12 +293,29 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -273,12 +293,29 @@ class CheckpointFunction(torch.autograd.Function):
the_module = unpacked_args[0] the_module = unpacked_args[0]
inc_counter(the_module) inc_counter(the_module)
# Because we run with torch.no_grad(), we can't actually access
# outputs.requires_grad. Instead, we manually compute it by
# checking if either the input or the module needs grads
parameters = list(the_module.parameters())
# If the module is wrapped by FlattenParamsWrapper, then the
# parameters would have been deleted. If so, we need to access
# the views into the flattened parameters.
if hasattr(the_module, "_unflattened_param_views"):
parameters += the_module._unflattened_param_views
output_requires_grad = any(param.requires_grad for param in parameters) or any(
x.requires_grad for x in tensor_inputs
)
parent_ctx_dict["output_requires_grad"] = output_requires_grad
if not isinstance(outputs, torch.Tensor): if not isinstance(outputs, torch.Tensor):
# Autograd Functions don't like non-Tensor outputs. We can split the # Autograd Functions don't like non-Tensor outputs. We can split the
# non-Tensor and Tensor outputs, returning the former by reference # non-Tensor and Tensor outputs, returning the former by reference
# through *parent_ctx_dict* and returning the latter directly. # through *parent_ctx_dict* and returning the latter directly.
outputs, packed_non_tensor_outputs = split_non_tensors(outputs) outputs, packed_non_tensor_outputs = split_non_tensors(outputs)
parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs
return outputs return outputs
@staticmethod @staticmethod
...@@ -317,10 +354,12 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -317,10 +354,12 @@ class CheckpointFunction(torch.autograd.Function):
if tensor_outputs[i].requires_grad: if tensor_outputs[i].requires_grad:
outputs_with_grad.append(tensor_outputs[i]) outputs_with_grad.append(tensor_outputs[i])
args_with_grad.append(args[i]) args_with_grad.append(args[i])
if len(outputs_with_grad) == 0: if len(outputs_with_grad) == 0:
raise RuntimeError("None of the outputs have requires_grad=True, " "this checkpoint() is not necessary") raise RuntimeError("None of the outputs have requires_grad=True, " "this checkpoint() is not necessary")
torch.autograd.backward(outputs_with_grad, args_with_grad) torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs)
return (None, None, None, None) + grads return (None, None, None, None) + grads
...@@ -325,6 +325,11 @@ class FlattenParamsWrapper(nn.Module): ...@@ -325,6 +325,11 @@ class FlattenParamsWrapper(nn.Module):
delattr(m, n) delattr(m, n)
m.register_parameter(n, getattr(shared_m, shared_n)) m.register_parameter(n, getattr(shared_m, shared_n))
# Delete the param views into the flat params since we will delete the
# flat params next
if hasattr(self._fpw_module, "_unflattened_param_views"):
delattr(self._fpw_module, "_unflattened_param_views")
for n in self.flat_param_names: for n in self.flat_param_names:
# This ensures the flat params are removed from the module. # This ensures the flat params are removed from the module.
delattr(self, n) delattr(self, n)
...@@ -336,8 +341,15 @@ class FlattenParamsWrapper(nn.Module): ...@@ -336,8 +341,15 @@ class FlattenParamsWrapper(nn.Module):
""" """
assert self.is_flattened assert self.is_flattened
ps = self.get_param_views() ps = self.get_param_views()
param_views = []
for (_, m, n), p in zip(self._param_infos, ps): for (_, m, n), p in zip(self._param_infos, ps):
setattr(m, n, p) # This will set as plain attr setattr(m, n, p) # This will set as plain attr
param_views.append(p)
# Save param views for easy access if anyone still wants to access
# parameters of the module.
setattr(self._fpw_module, "_unflattened_param_views", param_views)
for (_, _, m, n, shared_m, shared_n) in self._shared_param_infos: for (_, _, m, n, shared_m, shared_n) in self._shared_param_infos:
setattr(m, n, getattr(shared_m, shared_n)) setattr(m, n, getattr(shared_m, shared_n))
......
...@@ -11,6 +11,7 @@ import torch.nn as nn ...@@ -11,6 +11,7 @@ import torch.nn as nn
from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper, disable_checkpointing from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper, disable_checkpointing
from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.misc import checkpoint_wrapper as deprecated_checkpoint_wrapper from fairscale.nn.misc import checkpoint_wrapper as deprecated_checkpoint_wrapper
from fairscale.utils import torch_version from fairscale.utils import torch_version
from fairscale.utils.testing import skip_if_no_cuda from fairscale.utils.testing import skip_if_no_cuda
...@@ -272,7 +273,7 @@ def test_deprecated_path(): ...@@ -272,7 +273,7 @@ def test_deprecated_path():
@skip_if_no_cuda @skip_if_no_cuda
def test_list_input(): def test_list_input():
""" Test checkpointing with input argument type being a list. """Test checkpointing with input argument type being a list.
Note: Testing shows that PyTorch's torch.utils.checkpoint function does not pass this test. Note: Testing shows that PyTorch's torch.utils.checkpoint function does not pass this test.
""" """
...@@ -306,7 +307,7 @@ def test_list_input(): ...@@ -306,7 +307,7 @@ def test_list_input():
def test_checkpoint_disabling(): def test_checkpoint_disabling():
""" Test to check new disable_checkpoint() API added to checkpoint_wrapper.""" """Test to check new disable_checkpoint() API added to checkpoint_wrapper."""
class TestModel(nn.Module): class TestModel(nn.Module):
def __init__(self): def __init__(self):
...@@ -339,3 +340,121 @@ def test_checkpoint_disabling(): ...@@ -339,3 +340,121 @@ def test_checkpoint_disabling():
# Backward. cnt remains same as checkpointing is disabled # Backward. cnt remains same as checkpointing is disabled
y.backward() y.backward()
assert model2.cnt == 1 assert model2.cnt == 1
def test_checkpoint_requires_grad():
"""Test to check checkpointing when outputs do not require gradient."""
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.cnt = 0
self.linear = nn.Linear(2, 2)
def forward(self, x):
self.cnt += 1
return self.linear(x)
x = torch.rand(4, 2)
model = nn.Sequential(
checkpoint_wrapper(TestModel()),
checkpoint_wrapper(TestModel()),
checkpoint_wrapper(TestModel()),
checkpoint_wrapper(TestModel()),
)
model[0].requires_grad_(False)
model[1].requires_grad_(False)
model[2].requires_grad_(False)
y = model(x)
y = y.sum()
y.backward()
# Since only last model needs grad, we only run forward twice for it
assert model[0].cnt == 1
assert model[1].cnt == 1
assert model[2].cnt == 1
assert model[3].cnt == 2
# Now test with first model needing grad
model = nn.Sequential(
checkpoint_wrapper(TestModel()),
checkpoint_wrapper(TestModel()),
checkpoint_wrapper(TestModel()),
checkpoint_wrapper(TestModel()),
)
model[0].requires_grad_(True)
model[1].requires_grad_(False)
model[2].requires_grad_(False)
y = model(x)
y = y.sum()
y.backward()
# Since first model needs grad, all models need grad, so we run forward twice for all
assert model[0].cnt == 2
assert model[1].cnt == 2
assert model[2].cnt == 2
assert model[3].cnt == 2
# Stress test with multiple inputs/outputs, of which some are not Tensor
class TestModel2(nn.Module):
def __init__(self):
super().__init__()
self.cnt = 0
self.linear = nn.Linear(2, 2)
def forward(self, x, y, z):
self.cnt += 1
z = z + [self.cnt]
return self.linear(x + y), z, ["hi"]
model1 = checkpoint_wrapper(TestModel())
model2 = checkpoint_wrapper(TestModel())
model3 = checkpoint_wrapper(TestModel2())
model4 = checkpoint_wrapper(TestModel())
model1.requires_grad_(False)
model2.requires_grad_(False)
y = model4(model3(model1(x), model2(x), ["bye"])[0])
y = y.sum()
y.backward()
assert model1.cnt == 1
assert model2.cnt == 1
assert model3.cnt == 2
assert model4.cnt == 2
model1 = checkpoint_wrapper(TestModel())
model2 = checkpoint_wrapper(TestModel())
model3 = checkpoint_wrapper(TestModel2())
model4 = checkpoint_wrapper(TestModel())
model2.requires_grad_(False)
y = model4(model3(model1(x), model2(x), ["bye"])[0])
y = y.sum()
y.backward()
assert model1.cnt == 2
assert model2.cnt == 1
assert model3.cnt == 2
assert model4.cnt == 2
# Test flattened pararameters
model = nn.Sequential(
FlattenParamsWrapper(checkpoint_wrapper(TestModel())),
FlattenParamsWrapper(checkpoint_wrapper(TestModel())),
FlattenParamsWrapper(checkpoint_wrapper(TestModel())),
FlattenParamsWrapper(checkpoint_wrapper(TestModel())),
)
model[0].requires_grad_(False)
model[1].requires_grad_(False)
y = model(x)
y = y.sum()
y.backward()
assert model[0].cnt == 1
assert model[1].cnt == 1
assert model[2].cnt == 2
assert model[3].cnt == 2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment