Commit 5654d831 authored by Mircea Cimpoi's avatar Mircea Cimpoi Committed by Facebook GitHub Bot
Browse files

Refactor quantization logic to separate function

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/296

Extract the model quantization logic to separate function.
+minor cleanup (remove unused imports, f-strings)

Reviewed By: wat3rBro

Differential Revision: D36999434

fbshipit-source-id: 7aad64921b8cdf8779527c19b077ee788403c6b8
parent 55dc3da1
......@@ -24,21 +24,16 @@ NOTE:
import json
import logging
import os
import sys
from typing import Callable, Dict, NamedTuple, Optional, Tuple, Union
from typing import Iterable, Tuple
import torch
import torch.nn as nn
from d2go.config import CfgNode
from d2go.export.api import ModelExportMethod, ModelExportMethodRegistry
from d2go.quantization.modeling import post_training_quantize
from detectron2.utils.file_io import PathManager
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.predictor.api import FuncInfo, ModelInfo, PredictorInfo
from mobile_cv.predictor.builtin_functions import (
IdentityPostprocess,
IdentityPreprocess,
NaiveRunFunc,
)
from mobile_cv.predictor.api import ModelInfo, PredictorInfo
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION > (1, 10):
......@@ -51,13 +46,24 @@ else:
logger = logging.getLogger(__name__)
def convert_predictor(
cfg,
pytorch_model,
predictor_type,
data_loader,
def convert_model(
cfg: CfgNode,
pytorch_model: nn.Module,
predictor_type: str,
data_loader: Iterable,
):
if "int8" in predictor_type:
"""Converts pytorch model to pytorch model (fuse for fp32, fake quantize for int8)"""
return (
convert_quantized_model(cfg, pytorch_model, data_loader)
if "int8" in predictor_type
else _convert_fp_model(cfg, pytorch_model, data_loader)
)
def convert_quantized_model(
cfg: CfgNode, pytorch_model: nn.Module, data_loader: Iterable
) -> nn.Module:
"""Converts pytorch model to fake-quantized pytorch model."""
if not cfg.QUANTIZATION.QAT.ENABLED:
logger.info(
"The model is not quantized during training, running post"
......@@ -79,10 +85,16 @@ def convert_predictor(
else: # FX graph mode quantization
pytorch_model = convert_fx(pytorch_model)
logger.info("Quantized Model:\n{}".format(pytorch_model))
else:
logger.info(f"Quantized Model:\n{pytorch_model}")
return pytorch_model
def _convert_fp_model(
cfg: CfgNode, pytorch_model: nn.Module, data_loader: Iterable
) -> nn.Module:
"""Converts floating point predictor"""
pytorch_model = fuse_utils.fuse_model(pytorch_model)
logger.info("Fused Model:\n{}".format(pytorch_model))
logger.info(f"Fused Model:\n{pytorch_model}")
if fuse_utils.count_bn_exist(pytorch_model) > 0:
logger.warning("BN existed in pytorch model after fusing.")
return pytorch_model
......@@ -102,7 +114,7 @@ def convert_and_export_predictor(
- export: exporting the converted `pytorch_model` to predictor. This step
should not alter the behaviour of model.
"""
pytorch_model = convert_predictor(cfg, pytorch_model, predictor_type, data_loader)
pytorch_model = convert_model(cfg, pytorch_model, predictor_type, data_loader)
return export_predictor(cfg, pytorch_model, predictor_type, output_dir, data_loader)
......@@ -146,7 +158,7 @@ def _export_single_model(
model_export_method_str = model_export_method
model_export_method = ModelExportMethodRegistry.get(model_export_method)
assert issubclass(model_export_method, ModelExportMethod), model_export_method
logger.info("Using model export method: {}".format(model_export_method))
logger.info(f"Using model export method: {model_export_method}")
load_kwargs = model_export_method.export(
model=model,
......@@ -159,9 +171,7 @@ def _export_single_model(
model_rel_path = os.path.relpath(save_path, predictor_path)
return ModelInfo(
path=model_rel_path,
export_method="{}.{}".format(
model_export_method.__module__, model_export_method.__qualname__
),
export_method=f"{model_export_method.__module__}.{model_export_method.__qualname__}",
load_kwargs=load_kwargs,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment