"examples/tensorflow/dgi/gcn.py" did not exist on "b355d1eddf80fca8d80396d776006a9be2195ae4"
Commit 3204f147 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

consolidate the creation of qconfig

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/210

Reviewed By: kimishpatel

Differential Revision: D35631192

fbshipit-source-id: a713d86734c6937c16c7ced705171db9ea2f0894
parent ae2f2f64
...@@ -5,10 +5,9 @@ ...@@ -5,10 +5,9 @@
import inspect import inspect
import logging import logging
import torch
import torch.nn as nn import torch.nn as nn
from d2go.export.api import PredictorExportConfig from d2go.export.api import PredictorExportConfig
from d2go.utils.qat_utils import get_qat_qconfig from d2go.modeling.quantization import set_backend_and_create_qconfig
from detectron2.modeling import GeneralizedRCNN from detectron2.modeling import GeneralizedRCNN
from detectron2.modeling.backbone.fpn import FPN from detectron2.modeling.backbone.fpn import FPN
from detectron2.modeling.postprocessing import detector_postprocess from detectron2.modeling.postprocessing import detector_postprocess
...@@ -255,14 +254,7 @@ def _fx_quant_prepare(self, cfg): ...@@ -255,14 +254,7 @@ def _fx_quant_prepare(self, cfg):
@RCNN_PREPARE_FOR_QUANT_REGISTRY.register() @RCNN_PREPARE_FOR_QUANT_REGISTRY.register()
def default_rcnn_prepare_for_quant(self, cfg): def default_rcnn_prepare_for_quant(self, cfg):
model = self model = self
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND model.qconfig = set_backend_and_create_qconfig(cfg, is_train=model.training)
model.qconfig = (
get_qat_qconfig(
cfg.QUANTIZATION.BACKEND, cfg.QUANTIZATION.QAT.FAKE_QUANT_METHOD
)
if model.training
else torch.ao.quantization.get_default_qconfig(cfg.QUANTIZATION.BACKEND)
)
if ( if (
hasattr(model, "roi_heads") hasattr(model, "roi_heads")
and hasattr(model.roi_heads, "mask_head") and hasattr(model.roi_heads, "mask_head")
......
...@@ -16,6 +16,7 @@ from detectron2.engine import SimpleTrainer ...@@ -16,6 +16,7 @@ from detectron2.engine import SimpleTrainer
from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat from mobile_cv.arch.quantization.observer import update_stat as observer_update_stat
from mobile_cv.arch.utils import fuse_utils from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.iter_utils import recursive_iterate from mobile_cv.common.misc.iter_utils import recursive_iterate
from mobile_cv.common.misc.registry import Registry
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION > (1, 10): if TORCH_VERSION > (1, 10):
...@@ -83,6 +84,8 @@ def add_quantization_default_configs(_C): ...@@ -83,6 +84,8 @@ def add_quantization_default_configs(_C):
_C.QUANTIZATION = CfgNode() _C.QUANTIZATION = CfgNode()
# Note: EAGER_MODE == False currently represents FX graph mode quantization # Note: EAGER_MODE == False currently represents FX graph mode quantization
_C.QUANTIZATION.EAGER_MODE = True _C.QUANTIZATION.EAGER_MODE = True
# Available backends include PyTorch's natively supported backends (i.e. fbgemm and
# qnnpack), plus D2Go-defined backends such as "qnnpack@symmetric".
_C.QUANTIZATION.BACKEND = "fbgemm" _C.QUANTIZATION.BACKEND = "fbgemm"
# used to enable metarch set_custom_qscheme (need to implement) # used to enable metarch set_custom_qscheme (need to implement)
...@@ -202,6 +205,109 @@ def mock_quantization_type(quant_func): ...@@ -202,6 +205,109 @@ def mock_quantization_type(quant_func):
return wrapper return wrapper
def holistic_get_qconfig(backend, is_qat, use_symmetric=False):
"""
Config-less vanilla way to create the QConfig, suitable for explicitly creating qconfig.
"""
if use_symmetric:
if not backend == "qnnpack":
raise ValueError(
f"Only qnnpack supports Symmetric quantization, given: {backend}"
)
if is_qat:
return torch.ao.quantization.default_symmetric_qnnpack_qat_qconfig
else:
return torch.ao.quantization.default_symmetric_qnnpack_qconfig
else:
if is_qat:
return torch.ao.quantization.get_default_qat_qconfig(backend)
else:
return torch.ao.quantization.get_default_qconfig(backend)
def validate_native_backend(backend):
_PYTORCH_NATIVE_BACKENDS = ["fbgemm", "qnnpack"]
if backend not in _PYTORCH_NATIVE_BACKENDS:
raise ValueError(
f"Unrecognized backend: {backend}, PyTorch"
f" supported backends are: {_PYTORCH_NATIVE_BACKENDS}"
)
def _smart_parse_extended_backend(extended_backend):
"""
D2Go extends the definition of quantization "backend". In addition to PyTorch's
native backends (i.e. qnnpack and fbgemm), we allow other type of backend so users
can easily express different settings. Here are the supported cases:
1. Symmetric quantization: "qnnpack@symmetric" refers to using QNNPACK with
symmetric QConfig.
"""
backend = extended_backend
# default options
options = {
"is_symmetric": False,
}
if "@symmetric" in backend:
options["is_symmetric"] = True
backend = backend.replace("@symmetric", "", 1)
validate_native_backend(backend)
return backend, options
def _smart_decode_backend(extended_backend):
"""
Since we extend the definition of quantization backend, user shouldn't directly use
cfg.QUANTIZATION.BACKEND under PyTorch's context, this is the translation function
if direct use is necessary.
"""
return _smart_parse_extended_backend(extended_backend)[0]
QCONFIG_CREATOR_REGISTRY = Registry("QCONFIG_CREATOR_REGISTRY")
@QCONFIG_CREATOR_REGISTRY.register("smart")
def _smart_set_backend_and_create_qconfig(cfg, *, is_train):
"""
This is the default / "smart" way to create qconfig based on various of configs,
supports:
- learnable QAT
- set symmetric quantization via backend.
"""
backend, options = _smart_parse_extended_backend(cfg.QUANTIZATION.BACKEND)
is_symmetric = options["is_symmetric"]
# Set backend
torch.backends.quantized.engine = backend
qat_method = cfg.QUANTIZATION.QAT.FAKE_QUANT_METHOD
assert qat_method in ["default", "learnable"]
if is_train and qat_method == "learnable":
qconfig = qat_utils.get_learnable_qat_qconfig(backend)
else:
qconfig = holistic_get_qconfig(
backend=backend, is_qat=is_train, use_symmetric=is_symmetric
)
return qconfig
def set_backend_and_create_qconfig(cfg, *, is_train):
"""
Recommended function to create qconfig given D2Go's quantization config.
"""
# In case we need different implmentation, we can add a new key called
# QUANTIZATION.QCONFIG_CREATOR with "smart" as default value, and use this key
# to toggle between registries.
return QCONFIG_CREATOR_REGISTRY.get("smart")(cfg, is_train=is_train)
def default_prepare_for_quant(cfg, model): def default_prepare_for_quant(cfg, model):
""" """
Default implementation of preparing a model for quantization. This function will Default implementation of preparing a model for quantization. This function will
...@@ -226,13 +332,7 @@ def default_prepare_for_quant(cfg, model): ...@@ -226,13 +332,7 @@ def default_prepare_for_quant(cfg, model):
Return: Return:
nn.Module: a ready model for QAT training or PTQ calibration nn.Module: a ready model for QAT training or PTQ calibration
""" """
qconfig = ( qconfig = set_backend_and_create_qconfig(cfg, is_train=model.training)
qat_utils.get_qat_qconfig(
cfg.QUANTIZATION.BACKEND, cfg.QUANTIZATION.QAT.FAKE_QUANT_METHOD
)
if model.training
else torch.ao.quantization.get_default_qconfig(cfg.QUANTIZATION.BACKEND)
)
if cfg.QUANTIZATION.EAGER_MODE: if cfg.QUANTIZATION.EAGER_MODE:
model = fuse_utils.fuse_model( model = fuse_utils.fuse_model(
...@@ -241,7 +341,6 @@ def default_prepare_for_quant(cfg, model): ...@@ -241,7 +341,6 @@ def default_prepare_for_quant(cfg, model):
inplace=True, inplace=True,
) )
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
model.qconfig = qconfig model.qconfig = qconfig
# TODO(future diff): move the torch.ao.quantization.prepare(...) call # TODO(future diff): move the torch.ao.quantization.prepare(...) call
# here, to be consistent with the FX branch # here, to be consistent with the FX branch
...@@ -261,6 +360,20 @@ def default_prepare_for_quant_convert(cfg, model): ...@@ -261,6 +360,20 @@ def default_prepare_for_quant_convert(cfg, model):
return convert_fx(model) return convert_fx(model)
def apply_prepare_for_quant(cfg, model):
# TODO: create a warning for the direct use of `torch.ao.quantization.get_default_qconfig`
# or `torch.ao.quantization.get_default_qat_qconfig` without calling D2Go's high-level
# `set_backend_and_create_qconfig` API.
if hasattr(model, "prepare_for_quant"):
model = model.prepare_for_quant(cfg)
else:
logger.info("Using default implementation for prepare_for_quant")
model = default_prepare_for_quant(cfg, model)
return model
@mock_quantization_type @mock_quantization_type
def post_training_quantize(cfg, model, data_loader): def post_training_quantize(cfg, model, data_loader):
"""Calibrate a model, convert it to a quantized pytorch model""" """Calibrate a model, convert it to a quantized pytorch model"""
...@@ -270,12 +383,7 @@ def post_training_quantize(cfg, model, data_loader): ...@@ -270,12 +383,7 @@ def post_training_quantize(cfg, model, data_loader):
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
if hasattr(model, "prepare_for_quant"): model = apply_prepare_for_quant(cfg, model)
model = model.prepare_for_quant(cfg)
else:
logger.info("Using default implementation for prepare_for_quant")
model = default_prepare_for_quant(cfg, model)
if cfg.QUANTIZATION.EAGER_MODE: if cfg.QUANTIZATION.EAGER_MODE:
torch.ao.quantization.prepare(model, inplace=True) torch.ao.quantization.prepare(model, inplace=True)
logger.info("Prepared the PTQ model for calibration:\n{}".format(model)) logger.info("Prepared the PTQ model for calibration:\n{}".format(model))
...@@ -316,25 +424,6 @@ def post_training_quantize(cfg, model, data_loader): ...@@ -316,25 +424,6 @@ def post_training_quantize(cfg, model, data_loader):
return model return model
def _prepare_model_for_qat(cfg, model):
if cfg.QUANTIZATION.EAGER_MODE:
if hasattr(model, "prepare_for_quant"):
model = model.prepare_for_quant(cfg)
else:
logger.info("Using default implementation for prepare_for_quant")
model = default_prepare_for_quant(cfg, model)
# TODO(future diff): move this into prepare_for_quant to match FX branch
torch.ao.quantization.prepare_qat(model, inplace=True)
else: # FX graph mode quantization
if hasattr(model, "prepare_for_quant"):
model = model.prepare_for_quant(cfg)
else:
logger.info("Using default implementation for prepare_for_quant")
model = default_prepare_for_quant(cfg, model)
return model
@mock_quantization_type @mock_quantization_type
def setup_qat_model( def setup_qat_model(
cfg, cfg,
...@@ -349,14 +438,17 @@ def setup_qat_model( ...@@ -349,14 +438,17 @@ def setup_qat_model(
raise RuntimeError("The model is already setup to be QAT, cannot setup again!") raise RuntimeError("The model is already setup to be QAT, cannot setup again!")
device = model_fp32.device device = model_fp32.device
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND # FIXME: seems that we can remove this
torch.backends.quantized.engine = _smart_decode_backend(cfg.QUANTIZATION.BACKEND)
qat_method = cfg.QUANTIZATION.QAT.FAKE_QUANT_METHOD qat_method = cfg.QUANTIZATION.QAT.FAKE_QUANT_METHOD
# prepare for qat may modify the fp32 model directly so we create a copy # prepare for qat may modify the fp32 model directly so we create a copy
model_fp32_state_dict = model_fp32.state_dict() model_fp32_state_dict = model_fp32.state_dict()
# prepare model for qat # prepare model for qat
model = _prepare_model_for_qat(cfg, model_fp32) model = apply_prepare_for_quant(cfg, model_fp32)
if cfg.QUANTIZATION.EAGER_MODE:
torch.ao.quantization.prepare_qat(model, inplace=True)
# make sure the proper qconfig are used in the model # make sure the proper qconfig are used in the model
qat_utils.check_for_learnable_fake_quant_ops(qat_method, model) qat_utils.check_for_learnable_fake_quant_ops(qat_method, model)
......
...@@ -38,7 +38,7 @@ def check_for_learnable_fake_quant_ops(qat_method, model): ...@@ -38,7 +38,7 @@ def check_for_learnable_fake_quant_ops(qat_method, model):
if qat_method == "learnable": if qat_method == "learnable":
if not _has_module(model, _LearnableFakeQuantize): if not _has_module(model, _LearnableFakeQuantize):
raise Exception( raise Exception(
"No learnable fake quant is used for learnable quantzation, please use d2go.utils.qat_utils.get_qat_qconfig() to get proper qconfig" "No learnable fake quant is used for learnable quantzation, please use d2go.utils.qat_utils.get_learnable_qat_qconfig() to get proper qconfig"
) )
...@@ -57,11 +57,8 @@ def iterate_module_named_parameters(model, check_requires_grad=True): ...@@ -57,11 +57,8 @@ def iterate_module_named_parameters(model, check_requires_grad=True):
yield module_name, module, module_param_name, value yield module_name, module, module_param_name, value
def get_qat_qconfig(backend, qat_method="default"): def get_learnable_qat_qconfig(backend):
assert backend in ["qnnpack", "fbgemm"] assert backend in ["qnnpack", "fbgemm"]
assert qat_method in ["default", "learnable"]
if qat_method == "default":
return torch.quantization.get_default_qat_qconfig(backend)
ACT_CONFIGS = { ACT_CONFIGS = {
# follow `get_default_qat_qconfig()` # follow `get_default_qat_qconfig()`
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import torch import torch
from d2go.modeling.quantization import set_backend_and_create_qconfig
from d2go.utils.testing.data_loader_helper import create_local_dataset from d2go.utils.testing.data_loader_helper import create_local_dataset
from detectron2.modeling import META_ARCH_REGISTRY from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.structures import Boxes, ImageList, Instances from detectron2.structures import Boxes, ImageList, Instances
...@@ -54,7 +55,7 @@ class DetMetaArchForTest(torch.nn.Module): ...@@ -54,7 +55,7 @@ class DetMetaArchForTest(torch.nn.Module):
def prepare_for_quant(self, cfg): def prepare_for_quant(self, cfg):
self.avgpool = prepare_qat_fx( self.avgpool = prepare_qat_fx(
self.avgpool, self.avgpool,
{"": torch.ao.quantization.get_default_qat_qconfig()}, {"": set_backend_and_create_qconfig(cfg, is_train=self.training)},
) )
return self return self
......
...@@ -110,6 +110,8 @@ class BaseSemanticSegTestCase: ...@@ -110,6 +110,8 @@ class BaseSemanticSegTestCase:
class TestR50FPN(BaseSemanticSegTestCase.TemplateTestCase): class TestR50FPN(BaseSemanticSegTestCase.TemplateTestCase):
def setup_custom_test(self): def setup_custom_test(self):
self.cfg.merge_from_file("detectron2://Misc/semantic_R_50_FPN_1x.yaml") self.cfg.merge_from_file("detectron2://Misc/semantic_R_50_FPN_1x.yaml")
# discard pretrained backbone weights
self.cfg.merge_from_list(["MODEL.WEIGHTS", ""])
def test_export_torchscript(self): def test_export_torchscript(self):
self._test_export("torchscript", compare_match=True) self._test_export("torchscript", compare_match=True)
...@@ -10,6 +10,7 @@ from typing import Dict ...@@ -10,6 +10,7 @@ from typing import Dict
import pytorch_lightning as pl # type: ignore import pytorch_lightning as pl # type: ignore
import torch import torch
from d2go.config import CfgNode, temp_defrost from d2go.config import CfgNode, temp_defrost
from d2go.modeling.quantization import set_backend_and_create_qconfig
from d2go.runner import create_runner from d2go.runner import create_runner
from d2go.runner.callbacks.quantization import ( from d2go.runner.callbacks.quantization import (
QuantizationAwareTraining, QuantizationAwareTraining,
...@@ -175,7 +176,7 @@ class TestLightningTask(unittest.TestCase): ...@@ -175,7 +176,7 @@ class TestLightningTask(unittest.TestCase):
def prepare_for_quant(self, cfg): def prepare_for_quant(self, cfg):
self.avgpool = prepare_qat_fx( self.avgpool = prepare_qat_fx(
self.avgpool, self.avgpool,
{"": torch.ao.quantization.get_default_qat_qconfig()}, {"": set_backend_and_create_qconfig(cfg, is_train=self.training)},
self.custom_config_dict, self.custom_config_dict,
) )
return self return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment