"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a23ad87d7aecf9e3db7f18287beb6116f96c2313"
Commit 1581776b authored by Ajinkya Deogade's avatar Ajinkya Deogade Committed by Facebook GitHub Bot
Browse files

Quantization: create a separate buck target

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/552

This diff breaks down the TARGETS for dir `quantization`.
Apart from creating the TARGETS the diff temporarily copies the function `_convert_to_d2` from `d2go/runner/lightning_task.py` to avoid circular dependencies. The change is reverted in the diff D46096373.

Reviewed By: tglik

Differential Revision: D45912067

fbshipit-source-id: b430b2abd129690f8c56479bb75819940fde4e3b
parent 77dfafa2
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import copy import copy
import logging import logging
import math import math
from typing import Tuple from typing import Any, Dict, Tuple
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
import torch import torch
...@@ -76,7 +76,9 @@ class QATCheckpointer(DetectionCheckpointer): ...@@ -76,7 +76,9 @@ class QATCheckpointer(DetectionCheckpointer):
# assume file is from lightning; no one else seems to use the ".ckpt" extension # assume file is from lightning; no one else seems to use the ".ckpt" extension
with PathManager.open(filename, "rb") as f: with PathManager.open(filename, "rb") as f:
data = self._torch_load(f) data = self._torch_load(f)
from d2go.runner.lightning_task import _convert_to_d2 # TODO: Remove once buck targets are modularized and directly use
# from d2go.runner.lightning_task import _convert_to_d2
# from d2go.runner.lightning_task import _convert_to_d2
_convert_to_d2(data) _convert_to_d2(data)
return data return data
...@@ -687,3 +689,45 @@ def forward_custom_prepare_fx(root, sub_module_name, orig_ret): ...@@ -687,3 +689,45 @@ def forward_custom_prepare_fx(root, sub_module_name, orig_ret):
return m return m
return root, new_callback return root, new_callback
# TODO: Remove once buck targets are modularized and directly use
# from d2go.runner.lightning_task import _convert_to_d2
_STATE_DICT_KEY = "state_dict"
_OLD_STATE_DICT_KEY = "model"
_OLD_EMA_KEY = "ema_state"
def _convert_to_d2(lightning_checkpoint: Dict[str, Any]) -> None:
prefix = "model" # based on DefaultTask.model.
old_keys = [x.lstrip("model.") for x in lightning_checkpoint[_STATE_DICT_KEY]]
for key in old_keys:
if f"{prefix}.{key}" in lightning_checkpoint[_STATE_DICT_KEY]:
lightning_checkpoint[_STATE_DICT_KEY][key] = lightning_checkpoint[
_STATE_DICT_KEY
][f"{prefix}.{key}"]
del lightning_checkpoint[_STATE_DICT_KEY][f"{prefix}.{key}"]
for old, new in zip(
[_STATE_DICT_KEY, "global_step"], [_OLD_STATE_DICT_KEY, "iteration"]
):
lightning_checkpoint[new] = lightning_checkpoint[old]
del lightning_checkpoint[old]
for old, new in zip(
["optimizer_states", "lr_schedulers"], ["optimizer", "scheduler"]
):
if old not in lightning_checkpoint:
continue
lightning_checkpoint[new] = [lightning_checkpoint[old]]
del lightning_checkpoint[old]
for key in [
"epoch",
"pytorch-lightning_versio",
"callbacks",
"hparams_name",
"hyper_parameters",
]:
if key in lightning_checkpoint:
del lightning_checkpoint[key]
from typing import Tuple from typing import Tuple
import torch import torch
from d2go.quantization import learnable_qat from d2go.quantization.learnable_qat import convert_to_learnable_qconfig
from mobile_cv.common.misc.registry import Registry from mobile_cv.common.misc.registry import Registry
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
...@@ -69,7 +69,7 @@ def _smart_set_backend_and_create_qconfig(cfg, *, is_train): ...@@ -69,7 +69,7 @@ def _smart_set_backend_and_create_qconfig(cfg, *, is_train):
backend=backend, is_qat=is_train, use_symmetric=is_symmetric backend=backend, is_qat=is_train, use_symmetric=is_symmetric
) )
if is_train and qat_method == "learnable": if is_train and qat_method == "learnable":
qconfig = learnable_qat.convert_to_learnable_qconfig(qconfig) qconfig = convert_to_learnable_qconfig(qconfig)
return qconfig return qconfig
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment