Commit 05b33018 authored by Naveen Suda's avatar Naveen Suda Committed by Facebook GitHub Bot
Browse files

use custom prepare function

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/660

To enable w8a16 sigmoid in d2go, we need to use custom prepare function.

Reviewed By: ayushidalmia, jiaxuzhu92

Differential Revision: D56275899

fbshipit-source-id: 654900011a1393e81289e8c9412b5886831765e2
parent ba7c235b
......@@ -6,6 +6,7 @@ from typing import Tuple
import torch
from mobile_cv.common.misc.oss_utils import fb_overwritable
from torch.ao.quantization.quantize import prepare, prepare_qat
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
......@@ -28,3 +29,9 @@ def get_convert_fn(cfg, example_inputs=None, qconfig_mapping=None, backend_confi
return convert
else:
return convert_fx
@fb_overwritable()
def get_prepare_fn(cfg, is_qat):
if cfg.QUANTIZATION.EAGER_MODE:
return prepare_qat if is_qat else prepare
......@@ -10,7 +10,7 @@ from typing import Any, Dict, Optional, Tuple
import detectron2.utils.comm as comm
import torch
from d2go.quantization import learnable_qat
from d2go.quantization.fx import get_convert_fn, get_prepare_fx_fn
from d2go.quantization.fx import get_convert_fn, get_prepare_fn, get_prepare_fx_fn
from d2go.quantization.qconfig import (
set_backend_and_create_qconfig,
smart_decode_backend,
......@@ -34,6 +34,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
)
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
# some tests still import prepare/convert from below. So don't remove these.
if TORCH_VERSION > (1, 10):
from torch.ao.quantization.quantize import convert
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
......@@ -388,10 +389,8 @@ def prepare_fake_quant_model(cfg, model, is_qat, example_input=None):
)
model = default_prepare_for_quant(cfg, model)
# NOTE: eager model needs to call prepare after `prepare_for_quant`
if is_qat:
torch.ao.quantization.prepare_qat(model, inplace=True)
else:
torch.ao.quantization.prepare(model, inplace=True)
prepare_fn = get_prepare_fn(cfg, is_qat)
prepare_fn(model, inplace=True)
else:
# FX graph mode requires the model to be symbolically traceable, swap common
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment