Commit 5aa5e3c8 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

allow replacing prepare_for_export/quant for RCNN via config

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/67

When extending RCNN, sometimes we need to also modify `prepare_for_export/quant`. It might be inconvenient to create new meta-arch just for the purpose of modifying `prepare_for_export/quant`, therefore we create the registry so user can change them via config.

Reviewed By: zhanghang1989

Differential Revision: D28308056

fbshipit-source-id: 4f169eb38292a75d15d3b9f44694480eaa9244e0
parent 608bf2ec
...@@ -5,196 +5,31 @@ ...@@ -5,196 +5,31 @@
import logging import logging
from functools import lru_cache from functools import lru_cache
import torch from d2go.modeling.meta_arch.rcnn import GeneralizedRCNNPatch
from d2go.export.api import PredictorExportConfig from detectron2.modeling import GeneralizedRCNN
from d2go.utils.prepare_for_export import d2_meta_arch_prepare_for_export
from detectron2.export.caffe2_modeling import (
META_ARCH_CAFFE2_EXPORT_TYPE_MAP,
convert_batched_inputs_to_c2_format,
)
from detectron2.export.shared import get_pb_arg_vali, get_pb_arg_vals
from detectron2.modeling import META_ARCH_REGISTRY, GeneralizedRCNN
from detectron2.modeling.postprocessing import detector_postprocess
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.arch.utils.quantize_utils import (
wrap_non_quant_group_norm,
wrap_quant_subclass,
QuantWrapper,
)
from mobile_cv.predictor.api import FuncInfo
from torch.quantization.quantize_fx import prepare_fx, prepare_qat_fx, convert_fx
from detectron2.projects.point_rend import PointRendMaskHead
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@lru_cache() # only call once @lru_cache() # only call once
def patch_d2_meta_arch(): def patch_d2_meta_arch():
# HACK: inject prepare_for_export for all D2's meta-arch """
for cls_obj in META_ARCH_REGISTRY._obj_map.values(): D2Go requires interfaces like prepare_for_export/prepare_for_quant from meta-arch in
if cls_obj.__module__.startswith("detectron2."): order to do export/quant, this function applies the monkey patch to the original
if hasattr(cls_obj, "prepare_for_export"): D2's meta-archs.
assert cls_obj.prepare_for_export == d2_meta_arch_prepare_for_export """
else:
cls_obj.prepare_for_export = d2_meta_arch_prepare_for_export def _check_and_set(cls_obj, method_name, method_func):
if hasattr(cls_obj, method_name):
if hasattr(cls_obj, "prepare_for_quant"): assert getattr(cls_obj, method_name) == method_func
assert cls_obj.prepare_for_quant == d2_meta_arch_prepare_for_quant else:
else: setattr(cls_obj, method_name, method_func)
cls_obj.prepare_for_quant = d2_meta_arch_prepare_for_quant
def _apply_patch(dst_cls, src_cls):
if hasattr(cls_obj, "prepare_for_quant_convert"): assert hasattr(src_cls, "METHODS_TO_PATCH")
assert ( for method_name in src_cls.METHODS_TO_PATCH:
cls_obj.prepare_for_quant_convert assert hasattr(src_cls, method_name)
== d2_meta_arch_prepare_for_quant_convert _check_and_set(dst_cls, method_name, getattr(src_cls, method_name))
)
else: _apply_patch(GeneralizedRCNN, GeneralizedRCNNPatch)
cls_obj.prepare_for_quant_convert = ( # TODO: patch other meta-archs defined in D2
d2_meta_arch_prepare_for_quant_convert
)
def _apply_eager_mode_quant(cfg, model):
if isinstance(model, GeneralizedRCNN):
""" Wrap each quantized part of the model to insert Quant and DeQuant in-place """
# Wrap backbone and proposal_generator
model.backbone = wrap_quant_subclass(
model.backbone, n_inputs=1, n_outputs=len(model.backbone._out_features)
)
model.proposal_generator.rpn_head = wrap_quant_subclass(
model.proposal_generator.rpn_head,
n_inputs=len(cfg.MODEL.RPN.IN_FEATURES),
n_outputs=len(cfg.MODEL.RPN.IN_FEATURES) * 2,
)
# Wrap the roi_heads, box_pooler is not quantized
if hasattr(model.roi_heads, "box_head"):
model.roi_heads.box_head = wrap_quant_subclass(
model.roi_heads.box_head,
n_inputs=1,
n_outputs=1,
)
# for faster_rcnn_R_50_C4
if hasattr(model.roi_heads, "res5"):
model.roi_heads.res5 = wrap_quant_subclass(
model.roi_heads.res5,
n_inputs=1,
n_outputs=1,
)
model.roi_heads.box_predictor = wrap_quant_subclass(
model.roi_heads.box_predictor, n_inputs=1, n_outputs=2
)
# Optionally wrap keypoint and mask heads, pools are not quantized
if hasattr(model.roi_heads, "keypoint_head"):
model.roi_heads.keypoint_head = wrap_quant_subclass(
model.roi_heads.keypoint_head,
n_inputs=1,
n_outputs=1,
wrapped_method_name="layers",
)
if hasattr(model.roi_heads, "mask_head"):
model.roi_heads.mask_head = wrap_quant_subclass(
model.roi_heads.mask_head,
n_inputs=1,
n_outputs=1,
wrapped_method_name="layers",
)
# StandardROIHeadsWithSubClass uses a subclass head
if hasattr(model.roi_heads, "subclass_head"):
q_subclass_head = QuantWrapper(model.roi_heads.subclass_head)
model.roi_heads.subclass_head = q_subclass_head
else:
raise NotImplementedError(
"Eager mode for {} is not supported".format(type(model))
)
# TODO: wrap the normalizer and make it quantizable
# NOTE: GN is not quantizable, assuming all GN follows a quantized conv,
# wrap them with dequant-quant
model = wrap_non_quant_group_norm(model)
return model
def _fx_quant_prepare(self, cfg):
prep_fn = prepare_qat_fx if self.training else prepare_fx
qconfig = {"": self.qconfig}
self.backbone = prep_fn(
self.backbone,
qconfig,
{"preserved_attributes": ["size_divisibility"]},
)
self.proposal_generator.rpn_head.rpn_feature = prep_fn(
self.proposal_generator.rpn_head.rpn_feature, qconfig
)
self.proposal_generator.rpn_head.rpn_regressor.cls_logits = prep_fn(
self.proposal_generator.rpn_head.rpn_regressor.cls_logits, qconfig
)
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred = prep_fn(
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred, qconfig
)
self.roi_heads.box_head.roi_box_conv = prep_fn(
self.roi_heads.box_head.roi_box_conv, qconfig
)
self.roi_heads.box_head.avgpool = prep_fn(self.roi_heads.box_head.avgpool, qconfig)
self.roi_heads.box_predictor.cls_score = prep_fn(
self.roi_heads.box_predictor.cls_score, qconfig
)
self.roi_heads.box_predictor.bbox_pred = prep_fn(
self.roi_heads.box_predictor.bbox_pred, qconfig
)
def d2_meta_arch_prepare_for_quant(self, cfg):
model = self
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
model.qconfig = (
torch.quantization.get_default_qat_qconfig(cfg.QUANTIZATION.BACKEND)
if model.training
else torch.quantization.get_default_qconfig(cfg.QUANTIZATION.BACKEND)
)
if hasattr(model, "roi_heads") and hasattr(model.roi_heads, "mask_head") and \
isinstance(model.roi_heads.mask_head, PointRendMaskHead):
model.roi_heads.mask_head.qconfig = None
logger.info("Setup the model with qconfig:\n{}".format(model.qconfig))
# Modify the model for eager mode
if cfg.QUANTIZATION.EAGER_MODE:
model = _apply_eager_mode_quant(cfg, model)
model = fuse_utils.fuse_model(model, inplace=True)
else:
_fx_quant_prepare(model, cfg)
return model
def d2_meta_arch_prepare_for_quant_convert(self, cfg):
if cfg.QUANTIZATION.EAGER_MODE:
raise NotImplementedError()
self.backbone = convert_fx(
self.backbone,
convert_custom_config_dict={"preserved_attributes": ["size_divisibility"]},
)
self.proposal_generator.rpn_head.rpn_feature = convert_fx(
self.proposal_generator.rpn_head.rpn_feature
)
self.proposal_generator.rpn_head.rpn_regressor.cls_logits = convert_fx(
self.proposal_generator.rpn_head.rpn_regressor.cls_logits
)
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred = convert_fx(
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred
)
self.roi_heads.box_head.roi_box_conv = convert_fx(
self.roi_heads.box_head.roi_box_conv
)
self.roi_heads.box_head.avgpool = convert_fx(self.roi_heads.box_head.avgpool)
self.roi_heads.box_predictor.cls_score = convert_fx(
self.roi_heads.box_predictor.cls_score
)
self.roi_heads.box_predictor.bbox_pred = convert_fx(
self.roi_heads.box_predictor.bbox_pred
)
return self
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import torch
from d2go.export.api import PredictorExportConfig
from d2go.utils.export_utils import (
D2Caffe2MetaArchPreprocessFunc,
D2Caffe2MetaArchPostprocessFunc,
D2RCNNTracingWrapper,
)
from detectron2.export.caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP
from detectron2.modeling import GeneralizedRCNN
from detectron2.projects.point_rend import PointRendMaskHead
from detectron2.utils.registry import Registry
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.arch.utils.quantize_utils import (
wrap_non_quant_group_norm,
wrap_quant_subclass,
QuantWrapper,
)
from mobile_cv.predictor.api import FuncInfo
from torch.quantization.quantize_fx import prepare_fx, prepare_qat_fx, convert_fx
logger = logging.getLogger(__name__)
# NOTE: Customized heads are often used in the GeneralizedRCNN, this leads to the needs
# for also customizating export/quant APIs, therefore registries are provided for easy
# override without creating new meta-archs. For other less general meta-arch, this type
# of registries might be over-kill.
RCNN_PREPARE_FOR_EXPORT_REGISTRY = Registry("RCNN_PREPARE_FOR_EXPORT")
RCNN_PREPARE_FOR_QUANT_REGISTRY = Registry("RCNN_PREPARE_FOR_QUANT")
RCNN_PREPARE_FOR_QUANT_CONVERT_REGISTRY = Registry("RCNN_PREPARE_FOR_QUANT_CONVERT")
class GeneralizedRCNNPatch:
METHODS_TO_PATCH = [
"prepare_for_export",
"prepare_for_quant",
"prepare_for_quant_convert",
]
def prepare_for_export(self, cfg, *args, **kwargs):
func = RCNN_PREPARE_FOR_EXPORT_REGISTRY.get(cfg.RCNN_PREPARE_FOR_EXPORT)
return func(self, cfg, *args, **kwargs)
def prepare_for_quant(self, cfg, *args, **kwargs):
func = RCNN_PREPARE_FOR_QUANT_REGISTRY.get(cfg.RCNN_PREPARE_FOR_QUANT)
return func(self, cfg, *args, **kwargs)
def prepare_for_quant_convert(self, cfg, *args, **kwargs):
func = RCNN_PREPARE_FOR_QUANT_CONVERT_REGISTRY.get(
cfg.RCNN_PREPARE_FOR_QUANT_CONVERT
)
return func(self, cfg, *args, **kwargs)
@RCNN_PREPARE_FOR_EXPORT_REGISTRY.register()
def default_rcnn_prepare_for_export(self, cfg, inputs, predictor_type):
if "torchscript" in predictor_type and "@tracing" in predictor_type:
return PredictorExportConfig(
model=D2RCNNTracingWrapper(self),
data_generator=D2RCNNTracingWrapper.generator_trace_inputs,
run_func_info=FuncInfo.gen_func_info(
D2RCNNTracingWrapper.RunFunc, params={}
),
)
if cfg.MODEL.META_ARCHITECTURE in META_ARCH_CAFFE2_EXPORT_TYPE_MAP:
C2MetaArch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[cfg.MODEL.META_ARCHITECTURE]
c2_compatible_model = C2MetaArch(cfg, self)
preprocess_info = FuncInfo.gen_func_info(
D2Caffe2MetaArchPreprocessFunc,
params=D2Caffe2MetaArchPreprocessFunc.get_params(cfg, c2_compatible_model),
)
postprocess_info = FuncInfo.gen_func_info(
D2Caffe2MetaArchPostprocessFunc,
params=D2Caffe2MetaArchPostprocessFunc.get_params(cfg, c2_compatible_model),
)
preprocess_func = preprocess_info.instantiate()
return PredictorExportConfig(
model=c2_compatible_model,
# Caffe2MetaArch takes a single tuple as input (which is the return of
# preprocess_func), data_generator requires all positional args as a tuple.
data_generator=lambda x: (preprocess_func(x),),
preprocess_info=preprocess_info,
postprocess_info=postprocess_info,
)
raise NotImplementedError("Can't determine prepare_for_tracing!")
def _apply_eager_mode_quant(cfg, model):
if isinstance(model, GeneralizedRCNN):
"""Wrap each quantized part of the model to insert Quant and DeQuant in-place"""
# Wrap backbone and proposal_generator
model.backbone = wrap_quant_subclass(
model.backbone, n_inputs=1, n_outputs=len(model.backbone._out_features)
)
model.proposal_generator.rpn_head = wrap_quant_subclass(
model.proposal_generator.rpn_head,
n_inputs=len(cfg.MODEL.RPN.IN_FEATURES),
n_outputs=len(cfg.MODEL.RPN.IN_FEATURES) * 2,
)
# Wrap the roi_heads, box_pooler is not quantized
if hasattr(model.roi_heads, "box_head"):
model.roi_heads.box_head = wrap_quant_subclass(
model.roi_heads.box_head,
n_inputs=1,
n_outputs=1,
)
# for faster_rcnn_R_50_C4
if hasattr(model.roi_heads, "res5"):
model.roi_heads.res5 = wrap_quant_subclass(
model.roi_heads.res5,
n_inputs=1,
n_outputs=1,
)
model.roi_heads.box_predictor = wrap_quant_subclass(
model.roi_heads.box_predictor, n_inputs=1, n_outputs=2
)
# Optionally wrap keypoint and mask heads, pools are not quantized
if hasattr(model.roi_heads, "keypoint_head"):
model.roi_heads.keypoint_head = wrap_quant_subclass(
model.roi_heads.keypoint_head,
n_inputs=1,
n_outputs=1,
wrapped_method_name="layers",
)
if hasattr(model.roi_heads, "mask_head"):
model.roi_heads.mask_head = wrap_quant_subclass(
model.roi_heads.mask_head,
n_inputs=1,
n_outputs=1,
wrapped_method_name="layers",
)
# StandardROIHeadsWithSubClass uses a subclass head
if hasattr(model.roi_heads, "subclass_head"):
q_subclass_head = QuantWrapper(model.roi_heads.subclass_head)
model.roi_heads.subclass_head = q_subclass_head
else:
raise NotImplementedError(
"Eager mode for {} is not supported".format(type(model))
)
# TODO: wrap the normalizer and make it quantizable
# NOTE: GN is not quantizable, assuming all GN follows a quantized conv,
# wrap them with dequant-quant
model = wrap_non_quant_group_norm(model)
return model
def _fx_quant_prepare(self, cfg):
prep_fn = prepare_qat_fx if self.training else prepare_fx
qconfig = {"": self.qconfig}
self.backbone = prep_fn(
self.backbone,
qconfig,
{"preserved_attributes": ["size_divisibility"]},
)
self.proposal_generator.rpn_head.rpn_feature = prep_fn(
self.proposal_generator.rpn_head.rpn_feature, qconfig
)
self.proposal_generator.rpn_head.rpn_regressor.cls_logits = prep_fn(
self.proposal_generator.rpn_head.rpn_regressor.cls_logits, qconfig
)
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred = prep_fn(
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred, qconfig
)
self.roi_heads.box_head.roi_box_conv = prep_fn(
self.roi_heads.box_head.roi_box_conv, qconfig
)
self.roi_heads.box_head.avgpool = prep_fn(self.roi_heads.box_head.avgpool, qconfig)
self.roi_heads.box_predictor.cls_score = prep_fn(
self.roi_heads.box_predictor.cls_score, qconfig
)
self.roi_heads.box_predictor.bbox_pred = prep_fn(
self.roi_heads.box_predictor.bbox_pred, qconfig
)
@RCNN_PREPARE_FOR_QUANT_REGISTRY.register()
def default_rcnn_prepare_for_quant(self, cfg):
model = self
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
model.qconfig = (
torch.quantization.get_default_qat_qconfig(cfg.QUANTIZATION.BACKEND)
if model.training
else torch.quantization.get_default_qconfig(cfg.QUANTIZATION.BACKEND)
)
if (
hasattr(model, "roi_heads")
and hasattr(model.roi_heads, "mask_head")
and isinstance(model.roi_heads.mask_head, PointRendMaskHead)
):
model.roi_heads.mask_head.qconfig = None
logger.info("Setup the model with qconfig:\n{}".format(model.qconfig))
# Modify the model for eager mode
if cfg.QUANTIZATION.EAGER_MODE:
model = _apply_eager_mode_quant(cfg, model)
model = fuse_utils.fuse_model(model, inplace=True)
else:
_fx_quant_prepare(model, cfg)
return model
@RCNN_PREPARE_FOR_QUANT_CONVERT_REGISTRY.register()
def default_rcnn_prepare_for_quant_convert(self, cfg):
if cfg.QUANTIZATION.EAGER_MODE:
raise NotImplementedError()
self.backbone = convert_fx(
self.backbone,
convert_custom_config_dict={"preserved_attributes": ["size_divisibility"]},
)
self.proposal_generator.rpn_head.rpn_feature = convert_fx(
self.proposal_generator.rpn_head.rpn_feature
)
self.proposal_generator.rpn_head.rpn_regressor.cls_logits = convert_fx(
self.proposal_generator.rpn_head.rpn_regressor.cls_logits
)
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred = convert_fx(
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred
)
self.roi_heads.box_head.roi_box_conv = convert_fx(
self.roi_heads.box_head.roi_box_conv
)
self.roi_heads.box_head.avgpool = convert_fx(self.roi_heads.box_head.avgpool)
self.roi_heads.box_predictor.cls_score = convert_fx(
self.roi_heads.box_predictor.cls_score
)
self.roi_heads.box_predictor.bbox_pred = convert_fx(
self.roi_heads.box_predictor.bbox_pred
)
return self
...@@ -663,4 +663,9 @@ class GeneralizedRCNNRunner(Detectron2GoRunner): ...@@ -663,4 +663,9 @@ class GeneralizedRCNNRunner(Detectron2GoRunner):
_C = super(GeneralizedRCNNRunner, GeneralizedRCNNRunner).get_default_cfg() _C = super(GeneralizedRCNNRunner, GeneralizedRCNNRunner).get_default_cfg()
_C.EXPORT_CAFFE2 = CN() _C.EXPORT_CAFFE2 = CN()
_C.EXPORT_CAFFE2.USE_HEATMAP_MAX_KEYPOINT = False _C.EXPORT_CAFFE2.USE_HEATMAP_MAX_KEYPOINT = False
_C.RCNN_PREPARE_FOR_EXPORT = "default_rcnn_prepare_for_export"
_C.RCNN_PREPARE_FOR_QUANT = "default_rcnn_prepare_for_quant"
_C.RCNN_PREPARE_FOR_QUANT_CONVERT = "default_rcnn_prepare_for_quant_convert"
return _C return _C
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from d2go.export.api import PredictorExportConfig
from d2go.utils.export_utils import (
D2Caffe2MetaArchPreprocessFunc,
D2Caffe2MetaArchPostprocessFunc,
D2RCNNTracingWrapper,
)
from detectron2.export.caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP
from mobile_cv.predictor.api import FuncInfo
logger = logging.getLogger(__name__)
def d2_meta_arch_prepare_for_export(self, cfg, inputs, predictor_type):
if "torchscript" in predictor_type and "@tracing" in predictor_type:
return PredictorExportConfig(
model=D2RCNNTracingWrapper(self),
data_generator=D2RCNNTracingWrapper.generator_trace_inputs,
run_func_info=FuncInfo.gen_func_info(
D2RCNNTracingWrapper.RunFunc, params={}
),
)
if cfg.MODEL.META_ARCHITECTURE in META_ARCH_CAFFE2_EXPORT_TYPE_MAP:
C2MetaArch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[cfg.MODEL.META_ARCHITECTURE]
c2_compatible_model = C2MetaArch(cfg, self)
preprocess_info = FuncInfo.gen_func_info(
D2Caffe2MetaArchPreprocessFunc,
params=D2Caffe2MetaArchPreprocessFunc.get_params(cfg, c2_compatible_model),
)
postprocess_info = FuncInfo.gen_func_info(
D2Caffe2MetaArchPostprocessFunc,
params=D2Caffe2MetaArchPostprocessFunc.get_params(cfg, c2_compatible_model),
)
preprocess_func = preprocess_info.instantiate()
return PredictorExportConfig(
model=c2_compatible_model,
# Caffe2MetaArch takes a single tuple as input (which is the return of
# preprocess_func), data_generator requires all positional args as a tuple.
data_generator=lambda x: (preprocess_func(x),),
preprocess_info=preprocess_info,
postprocess_info=postprocess_info,
)
raise NotImplementedError("Can't determine prepare_for_tracing!")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment