Commit bbc14dd7 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

initial support for loading lightning-based checkpointer in runner-based trainer

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/345

Reviewed By: xiecong

Differential Revision: D38086885

fbshipit-source-id: 808e104ee50c8870ae091533ac67b440e1bb8351
parent b04ba38b
......@@ -12,6 +12,7 @@ import torch
from d2go.quantization import learnable_qat
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.engine import HookBase, SimpleTrainer
from detectron2.utils.file_io import PathManager
from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.iter_utils import recursive_iterate
......@@ -34,6 +35,8 @@ def _is_observer_key(state_dict_key):
return any(x in state_dict_key for x in observer_keys)
# TODO: replace QATCheckpointer with central D2GoCheckpointer which supports customize
# state_dict re-mapping (which includes QAT re-mapping).
class QATCheckpointer(DetectionCheckpointer):
"""
Extend the Checkpointer to support loading (QAT / non-QAT) weight into
......@@ -44,6 +47,20 @@ class QATCheckpointer(DetectionCheckpointer):
def _is_q_state_dict(cls, state_dict):
return any(_is_observer_key(k) for k in state_dict)
# HACK: temporarily put it here, move to centrail D2GoCheckpointer later on
def _load_file(self, filename):
# support loading lightning checkpointer
if filename.endswith(".ckpt"):
# assume file is from lightning; no one else seems to use the ".ckpt" extension
with PathManager.open(filename, "rb") as f:
data = torch.load(f, map_location=torch.device("cpu"))
from d2go.runner.lightning_task import _convert_to_d2
_convert_to_d2(data)
return data
return super()._load_file(filename)
def _load_model(self, checkpoint):
model_is_qat = self._is_q_state_dict(self.model.state_dict())
checkpoint_is_qat = self._is_q_state_dict(checkpoint["model"])
......
......@@ -96,6 +96,36 @@ def _convert_to_lightning(d2_checkpoint: Dict[str, Any]) -> None:
d2_checkpoint["epoch"] = 0
def _convert_to_d2(lightning_checkpoint: Dict[str, Any]) -> None:
prefix = "model" # based on DefaultTask.model.
old_keys = [x.lstrip("model.") for x in lightning_checkpoint[_STATE_DICT_KEY]]
for key in old_keys:
lightning_checkpoint[_STATE_DICT_KEY][key] = lightning_checkpoint[
_STATE_DICT_KEY
][f"{prefix}.{key}"]
del lightning_checkpoint[_STATE_DICT_KEY][f"{prefix}.{key}"]
for old, new in zip(
[_STATE_DICT_KEY, "global_step"], [_OLD_STATE_DICT_KEY, "iteration"]
):
lightning_checkpoint[new] = lightning_checkpoint[old]
del lightning_checkpoint[old]
for old, new in zip(
["optimizer_states", "lr_schedulers"], ["optimizer", "scheduler"]
):
if old not in lightning_checkpoint:
continue
lightning_checkpoint[new] = [lightning_checkpoint[old]]
del lightning_checkpoint[old]
del lightning_checkpoint["epoch"]
del lightning_checkpoint["pytorch-lightning_version"]
del lightning_checkpoint["callbacks"]
del lightning_checkpoint["hparams_name"]
del lightning_checkpoint["hyper_parameters"]
class ModelTag(str, Enum):
DEFAULT = "default"
EMA = "ema"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment