• Alex Xiao's avatar
    [feat] set requires_grad of output tensors of checkpointed modules properly (#787) · 482944d9
    Alex Xiao authored
    
    
    Before this commit, output tensors of checkpointed modules always
    require grad, even if they shouldn't. This commit makes it so that
    the outputs of checkpointed modules only require grad if either
    the input requires grad or if the parameters require grad.
    
    To achieve this, this commit also adds a new _unflattened_param_views
    attribute to modules being flattened. This allows the checkpointing
    to still access the parameters and check if gradients need to be
    computed.
    Co-authored-by: default avatarAlex Xiao <axiao@fb.com>
    482944d9
test_checkpoint_activations.py 14.4 KB