Commit 09bd2869 authored by Naveen Suda's avatar Naveen Suda Committed by Facebook GitHub Bot
Browse files

pt2e quantization support in D2Go

Summary: Add pt2e quantization support in D2Go.

Reviewed By: chakriu

Differential Revision: D54132092

fbshipit-source-id: 34a9ba79a5eb49ed27a3f33454078b0df37cf2f0
parent a637c6cc
......@@ -42,7 +42,7 @@ logger = logging.getLogger(__name__)
def is_predictor_quantized(predictor_type: str) -> bool:
return "int8" in predictor_type
return "int8" in predictor_type or "quant" in predictor_type
def convert_model(
......@@ -74,7 +74,7 @@ def convert_quantized_model(
# only check bn exists in ptq as qat still has bn inside fused ops
if fuse_utils.check_bn_exist(pytorch_model):
logger.warn("Post training quantized model has bn inside fused ops")
logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...")
logger.info("Converting quantized model...")
# convert the fake-quantized model to int8 model
pytorch_model = convert_to_quantized_model(cfg, pytorch_model)
......@@ -199,6 +199,13 @@ def default_export_predictor(
models_info = {}
for name, model in export_config.model.items():
save_path = os.path.join(predictor_path, name)
model_export_kwargs = (
{}
if export_config.model_export_kwargs is None
else export_config.model_export_kwargs[name]
)
if hasattr(cfg, "QUANTIZATION") and cfg.QUANTIZATION.RECIPE is not None:
model_export_kwargs["recipe"] = cfg.QUANTIZATION.RECIPE
model_info = _export_single_model(
predictor_path=predictor_path,
model=model,
......@@ -209,23 +216,26 @@ def default_export_predictor(
if export_config.model_export_method is None
else export_config.model_export_method[name]
),
model_export_kwargs=(
{}
if export_config.model_export_kwargs is None
else export_config.model_export_kwargs[name]
),
model_export_kwargs=model_export_kwargs,
)
models_info[name] = model_info
predictor_init_kwargs["models"] = models_info
else:
save_path = predictor_path # for single model exported files are put under `predictor_path` together with predictor_info.json
model_export_kwargs = (
{}
if export_config.model_export_kwargs is None
else export_config.model_export_kwargs
)
if hasattr(cfg, "QUANTIZATION") and cfg.QUANTIZATION.RECIPE is not None:
model_export_kwargs["recipe"] = cfg.QUANTIZATION.RECIPE
model_info = _export_single_model(
predictor_path=predictor_path,
model=export_config.model,
input_args=model_inputs,
save_path=save_path,
model_export_method=export_config.model_export_method or predictor_type,
model_export_kwargs=export_config.model_export_kwargs or {},
model_export_kwargs=model_export_kwargs,
)
predictor_init_kwargs["model"] = model_info
......
......@@ -22,6 +22,16 @@ from mobile_cv.arch.quantization.observer import update_stat as observer_update_
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.iter_utils import recursive_iterate
from torch import nn
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
prepare_qat_pt2e,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION > (1, 10):
......@@ -34,6 +44,7 @@ else:
logger = logging.getLogger(__name__)
_CONVERT_FX_CALLBACK_ATTRIBUTE = "_convert_fx_callback"
_CONVERT_PT2E_CALLBACK_ATTRIBUTE = "_convert_pt2e_callback"
_STATE_DICT_KEY = "state_dict"
_OLD_STATE_DICT_KEY = "model"
_OLD_EMA_KEY = "ema_state"
......@@ -185,6 +196,9 @@ def add_quantization_default_configs(_C):
_C.QUANTIZATION.PTQ.CALIBRATION_NUM_IMAGES = 16 # NOTE: this is actually iterations
_C.QUANTIZATION.PTQ.CALIBRATION_FORCE_ON_GPU = False
_C.QUANTIZATION.PT2E = False
_C.QUANTIZATION.RECIPE = None
# register deprecated and renamed keys
_C.register_deprecated_key("QUANTIZATION.QAT.LOAD_PRETRAINED")
_C.register_renamed_key("QUANTIZATION.QAT.BACKEND", "QUANTIZATION.BACKEND")
......@@ -336,14 +350,34 @@ def default_custom_prepare_fx(cfg, model, is_qat, example_input=None):
return model, convert_fn
def _get_symmetric_xnnpack_quantizer() -> XNNPACKQuantizer:
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(operator_config)
return quantizer
def prepare_fake_quant_model(cfg, model, is_qat, example_input=None):
"""
Centralized function to prepare fp32 model (D2Go's MetaArch) to fake quant model.
"""
if cfg.QUANTIZATION.PT2E: # pt2e quantization
if hasattr(model, "custom_prepare_pt2e"):
model, convert_pt2e_callback = model.custom_prepare_pt2e(cfg)
else:
logger.info("Using default pt2e quantization APIs with XNNPACKQuantizer")
captured_model = capture_pre_autograd_graph(model, example_input)
quantizer = _get_symmetric_xnnpack_quantizer()
if is_qat:
model = prepare_qat_pt2e(captured_model, quantizer)
else:
model = prepare_pt2e(captured_model, quantizer)
convert_pt2e_callback = convert_pt2e
setattr(model, _CONVERT_PT2E_CALLBACK_ATTRIBUTE, convert_pt2e_callback)
else: # pt1.x/legacy quantization recipe
# TODO: create a warning for the direct use of `torch.ao.quantization.get_default_qconfig`
# or `torch.ao.quantization.get_default_qat_qconfig` without calling D2Go's high-level
# `set_backend_and_create_qconfig` API.
if cfg.QUANTIZATION.EAGER_MODE:
if hasattr(model, "prepare_for_quant"):
model = model.prepare_for_quant(cfg)
......@@ -398,9 +432,14 @@ def convert_to_quantized_model(cfg, fp32_model):
Contralized function to convert fake quant model (fp32 operators) to "real"
quantized model (int8 operators).
"""
if cfg.QUANTIZATION.PT2E: # pt2e quantization
logger.info("Using pt2e convert")
convert_pt2e_callback = getattr(fp32_model, _CONVERT_PT2E_CALLBACK_ATTRIBUTE)
quantized_model = convert_pt2e_callback(fp32_model)
else:
if cfg.QUANTIZATION.EAGER_MODE:
convert_fn = get_convert_fn(cfg)
int8_model = convert_fn(fp32_model, inplace=False)
quantized_model = convert_fn(fp32_model, inplace=False)
else:
# FX graph mode quantization
if not hasattr(fp32_model, _CONVERT_FX_CALLBACK_ATTRIBUTE):
......@@ -410,9 +449,10 @@ def convert_to_quantized_model(cfg, fp32_model):
)
convert_fx_callback = getattr(fp32_model, _CONVERT_FX_CALLBACK_ATTRIBUTE)
int8_model = convert_fx_callback(fp32_model)
quantized_model = convert_fx_callback(fp32_model)
logger.info(f"Quantization backend: {cfg.QUANTIZATION.BACKEND}")
return int8_model
return quantized_model
@mock_quantization_type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment