"git@developer.sourcefind.cn:OpenDAS/fairseq.git" did not exist on "885e7ec9ecbf19c36471e9aea095e2c8f238be83"
Commit cb985322 authored by Francisc Bungiu's avatar Francisc Bungiu Committed by Facebook GitHub Bot
Browse files

Implement Boltnn conversion

Summary:
Implementing `prepare_for_export` using the boltnn conversion from https://fburl.com/diffusion/ql1i3358.
Implementing `prepare_for_quant` using the quantization from https://fburl.com/diffusion/8nre9o03.

Differential Revision: D29817424

fbshipit-source-id: 800571ecf7f07d01c0a3a12100525354b48fe568
parent cbb6843e
...@@ -114,7 +114,8 @@ def convert_and_export_predictor( ...@@ -114,7 +114,8 @@ def convert_and_export_predictor(
) )
pytorch_model = post_training_quantize(cfg, pytorch_model, data_loader) pytorch_model = post_training_quantize(cfg, pytorch_model, data_loader)
# only check bn exists in ptq as qat still has bn inside fused ops # only check bn exists in ptq as qat still has bn inside fused ops
assert not fuse_utils.check_bn_exist(pytorch_model) if fuse_utils.check_bn_exist(pytorch_model):
logger.warn(f"Post training quantized model has bn inside fused ops")
logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...") logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...")
if cfg.QUANTIZATION.EAGER_MODE: if cfg.QUANTIZATION.EAGER_MODE:
......
...@@ -288,6 +288,9 @@ def post_training_quantize(cfg, model, data_loader): ...@@ -288,6 +288,9 @@ def post_training_quantize(cfg, model, data_loader):
calibration_iters = cfg.QUANTIZATION.PTQ.CALIBRATION_NUM_IMAGES calibration_iters = cfg.QUANTIZATION.PTQ.CALIBRATION_NUM_IMAGES
for idx, inputs in enumerate(data_loader): for idx, inputs in enumerate(data_loader):
# Setting CALIBRATION_NUM_IMAGES to 0 allows skipping calibration
if idx == calibration_iters:
break
logger.info("Running calibration iter: {}/{}".format(idx, calibration_iters)) logger.info("Running calibration iter: {}/{}".format(idx, calibration_iters))
if calibration_force_on_gpu: if calibration_force_on_gpu:
...@@ -299,8 +302,6 @@ def post_training_quantize(cfg, model, data_loader): ...@@ -299,8 +302,6 @@ def post_training_quantize(cfg, model, data_loader):
with torch.no_grad(): with torch.no_grad():
model(inputs) model(inputs)
if idx + 1 == calibration_iters:
break
else: else:
logger.warning("Can't run enough calibration iterations") logger.warning("Can't run enough calibration iterations")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment