Convert local checkpoint to global one automatically in d2go FSDP checkpointer
Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/446 ## Design Following D41861308, local checkpoints need to be converted to global ones before being loaded and used in non-FSDP wrapped models. This diff implements such conversion in d2go checkpointer level to allow automatic conversion with minimal user interference and no new config key. In previous diff, `FSDPWrapper` has 2 loading modes and 2 saving modes: it uses `load_local_state_dict` to determine whether the ckpt we want to load is local or global, and uses `use_local_state_dict` to decide whether to save new ckpts as local or global. Thus, there are 4 combinations of loading/saving modes: 1. load local + save local 2. load local + save global 3. load global + save local 4. load global + save global And the local-to-global checkpoint conversion maps to mode 2: load local + save global. Thus, when the checkpointer is in mode 2, it automatically saves the model to a global ckpt right after it loads the local ckpt. Because this happens in checkpointer level, normal training/eval can resume after ckpt conversion. This gives users a consistent and seamless experience with normal training/eval, while also providing a separate ckpt conversion feature via eval-only. ## Usage Suppose we want to convert local checkpoint `/tmp/model_final`, user can run the same training command with extra args: `MODEL.WEIGHTS=/tmp/model_final` and `FSDP.USE_LOCAL_STATE_DICT=False` Wiki: https://www.internalfb.com/intern/wiki/Mobile_Vision/Detectron2Go/D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)Go_Tutorials/Diffusion_Pipeline/Diffusion_Model_Inference/#using-checkpoints-traine Reviewed By: wat3rBro Differential Revision: D41926662 fbshipit-source-id: 18a62607a79b0e917d929e9ea85ac1658fb895ca
Showing
Please register or sign in to comment