Commit 9d649b1e authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

workaround the quantization for FPN

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/163

Make quantizing FPN work, note that this is not a proper fix, which might be making pytorch picking the D2 (https://github.com/facebookresearch/d2go/commit/7992f91324aee6ae59795063a007c6837e60cdb8)'s Conv2d, and we need to revert this diff if it's supported.

Differential Revision: D33523917

fbshipit-source-id: 3d00f540a9fcb75a34125c244d86263d517a359f
parent 02ecf002
...@@ -9,6 +9,7 @@ import torch.nn as nn ...@@ -9,6 +9,7 @@ import torch.nn as nn
from d2go.export.api import PredictorExportConfig from d2go.export.api import PredictorExportConfig
from d2go.utils.qat_utils import get_qat_qconfig from d2go.utils.qat_utils import get_qat_qconfig
from detectron2.modeling import GeneralizedRCNN from detectron2.modeling import GeneralizedRCNN
from detectron2.modeling.backbone.fpn import FPN
from detectron2.modeling.postprocessing import detector_postprocess from detectron2.modeling.postprocessing import detector_postprocess
from detectron2.projects.point_rend import PointRendMaskHead from detectron2.projects.point_rend import PointRendMaskHead
from detectron2.utils.registry import Registry from detectron2.utils.registry import Registry
...@@ -113,9 +114,22 @@ def _apply_eager_mode_quant(cfg, model): ...@@ -113,9 +114,22 @@ def _apply_eager_mode_quant(cfg, model):
"""Wrap each quantized part of the model to insert Quant and DeQuant in-place""" """Wrap each quantized part of the model to insert Quant and DeQuant in-place"""
# Wrap backbone and proposal_generator # Wrap backbone and proposal_generator
model.backbone = wrap_quant_subclass( if isinstance(model.backbone, FPN):
model.backbone, n_inputs=1, n_outputs=len(model.backbone._out_features) # HACK: currently the quantization won't pick up D2's the Conv2d, which is
) # used by D2's default FPN (same as FBNetV2FPN), this causes problem if we
# warpping entire backbone as whole. The current solution is only quantizing
# bottom_up and leaving other parts un-quantized. TODO (T109761730): However
# we need to re-visit this if using other (fbnet-based) FPN module since the
# new FPN module might be pikced by quantization.
model.backbone.bottom_up = wrap_quant_subclass(
model.backbone.bottom_up,
n_inputs=1,
n_outputs=len(model.backbone.bottom_up._out_features),
)
else:
model.backbone = wrap_quant_subclass(
model.backbone, n_inputs=1, n_outputs=len(model.backbone._out_features)
)
model.proposal_generator.rpn_head = wrap_quant_subclass( model.proposal_generator.rpn_head = wrap_quant_subclass(
model.proposal_generator.rpn_head, model.proposal_generator.rpn_head,
n_inputs=len(cfg.MODEL.RPN.IN_FEATURES), n_inputs=len(cfg.MODEL.RPN.IN_FEATURES),
...@@ -175,6 +189,7 @@ def _apply_eager_mode_quant(cfg, model): ...@@ -175,6 +189,7 @@ def _apply_eager_mode_quant(cfg, model):
def _fx_quant_prepare(self, cfg): def _fx_quant_prepare(self, cfg):
prep_fn = prepare_qat_fx if self.training else prepare_fx prep_fn = prepare_qat_fx if self.training else prepare_fx
qconfig = {"": self.qconfig} qconfig = {"": self.qconfig}
assert not isinstance(self.backbone, FPN), "FPN is not supported in FX mode"
self.backbone = prep_fn( self.backbone = prep_fn(
self.backbone, self.backbone,
qconfig, qconfig,
...@@ -261,6 +276,7 @@ def default_rcnn_prepare_for_quant_convert(self, cfg): ...@@ -261,6 +276,7 @@ def default_rcnn_prepare_for_quant_convert(self, cfg):
if cfg.QUANTIZATION.EAGER_MODE: if cfg.QUANTIZATION.EAGER_MODE:
raise NotImplementedError() raise NotImplementedError()
assert not isinstance(self.backbone, FPN), "FPN is not supported in FX mode"
self.backbone = convert_fx( self.backbone = convert_fx(
self.backbone, self.backbone,
convert_custom_config_dict={"preserved_attributes": ["size_divisibility"]}, convert_custom_config_dict={"preserved_attributes": ["size_divisibility"]},
......
...@@ -41,6 +41,28 @@ class TestFBNetV3MaskRCNNFP32(RCNNBaseTestCases.TemplateTestCase): ...@@ -41,6 +41,28 @@ class TestFBNetV3MaskRCNNFP32(RCNNBaseTestCases.TemplateTestCase):
self._test_export(predictor_type, compare_match=compare_match) self._test_export(predictor_type, compare_match=compare_match)
class TestFBNetV3MaskRCNNFPNFP32(RCNNBaseTestCases.TemplateTestCase):
def setup_custom_test(self):
super().setup_custom_test()
self.cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3g_fpn.yaml")
def test_inference(self):
self._test_inference()
@RCNNBaseTestCases.expand_parameterized_test_export(
[
["torchscript@c2_ops", True],
["torchscript", True],
["torchscript_int8@c2_ops", False],
["torchscript_int8", False],
]
)
def test_export(self, predictor_type, compare_match):
if os.getenv("OSSRUN") == "1" and "@c2_ops" in predictor_type:
self.skipTest("Caffe2 is not available for OSS")
self._test_export(predictor_type, compare_match=compare_match)
class TestFBNetV3MaskRCNNQATEager(RCNNBaseTestCases.TemplateTestCase): class TestFBNetV3MaskRCNNQATEager(RCNNBaseTestCases.TemplateTestCase):
def setup_custom_test(self): def setup_custom_test(self):
super().setup_custom_test() super().setup_custom_test()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment