Commit 859f0bb9 authored by John Lee's avatar John Lee Committed by Facebook GitHub Bot
Browse files

Instrument checkpoints for FSDPCheckpointer

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/536

This diff insteuments checkpoints using signpost for FSDPCheckpointer using D44278485 as a reference

Reviewed By: miqueljubert

Differential Revision: D45524792

fbshipit-source-id: 9b7e004e6853141ee26d65ae11f79b1f5f5db0e6
parent 5ecbb174
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import uuid
from contextlib import ContextDecorator
from d2go.checkpoint.log_checkpoint import log_checkpoint
logger = logging.getLogger(__name__)
class instrument_checkpoint(ContextDecorator):
def __init__(
self,
checkpoint_type: str,
) -> None:
super().__init__()
self.unique_id = uuid.uuid1().int >> 97
self.checkpoint_type = checkpoint_type
def __enter__(self) -> "instrument_checkpoint":
log_checkpoint(
checkpoint_type=self.checkpoint_type,
unique_id=self.unique_id,
state="begin",
)
return self
def __exit__(self, exc_type, exc_value, tb) -> bool:
log_checkpoint(
checkpoint_type=self.checkpoint_type,
unique_id=self.unique_id,
state="end",
)
return True
...@@ -5,6 +5,8 @@ from typing import Callable, cast, IO ...@@ -5,6 +5,8 @@ from typing import Callable, cast, IO
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import torch import torch
from d2go.checkpoint.checkpoint_instrumentation import instrument_checkpoint
from d2go.checkpoint.utils import ( from d2go.checkpoint.utils import (
gather_ema_state_dict, gather_ema_state_dict,
gather_optimizer_state_dict, gather_optimizer_state_dict,
...@@ -45,6 +47,7 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -45,6 +47,7 @@ class FSDPCheckpointer(QATCheckpointer):
def is_distributed(self) -> bool: def is_distributed(self) -> bool:
return True return True
@instrument_checkpoint("load")
def load(self, path: str, checkpointables=None): def load(self, path: str, checkpointables=None):
""" """
Add support for loading sharded optimizer states in FSDP. Add support for loading sharded optimizer states in FSDP.
...@@ -130,6 +133,7 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -130,6 +133,7 @@ class FSDPCheckpointer(QATCheckpointer):
_log_api_usage_on_main_process(f"{LOG_API_IDENTIFIER}.load.ddp") _log_api_usage_on_main_process(f"{LOG_API_IDENTIFIER}.load.ddp")
return super().load(path, checkpointables=checkpointables) return super().load(path, checkpointables=checkpointables)
@instrument_checkpoint("save")
def save(self, name: str, tag_last_ckpt=True, **kwargs) -> None: def save(self, name: str, tag_last_ckpt=True, **kwargs) -> None:
""" """
Add support for saving sharding models and optimizers. Add support for saving sharding models and optimizers.
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from mobile_cv.common.misc.oss_utils import fb_overwritable
logger = logging.getLogger(__name__)
@fb_overwritable()
def log_checkpoint(checkpoint_type=str, unique_id=int, state=str) -> None:
logger.info(f"Checkpoint:{unique_id} {checkpoint_type} {state} ")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment