Commit 8cb50233 authored by Alan Lin's avatar Alan Lin Committed by Facebook GitHub Bot
Browse files

Add export registry for FCOS

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/491

As titled, although FCOS usually requires no customized export methods. We found that our internal MUI platform asks the exported model to follow certain protocols. To avoid mixing-up with external code, adding a export func registry to bypass it.

Reviewed By: wat3rBro

Differential Revision: D43800839

fbshipit-source-id: 41c8ebb10610ec92d17461211315c15908277b28
parent a7dc757c
......@@ -14,6 +14,7 @@ from detectron2.layers.batch_norm import CycleBatchNormList
from detectron2.modeling.backbone import build_backbone
from detectron2.modeling.backbone.fpn import FPN
from detectron2.modeling.meta_arch.fcos import FCOS as d2_FCOS, FCOSHead
from detectron2.utils.registry import Registry
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.arch.utils.quantize_utils import (
wrap_non_quant_group_norm,
......@@ -25,6 +26,9 @@ from mobile_cv.predictor.api import FuncInfo
logger = logging.getLogger(__name__)
# Registry to store custom export logic
FCOS_PREPARE_FOR_EXPORT_REGISTRY = Registry("FCOS_PREPARE_FOR_EXPORT")
class FCOSInferenceWrapper(nn.Module):
def __init__(
......@@ -56,6 +60,9 @@ def add_fcos_configs(cfg):
cfg.MODEL.FCOS.FOCAL_LOSS_ALPHA = 0.25
cfg.MODEL.FCOS.FOCAL_LOSS_GAMMA = 2.0
# Export method
cfg.FCOS_PREPARE_FOR_EXPORT = "default_fcos_prepare_for_export"
# Re-register D2's meta-arch in D2Go with updated APIs
@META_ARCH_REGISTRY.register()
......@@ -101,23 +108,10 @@ class FCOS(d2_FCOS):
"max_detections_per_image": cfg.TEST.DETECTIONS_PER_IMAGE,
}
# HACK: FCOS can share the same prepare functions w/ RCNN, w/ certain constrains
def prepare_for_export(self, cfg, inputs, predictor_type):
preprocess_info = FuncInfo.gen_func_info(
D2RCNNInferenceWrapper.Preprocess, params={}
)
preprocess_func = preprocess_info.instantiate()
return PredictorExportConfig(
model=FCOSInferenceWrapper(self),
data_generator=lambda x: (preprocess_func(x),),
model_export_method=predictor_type, # check this
preprocess_info=preprocess_info,
postprocess_info=FuncInfo.gen_func_info(
D2RCNNInferenceWrapper.Postprocess,
params={"detector_postprocess_done_in_model": True},
),
)
# HACK: default FCOS export shares the same prepare functions w/ RCNN under certain constrains
def prepare_for_export(self, cfg, *args, **kwargs):
func = FCOS_PREPARE_FOR_EXPORT_REGISTRY.get(cfg.FCOS_PREPARE_FOR_EXPORT)
return func(self, cfg, *args, **kwargs)
def prepare_for_quant(self, cfg, *args, **kwargs):
"""Wrap each quantized part of the model to insert Quant and DeQuant in-place"""
......@@ -182,3 +176,24 @@ class FCOS(d2_FCOS):
)
model = wrap_non_quant_group_norm(model)
return model
@FCOS_PREPARE_FOR_EXPORT_REGISTRY.register()
def default_fcos_prepare_for_export(self, cfg, inputs, predictor_type):
pytorch_model = self
preprocess_info = FuncInfo.gen_func_info(
D2RCNNInferenceWrapper.Preprocess, params={}
)
preprocess_func = preprocess_info.instantiate()
return PredictorExportConfig(
model=FCOSInferenceWrapper(pytorch_model),
data_generator=lambda x: (preprocess_func(x),),
model_export_method=predictor_type,
preprocess_info=preprocess_info,
postprocess_info=FuncInfo.gen_func_info(
D2RCNNInferenceWrapper.Postprocess,
params={"detector_postprocess_done_in_model": True},
),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment