Commit 9d649b1e authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

workaround the quantization for FPN

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/163

Make quantizing FPN work, note that this is not a proper fix, which might be making pytorch picking the D2 (https://github.com/facebookresearch/d2go/commit/7992f91324aee6ae59795063a007c6837e60cdb8)'s Conv2d, and we need to revert this diff if it's supported.

Differential Revision: D33523917

fbshipit-source-id: 3d00f540a9fcb75a34125c244d86263d517a359f
parent 02ecf002
......@@ -9,6 +9,7 @@ import torch.nn as nn
from d2go.export.api import PredictorExportConfig
from d2go.utils.qat_utils import get_qat_qconfig
from detectron2.modeling import GeneralizedRCNN
from detectron2.modeling.backbone.fpn import FPN
from detectron2.modeling.postprocessing import detector_postprocess
from detectron2.projects.point_rend import PointRendMaskHead
from detectron2.utils.registry import Registry
......@@ -113,6 +114,19 @@ def _apply_eager_mode_quant(cfg, model):
"""Wrap each quantized part of the model to insert Quant and DeQuant in-place"""
# Wrap backbone and proposal_generator
if isinstance(model.backbone, FPN):
# HACK: currently the quantization won't pick up D2's the Conv2d, which is
# used by D2's default FPN (same as FBNetV2FPN), this causes problem if we
# warpping entire backbone as whole. The current solution is only quantizing
# bottom_up and leaving other parts un-quantized. TODO (T109761730): However
# we need to re-visit this if using other (fbnet-based) FPN module since the
# new FPN module might be pikced by quantization.
model.backbone.bottom_up = wrap_quant_subclass(
model.backbone.bottom_up,
n_inputs=1,
n_outputs=len(model.backbone.bottom_up._out_features),
)
else:
model.backbone = wrap_quant_subclass(
model.backbone, n_inputs=1, n_outputs=len(model.backbone._out_features)
)
......@@ -175,6 +189,7 @@ def _apply_eager_mode_quant(cfg, model):
def _fx_quant_prepare(self, cfg):
prep_fn = prepare_qat_fx if self.training else prepare_fx
qconfig = {"": self.qconfig}
assert not isinstance(self.backbone, FPN), "FPN is not supported in FX mode"
self.backbone = prep_fn(
self.backbone,
qconfig,
......@@ -261,6 +276,7 @@ def default_rcnn_prepare_for_quant_convert(self, cfg):
if cfg.QUANTIZATION.EAGER_MODE:
raise NotImplementedError()
assert not isinstance(self.backbone, FPN), "FPN is not supported in FX mode"
self.backbone = convert_fx(
self.backbone,
convert_custom_config_dict={"preserved_attributes": ["size_divisibility"]},
......
......@@ -41,6 +41,28 @@ class TestFBNetV3MaskRCNNFP32(RCNNBaseTestCases.TemplateTestCase):
self._test_export(predictor_type, compare_match=compare_match)
class TestFBNetV3MaskRCNNFPNFP32(RCNNBaseTestCases.TemplateTestCase):
def setup_custom_test(self):
super().setup_custom_test()
self.cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3g_fpn.yaml")
def test_inference(self):
self._test_inference()
@RCNNBaseTestCases.expand_parameterized_test_export(
[
["torchscript@c2_ops", True],
["torchscript", True],
["torchscript_int8@c2_ops", False],
["torchscript_int8", False],
]
)
def test_export(self, predictor_type, compare_match):
if os.getenv("OSSRUN") == "1" and "@c2_ops" in predictor_type:
self.skipTest("Caffe2 is not available for OSS")
self._test_export(predictor_type, compare_match=compare_match)
class TestFBNetV3MaskRCNNQATEager(RCNNBaseTestCases.TemplateTestCase):
def setup_custom_test(self):
super().setup_custom_test()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment