Commit 1581776b authored by Ajinkya Deogade's avatar Ajinkya Deogade Committed by Facebook GitHub Bot
Browse files

Quantization: create a separate buck target

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/552

This diff breaks down the TARGETS for dir `quantization`.
Apart from creating the TARGETS the diff temporarily copies the function `_convert_to_d2` from `d2go/runner/lightning_task.py` to avoid circular dependencies. The change is reverted in the diff D46096373.

Reviewed By: tglik

Differential Revision: D45912067

fbshipit-source-id: b430b2abd129690f8c56479bb75819940fde4e3b
parent 77dfafa2
......@@ -5,7 +5,7 @@
import copy
import logging
import math
from typing import Tuple
from typing import Any, Dict, Tuple
import detectron2.utils.comm as comm
import torch
......@@ -76,7 +76,9 @@ class QATCheckpointer(DetectionCheckpointer):
# assume file is from lightning; no one else seems to use the ".ckpt" extension
with PathManager.open(filename, "rb") as f:
data = self._torch_load(f)
from d2go.runner.lightning_task import _convert_to_d2
# TODO: Remove once buck targets are modularized and directly use
# from d2go.runner.lightning_task import _convert_to_d2
# from d2go.runner.lightning_task import _convert_to_d2
_convert_to_d2(data)
return data
......@@ -687,3 +689,45 @@ def forward_custom_prepare_fx(root, sub_module_name, orig_ret):
return m
return root, new_callback
# TODO: Remove once buck targets are modularized and directly use
# from d2go.runner.lightning_task import _convert_to_d2
_STATE_DICT_KEY = "state_dict"
_OLD_STATE_DICT_KEY = "model"
_OLD_EMA_KEY = "ema_state"
def _convert_to_d2(lightning_checkpoint: Dict[str, Any]) -> None:
prefix = "model" # based on DefaultTask.model.
old_keys = [x.lstrip("model.") for x in lightning_checkpoint[_STATE_DICT_KEY]]
for key in old_keys:
if f"{prefix}.{key}" in lightning_checkpoint[_STATE_DICT_KEY]:
lightning_checkpoint[_STATE_DICT_KEY][key] = lightning_checkpoint[
_STATE_DICT_KEY
][f"{prefix}.{key}"]
del lightning_checkpoint[_STATE_DICT_KEY][f"{prefix}.{key}"]
for old, new in zip(
[_STATE_DICT_KEY, "global_step"], [_OLD_STATE_DICT_KEY, "iteration"]
):
lightning_checkpoint[new] = lightning_checkpoint[old]
del lightning_checkpoint[old]
for old, new in zip(
["optimizer_states", "lr_schedulers"], ["optimizer", "scheduler"]
):
if old not in lightning_checkpoint:
continue
lightning_checkpoint[new] = [lightning_checkpoint[old]]
del lightning_checkpoint[old]
for key in [
"epoch",
"pytorch-lightning_versio",
"callbacks",
"hparams_name",
"hyper_parameters",
]:
if key in lightning_checkpoint:
del lightning_checkpoint[key]
from typing import Tuple
import torch
from d2go.quantization import learnable_qat
from d2go.quantization.learnable_qat import convert_to_learnable_qconfig
from mobile_cv.common.misc.registry import Registry
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
......@@ -69,7 +69,7 @@ def _smart_set_backend_and_create_qconfig(cfg, *, is_train):
backend=backend, is_qat=is_train, use_symmetric=is_symmetric
)
if is_train and qat_method == "learnable":
qconfig = learnable_qat.convert_to_learnable_qconfig(qconfig)
qconfig = convert_to_learnable_qconfig(qconfig)
return qconfig
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment