"tests/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "35fdf5371e94ed4252a61cd885137e1a1d38eda1"
Commit 5654d831 authored by Mircea Cimpoi's avatar Mircea Cimpoi Committed by Facebook GitHub Bot
Browse files

Refactor quantization logic to separate function

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/296

Extract the model quantization logic to separate function.
+minor cleanup (remove unused imports, f-strings)

Reviewed By: wat3rBro

Differential Revision: D36999434

fbshipit-source-id: 7aad64921b8cdf8779527c19b077ee788403c6b8
parent 55dc3da1
...@@ -24,21 +24,16 @@ NOTE: ...@@ -24,21 +24,16 @@ NOTE:
import json import json
import logging import logging
import os import os
import sys from typing import Iterable, Tuple
from typing import Callable, Dict, NamedTuple, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from d2go.config import CfgNode
from d2go.export.api import ModelExportMethod, ModelExportMethodRegistry from d2go.export.api import ModelExportMethod, ModelExportMethodRegistry
from d2go.quantization.modeling import post_training_quantize from d2go.quantization.modeling import post_training_quantize
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from mobile_cv.arch.utils import fuse_utils from mobile_cv.arch.utils import fuse_utils
from mobile_cv.predictor.api import FuncInfo, ModelInfo, PredictorInfo from mobile_cv.predictor.api import ModelInfo, PredictorInfo
from mobile_cv.predictor.builtin_functions import (
IdentityPostprocess,
IdentityPreprocess,
NaiveRunFunc,
)
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION > (1, 10): if TORCH_VERSION > (1, 10):
...@@ -51,40 +46,57 @@ else: ...@@ -51,40 +46,57 @@ else:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def convert_predictor( def convert_model(
cfg, cfg: CfgNode,
pytorch_model, pytorch_model: nn.Module,
predictor_type, predictor_type: str,
data_loader, data_loader: Iterable,
): ):
if "int8" in predictor_type: """Converts pytorch model to pytorch model (fuse for fp32, fake quantize for int8)"""
if not cfg.QUANTIZATION.QAT.ENABLED: return (
logger.info( convert_quantized_model(cfg, pytorch_model, data_loader)
"The model is not quantized during training, running post" if "int8" in predictor_type
" training quantization ..." else _convert_fp_model(cfg, pytorch_model, data_loader)
) )
pytorch_model = post_training_quantize(cfg, pytorch_model, data_loader) def convert_quantized_model(
# only check bn exists in ptq as qat still has bn inside fused ops cfg: CfgNode, pytorch_model: nn.Module, data_loader: Iterable
if fuse_utils.check_bn_exist(pytorch_model): ) -> nn.Module:
logger.warn("Post training quantized model has bn inside fused ops") """Converts pytorch model to fake-quantized pytorch model."""
logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...") if not cfg.QUANTIZATION.QAT.ENABLED:
logger.info(
if hasattr(pytorch_model, "prepare_for_quant_convert"): "The model is not quantized during training, running post"
pytorch_model = pytorch_model.prepare_for_quant_convert(cfg) " training quantization ..."
else: )
# TODO(T93870381): move this to a default function
if cfg.QUANTIZATION.EAGER_MODE: pytorch_model = post_training_quantize(cfg, pytorch_model, data_loader)
pytorch_model = convert(pytorch_model, inplace=False) # only check bn exists in ptq as qat still has bn inside fused ops
else: # FX graph mode quantization if fuse_utils.check_bn_exist(pytorch_model):
pytorch_model = convert_fx(pytorch_model) logger.warn("Post training quantized model has bn inside fused ops")
logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...")
logger.info("Quantized Model:\n{}".format(pytorch_model))
if hasattr(pytorch_model, "prepare_for_quant_convert"):
pytorch_model = pytorch_model.prepare_for_quant_convert(cfg)
else: else:
pytorch_model = fuse_utils.fuse_model(pytorch_model) # TODO(T93870381): move this to a default function
logger.info("Fused Model:\n{}".format(pytorch_model)) if cfg.QUANTIZATION.EAGER_MODE:
if fuse_utils.count_bn_exist(pytorch_model) > 0: pytorch_model = convert(pytorch_model, inplace=False)
logger.warning("BN existed in pytorch model after fusing.") else: # FX graph mode quantization
pytorch_model = convert_fx(pytorch_model)
logger.info(f"Quantized Model:\n{pytorch_model}")
return pytorch_model
def _convert_fp_model(
cfg: CfgNode, pytorch_model: nn.Module, data_loader: Iterable
) -> nn.Module:
"""Converts floating point predictor"""
pytorch_model = fuse_utils.fuse_model(pytorch_model)
logger.info(f"Fused Model:\n{pytorch_model}")
if fuse_utils.count_bn_exist(pytorch_model) > 0:
logger.warning("BN existed in pytorch model after fusing.")
return pytorch_model return pytorch_model
...@@ -102,7 +114,7 @@ def convert_and_export_predictor( ...@@ -102,7 +114,7 @@ def convert_and_export_predictor(
- export: exporting the converted `pytorch_model` to predictor. This step - export: exporting the converted `pytorch_model` to predictor. This step
should not alter the behaviour of model. should not alter the behaviour of model.
""" """
pytorch_model = convert_predictor(cfg, pytorch_model, predictor_type, data_loader) pytorch_model = convert_model(cfg, pytorch_model, predictor_type, data_loader)
return export_predictor(cfg, pytorch_model, predictor_type, output_dir, data_loader) return export_predictor(cfg, pytorch_model, predictor_type, output_dir, data_loader)
...@@ -146,7 +158,7 @@ def _export_single_model( ...@@ -146,7 +158,7 @@ def _export_single_model(
model_export_method_str = model_export_method model_export_method_str = model_export_method
model_export_method = ModelExportMethodRegistry.get(model_export_method) model_export_method = ModelExportMethodRegistry.get(model_export_method)
assert issubclass(model_export_method, ModelExportMethod), model_export_method assert issubclass(model_export_method, ModelExportMethod), model_export_method
logger.info("Using model export method: {}".format(model_export_method)) logger.info(f"Using model export method: {model_export_method}")
load_kwargs = model_export_method.export( load_kwargs = model_export_method.export(
model=model, model=model,
...@@ -159,9 +171,7 @@ def _export_single_model( ...@@ -159,9 +171,7 @@ def _export_single_model(
model_rel_path = os.path.relpath(save_path, predictor_path) model_rel_path = os.path.relpath(save_path, predictor_path)
return ModelInfo( return ModelInfo(
path=model_rel_path, path=model_rel_path,
export_method="{}.{}".format( export_method=f"{model_export_method.__module__}.{model_export_method.__qualname__}",
model_export_method.__module__, model_export_method.__qualname__
),
load_kwargs=load_kwargs, load_kwargs=load_kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment