Commit 6d276498 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

remove unnecessary eager mode branch from prepare_for_quant_convert/custom_convert_fx

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/326

since prepare_for_quant_convert/custom_convert_fx is FX only, we can remove eager mode code from it. Also adjust the logic in `d2go/export/exporter.py` accordingly to reflect this.

Reviewed By: jerryzh168

Differential Revision: D37676977

fbshipit-source-id: ebd05082ee81bc1ac32fcc2a87bc0dfaacedd5bd
parent 97904ba4
...@@ -24,24 +24,19 @@ NOTE: ...@@ -24,24 +24,19 @@ NOTE:
import json import json
import logging import logging
import os import os
from typing import Iterable, Tuple from typing import Iterable
import torch
import torch.nn as nn import torch.nn as nn
from d2go.config import CfgNode from d2go.config import CfgNode
from d2go.export.api import ModelExportMethod, ModelExportMethodRegistry from d2go.export.api import ModelExportMethod, ModelExportMethodRegistry
from d2go.quantization.modeling import post_training_quantize from d2go.quantization.modeling import (
convert_to_quantized_model,
post_training_quantize,
)
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from mobile_cv.arch.utils import fuse_utils from mobile_cv.arch.utils import fuse_utils
from mobile_cv.predictor.api import ModelInfo, PredictorInfo from mobile_cv.predictor.api import ModelInfo, PredictorInfo
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION > (1, 10):
from torch.ao.quantization import convert
from torch.ao.quantization.quantize_fx import convert_fx
else:
from torch.quantization import convert
from torch.quantization.quantize_fx import convert_fx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -63,8 +58,9 @@ def convert_model( ...@@ -63,8 +58,9 @@ def convert_model(
def convert_quantized_model( def convert_quantized_model(
cfg: CfgNode, pytorch_model: nn.Module, data_loader: Iterable cfg: CfgNode, pytorch_model: nn.Module, data_loader: Iterable
) -> nn.Module: ) -> nn.Module:
"""Converts pytorch model to fake-quantized pytorch model."""
if not cfg.QUANTIZATION.QAT.ENABLED: if not cfg.QUANTIZATION.QAT.ENABLED:
# For PTQ, converts pytorch model to fake-quantized pytorch model. For QAT, the
# built pytorch model is already fake-quantized.
logger.info( logger.info(
"The model is not quantized during training, running post" "The model is not quantized during training, running post"
" training quantization ..." " training quantization ..."
...@@ -76,15 +72,8 @@ def convert_quantized_model( ...@@ -76,15 +72,8 @@ def convert_quantized_model(
logger.warn("Post training quantized model has bn inside fused ops") logger.warn("Post training quantized model has bn inside fused ops")
logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...") logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...")
if hasattr(pytorch_model, "custom_convert_fx"): # convert the fake-quantized model to int8 model
pytorch_model = pytorch_model.custom_convert_fx(cfg) pytorch_model = convert_to_quantized_model(cfg, pytorch_model)
else:
# TODO(T93870381): move this to a default function
if cfg.QUANTIZATION.EAGER_MODE:
pytorch_model = convert(pytorch_model, inplace=False)
else: # FX graph mode quantization
pytorch_model = convert_fx(pytorch_model)
logger.info(f"Quantized Model:\n{pytorch_model}") logger.info(f"Quantized Model:\n{pytorch_model}")
return pytorch_model return pytorch_model
......
...@@ -29,7 +29,6 @@ from mobile_cv.arch.utils.quantize_utils import ( ...@@ -29,7 +29,6 @@ from mobile_cv.arch.utils.quantize_utils import (
wrap_quant_subclass, wrap_quant_subclass,
) )
from mobile_cv.predictor.api import FuncInfo from mobile_cv.predictor.api import FuncInfo
from torch.ao.quantization import convert
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
# from torch.ao.quantization.utils import get_fqn_to_example_inputs # from torch.ao.quantization.utils import get_fqn_to_example_inputs
...@@ -329,35 +328,32 @@ def default_rcnn_prepare_for_quant(self, cfg, example_input=None): ...@@ -329,35 +328,32 @@ def default_rcnn_prepare_for_quant(self, cfg, example_input=None):
@RCNN_CUSTOM_CONVERT_FX_REGISTRY.register() @RCNN_CUSTOM_CONVERT_FX_REGISTRY.register()
def default_rcnn_custom_convert_fx(self, cfg): def default_rcnn_custom_convert_fx(self, cfg):
if cfg.QUANTIZATION.EAGER_MODE: assert not isinstance(self.backbone, FPN), "FPN is not supported in FX mode"
convert(self, inplace=True) self.backbone = convert_fx(
else: self.backbone,
assert not isinstance(self.backbone, FPN), "FPN is not supported in FX mode" convert_custom_config={
self.backbone = convert_fx( "preserved_attributes": ["size_divisibility", "padding_constraints"]
self.backbone, },
convert_custom_config={ )
"preserved_attributes": ["size_divisibility", "padding_constraints"] self.proposal_generator.rpn_head.rpn_feature = convert_fx(
}, self.proposal_generator.rpn_head.rpn_feature
) )
self.proposal_generator.rpn_head.rpn_feature = convert_fx( self.proposal_generator.rpn_head.rpn_regressor.cls_logits = convert_fx(
self.proposal_generator.rpn_head.rpn_feature self.proposal_generator.rpn_head.rpn_regressor.cls_logits
) )
self.proposal_generator.rpn_head.rpn_regressor.cls_logits = convert_fx( self.proposal_generator.rpn_head.rpn_regressor.bbox_pred = convert_fx(
self.proposal_generator.rpn_head.rpn_regressor.cls_logits self.proposal_generator.rpn_head.rpn_regressor.bbox_pred
) )
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred = convert_fx( self.roi_heads.box_head.roi_box_conv = convert_fx(
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred self.roi_heads.box_head.roi_box_conv
) )
self.roi_heads.box_head.roi_box_conv = convert_fx( self.roi_heads.box_head.avgpool = convert_fx(self.roi_heads.box_head.avgpool)
self.roi_heads.box_head.roi_box_conv self.roi_heads.box_predictor.cls_score = convert_fx(
) self.roi_heads.box_predictor.cls_score
self.roi_heads.box_head.avgpool = convert_fx(self.roi_heads.box_head.avgpool) )
self.roi_heads.box_predictor.cls_score = convert_fx( self.roi_heads.box_predictor.bbox_pred = convert_fx(
self.roi_heads.box_predictor.cls_score self.roi_heads.box_predictor.bbox_pred
) )
self.roi_heads.box_predictor.bbox_pred = convert_fx(
self.roi_heads.box_predictor.bbox_pred
)
return self return self
......
...@@ -20,8 +20,10 @@ from .qconfig import set_backend_and_create_qconfig, smart_decode_backend ...@@ -20,8 +20,10 @@ from .qconfig import set_backend_and_create_qconfig, smart_decode_backend
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION > (1, 10): if TORCH_VERSION > (1, 10):
from torch.ao.quantization import convert
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
else: else:
from torch.quantization import convert
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -283,6 +285,21 @@ def apply_prepare_for_quant(cfg, model, example_input=None): ...@@ -283,6 +285,21 @@ def apply_prepare_for_quant(cfg, model, example_input=None):
return model return model
def convert_to_quantized_model(cfg, fp32_model):
"""
Convert fake quant model (fp32 operators) to "real" quantized model (int8 operators)
"""
if cfg.QUANTIZATION.EAGER_MODE:
int8_model = convert(fp32_model, inplace=False)
else:
# FX graph mode quantization
if hasattr(fp32_model, "custom_convert_fx"):
int8_model = fp32_model.custom_convert_fx(cfg)
else:
int8_model = convert_fx(fp32_model)
return int8_model
@mock_quantization_type @mock_quantization_type
def post_training_quantize(cfg, model, data_loader): def post_training_quantize(cfg, model, data_loader):
"""Calibrate a model, convert it to a quantized pytorch model""" """Calibrate a model, convert it to a quantized pytorch model"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment