Commit 09bd2869 authored by Naveen Suda's avatar Naveen Suda Committed by Facebook GitHub Bot
Browse files

pt2e quantization support in D2Go

Summary: Add pt2e quantization support in D2Go.

Reviewed By: chakriu

Differential Revision: D54132092

fbshipit-source-id: 34a9ba79a5eb49ed27a3f33454078b0df37cf2f0
parent a637c6cc
...@@ -42,7 +42,7 @@ logger = logging.getLogger(__name__) ...@@ -42,7 +42,7 @@ logger = logging.getLogger(__name__)
def is_predictor_quantized(predictor_type: str) -> bool: def is_predictor_quantized(predictor_type: str) -> bool:
return "int8" in predictor_type return "int8" in predictor_type or "quant" in predictor_type
def convert_model( def convert_model(
...@@ -74,7 +74,7 @@ def convert_quantized_model( ...@@ -74,7 +74,7 @@ def convert_quantized_model(
# only check bn exists in ptq as qat still has bn inside fused ops # only check bn exists in ptq as qat still has bn inside fused ops
if fuse_utils.check_bn_exist(pytorch_model): if fuse_utils.check_bn_exist(pytorch_model):
logger.warn("Post training quantized model has bn inside fused ops") logger.warn("Post training quantized model has bn inside fused ops")
logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...") logger.info("Converting quantized model...")
# convert the fake-quantized model to int8 model # convert the fake-quantized model to int8 model
pytorch_model = convert_to_quantized_model(cfg, pytorch_model) pytorch_model = convert_to_quantized_model(cfg, pytorch_model)
...@@ -199,6 +199,13 @@ def default_export_predictor( ...@@ -199,6 +199,13 @@ def default_export_predictor(
models_info = {} models_info = {}
for name, model in export_config.model.items(): for name, model in export_config.model.items():
save_path = os.path.join(predictor_path, name) save_path = os.path.join(predictor_path, name)
model_export_kwargs = (
{}
if export_config.model_export_kwargs is None
else export_config.model_export_kwargs[name]
)
if hasattr(cfg, "QUANTIZATION") and cfg.QUANTIZATION.RECIPE is not None:
model_export_kwargs["recipe"] = cfg.QUANTIZATION.RECIPE
model_info = _export_single_model( model_info = _export_single_model(
predictor_path=predictor_path, predictor_path=predictor_path,
model=model, model=model,
...@@ -209,23 +216,26 @@ def default_export_predictor( ...@@ -209,23 +216,26 @@ def default_export_predictor(
if export_config.model_export_method is None if export_config.model_export_method is None
else export_config.model_export_method[name] else export_config.model_export_method[name]
), ),
model_export_kwargs=( model_export_kwargs=model_export_kwargs,
{}
if export_config.model_export_kwargs is None
else export_config.model_export_kwargs[name]
),
) )
models_info[name] = model_info models_info[name] = model_info
predictor_init_kwargs["models"] = models_info predictor_init_kwargs["models"] = models_info
else: else:
save_path = predictor_path # for single model exported files are put under `predictor_path` together with predictor_info.json save_path = predictor_path # for single model exported files are put under `predictor_path` together with predictor_info.json
model_export_kwargs = (
{}
if export_config.model_export_kwargs is None
else export_config.model_export_kwargs
)
if hasattr(cfg, "QUANTIZATION") and cfg.QUANTIZATION.RECIPE is not None:
model_export_kwargs["recipe"] = cfg.QUANTIZATION.RECIPE
model_info = _export_single_model( model_info = _export_single_model(
predictor_path=predictor_path, predictor_path=predictor_path,
model=export_config.model, model=export_config.model,
input_args=model_inputs, input_args=model_inputs,
save_path=save_path, save_path=save_path,
model_export_method=export_config.model_export_method or predictor_type, model_export_method=export_config.model_export_method or predictor_type,
model_export_kwargs=export_config.model_export_kwargs or {}, model_export_kwargs=model_export_kwargs,
) )
predictor_init_kwargs["model"] = model_info predictor_init_kwargs["model"] = model_info
......
...@@ -22,6 +22,16 @@ from mobile_cv.arch.quantization.observer import update_stat as observer_update_ ...@@ -22,6 +22,16 @@ from mobile_cv.arch.quantization.observer import update_stat as observer_update_
from mobile_cv.arch.utils import fuse_utils from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.iter_utils import recursive_iterate from mobile_cv.common.misc.iter_utils import recursive_iterate
from torch import nn from torch import nn
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
prepare_qat_pt2e,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION > (1, 10): if TORCH_VERSION > (1, 10):
...@@ -34,6 +44,7 @@ else: ...@@ -34,6 +44,7 @@ else:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_CONVERT_FX_CALLBACK_ATTRIBUTE = "_convert_fx_callback" _CONVERT_FX_CALLBACK_ATTRIBUTE = "_convert_fx_callback"
_CONVERT_PT2E_CALLBACK_ATTRIBUTE = "_convert_pt2e_callback"
_STATE_DICT_KEY = "state_dict" _STATE_DICT_KEY = "state_dict"
_OLD_STATE_DICT_KEY = "model" _OLD_STATE_DICT_KEY = "model"
_OLD_EMA_KEY = "ema_state" _OLD_EMA_KEY = "ema_state"
...@@ -185,6 +196,9 @@ def add_quantization_default_configs(_C): ...@@ -185,6 +196,9 @@ def add_quantization_default_configs(_C):
_C.QUANTIZATION.PTQ.CALIBRATION_NUM_IMAGES = 16 # NOTE: this is actually iterations _C.QUANTIZATION.PTQ.CALIBRATION_NUM_IMAGES = 16 # NOTE: this is actually iterations
_C.QUANTIZATION.PTQ.CALIBRATION_FORCE_ON_GPU = False _C.QUANTIZATION.PTQ.CALIBRATION_FORCE_ON_GPU = False
_C.QUANTIZATION.PT2E = False
_C.QUANTIZATION.RECIPE = None
# register deprecated and renamed keys # register deprecated and renamed keys
_C.register_deprecated_key("QUANTIZATION.QAT.LOAD_PRETRAINED") _C.register_deprecated_key("QUANTIZATION.QAT.LOAD_PRETRAINED")
_C.register_renamed_key("QUANTIZATION.QAT.BACKEND", "QUANTIZATION.BACKEND") _C.register_renamed_key("QUANTIZATION.QAT.BACKEND", "QUANTIZATION.BACKEND")
...@@ -336,59 +350,79 @@ def default_custom_prepare_fx(cfg, model, is_qat, example_input=None): ...@@ -336,59 +350,79 @@ def default_custom_prepare_fx(cfg, model, is_qat, example_input=None):
return model, convert_fn return model, convert_fn
def _get_symmetric_xnnpack_quantizer() -> XNNPACKQuantizer:
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(operator_config)
return quantizer
def prepare_fake_quant_model(cfg, model, is_qat, example_input=None): def prepare_fake_quant_model(cfg, model, is_qat, example_input=None):
""" """
Centralized function to prepare fp32 model (D2Go's MetaArch) to fake quant model. Centralized function to prepare fp32 model (D2Go's MetaArch) to fake quant model.
""" """
# TODO: create a warning for the direct use of `torch.ao.quantization.get_default_qconfig` if cfg.QUANTIZATION.PT2E: # pt2e quantization
# or `torch.ao.quantization.get_default_qat_qconfig` without calling D2Go's high-level if hasattr(model, "custom_prepare_pt2e"):
# `set_backend_and_create_qconfig` API. model, convert_pt2e_callback = model.custom_prepare_pt2e(cfg)
if cfg.QUANTIZATION.EAGER_MODE:
if hasattr(model, "prepare_for_quant"):
model = model.prepare_for_quant(cfg)
else: else:
logger.info( logger.info("Using default pt2e quantization APIs with XNNPACKQuantizer")
"Using default implementation for prepare_for_quant (eager mode)" captured_model = capture_pre_autograd_graph(model, example_input)
) quantizer = _get_symmetric_xnnpack_quantizer()
model = default_prepare_for_quant(cfg, model) if is_qat:
# NOTE: eager model needs to call prepare after `prepare_for_quant` model = prepare_qat_pt2e(captured_model, quantizer)
if is_qat: else:
torch.ao.quantization.prepare_qat(model, inplace=True) model = prepare_pt2e(captured_model, quantizer)
else: convert_pt2e_callback = convert_pt2e
torch.ao.quantization.prepare(model, inplace=True) setattr(model, _CONVERT_PT2E_CALLBACK_ATTRIBUTE, convert_pt2e_callback)
else: # pt1.x/legacy quantization recipe
else: # TODO: create a warning for the direct use of `torch.ao.quantization.get_default_qconfig`
# FX graph mode requires the model to be symbolically traceable, swap common # or `torch.ao.quantization.get_default_qat_qconfig` without calling D2Go's high-level
# modules like SyncBN to FX-friendly version. # `set_backend_and_create_qconfig` API.
if not is_qat: if cfg.QUANTIZATION.EAGER_MODE:
# NOTE: we only do this for PTQ, because we want to keep using unmodified if hasattr(model, "prepare_for_quant"):
# model during QAT. model = model.prepare_for_quant(cfg)
model = fuse_utils.swap_modules(model) else:
logger.info(
if hasattr(model, "custom_prepare_fx"): "Using default implementation for prepare_for_quant (eager mode)"
ret = model.custom_prepare_fx(cfg, is_qat, example_input)
if not (isinstance(ret, tuple) and len(ret) == 2):
raise ValueError(
"`custom_prepare_fx` requires return model and convert_callback"
) )
model, convert_fx_callback = ret model = default_prepare_for_quant(cfg, model)
# NOTE: eager model needs to call prepare after `prepare_for_quant`
if is_qat:
torch.ao.quantization.prepare_qat(model, inplace=True)
else:
torch.ao.quantization.prepare(model, inplace=True)
else: else:
logger.info( # FX graph mode requires the model to be symbolically traceable, swap common
"Using default implementation for custom_prepare_fx (FX graph mode)" # modules like SyncBN to FX-friendly version.
) if not is_qat:
model, convert_fx_callback = default_custom_prepare_fx( # NOTE: we only do this for PTQ, because we want to keep using unmodified
cfg, model, is_qat, example_input # model during QAT.
) model = fuse_utils.swap_modules(model)
if hasattr(model, "custom_prepare_fx"):
ret = model.custom_prepare_fx(cfg, is_qat, example_input)
if not (isinstance(ret, tuple) and len(ret) == 2):
raise ValueError(
"`custom_prepare_fx` requires return model and convert_callback"
)
model, convert_fx_callback = ret
else:
logger.info(
"Using default implementation for custom_prepare_fx (FX graph mode)"
)
model, convert_fx_callback = default_custom_prepare_fx(
cfg, model, is_qat, example_input
)
# HACK: store the convert_callback function as model attribute, which can be # HACK: store the convert_callback function as model attribute, which can be
# later accessed to convert fake quant model to quantized model. We'll find a # later accessed to convert fake quant model to quantized model. We'll find a
# better place to store this. # better place to store this.
if hasattr(model, _CONVERT_FX_CALLBACK_ATTRIBUTE): if hasattr(model, _CONVERT_FX_CALLBACK_ATTRIBUTE):
raise AttributeError( raise AttributeError(
f"{_CONVERT_FX_CALLBACK_ATTRIBUTE} is already set in model: {model}" f"{_CONVERT_FX_CALLBACK_ATTRIBUTE} is already set in model: {model}"
) )
setattr(model, _CONVERT_FX_CALLBACK_ATTRIBUTE, convert_fx_callback) setattr(model, _CONVERT_FX_CALLBACK_ATTRIBUTE, convert_fx_callback)
return model return model
...@@ -398,21 +432,27 @@ def convert_to_quantized_model(cfg, fp32_model): ...@@ -398,21 +432,27 @@ def convert_to_quantized_model(cfg, fp32_model):
Contralized function to convert fake quant model (fp32 operators) to "real" Contralized function to convert fake quant model (fp32 operators) to "real"
quantized model (int8 operators). quantized model (int8 operators).
""" """
if cfg.QUANTIZATION.EAGER_MODE: if cfg.QUANTIZATION.PT2E: # pt2e quantization
convert_fn = get_convert_fn(cfg) logger.info("Using pt2e convert")
int8_model = convert_fn(fp32_model, inplace=False) convert_pt2e_callback = getattr(fp32_model, _CONVERT_PT2E_CALLBACK_ATTRIBUTE)
quantized_model = convert_pt2e_callback(fp32_model)
else: else:
# FX graph mode quantization if cfg.QUANTIZATION.EAGER_MODE:
if not hasattr(fp32_model, _CONVERT_FX_CALLBACK_ATTRIBUTE): convert_fn = get_convert_fn(cfg)
raise AttributeError( quantized_model = convert_fn(fp32_model, inplace=False)
f"Can't find {_CONVERT_FX_CALLBACK_ATTRIBUTE} in model, please check " else:
f"`prepare_fake_quant_model` has been called: {fp32_model}" # FX graph mode quantization
) if not hasattr(fp32_model, _CONVERT_FX_CALLBACK_ATTRIBUTE):
raise AttributeError(
f"Can't find {_CONVERT_FX_CALLBACK_ATTRIBUTE} in model, please check "
f"`prepare_fake_quant_model` has been called: {fp32_model}"
)
convert_fx_callback = getattr(fp32_model, _CONVERT_FX_CALLBACK_ATTRIBUTE) convert_fx_callback = getattr(fp32_model, _CONVERT_FX_CALLBACK_ATTRIBUTE)
int8_model = convert_fx_callback(fp32_model) quantized_model = convert_fx_callback(fp32_model)
logger.info(f"Quantization backend: {cfg.QUANTIZATION.BACKEND}")
return int8_model return quantized_model
@mock_quantization_type @mock_quantization_type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment