Commit b0ef9f39 authored by Alan Lin's avatar Alan Lin Committed by Facebook GitHub Bot
Browse files

Add FCOS quantization support

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/480

As titled, this diff is a follow up of D38102296, adding FCOS quantization support.

A few items:
1. Add the FCOSInferenceWrapper
2. Add `prepare_for_export` and `prepare_for_quant` in FCOS.

NOTE: To avoid changing the adoption of `CycleBatchNormList` as mentioned in the previous diff, I have to add a hacky solution in the `prepare_for_quant` function. Specifically, flatten the one-element CycleBatchNormList to a BatchNorm2d.

Reviewed By: wat3rBro

Differential Revision: D43522795

fbshipit-source-id: d34eba006af675d0a90111aff0960b40a212c03c
parent 256b3f47
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import torch.nn as nn
from d2go.config import CfgNode as CN from d2go.config import CfgNode as CN
from d2go.export.api import PredictorExportConfig
from d2go.modeling.meta_arch.rcnn import D2RCNNInferenceWrapper
from d2go.quantization.qconfig import set_backend_and_create_qconfig
from d2go.registry.builtin import META_ARCH_REGISTRY from d2go.registry.builtin import META_ARCH_REGISTRY
from detectron2.config import configurable from detectron2.config import configurable
from detectron2.layers.batch_norm import CycleBatchNormList
from detectron2.modeling.backbone import build_backbone from detectron2.modeling.backbone import build_backbone
from detectron2.modeling.backbone.fpn import FPN
from detectron2.modeling.meta_arch.fcos import FCOS as d2_FCOS, FCOSHead from detectron2.modeling.meta_arch.fcos import FCOS as d2_FCOS, FCOSHead
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.arch.utils.quantize_utils import (
wrap_non_quant_group_norm,
wrap_quant_subclass,
)
from mobile_cv.predictor.api import FuncInfo
logger = logging.getLogger(__name__)
class FCOSInferenceWrapper(nn.Module):
def __init__(
self,
model,
):
super().__init__()
self.model = model
def forward(self, image):
inputs = [{"image": image}]
return self.model.forward(inputs)[0]["instances"]
def add_fcos_configs(cfg): def add_fcos_configs(cfg):
...@@ -66,3 +100,85 @@ class FCOS(d2_FCOS): ...@@ -66,3 +100,85 @@ class FCOS(d2_FCOS):
"test_nms_thresh": cfg.MODEL.FCOS.NMS_THRESH_TEST, "test_nms_thresh": cfg.MODEL.FCOS.NMS_THRESH_TEST,
"max_detections_per_image": cfg.TEST.DETECTIONS_PER_IMAGE, "max_detections_per_image": cfg.TEST.DETECTIONS_PER_IMAGE,
} }
# HACK: FCOS can share the same prepare functions w/ RCNN, w/ certain constrains
def prepare_for_export(self, cfg, inputs, predictor_type):
preprocess_info = FuncInfo.gen_func_info(
D2RCNNInferenceWrapper.Preprocess, params={}
)
preprocess_func = preprocess_info.instantiate()
return PredictorExportConfig(
model=FCOSInferenceWrapper(self),
data_generator=lambda x: (preprocess_func(x),),
model_export_method=predictor_type, # check this
preprocess_info=preprocess_info,
postprocess_info=FuncInfo.gen_func_info(
D2RCNNInferenceWrapper.Postprocess,
params={"detector_postprocess_done_in_model": True},
),
)
def prepare_for_quant(self, cfg, *args, **kwargs):
"""Wrap each quantized part of the model to insert Quant and DeQuant in-place"""
model = self
qconfig = set_backend_and_create_qconfig(
cfg, is_train=cfg.QUANTIZATION.QAT.ENABLED
)
logger.info("Setup the model with qconfig:\n{}".format(qconfig))
model.backbone.qconfig = qconfig
model.head.qconfig = qconfig
# Wrap the backbone based on the architecture type
if isinstance(model.backbone, FPN):
# Same trick in RCNN's _apply_eager_mode_quant
model.backbone.bottom_up = wrap_quant_subclass(
model.backbone.bottom_up,
n_inputs=1,
n_outputs=len(model.backbone.bottom_up._out_features),
)
else:
model.backbone = wrap_quant_subclass(
model.backbone, n_inputs=1, n_outputs=len(model.backbone._out_features)
)
def unpack_cyclebatchnormlist(module):
# HACK: This function flattens CycleBatchNormList for quantization purpose
if isinstance(module, CycleBatchNormList):
if len(module) > 1:
# TODO: add quantization support of CycleBatchNormList
raise NotImplementedError(
"CycleBatchNormList w/ more than one element cannot be quantized"
)
else:
num_channel = module.weight.size(0)
new_module = nn.BatchNorm2d(num_channel, affine=True)
new_module.weight = module.weight
new_module.bias = module.bias
new_module.running_mean = module[0].running_mean
new_module.running_var = module[0].running_var
module = new_module
else:
for name, child in module.named_children():
new_child = unpack_cyclebatchnormlist(child)
if new_child is not child:
module.add_module(name, new_child)
return module
model.head = unpack_cyclebatchnormlist(model.head)
# Wrap the FCOS head
model.head = wrap_quant_subclass(
model.head,
n_inputs=len(cfg.MODEL.FCOS.IN_FEATURES),
n_outputs=len(cfg.MODEL.FCOS.IN_FEATURES) * 3,
)
model = fuse_utils.fuse_model(
model,
is_qat=cfg.QUANTIZATION.QAT.ENABLED,
inplace=True,
)
model = wrap_non_quant_group_norm(model)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment