Make AMP compatible with FSDP
Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/458 Make AMP compatible with FSDP. FSDP does not depend on the torch AMP module and implements its own MixedPrecision module. This MixedPrecision module directly saves additional copy of weights in lower precision and use run these tensors in mixed precision training. This is very different from AMP, which automatically casts tensors to lower precision upon tensor operations. This diff solves some compatibility bugs between AMP and FSDP with 2 changes: 1. Use "never_wrap_policy" as the default dummy autowrap policy. FSDP Mixed Precision doesn't work with Batchnorm layers. This is because FSDP and other resources like NVidia apex highly discourage running lower precision for batchnorm: https://github.com/pytorch/pytorch/issues/75478. We need to use some autowrap policy in order to let FSDP surpass batchnorm layers in constructing mixed precision. 2. Wrap FSDPWrapper.forward() with autocast() FSDP Mixed Precision uses lower-precision tensors in computation, which could raise type mismatch error when amp.autocast() is not enabled, like in eval. Thus, we wrap FSDP forward() with autocast() Reviewed By: wat3rBro Differential Revision: D41328834 fbshipit-source-id: 18cf94c4ad8d9422ffd3bb335873cd29ac987ae9
Showing
Please register or sign in to comment