Commit 6d152388 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

implement the example input for rcnn

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/328

Reviewed By: jerryzh168

Differential Revision: D37695452

fbshipit-source-id: 744b1085365d1e155ea9e9fe51a6237994d90fa7
parent 6d276498
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import functools
import inspect
import json
import logging
import math
from typing import Any, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from d2go.config import CfgNode
......@@ -21,6 +23,8 @@ from detectron2.modeling import (
from detectron2.modeling.backbone.fpn import FPN
from detectron2.modeling.postprocessing import detector_postprocess
from detectron2.projects.point_rend import PointRendMaskHead
from detectron2.structures import Boxes, Instances, Keypoints, PolygonMasks
from detectron2.utils.events import EventStorage
from detectron2.utils.registry import Registry
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.arch.utils.quantize_utils import (
......@@ -30,8 +34,7 @@ from mobile_cv.arch.utils.quantize_utils import (
)
from mobile_cv.predictor.api import FuncInfo
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
# from torch.ao.quantization.utils import get_fqn_to_example_inputs
from torch.ao.quantization.utils import get_fqn_to_example_inputs
logger = logging.getLogger(__name__)
......@@ -42,7 +45,6 @@ logger = logging.getLogger(__name__)
# of registries might be over-kill.
RCNN_PREPARE_FOR_EXPORT_REGISTRY = Registry("RCNN_PREPARE_FOR_EXPORT")
RCNN_PREPARE_FOR_QUANT_REGISTRY = Registry("RCNN_PREPARE_FOR_QUANT")
RCNN_CUSTOM_CONVERT_FX_REGISTRY = Registry("RCNN_CUSTOM_CONVERT_FX")
# Re-register D2's meta-arch in D2Go with updated APIs
......@@ -56,18 +58,15 @@ class GeneralizedRCNN(_GeneralizedRCNN):
func = RCNN_PREPARE_FOR_QUANT_REGISTRY.get(cfg.RCNN_PREPARE_FOR_QUANT)
return func(self, cfg, *args, **kwargs)
def custom_convert_fx(self, cfg, *args, **kwargs):
func = RCNN_CUSTOM_CONVERT_FX_REGISTRY.get(cfg.RCNN_CUSTOM_CONVERT_FX)
return func(self, cfg, *args, **kwargs)
def custom_prepare_fx(self, cfg, example_input=None):
return default_rcnn_custom_prepare_fx(self, cfg, example_input)
def custom_convert_fx(self, cfg):
return default_rcnn_custom_convert_fx(self, cfg)
def _cast_model_to_device(self, device):
return _cast_detection_model(self, device)
@property
def example_input(self):
# TODO[quant-example-inputs]: provide correct example_input for GeneralizedRCNN
return torch.randn(1, 3, 224, 224)
# Re-register D2's meta-arch in D2Go with updated APIs
@META_ARCH_REGISTRY.register()
......@@ -213,26 +212,113 @@ def _apply_eager_mode_quant(cfg, model):
return model
def _lcm(x: Optional[int], y: Optional[int]) -> int:
if x is None or x == 0:
return y
if y is None or y == 0:
return x
return x * y // math.gcd(x, y)
def _get_example_rcnn_input(image_tensor_size: int):
def _get_batch():
# example input image
# TODO: do not hard-code channel size 3
image = torch.randn(3, image_tensor_size, image_tensor_size)
# example GT instances
num_instances = 2
gt_boxes = torch.tensor([[0.0, 0.0, 10.0, 10.0]] * num_instances)
gt_boxes = Boxes(gt_boxes)
gt_classes = torch.tensor([0] * num_instances)
polygon = np.array([0.0, 0.0, 10.0, 0.0, 10.0, 10.0]) # x1,y1,x2,y2,x3,y3
gt_masks = PolygonMasks([[polygon]] * num_instances)
# TODO: make keypoints inside box and set visibililty
# TODO: do not hard-code num_keypoints 17
keypoints = torch.randn(num_instances, 17, 3)
gt_keypoints = Keypoints(keypoints)
# NOTE: currenlty supports faster/mask/keypoint RCNN
instances = Instances(
image_size=(10, 10),
gt_boxes=gt_boxes,
gt_classes=gt_classes,
gt_masks=gt_masks,
gt_keypoints=gt_keypoints,
)
return {
# `file_name` and `image_id` are not used, can be any value.
"file_name": "fake_example_image.jpg",
"image_id": 42,
# `height` and `width` are used in post-processing to scale predictions back
# to original size, not used during training.
"height": 10,
"width": 10,
"image": image,
"instances": instances,
# NOTE: proposals are not supported
}
return [_get_batch(), _get_batch()]
def _set_qconfig(model, cfg):
model.qconfig = set_backend_and_create_qconfig(cfg, is_train=model.training)
# skip quantization for point rend head
if (
hasattr(model, "roi_heads")
and hasattr(model.roi_heads, "mask_head")
and isinstance(model.roi_heads.mask_head, PointRendMaskHead)
):
model.roi_heads.mask_head.qconfig = None
logger.info("Setup the model with qconfig:\n{}".format(model.qconfig))
@RCNN_PREPARE_FOR_QUANT_REGISTRY.register()
def default_rcnn_prepare_for_quant(self, cfg):
model = self
_set_qconfig(model, cfg)
# Modify the model for eager mode
model = _apply_eager_mode_quant(cfg, model)
model = fuse_utils.fuse_model(
model,
is_qat=cfg.QUANTIZATION.QAT.ENABLED,
inplace=True,
)
return model
def default_rcnn_custom_prepare_fx(self, cfg, example_input=None):
model = self
_set_qconfig(model, cfg)
# construct example input for FX when not provided
if example_input is None:
assert (
model.training
), "Currently only (FX mode) QAT requires user-provided `example_input`"
# make sure the image size can be divided by all strides and size_divisibility
required_strides = [model.backbone.size_divisibility] + [
shape_spec.stride for shape_spec in model.backbone.output_shape().values()
]
image_tensor_size = functools.reduce(_lcm, required_strides)
example_input = _get_example_rcnn_input(image_tensor_size)
_fx_quant_prepare(model, cfg, example_input)
return model
def _fx_quant_prepare(self, cfg, example_input):
prep_fn = prepare_qat_fx if self.training else prepare_fx
qconfig = {"": self.qconfig}
assert not isinstance(self.backbone, FPN), "FPN is not supported in FX mode"
# TODO[quant-example-inputs]: set a correct example_input and uncoment the next line
# fqn_to_example_inputs = get_fqn_to_example_inputs(self, (example_input,))
fqn_to_example_inputs = {
"backbone": (torch.randn(1, 3, 224, 224),),
"proposal_generator.rpn_head.rpn_feature": (torch.randn(1, 3, 224, 224),),
"proposal_generator.rpn_head.rpn_regressor.cls_logits": (
torch.randn(1, 3, 224, 224),
),
"proposal_generator.rpn_head.rpn_regressor.bbox_pred": (
torch.randn(1, 3, 224, 224),
),
"roi_heads.box_head.roi_box_conv": (torch.randn(1, 3, 224, 224),),
"roi_heads.box_head.avgpool": (torch.randn(1, 3, 224, 224),),
"roi_heads.box_predictor.cls_score": (torch.randn(1, 3, 224, 224),),
"roi_heads.box_predictor.bbox_pred": (torch.randn(1, 3, 224, 224),),
}
with EventStorage() as _: # D2's rcnn requires EventStorage when for loss
with torch.no_grad():
fqn_to_example_inputs = get_fqn_to_example_inputs(self, (example_input,))
self.backbone = prep_fn(
self.backbone,
qconfig,
......@@ -278,7 +364,7 @@ def _fx_quant_prepare(self, cfg, example_input):
self.roi_heads.box_head.avgpool = prep_fn(
self.roi_heads.box_head.avgpool,
qconfig,
fqn_to_example_inputs["roi_heads.box_head.avgpool"],
(torch.randn(1, 3, 224, 224),),
prepare_custom_config={
"input_quantized_idxs": [0],
"output_quantized_idxs": [0],
......@@ -298,35 +384,6 @@ def _fx_quant_prepare(self, cfg, example_input):
)
@RCNN_PREPARE_FOR_QUANT_REGISTRY.register()
def default_rcnn_prepare_for_quant(self, cfg, example_input=None):
model = self
model.qconfig = set_backend_and_create_qconfig(cfg, is_train=model.training)
if (
hasattr(model, "roi_heads")
and hasattr(model.roi_heads, "mask_head")
and isinstance(model.roi_heads.mask_head, PointRendMaskHead)
):
model.roi_heads.mask_head.qconfig = None
logger.info("Setup the model with qconfig:\n{}".format(model.qconfig))
# Modify the model for eager mode
if cfg.QUANTIZATION.EAGER_MODE:
model = _apply_eager_mode_quant(cfg, model)
model = fuse_utils.fuse_model(
model,
is_qat=cfg.QUANTIZATION.QAT.ENABLED,
inplace=True,
)
else:
if example_input is None:
example_input = model.example_input
_fx_quant_prepare(model, cfg, example_input)
return model
@RCNN_CUSTOM_CONVERT_FX_REGISTRY.register()
def default_rcnn_custom_convert_fx(self, cfg):
assert not isinstance(self.backbone, FPN), "FPN is not supported in FX mode"
self.backbone = convert_fx(
......
......@@ -276,7 +276,23 @@ def apply_prepare_for_quant(cfg, model, example_input=None):
# or `torch.ao.quantization.get_default_qat_qconfig` without calling D2Go's high-level
# `set_backend_and_create_qconfig` API.
if cfg.QUANTIZATION.EAGER_MODE:
if hasattr(model, "prepare_for_quant"):
model = model.prepare_for_quant(cfg)
else:
logger.info("Using default implementation for prepare_for_quant")
model = default_prepare_for_quant(cfg, model, example_input)
# NOTE: eager model needs to call prepare after `prepare_for_quant`
if model.training:
torch.ao.quantization.prepare_qat(model, inplace=True)
else:
torch.ao.quantization.prepare(model, inplace=True)
else:
if hasattr(model, "custom_prepare_fx"):
model = model.custom_prepare_fx(cfg, example_input)
# TODO: remove this branch after completely separating the eager and FX APIs
elif hasattr(model, "prepare_for_quant"):
model = model.prepare_for_quant(cfg, example_input)
else:
logger.info("Using default implementation for prepare_for_quant")
......@@ -311,8 +327,6 @@ def post_training_quantize(cfg, model, data_loader):
example_input = next(iter(data_loader))
model = apply_prepare_for_quant(cfg, model, example_input)
if cfg.QUANTIZATION.EAGER_MODE:
torch.ao.quantization.prepare(model, inplace=True)
logger.info("Prepared the PTQ model for calibration:\n{}".format(model))
# Option for forcing running calibration on GPU, works only when the model supports
......@@ -374,8 +388,6 @@ def setup_qat_model(
# prepare model for qat
model = apply_prepare_for_quant(cfg, model_fp32)
if cfg.QUANTIZATION.EAGER_MODE:
torch.ao.quantization.prepare_qat(model, inplace=True)
# make sure the proper qconfig are used in the model
learnable_qat.check_for_learnable_fake_quant_ops(qat_method, model)
......
......@@ -108,7 +108,6 @@ def _add_rcnn_default_config(_C: CN) -> None:
_C.RCNN_PREPARE_FOR_EXPORT = "default_rcnn_prepare_for_export"
_C.RCNN_PREPARE_FOR_QUANT = "default_rcnn_prepare_for_quant"
_C.RCNN_CUSTOM_CONVERT_FX = "default_rcnn_custom_convert_fx"
_C.register_deprecated_key("RCNN_PREPARE_FOR_QUANT_CONVERT")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment