add correctness test for ai infra checkpointer with FSDP local mode in d2go runner
Summary: X-link: https://github.com/facebookresearch/mobile-vision/pull/141 Pull Request resolved: https://github.com/facebookresearch/d2go/pull/484 This diff adds a correctness test for ai infra checkpointer + FSDP local mode within a d2go runner. It verifies that ai infra checkpointer saves the exact same model as the old checkpointer, and that we can convert between ai infra checkpoints (local) and fsdp checkpoints (local + global) seamlessly. Note: adapted from mattcyu1's script D43492498. ## Testing Testing is done by saving with both ai infra and fsdp checkponter and compare the state dict produced. Here are the steps: 1. Build the model. Save a local ckpt using the FSDP checkpointer and another local ckpt using the AIInfra checkpointer 2. Reset the model. Load local ckpt using the FSDP checkpointer and convert it to global ckpt 3. Reset the model. Load local ckpt using the AIInfra checkpointer and re-save it as global ckpt using the FSDP checkpointer 4. Compare the two global state dicts ## Others 1. Add a launch decorator for d2go.distributed worker using the one in `fbcode/mobile-vision/mobile_cv/mobile_cv/torch/utils_pytorch/distributed_helper.py` 2. Remove `ema_state.load_state_dict()` in loading. This is needed because ai infra checkpointer loads state dict in place before `ema_state.load_state_dict()` is called. Since it's inplace loading, both ema_state and state_dict['ema_state'] points to the same tensor. Calling` ema.load_state_dict()` clears ema_state, effectively freeing the tensor and cause it to return an empty dict. Solution: Don't call `ema_state.load_state_dict()` because it's already loaded. More info: https://www.internalfb.com/intern/wiki/Components_in_AI/Checkpoint/Getting_Started/Input_Output_Contract/#load Reviewed By: xunnanxu Differential Revision: D43423572 fbshipit-source-id: 8c4a47917670ea1205f952540d1e4cb9fc9232c0
Showing
Please register or sign in to comment