Commit 3ba489fa authored by Jiaxu Zhu's avatar Jiaxu Zhu Committed by Facebook GitHub Bot
Browse files

use `get_convert_fx_fn` for eager mode convert

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/564

As title, as we need `ai_factory.quantization.convert.convert_eager` for Stinson models. This diff renames ``get_convert_fx_fn` to `get_convert_fn` and includes eager mode convert functions as well

Reviewed By: wat3rBro

Differential Revision: D46368438

fbshipit-source-id: 5ebea1f05b43b476a14ab1091f6ce39bffe614d3
parent 7d35bae7
......@@ -10,8 +10,10 @@ from mobile_cv.common.misc.oss_utils import fb_overwritable
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION > (1, 10):
from torch.ao.quantization.quantize import convert
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
else:
from torch.quantization.quantize import convert
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
......@@ -21,5 +23,8 @@ def get_prepare_fx_fn(cfg, is_qat):
@fb_overwritable()
def get_convert_fx_fn(cfg, example_inputs, qconfig_mapping=None, backend_config=None):
return convert_fx
def get_convert_fn(cfg, example_inputs=None, qconfig_mapping=None, backend_config=None):
if cfg.QUANTIZATION.EAGER_MODE:
return convert
else:
return convert_fx
......@@ -10,7 +10,7 @@ from typing import Any, Dict, Tuple
import detectron2.utils.comm as comm
import torch
from d2go.quantization import learnable_qat
from d2go.quantization.fx import get_convert_fx_fn, get_prepare_fx_fn
from d2go.quantization.fx import get_convert_fn, get_prepare_fx_fn
from d2go.quantization.qconfig import (
set_backend_and_create_qconfig,
smart_decode_backend,
......@@ -329,7 +329,7 @@ def default_custom_prepare_fx(cfg, model, is_qat, example_input=None):
qconfig_mapping=qconfig_dict,
example_inputs=(example_input,),
)
convert_fn = get_convert_fx_fn(cfg, (example_input,))
convert_fn = get_convert_fn(cfg, (example_input,))
return model, convert_fn
......@@ -396,7 +396,8 @@ def convert_to_quantized_model(cfg, fp32_model):
quantized model (int8 operators).
"""
if cfg.QUANTIZATION.EAGER_MODE:
int8_model = convert(fp32_model, inplace=False)
convert_fn = get_convert_fn(cfg)
int8_model = convert_fn(fp32_model, inplace=False)
else:
# FX graph mode quantization
if not hasattr(fp32_model, _CONVERT_FX_CALLBACK_ATTRIBUTE):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment