"...text-generation-inference.git" did not exist on "5d27f5259b710c5d77925930bd804edcaa3821a6"
Make EMA checkpointing with FSDP more robust
Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/615 Previous FSDP EMA checkpointing logic directly handles `EMAState`: it manually calls `FSDP.summon_full_params()` to gather the full model params, and reconstruct/load an `EMAState` for checkpointing. This logic has two drawbacks: 1. `FSDP.summon_full_params()` gathers all model weights at the same time, which could cause OOM issues if the model can't fit into a single GPU. This is quite common for FSDP workloads. 2. Directly saving and loading `EMAState` is error-prone. EMA state dict has different semantics and behaviors than `model.state_dict()`. However, users often expect it to function seamlessly like the model state dict This diff modifies the save/load logic of EMA to directly use `model.state_dict()` to solve the above 2 painpoints Reviewed By: wat3rBro Differential Revision: D48813697 fbshipit-source-id: be53c2677d2e493ba923508bbd82d9d295397941
Showing
Please register or sign in to comment