Commit bd4ba04d authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

make model export function as a registry

Reviewed By: zhanghang1989

Differential Revision: D27710199

fbshipit-source-id: 178a28972dcc06350e99263f4b38f284cf10c890
parent fb3ba095
......@@ -35,6 +35,7 @@ import torch.quantization.quantize_fx
from d2go.modeling.quantization import post_training_quantize
from detectron2.utils.file_io import PathManager
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.registry import Registry
from mobile_cv.predictor.api import FuncInfo, ModelInfo, PredictorInfo
from mobile_cv.predictor.builtin_functions import (
IdentityPostprocess,
......@@ -72,6 +73,7 @@ class PredictorExportConfig(NamedTuple):
# Shall we save data_generator in the predictor? This might be necessary when data
# is needed, eg. running benchmark for sub models
data_generator: Optional[Callable] = None
model_export_method: Optional[str] = None
model_export_kwargs: Optional[Union[Dict, Any]] = None
preprocess_info: FuncInfo = FuncInfo.gen_func_info(IdentityPreprocess, params={})
......@@ -103,7 +105,7 @@ def convert_and_export_predictor(
# TODO(future diff): move this logic to prepare_for_quant_convert
pytorch_model = torch.quantization.convert(pytorch_model, inplace=False)
else: # FX graph mode quantization
if hasattr(pytorch_model, 'prepare_for_quant_convert'):
if hasattr(pytorch_model, "prepare_for_quant_convert"):
pytorch_model = pytorch_model.prepare_for_quant_convert(cfg)
else:
# TODO(future diff): move this to a default function
......@@ -173,18 +175,19 @@ def default_export_predictor(
if export_config.data_generator is not None
else None
)
model_export_method = export_config.model_export_method or predictor_type
model_export_kwargs = export_config.model_export_kwargs or {}
# the default implementation assumes model type is the same as the predictor type
model_type = predictor_type
model_path = predictor_path # might be sub dir for multiple models
standard_model_export(
model,
model_type=model_type,
save_path=model_path,
load_kwargs = ModelExportMethodRegistry.get(model_export_method).export(
model=model,
input_args=input_args,
save_path=model_path,
**model_export_kwargs,
)
assert isinstance(load_kwargs, dict) # TODO: save this in predictor_info
model_rel_path = os.path.relpath(model_path, predictor_path)
# assemble predictor
......@@ -202,22 +205,29 @@ def default_export_predictor(
return predictor_path
# TODO: determine if saving data should be part of standard_model_export or not.
# TODO: determine how to support PTQ, option 1): do everything inside this function,
# drawback: needs data loader; no customization. option 2): do calibration outside,
# and only do tracing inside (same as fp32 torchscript model).
# TODO: define the supported model types, current caffe2/torchscript/torchscript_int8
# is not enough.
# TODO: determine if registry is needed (probably not since we only need to support
# a few known formats) as library code.
def standard_model_export(model, model_type, save_path, input_args, **kwargs):
if model_type.startswith("torchscript"):
from d2go.export.torchscript import trace_and_save_torchscript
trace_and_save_torchscript(model, input_args, save_path, **kwargs)
elif model_type == "caffe2":
ModelExportMethodRegistry = Registry("ModelExportMethod", allow_override=True)
@ModelExportMethodRegistry.register("caffe2")
class DefaultCaffe2Export(object):
@classmethod
def export(cls, model, input_args, save_path, **export_kwargs):
from d2go.export.caffe2 import export_caffe2
# TODO: export_caffe2 depends on D2, need to make a copy of the implementation
# TODO: support specifying optimization pass via kwargs
export_caffe2(model, input_args[0], save_path, **kwargs)
else:
raise NotImplementedError("Incorrect model_type: {}".format(model_type))
export_caffe2(model, input_args[0], save_path, **export_kwargs)
return {}
@ModelExportMethodRegistry.register("torchscript")
@ModelExportMethodRegistry.register("torchscript@tracing")
@ModelExportMethodRegistry.register("torchscript@scripting")
@ModelExportMethodRegistry.register("torchscript_int8")
@ModelExportMethodRegistry.register("torchscript_int8@tracing")
@ModelExportMethodRegistry.register("torchscript_int8@scripting")
class DefaultTorchscriptExport(object):
@classmethod
def export(cls, model, input_args, save_path, **export_kwargs):
from d2go.export.torchscript import trace_and_save_torchscript
trace_and_save_torchscript(model, input_args, save_path, **export_kwargs)
return {}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment