Commit a536c85b authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

Add logging for checkpointer type, distributed mode, and checkpointing mode in d2go

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/534

Currently, d2go supports 2 checkpointers, 2 distributed modes and 3 checkpointing modes. The many options make it hard to maintain and manage all use cases. For example, after the recent migration to FSDP sharded_state_dict, it's hard to understand and trace down the usage of the deprecated version.

Per crassirostris and wat3rBro's advice, this diff add API loggings to better keep track of checkpointer usage in d2go.

## Appendix
2 checkpointers: FSDPCheckpointer, AIInfraCheckpointer
2 distributed modes: ddp, fsdp
3 checkpointing modes (fsdp only): local_state_dict, sharded_state_dict, full_state_dict

Reviewed By: tglik

Differential Revision: D45385021

fbshipit-source-id: 5d2cb115ed0fdada254b819793e376e410ecd97d
parent c7bd7dfe
...@@ -13,12 +13,16 @@ from d2go.checkpoint.utils import ( ...@@ -13,12 +13,16 @@ from d2go.checkpoint.utils import (
) )
from d2go.quantization.modeling import QATCheckpointer from d2go.quantization.modeling import QATCheckpointer
from d2go.trainer.fsdp import FSDPWrapper from d2go.trainer.fsdp import FSDPWrapper
from d2go.utils.misc import _log_api_usage_on_main_process
from mobile_cv.torch.utils_pytorch.distributed_helper import interleave_by_rank from mobile_cv.torch.utils_pytorch.distributed_helper import interleave_by_rank
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
LOG_API_IDENTIFIER = "checkpointing.FSDPCheckpointer"
def get_max_checkpoint_concurrency() -> int: def get_max_checkpoint_concurrency() -> int:
return comm.get_world_size() return comm.get_world_size()
...@@ -60,19 +64,21 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -60,19 +64,21 @@ class FSDPCheckpointer(QATCheckpointer):
) )
assert state_dict_type in ["LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] assert state_dict_type in ["LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
type_str = "local" if "LOCAL_STATE_DICT" else "sharded"
self.logger.info( self.logger.info(
f"[FSDPCheckpointer] Loading from {type_str} checkpoint ..." f"[FSDPCheckpointer] Loading from {state_dict_type} checkpoint ..."
) )
self.model.load_state_dict_type = StateDictType[state_dict_type] self.model.load_state_dict_type = StateDictType[state_dict_type]
load_path = os.path.join(path, f"rank{comm.get_rank()}.pth") load_path = os.path.join(path, f"rank{comm.get_rank()}.pth")
# loading path is a file: full global state dict is used # loading path is a file: full global state dict is used
else: else:
self.logger.info( self.logger.info(
"[FSDPCheckpointer] Loading from full checkpoint ..." "[FSDPCheckpointer] Loading from FULL_STATE_DICT checkpoint ..."
) )
self.model.load_state_dict_type = StateDictType.FULL_STATE_DICT self.model.load_state_dict_type = StateDictType.FULL_STATE_DICT
_log_api_usage_on_main_process(
f"{LOG_API_IDENTIFIER}.load.fsdp.{self.model.load_state_dict_type.name}" # pyre-ignore
)
# Convert local ckpt to global ckpt when we load from a local ckpt but want to save to global ckpt # Convert local ckpt to global ckpt when we load from a local ckpt but want to save to global ckpt
convert_local_ckpt_to_global = ( convert_local_ckpt_to_global = (
path path
...@@ -121,6 +127,7 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -121,6 +127,7 @@ class FSDPCheckpointer(QATCheckpointer):
# return all remaining checkpoints # return all remaining checkpoints
return checkpoint return checkpoint
else: else:
_log_api_usage_on_main_process(f"{LOG_API_IDENTIFIER}.load.ddp")
return super().load(path, checkpointables=checkpointables) return super().load(path, checkpointables=checkpointables)
def save(self, name: str, tag_last_ckpt=True, **kwargs) -> None: def save(self, name: str, tag_last_ckpt=True, **kwargs) -> None:
...@@ -131,9 +138,15 @@ class FSDPCheckpointer(QATCheckpointer): ...@@ -131,9 +138,15 @@ class FSDPCheckpointer(QATCheckpointer):
# If no sharding, only the main process enters the saving codepath; # If no sharding, only the main process enters the saving codepath;
# otherwise, all processes need to call state_dict() to enable state broadcasting among ranks # otherwise, all processes need to call state_dict() to enable state broadcasting among ranks
if not isinstance(self.model, FSDPWrapper): if not isinstance(self.model, FSDPWrapper):
_log_api_usage_on_main_process(f"{LOG_API_IDENTIFIER}.save.ddp")
if comm.is_main_process(): if comm.is_main_process():
return super().save(name, **kwargs) return super().save(name, **kwargs)
return return
_log_api_usage_on_main_process(
f"{LOG_API_IDENTIFIER}.save.fsdp.{self.model.state_dict_type.name}"
)
data = {} data = {}
# FSDP: model.state_dict() needs to be called by all ranks before saving # FSDP: model.state_dict() needs to be called by all ranks before saving
data["model"] = self.model.state_dict() data["model"] = self.model.state_dict()
......
...@@ -129,6 +129,14 @@ def _log_api_usage(identifier: str): ...@@ -129,6 +129,14 @@ def _log_api_usage(identifier: str):
torch._C._log_api_usage_once("d2go." + identifier) torch._C._log_api_usage_once("d2go." + identifier)
def _log_api_usage_on_main_process(identifier: str):
"""
Log the usage of d2go API only on the main process.
"""
if comm.is_main_process():
_log_api_usage(identifier)
def inplace_delegate( def inplace_delegate(
self, self,
api_name: str, api_name: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment