-
Alex Xiao authored
Before this commit, output tensors of checkpointed modules always require grad, even if they shouldn't. This commit makes it so that the outputs of checkpointed modules only require grad if either the input requires grad or if the parameters require grad. To achieve this, this commit also adds a new _unflattened_param_views attribute to modules being flattened. This allows the checkpointing to still access the parameters and check if gradients need to be computed. Co-authored-by:Alex Xiao <axiao@fb.com>
482944d9