Commit fd0cbb8f authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

add correctness test for ai infra checkpointer with FSDP local mode in d2go runner

Summary:
X-link: https://github.com/facebookresearch/mobile-vision/pull/141

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/484

This diff adds a correctness test for ai infra checkpointer + FSDP local mode within a d2go runner. It verifies that ai infra checkpointer saves the exact same model as the old checkpointer, and that we can convert between ai infra checkpoints (local) and fsdp checkpoints (local + global) seamlessly.

Note: adapted from mattcyu1's script D43492498.

## Testing

Testing is done by saving with both ai infra and fsdp checkponter and compare the state dict produced. Here are the steps:

1. Build the model. Save a local ckpt using the FSDP checkpointer and another local ckpt using the AIInfra checkpointer
2. Reset the model. Load local ckpt using the FSDP checkpointer and convert it to global ckpt
3. Reset the model. Load local ckpt using the AIInfra checkpointer and re-save it as global ckpt using the FSDP checkpointer
4. Compare the two global state dicts

## Others

1. Add a launch decorator for d2go.distributed worker using the one in `fbcode/mobile-vision/mobile_cv/mobile_cv/torch/utils_pytorch/distributed_helper.py`

2. Remove `ema_state.load_state_dict()` in loading. This is needed because ai infra checkpointer loads state dict in place before `ema_state.load_state_dict()` is called. Since it's inplace loading, both ema_state and state_dict['ema_state'] points to the same tensor. Calling` ema.load_state_dict()` clears ema_state, effectively freeing the tensor and cause it to return an empty dict.
Solution: Don't call `ema_state.load_state_dict()` because it's already loaded. More info: https://www.internalfb.com/intern/wiki/Components_in_AI/Checkpoint/Getting_Started/Input_Output_Contract/#load

Reviewed By: xunnanxu

Differential Revision: D43423572

fbshipit-source-id: 8c4a47917670ea1205f952540d1e4cb9fc9232c0
parent b0ef9f39
......@@ -27,6 +27,7 @@ from mobile_cv.torch.utils_pytorch.distributed_helper import (
DistributedParams,
enable_dist_process_groups,
launch as _launch,
launch_deco as _launch_deco,
save_return_deco,
)
......@@ -79,6 +80,13 @@ def distributed_worker(
return deco(main_func)(*args, **kwargs)
def launch_deco(**kwargs):
"""
launch_deco for d2go distributed worker
"""
return _launch_deco(launcher=launch, **kwargs)
def launch(
main_func: Callable[..., _RT],
num_processes_per_machine: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment