Propagate include_frozen/buffers to EMAState in FSDP FULL_STATE_DICT checkpoints
Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/620 EMA can be configured to exclude frozen (`requires_grad=False`) parameters and buffers, reducing memory use and checkpoint size. However `FULL_STATE_DICT` FSDP + EMA checkpoints construct an inner `EMAState` after unsharding FSDP parameters. This inner `EMAState` uses default `include_frozen` and `include_buffers` settings, resulting in checkpoints containing frozen parameters and buffers regardless of settings. Propagate `include_frozen` and `include_buffers` settings to the inner `EMAState` when gathering `FULL_STATE_DICT` FSDP EMA state. This change only affects frozen parameters with a parallel fix to PyTorch FSDP to propagate `requires_grad` across parameter sharding/unsharding: https://github.com/pytorch/pytorch/pull/109892. Reviewed By: daveboat Differential Revision: D49517178 fbshipit-source-id: 0fe159dcec9ec1f2c456ae2ee7798681e7536249
Showing
Please register or sign in to comment