Commit 83d9e8f7 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

finish the interface for ModelExportMethod

Reviewed By: zhanghang1989

Differential Revision: D27916281

fbshipit-source-id: 7ea01e99e9c2a9b19992f458abc786713ba5a4cd
parent d86ecc92
......@@ -27,7 +27,9 @@ NOTE:
import json
import logging
import os
from abc import ABC, abstractmethod
from typing import Callable, Dict, NamedTuple, Optional, Union
from typing import final
import torch
import torch.nn as nn
......@@ -39,6 +41,7 @@ from d2go.export.torchscript import (
from d2go.modeling.quantization import post_training_quantize
from detectron2.utils.file_io import PathManager
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.file_utils import make_temp_directory
from mobile_cv.common.misc.registry import Registry
from mobile_cv.predictor.api import FuncInfo, ModelInfo, PredictorInfo
from mobile_cv.predictor.builtin_functions import (
......@@ -46,6 +49,7 @@ from mobile_cv.predictor.builtin_functions import (
IdentityPreprocess,
NaiveRunFunc,
)
from mobile_cv.predictor.model_wrappers import load_model
logger = logging.getLogger(__name__)
......@@ -161,18 +165,28 @@ def _export_single_model(
save_path,
model_export_method,
model_export_kwargs,
predictor_type, # TODO: remove this after refactoring ModelInfo
):
assert isinstance(model, nn.Module), model
load_kwargs = ModelExportMethodRegistry.get(model_export_method).export(
# model_export_method either inherits ModelExportMethod or is a key in the registry
if isinstance(model_export_method, str):
model_export_method = ModelExportMethodRegistry.get(model_export_method)
assert issubclass(model_export_method, ModelExportMethod), model_export_method
load_kwargs = model_export_method.export(
model=model,
input_args=input_args,
save_path=save_path,
**model_export_kwargs,
)
assert isinstance(load_kwargs, dict) # TODO: save this in predictor_info
assert isinstance(load_kwargs, dict)
model_rel_path = os.path.relpath(save_path, predictor_path)
return ModelInfo(path=model_rel_path, type=predictor_type)
return ModelInfo(
path=model_rel_path,
export_method="{}.{}".format(
model_export_method.__module__, model_export_method.__qualname__
),
load_kwargs=load_kwargs,
)
def default_export_predictor(
......@@ -218,7 +232,6 @@ def default_export_predictor(
if export_config.model_export_kwargs is None
else export_config.model_export_kwargs[name]
),
predictor_type=predictor_type,
)
models_info[name] = model_info
predictor_init_kwargs["models"] = models_info
......@@ -231,7 +244,6 @@ def default_export_predictor(
save_path=save_path,
model_export_method=export_config.model_export_method or predictor_type,
model_export_kwargs=export_config.model_export_kwargs or {},
predictor_type=predictor_type,
)
predictor_init_kwargs["model"] = model_info
......@@ -245,18 +257,93 @@ def default_export_predictor(
return predictor_path
class ModelExportMethod(ABC):
"""
Base class for "model export method". Each model export method can export a pytorch
model to a certain deployable format, such as torchscript or caffe2. It consists
with `export` and `load` methods.
"""
@classmethod
@abstractmethod
def export(cls, model, input_args, save_path, **export_kwargs):
"""
Export the model to deployable format.
Args:
model (nn.Module): a pytorch model to export
input_args (Tuple[Any]): inputs of model, called as model(*input_args)
save_path (str): directory where the model will be exported
export_kwargs (Dict): additional parameters for exporting model defined
by each model export method.
Return:
load_kwargs (Dict): additional information (besides save_path) needed in
order to load the exported model. This needs to be JSON serializable.
"""
pass
@classmethod
@abstractmethod
def load(cls, save_path, **load_kwargs):
"""
Load the exported model back for inference.
Args:
save_path (str): directory where the model is stored.
load_kwargs (Dict): addtional information to load the exported model.
Returns:
model (nn.Module): a nn.Module (often time a wrapper for non torchscript
types like caffe2), it works the same as the original pytorch model,
i.e. getting the same output when called as model(*input_args)
"""
pass
@classmethod
@final
def test_export_and_load(cls, model, input_args, export_kwargs, output_checker):
"""
Illustrate the life-cycle of export and load, used for testing.
"""
with make_temp_directory("test_export_and_load") as save_path:
# run the orginal model
assert isinstance(model, nn.Module), model
assert isinstance(input_args, (list, tuple)), input_args
original_output = model(*input_args)
# export the model
model.eval() # TODO: decide where eval() should be called
load_kwargs = cls.export(model, input_args, save_path, **export_kwargs)
# sanity check for load_kwargs
assert isinstance(load_kwargs, dict), load_kwargs
assert json.dumps(load_kwargs), load_kwargs
# loaded model back
loaded_model = cls.load(save_path, **load_kwargs)
# run the loaded model
assert isinstance(loaded_model, nn.Module), loaded_model
new_output = loaded_model(*input_args)
# compare outputs
output_checker(new_output, original_output)
ModelExportMethodRegistry = Registry("ModelExportMethod", allow_override=True)
@ModelExportMethodRegistry.register("caffe2")
class DefaultCaffe2Export(object):
class DefaultCaffe2Export(ModelExportMethod):
@classmethod
def export(cls, model, input_args, save_path, **export_kwargs):
from d2go.export.caffe2 import export_caffe2
# HACK: workaround the current caffe2 export API
if not hasattr(model, "encode_additional_info"):
model.encode_additional_info = lambda predict_net, init_net: None
export_caffe2(model, input_args[0], save_path, **export_kwargs)
return {}
@classmethod
def load(cls, save_path, **load_kwargs):
return load_model(save_path, "caffe2")
@ModelExportMethodRegistry.register("torchscript")
@ModelExportMethodRegistry.register("torchscript@tracing")
......@@ -264,16 +351,20 @@ class DefaultCaffe2Export(object):
@ModelExportMethodRegistry.register("torchscript_int8")
@ModelExportMethodRegistry.register("torchscript_int8@tracing")
@ModelExportMethodRegistry.register("torchscript_int8@scripting")
class DefaultTorchscriptExport(object):
class DefaultTorchscriptExport(ModelExportMethod):
@classmethod
def export(cls, model, input_args, save_path, **export_kwargs):
trace_and_save_torchscript(model, input_args, save_path, **export_kwargs)
return {}
@classmethod
def load(cls, save_path, **load_kwargs):
return load_model(save_path, "torchscript")
@ModelExportMethodRegistry.register("torchscript_mobile")
@ModelExportMethodRegistry.register("torchscript_mobile_int8")
class DefaultTorchscriptMobileExport(object):
class DefaultTorchscriptMobileExport(ModelExportMethod):
@classmethod
def export(cls, model, input_args, save_path, **export_kwargs):
trace_and_save_torchscript(
......@@ -284,3 +375,7 @@ class DefaultTorchscriptMobileExport(object):
**export_kwargs,
)
return {}
@classmethod
def load(cls, save_path, **load_kwargs):
return load_model(save_path, "torchscript")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment