Commit 55133c79 authored by Jiaxu Zhu's avatar Jiaxu Zhu Committed by Facebook GitHub Bot
Browse files

Retry API Integration

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/386

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/382

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/364

As title,
when `QUANTIZATION.BACKEND` is set to `turing`, call `odai.transforms` APIs instead of OSS quantization.

Also updated `image_classification` as an example to get Turing quantized model via `custom_prepare_fx/custom_convert_fx`

Reviewed By: wat3rBro

Differential Revision: D40282390

fbshipit-source-id: 7d6509969cfe8537153e1d59f21967eeb7801fd1
parent 0fc2cd1c
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Tuple
import torch
from mobile_cv.common.misc.oss_utils import fb_overwritable
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION > (1, 10):
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
else:
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
@fb_overwritable()
def get_prepare_fx_fn(cfg, is_qat):
return prepare_qat_fx if is_qat else prepare_fx
@fb_overwritable()
def get_convert_fx_fn(cfg, example_inputs):
return convert_fx
...@@ -17,6 +17,7 @@ from mobile_cv.arch.quantization.observer import update_stat as observer_update_ ...@@ -17,6 +17,7 @@ from mobile_cv.arch.quantization.observer import update_stat as observer_update_
from mobile_cv.arch.utils import fuse_utils from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.iter_utils import recursive_iterate from mobile_cv.common.misc.iter_utils import recursive_iterate
from .fx import get_convert_fx_fn, get_prepare_fx_fn
from .qconfig import set_backend_and_create_qconfig, smart_decode_backend from .qconfig import set_backend_and_create_qconfig, smart_decode_backend
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
...@@ -291,13 +292,14 @@ def default_custom_prepare_fx(cfg, model, is_qat, example_input=None): ...@@ -291,13 +292,14 @@ def default_custom_prepare_fx(cfg, model, is_qat, example_input=None):
" their own MetaArch." " their own MetaArch."
) )
if is_qat: prepare_fn = get_prepare_fx_fn(cfg, is_qat)
model = prepare_qat_fx(model, qconfig_dict, (example_input,)) model = prepare_fn(
else: model,
model = prepare_fx(model, qconfig_dict, (example_input,)) qconfig_mapping=qconfig_dict,
example_inputs=(example_input,),
logger.info("Setup the model with qconfig:\n{}".format(qconfig)) )
return model, convert_fx convert_fn = get_convert_fx_fn(cfg, (example_input,))
return model, convert_fn
def prepare_fake_quant_model(cfg, model, is_qat, example_input=None): def prepare_fake_quant_model(cfg, model, is_qat, example_input=None):
......
...@@ -9,6 +9,7 @@ if TORCH_VERSION > (1, 10): ...@@ -9,6 +9,7 @@ if TORCH_VERSION > (1, 10):
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
else: else:
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
from mobile_cv.common.misc.oss_utils import fb_overwritable
QCONFIG_CREATOR_REGISTRY = Registry("QCONFIG_CREATOR_REGISTRY") QCONFIG_CREATOR_REGISTRY = Registry("QCONFIG_CREATOR_REGISTRY")
...@@ -82,6 +83,7 @@ def validate_native_backend(backend): ...@@ -82,6 +83,7 @@ def validate_native_backend(backend):
) )
@fb_overwritable()
def _smart_parse_extended_backend(extended_backend): def _smart_parse_extended_backend(extended_backend):
""" """
D2Go extends the definition of quantization "backend". In addition to PyTorch's D2Go extends the definition of quantization "backend". In addition to PyTorch's
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment