Commit d98f74aa authored by Jiaxu Zhu's avatar Jiaxu Zhu Committed by Facebook GitHub Bot
Browse files

API Integration

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/364

As title,
when `QUANTIZATION.BACKEND` is set to `turing`, call `odai.transforms` APIs instead of OSS quantization.

Also updated `image_classification` as an example to get Turing quantized model via `custom_prepare_fx/custom_convert_fx`

Reviewed By: wat3rBro

Differential Revision: D38436675

fbshipit-source-id: e4f0e02290512bce8b18c2369a67ed9b0f116825
parent 272290cd
......@@ -73,7 +73,7 @@ def convert_quantized_model(
logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...")
# convert the fake-quantized model to int8 model
pytorch_model = convert_to_quantized_model(cfg, pytorch_model)
pytorch_model = convert_to_quantized_model(cfg, pytorch_model, data_loader)
logger.info(f"Quantized Model:\n{pytorch_model}")
return pytorch_model
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Tuple
import torch
from mobile_cv.common.misc.oss_utils import fb_overwritable
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION > (1, 10):
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
else:
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
@fb_overwritable()
def get_prepare_fx_fn(cfg, is_qat):
return prepare_qat_fx if is_qat else prepare_fx
@fb_overwritable()
def get_convert_fx_fn(cfg, example_inputs):
return convert_fx
......@@ -17,6 +17,7 @@ from mobile_cv.arch.quantization.observer import update_stat as observer_update_
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.iter_utils import recursive_iterate
from .fx import get_convert_fx_fn, get_prepare_fx_fn
from .qconfig import set_backend_and_create_qconfig, smart_decode_backend
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
......@@ -289,12 +290,12 @@ def default_custom_prepare_fx(cfg, model, is_qat, example_input=None):
" their own MetaArch."
)
if is_qat:
model = prepare_qat_fx(model, qconfig_dict, (example_input,))
else:
model = prepare_fx(model, qconfig_dict, (example_input,))
logger.info("Setup the model with qconfig:\n{}".format(qconfig))
prepare_fn = get_prepare_fx_fn(cfg, is_qat)
model = prepare_fn(
model,
qconfig_mapping=qconfig_dict,
example_inputs=(example_input,),
)
return model
......@@ -343,7 +344,7 @@ def prepare_fake_quant_model(cfg, model, is_qat, example_input=None):
return model
def convert_to_quantized_model(cfg, fp32_model):
def convert_to_quantized_model(cfg, fp32_model, data_loader):
"""
Contralized function to convert fake quant model (fp32 operators) to "real"
quantized model (int8 operators).
......@@ -352,10 +353,12 @@ def convert_to_quantized_model(cfg, fp32_model):
int8_model = convert(fp32_model, inplace=False)
else:
# FX graph mode quantization
example_input = next(iter(data_loader))
if hasattr(fp32_model, "custom_convert_fx"):
int8_model = fp32_model.custom_convert_fx(cfg)
int8_model = fp32_model.custom_convert_fx(cfg, example_input)
else:
int8_model = convert_fx(fp32_model)
convert_fn = get_convert_fx_fn(cfg, (example_input,))
int8_model = convert_fn(fp32_model)
return int8_model
......
......@@ -9,6 +9,7 @@ if TORCH_VERSION > (1, 10):
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
else:
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
from mobile_cv.common.misc.oss_utils import fb_overwritable
QCONFIG_CREATOR_REGISTRY = Registry("QCONFIG_CREATOR_REGISTRY")
......@@ -82,6 +83,7 @@ def validate_native_backend(backend):
)
@fb_overwritable()
def _smart_parse_extended_backend(extended_backend):
"""
D2Go extends the definition of quantization "backend". In addition to PyTorch's
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment