Commit b0ef9f39 authored by Alan Lin's avatar Alan Lin Committed by Facebook GitHub Bot
Browse files

Add FCOS quantization support

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/480

As titled, this diff is a follow up of D38102296, adding FCOS quantization support.

A few items:
1. Add the FCOSInferenceWrapper
2. Add `prepare_for_export` and `prepare_for_quant` in FCOS.

NOTE: To avoid changing the adoption of `CycleBatchNormList` as mentioned in the previous diff, I have to add a hacky solution in the `prepare_for_quant` function. Specifically, flatten the one-element CycleBatchNormList to a BatchNorm2d.

Reviewed By: wat3rBro

Differential Revision: D43522795

fbshipit-source-id: d34eba006af675d0a90111aff0960b40a212c03c
parent 256b3f47
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import torch.nn as nn
from d2go.config import CfgNode as CN
from d2go.export.api import PredictorExportConfig
from d2go.modeling.meta_arch.rcnn import D2RCNNInferenceWrapper
from d2go.quantization.qconfig import set_backend_and_create_qconfig
from d2go.registry.builtin import META_ARCH_REGISTRY
from detectron2.config import configurable
from detectron2.layers.batch_norm import CycleBatchNormList
from detectron2.modeling.backbone import build_backbone
from detectron2.modeling.backbone.fpn import FPN
from detectron2.modeling.meta_arch.fcos import FCOS as d2_FCOS, FCOSHead
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.arch.utils.quantize_utils import (
wrap_non_quant_group_norm,
wrap_quant_subclass,
)
from mobile_cv.predictor.api import FuncInfo
logger = logging.getLogger(__name__)
class FCOSInferenceWrapper(nn.Module):
def __init__(
self,
model,
):
super().__init__()
self.model = model
def forward(self, image):
inputs = [{"image": image}]
return self.model.forward(inputs)[0]["instances"]
def add_fcos_configs(cfg):
......@@ -66,3 +100,85 @@ class FCOS(d2_FCOS):
"test_nms_thresh": cfg.MODEL.FCOS.NMS_THRESH_TEST,
"max_detections_per_image": cfg.TEST.DETECTIONS_PER_IMAGE,
}
# HACK: FCOS can share the same prepare functions w/ RCNN, w/ certain constrains
def prepare_for_export(self, cfg, inputs, predictor_type):
preprocess_info = FuncInfo.gen_func_info(
D2RCNNInferenceWrapper.Preprocess, params={}
)
preprocess_func = preprocess_info.instantiate()
return PredictorExportConfig(
model=FCOSInferenceWrapper(self),
data_generator=lambda x: (preprocess_func(x),),
model_export_method=predictor_type, # check this
preprocess_info=preprocess_info,
postprocess_info=FuncInfo.gen_func_info(
D2RCNNInferenceWrapper.Postprocess,
params={"detector_postprocess_done_in_model": True},
),
)
def prepare_for_quant(self, cfg, *args, **kwargs):
"""Wrap each quantized part of the model to insert Quant and DeQuant in-place"""
model = self
qconfig = set_backend_and_create_qconfig(
cfg, is_train=cfg.QUANTIZATION.QAT.ENABLED
)
logger.info("Setup the model with qconfig:\n{}".format(qconfig))
model.backbone.qconfig = qconfig
model.head.qconfig = qconfig
# Wrap the backbone based on the architecture type
if isinstance(model.backbone, FPN):
# Same trick in RCNN's _apply_eager_mode_quant
model.backbone.bottom_up = wrap_quant_subclass(
model.backbone.bottom_up,
n_inputs=1,
n_outputs=len(model.backbone.bottom_up._out_features),
)
else:
model.backbone = wrap_quant_subclass(
model.backbone, n_inputs=1, n_outputs=len(model.backbone._out_features)
)
def unpack_cyclebatchnormlist(module):
# HACK: This function flattens CycleBatchNormList for quantization purpose
if isinstance(module, CycleBatchNormList):
if len(module) > 1:
# TODO: add quantization support of CycleBatchNormList
raise NotImplementedError(
"CycleBatchNormList w/ more than one element cannot be quantized"
)
else:
num_channel = module.weight.size(0)
new_module = nn.BatchNorm2d(num_channel, affine=True)
new_module.weight = module.weight
new_module.bias = module.bias
new_module.running_mean = module[0].running_mean
new_module.running_var = module[0].running_var
module = new_module
else:
for name, child in module.named_children():
new_child = unpack_cyclebatchnormlist(child)
if new_child is not child:
module.add_module(name, new_child)
return module
model.head = unpack_cyclebatchnormlist(model.head)
# Wrap the FCOS head
model.head = wrap_quant_subclass(
model.head,
n_inputs=len(cfg.MODEL.FCOS.IN_FEATURES),
n_outputs=len(cfg.MODEL.FCOS.IN_FEATURES) * 3,
)
model = fuse_utils.fuse_model(
model,
is_qat=cfg.QUANTIZATION.QAT.ENABLED,
inplace=True,
)
model = wrap_non_quant_group_norm(model)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment