Commit 34a5a3e8 authored by Matthew Yu's avatar Matthew Yu Committed by Facebook GitHub Bot
Browse files

add ai infra checkpointer support for d2go

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/479

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/467

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/466

This allows internal solution to be plugged in, in a generic fashion, rather than relying on training patterns (FSDP or not).

Reviewed By: wat3rBro

Differential Revision: D42983444

fbshipit-source-id: a70bf0d25737d9cbbf22e3368363d3fdec57b8b5
parent 5d6bd7c2
from .api import is_distributed_checkpoint
from .fsdp_checkpoint import FSDPCheckpointer from .fsdp_checkpoint import FSDPCheckpointer
__all__ = ["FSDPCheckpointer"] __all__ = [
"is_distributed_checkpoint",
"FSDPCheckpointer",
]
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
from fvcore.common.checkpoint import Checkpointer
def is_distributed_checkpoint(checkpointer: Checkpointer) -> bool:
"""
Check if checkpointer supports distributed checkpointing,
in which case all ops need to be invoked in every rank.
"""
if hasattr(checkpointer, "is_distributed"):
return checkpointer.is_distributed()
return False
...@@ -21,6 +21,9 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -21,6 +21,9 @@ class FSDPCheckpointer(QATCheckpointer):
Extend the Checkpointer to support saving/loading FSDP models Extend the Checkpointer to support saving/loading FSDP models
""" """
def is_distributed(self) -> bool:
return True
def load(self, path: str, checkpointables=None): def load(self, path: str, checkpointables=None):
""" """
Add support for loading sharded optimizer states in FSDP. Add support for loading sharded optimizer states in FSDP.
......
...@@ -11,7 +11,7 @@ from typing import List, Optional, Type, Union ...@@ -11,7 +11,7 @@ from typing import List, Optional, Type, Union
import d2go.utils.abnormal_checker as abnormal_checker import d2go.utils.abnormal_checker as abnormal_checker
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import torch import torch
from d2go.checkpoint import FSDPCheckpointer from d2go.checkpoint import FSDPCheckpointer, is_distributed_checkpoint
from d2go.config import CfgNode, CONFIG_SCALING_METHOD_REGISTRY, temp_defrost from d2go.config import CfgNode, CONFIG_SCALING_METHOD_REGISTRY, temp_defrost
from d2go.config.utils import get_cfg_diff_table from d2go.config.utils import get_cfg_diff_table
from d2go.data.build import build_d2go_train_loader from d2go.data.build import build_d2go_train_loader
...@@ -647,8 +647,8 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -647,8 +647,8 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
# Note: when precise BN is enabled, some checkpoints will have more precise # Note: when precise BN is enabled, some checkpoints will have more precise
# statistics than others, if they are saved immediately after eval. # statistics than others, if they are saved immediately after eval.
# Note: FSDP requires all ranks to execute saving/loading logic # Note: FSDP requires all ranks to execute saving/loading logic
if comm.is_main_process() or isinstance( if comm.is_main_process() or is_distributed_checkpoint(
periodic_checkpointer.checkpointer, FSDPCheckpointer periodic_checkpointer.checkpointer
): ):
periodic_checkpointer.step(trainer.iter) periodic_checkpointer.step(trainer.iter)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment