Commit a2659f5c authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

split regressor's prepare_for_quant into eager and FX

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/336

Reviewed By: jerryzh168

Differential Revision: D37860495

fbshipit-source-id: 1ce0bc7bc8071d3bfbe53cd61ed180da62e29327
parent 762b4fd8
......@@ -285,9 +285,9 @@ def default_custom_convert_fx(cfg, model):
return convert_fx(model)
def convert_to_fake_quant_model(cfg, model, is_qat, example_input=None):
def prepare_fake_quant_model(cfg, model, is_qat, example_input=None):
"""
Centralized function to convert fp32 model (D2Go's MetaArch) to fake quant model.
Centralized function to prepare fp32 model (D2Go's MetaArch) to fake quant model.
"""
# TODO: create a warning for the direct use of `torch.ao.quantization.get_default_qconfig`
# or `torch.ao.quantization.get_default_qat_qconfig` without calling D2Go's high-level
......@@ -351,7 +351,7 @@ def post_training_quantize(cfg, model, data_loader):
param.requires_grad = False
example_input = next(iter(data_loader))
model = convert_to_fake_quant_model(cfg, model, False, example_input)
model = prepare_fake_quant_model(cfg, model, False, example_input)
logger.info("Prepared the PTQ model for calibration:\n{}".format(model))
# Option for forcing running calibration on GPU, works only when the model supports
......@@ -412,7 +412,7 @@ def setup_qat_model(
model_fp32_state_dict = model_fp32.state_dict()
# prepare model for qat
model = convert_to_fake_quant_model(cfg, model_fp32, True)
model = prepare_fake_quant_model(cfg, model_fp32, True)
# make sure the proper qconfig are used in the model
learnable_qat.check_for_learnable_fake_quant_ops(qat_method, model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment