"src/vscode:/vscode.git/clone" did not exist on "a7e941c379f0f9ab472c844b6a9f1d05d687b4e1"
Commit 34a5a3e8 authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

add ai infra checkpointer support for d2go

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/479

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/467

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/466

This allows internal solution to be plugged in, in a generic fashion, rather than relying on training patterns (FSDP or not).

Reviewed By: wat3rBro

Differential Revision: D42983444

fbshipit-source-id: a70bf0d25737d9cbbf22e3368363d3fdec57b8b5
parent 5d6bd7c2
from .api import is_distributed_checkpoint
from .fsdp_checkpoint import FSDPCheckpointer
__all__ = ["FSDPCheckpointer"]
__all__ = [
"is_distributed_checkpoint",
"FSDPCheckpointer",
]
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
from fvcore.common.checkpoint import Checkpointer
def is_distributed_checkpoint(checkpointer: Checkpointer) -> bool:
"""
Check if checkpointer supports distributed checkpointing,
in which case all ops need to be invoked in every rank.
"""
if hasattr(checkpointer, "is_distributed"):
return checkpointer.is_distributed()
return False
......@@ -21,6 +21,9 @@ class FSDPCheckpointer(QATCheckpointer):
Extend the Checkpointer to support saving/loading FSDP models
"""
def is_distributed(self) -> bool:
return True
def load(self, path: str, checkpointables=None):
"""
Add support for loading sharded optimizer states in FSDP.
......
......@@ -11,7 +11,7 @@ from typing import List, Optional, Type, Union
import d2go.utils.abnormal_checker as abnormal_checker
import detectron2.utils.comm as comm
import torch
from d2go.checkpoint import FSDPCheckpointer
from d2go.checkpoint import FSDPCheckpointer, is_distributed_checkpoint
from d2go.config import CfgNode, CONFIG_SCALING_METHOD_REGISTRY, temp_defrost
from d2go.config.utils import get_cfg_diff_table
from d2go.data.build import build_d2go_train_loader
......@@ -647,8 +647,8 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
# Note: when precise BN is enabled, some checkpoints will have more precise
# statistics than others, if they are saved immediately after eval.
# Note: FSDP requires all ranks to execute saving/loading logic
if comm.is_main_process() or isinstance(
periodic_checkpointer.checkpointer, FSDPCheckpointer
if comm.is_main_process() or is_distributed_checkpoint(
periodic_checkpointer.checkpointer
):
periodic_checkpointer.step(trainer.iter)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment