Commit d98f74aa authored by Jiaxu Zhu's avatar Jiaxu Zhu Committed by Facebook GitHub Bot
Browse files

API Integration

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/364

As title,
when `QUANTIZATION.BACKEND` is set to `turing`, call `odai.transforms` APIs instead of OSS quantization.

Also updated `image_classification` as an example to get Turing quantized model via `custom_prepare_fx/custom_convert_fx`

Reviewed By: wat3rBro

Differential Revision: D38436675

fbshipit-source-id: e4f0e02290512bce8b18c2369a67ed9b0f116825
parent 272290cd
...@@ -73,7 +73,7 @@ def convert_quantized_model( ...@@ -73,7 +73,7 @@ def convert_quantized_model(
logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...") logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...")
# convert the fake-quantized model to int8 model # convert the fake-quantized model to int8 model
pytorch_model = convert_to_quantized_model(cfg, pytorch_model) pytorch_model = convert_to_quantized_model(cfg, pytorch_model, data_loader)
logger.info(f"Quantized Model:\n{pytorch_model}") logger.info(f"Quantized Model:\n{pytorch_model}")
return pytorch_model return pytorch_model
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Tuple
import torch
from mobile_cv.common.misc.oss_utils import fb_overwritable
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION > (1, 10):
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
else:
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
@fb_overwritable()
def get_prepare_fx_fn(cfg, is_qat):
return prepare_qat_fx if is_qat else prepare_fx
@fb_overwritable()
def get_convert_fx_fn(cfg, example_inputs):
return convert_fx
...@@ -17,6 +17,7 @@ from mobile_cv.arch.quantization.observer import update_stat as observer_update_ ...@@ -17,6 +17,7 @@ from mobile_cv.arch.quantization.observer import update_stat as observer_update_
from mobile_cv.arch.utils import fuse_utils from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.iter_utils import recursive_iterate from mobile_cv.common.misc.iter_utils import recursive_iterate
from .fx import get_convert_fx_fn, get_prepare_fx_fn
from .qconfig import set_backend_and_create_qconfig, smart_decode_backend from .qconfig import set_backend_and_create_qconfig, smart_decode_backend
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
...@@ -289,12 +290,12 @@ def default_custom_prepare_fx(cfg, model, is_qat, example_input=None): ...@@ -289,12 +290,12 @@ def default_custom_prepare_fx(cfg, model, is_qat, example_input=None):
" their own MetaArch." " their own MetaArch."
) )
if is_qat: prepare_fn = get_prepare_fx_fn(cfg, is_qat)
model = prepare_qat_fx(model, qconfig_dict, (example_input,)) model = prepare_fn(
else: model,
model = prepare_fx(model, qconfig_dict, (example_input,)) qconfig_mapping=qconfig_dict,
example_inputs=(example_input,),
logger.info("Setup the model with qconfig:\n{}".format(qconfig)) )
return model return model
...@@ -343,7 +344,7 @@ def prepare_fake_quant_model(cfg, model, is_qat, example_input=None): ...@@ -343,7 +344,7 @@ def prepare_fake_quant_model(cfg, model, is_qat, example_input=None):
return model return model
def convert_to_quantized_model(cfg, fp32_model): def convert_to_quantized_model(cfg, fp32_model, data_loader):
""" """
Contralized function to convert fake quant model (fp32 operators) to "real" Contralized function to convert fake quant model (fp32 operators) to "real"
quantized model (int8 operators). quantized model (int8 operators).
...@@ -352,10 +353,12 @@ def convert_to_quantized_model(cfg, fp32_model): ...@@ -352,10 +353,12 @@ def convert_to_quantized_model(cfg, fp32_model):
int8_model = convert(fp32_model, inplace=False) int8_model = convert(fp32_model, inplace=False)
else: else:
# FX graph mode quantization # FX graph mode quantization
example_input = next(iter(data_loader))
if hasattr(fp32_model, "custom_convert_fx"): if hasattr(fp32_model, "custom_convert_fx"):
int8_model = fp32_model.custom_convert_fx(cfg) int8_model = fp32_model.custom_convert_fx(cfg, example_input)
else: else:
int8_model = convert_fx(fp32_model) convert_fn = get_convert_fx_fn(cfg, (example_input,))
int8_model = convert_fn(fp32_model)
return int8_model return int8_model
......
...@@ -9,6 +9,7 @@ if TORCH_VERSION > (1, 10): ...@@ -9,6 +9,7 @@ if TORCH_VERSION > (1, 10):
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
else: else:
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
from mobile_cv.common.misc.oss_utils import fb_overwritable
QCONFIG_CREATOR_REGISTRY = Registry("QCONFIG_CREATOR_REGISTRY") QCONFIG_CREATOR_REGISTRY = Registry("QCONFIG_CREATOR_REGISTRY")
...@@ -82,6 +83,7 @@ def validate_native_backend(backend): ...@@ -82,6 +83,7 @@ def validate_native_backend(backend):
) )
@fb_overwritable()
def _smart_parse_extended_backend(extended_backend): def _smart_parse_extended_backend(extended_backend):
""" """
D2Go extends the definition of quantization "backend". In addition to PyTorch's D2Go extends the definition of quantization "backend". In addition to PyTorch's
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment