Commit 8cb50233 authored by Alan Lin's avatar Alan Lin Committed by Facebook GitHub Bot
Browse files

Add export registry for FCOS

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/491

As titled, although FCOS usually requires no customized export methods. We found that our internal MUI platform asks the exported model to follow certain protocols. To avoid mixing-up with external code, adding a export func registry to bypass it.

Reviewed By: wat3rBro

Differential Revision: D43800839

fbshipit-source-id: 41c8ebb10610ec92d17461211315c15908277b28
parent a7dc757c
...@@ -14,6 +14,7 @@ from detectron2.layers.batch_norm import CycleBatchNormList ...@@ -14,6 +14,7 @@ from detectron2.layers.batch_norm import CycleBatchNormList
from detectron2.modeling.backbone import build_backbone from detectron2.modeling.backbone import build_backbone
from detectron2.modeling.backbone.fpn import FPN from detectron2.modeling.backbone.fpn import FPN
from detectron2.modeling.meta_arch.fcos import FCOS as d2_FCOS, FCOSHead from detectron2.modeling.meta_arch.fcos import FCOS as d2_FCOS, FCOSHead
from detectron2.utils.registry import Registry
from mobile_cv.arch.utils import fuse_utils from mobile_cv.arch.utils import fuse_utils
from mobile_cv.arch.utils.quantize_utils import ( from mobile_cv.arch.utils.quantize_utils import (
wrap_non_quant_group_norm, wrap_non_quant_group_norm,
...@@ -25,6 +26,9 @@ from mobile_cv.predictor.api import FuncInfo ...@@ -25,6 +26,9 @@ from mobile_cv.predictor.api import FuncInfo
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Registry to store custom export logic
FCOS_PREPARE_FOR_EXPORT_REGISTRY = Registry("FCOS_PREPARE_FOR_EXPORT")
class FCOSInferenceWrapper(nn.Module): class FCOSInferenceWrapper(nn.Module):
def __init__( def __init__(
...@@ -56,6 +60,9 @@ def add_fcos_configs(cfg): ...@@ -56,6 +60,9 @@ def add_fcos_configs(cfg):
cfg.MODEL.FCOS.FOCAL_LOSS_ALPHA = 0.25 cfg.MODEL.FCOS.FOCAL_LOSS_ALPHA = 0.25
cfg.MODEL.FCOS.FOCAL_LOSS_GAMMA = 2.0 cfg.MODEL.FCOS.FOCAL_LOSS_GAMMA = 2.0
# Export method
cfg.FCOS_PREPARE_FOR_EXPORT = "default_fcos_prepare_for_export"
# Re-register D2's meta-arch in D2Go with updated APIs # Re-register D2's meta-arch in D2Go with updated APIs
@META_ARCH_REGISTRY.register() @META_ARCH_REGISTRY.register()
...@@ -101,23 +108,10 @@ class FCOS(d2_FCOS): ...@@ -101,23 +108,10 @@ class FCOS(d2_FCOS):
"max_detections_per_image": cfg.TEST.DETECTIONS_PER_IMAGE, "max_detections_per_image": cfg.TEST.DETECTIONS_PER_IMAGE,
} }
# HACK: FCOS can share the same prepare functions w/ RCNN, w/ certain constrains # HACK: default FCOS export shares the same prepare functions w/ RCNN under certain constrains
def prepare_for_export(self, cfg, inputs, predictor_type): def prepare_for_export(self, cfg, *args, **kwargs):
preprocess_info = FuncInfo.gen_func_info( func = FCOS_PREPARE_FOR_EXPORT_REGISTRY.get(cfg.FCOS_PREPARE_FOR_EXPORT)
D2RCNNInferenceWrapper.Preprocess, params={} return func(self, cfg, *args, **kwargs)
)
preprocess_func = preprocess_info.instantiate()
return PredictorExportConfig(
model=FCOSInferenceWrapper(self),
data_generator=lambda x: (preprocess_func(x),),
model_export_method=predictor_type, # check this
preprocess_info=preprocess_info,
postprocess_info=FuncInfo.gen_func_info(
D2RCNNInferenceWrapper.Postprocess,
params={"detector_postprocess_done_in_model": True},
),
)
def prepare_for_quant(self, cfg, *args, **kwargs): def prepare_for_quant(self, cfg, *args, **kwargs):
"""Wrap each quantized part of the model to insert Quant and DeQuant in-place""" """Wrap each quantized part of the model to insert Quant and DeQuant in-place"""
...@@ -182,3 +176,24 @@ class FCOS(d2_FCOS): ...@@ -182,3 +176,24 @@ class FCOS(d2_FCOS):
) )
model = wrap_non_quant_group_norm(model) model = wrap_non_quant_group_norm(model)
return model return model
@FCOS_PREPARE_FOR_EXPORT_REGISTRY.register()
def default_fcos_prepare_for_export(self, cfg, inputs, predictor_type):
pytorch_model = self
preprocess_info = FuncInfo.gen_func_info(
D2RCNNInferenceWrapper.Preprocess, params={}
)
preprocess_func = preprocess_info.instantiate()
return PredictorExportConfig(
model=FCOSInferenceWrapper(pytorch_model),
data_generator=lambda x: (preprocess_func(x),),
model_export_method=predictor_type,
preprocess_info=preprocess_info,
postprocess_info=FuncInfo.gen_func_info(
D2RCNNInferenceWrapper.Postprocess,
params={"detector_postprocess_done_in_model": True},
),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment