[feat] set requires_grad of output tensors of checkpointed modules properly (#787)
Before this commit, output tensors of checkpointed modules always
require grad, even if they shouldn't. This commit makes it so that
the outputs of checkpointed modules only require grad if either
the input requires grad or if the parameters require grad.
To achieve this, this commit also adds a new _unflattened_param_views
attribute to modules being flattened. This allows the checkpointing
to still access the parameters and check if gradients need to be
computed.
Co-authored-by:
Alex Xiao <axiao@fb.com>
Showing
Please register or sign in to comment