Commit 83d9e8f7 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

finish the interface for ModelExportMethod

Reviewed By: zhanghang1989

Differential Revision: D27916281

fbshipit-source-id: 7ea01e99e9c2a9b19992f458abc786713ba5a4cd
parent d86ecc92
...@@ -27,7 +27,9 @@ NOTE: ...@@ -27,7 +27,9 @@ NOTE:
import json import json
import logging import logging
import os import os
from abc import ABC, abstractmethod
from typing import Callable, Dict, NamedTuple, Optional, Union from typing import Callable, Dict, NamedTuple, Optional, Union
from typing import final
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -39,6 +41,7 @@ from d2go.export.torchscript import ( ...@@ -39,6 +41,7 @@ from d2go.export.torchscript import (
from d2go.modeling.quantization import post_training_quantize from d2go.modeling.quantization import post_training_quantize
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from mobile_cv.arch.utils import fuse_utils from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.file_utils import make_temp_directory
from mobile_cv.common.misc.registry import Registry from mobile_cv.common.misc.registry import Registry
from mobile_cv.predictor.api import FuncInfo, ModelInfo, PredictorInfo from mobile_cv.predictor.api import FuncInfo, ModelInfo, PredictorInfo
from mobile_cv.predictor.builtin_functions import ( from mobile_cv.predictor.builtin_functions import (
...@@ -46,6 +49,7 @@ from mobile_cv.predictor.builtin_functions import ( ...@@ -46,6 +49,7 @@ from mobile_cv.predictor.builtin_functions import (
IdentityPreprocess, IdentityPreprocess,
NaiveRunFunc, NaiveRunFunc,
) )
from mobile_cv.predictor.model_wrappers import load_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -161,18 +165,28 @@ def _export_single_model( ...@@ -161,18 +165,28 @@ def _export_single_model(
save_path, save_path,
model_export_method, model_export_method,
model_export_kwargs, model_export_kwargs,
predictor_type, # TODO: remove this after refactoring ModelInfo
): ):
assert isinstance(model, nn.Module), model assert isinstance(model, nn.Module), model
load_kwargs = ModelExportMethodRegistry.get(model_export_method).export( # model_export_method either inherits ModelExportMethod or is a key in the registry
if isinstance(model_export_method, str):
model_export_method = ModelExportMethodRegistry.get(model_export_method)
assert issubclass(model_export_method, ModelExportMethod), model_export_method
load_kwargs = model_export_method.export(
model=model, model=model,
input_args=input_args, input_args=input_args,
save_path=save_path, save_path=save_path,
**model_export_kwargs, **model_export_kwargs,
) )
assert isinstance(load_kwargs, dict) # TODO: save this in predictor_info assert isinstance(load_kwargs, dict)
model_rel_path = os.path.relpath(save_path, predictor_path) model_rel_path = os.path.relpath(save_path, predictor_path)
return ModelInfo(path=model_rel_path, type=predictor_type) return ModelInfo(
path=model_rel_path,
export_method="{}.{}".format(
model_export_method.__module__, model_export_method.__qualname__
),
load_kwargs=load_kwargs,
)
def default_export_predictor( def default_export_predictor(
...@@ -218,7 +232,6 @@ def default_export_predictor( ...@@ -218,7 +232,6 @@ def default_export_predictor(
if export_config.model_export_kwargs is None if export_config.model_export_kwargs is None
else export_config.model_export_kwargs[name] else export_config.model_export_kwargs[name]
), ),
predictor_type=predictor_type,
) )
models_info[name] = model_info models_info[name] = model_info
predictor_init_kwargs["models"] = models_info predictor_init_kwargs["models"] = models_info
...@@ -231,7 +244,6 @@ def default_export_predictor( ...@@ -231,7 +244,6 @@ def default_export_predictor(
save_path=save_path, save_path=save_path,
model_export_method=export_config.model_export_method or predictor_type, model_export_method=export_config.model_export_method or predictor_type,
model_export_kwargs=export_config.model_export_kwargs or {}, model_export_kwargs=export_config.model_export_kwargs or {},
predictor_type=predictor_type,
) )
predictor_init_kwargs["model"] = model_info predictor_init_kwargs["model"] = model_info
...@@ -245,18 +257,93 @@ def default_export_predictor( ...@@ -245,18 +257,93 @@ def default_export_predictor(
return predictor_path return predictor_path
class ModelExportMethod(ABC):
"""
Base class for "model export method". Each model export method can export a pytorch
model to a certain deployable format, such as torchscript or caffe2. It consists
with `export` and `load` methods.
"""
@classmethod
@abstractmethod
def export(cls, model, input_args, save_path, **export_kwargs):
"""
Export the model to deployable format.
Args:
model (nn.Module): a pytorch model to export
input_args (Tuple[Any]): inputs of model, called as model(*input_args)
save_path (str): directory where the model will be exported
export_kwargs (Dict): additional parameters for exporting model defined
by each model export method.
Return:
load_kwargs (Dict): additional information (besides save_path) needed in
order to load the exported model. This needs to be JSON serializable.
"""
pass
@classmethod
@abstractmethod
def load(cls, save_path, **load_kwargs):
"""
Load the exported model back for inference.
Args:
save_path (str): directory where the model is stored.
load_kwargs (Dict): addtional information to load the exported model.
Returns:
model (nn.Module): a nn.Module (often time a wrapper for non torchscript
types like caffe2), it works the same as the original pytorch model,
i.e. getting the same output when called as model(*input_args)
"""
pass
@classmethod
@final
def test_export_and_load(cls, model, input_args, export_kwargs, output_checker):
"""
Illustrate the life-cycle of export and load, used for testing.
"""
with make_temp_directory("test_export_and_load") as save_path:
# run the orginal model
assert isinstance(model, nn.Module), model
assert isinstance(input_args, (list, tuple)), input_args
original_output = model(*input_args)
# export the model
model.eval() # TODO: decide where eval() should be called
load_kwargs = cls.export(model, input_args, save_path, **export_kwargs)
# sanity check for load_kwargs
assert isinstance(load_kwargs, dict), load_kwargs
assert json.dumps(load_kwargs), load_kwargs
# loaded model back
loaded_model = cls.load(save_path, **load_kwargs)
# run the loaded model
assert isinstance(loaded_model, nn.Module), loaded_model
new_output = loaded_model(*input_args)
# compare outputs
output_checker(new_output, original_output)
ModelExportMethodRegistry = Registry("ModelExportMethod", allow_override=True) ModelExportMethodRegistry = Registry("ModelExportMethod", allow_override=True)
@ModelExportMethodRegistry.register("caffe2") @ModelExportMethodRegistry.register("caffe2")
class DefaultCaffe2Export(object): class DefaultCaffe2Export(ModelExportMethod):
@classmethod @classmethod
def export(cls, model, input_args, save_path, **export_kwargs): def export(cls, model, input_args, save_path, **export_kwargs):
from d2go.export.caffe2 import export_caffe2 from d2go.export.caffe2 import export_caffe2
# HACK: workaround the current caffe2 export API
if not hasattr(model, "encode_additional_info"):
model.encode_additional_info = lambda predict_net, init_net: None
export_caffe2(model, input_args[0], save_path, **export_kwargs) export_caffe2(model, input_args[0], save_path, **export_kwargs)
return {} return {}
@classmethod
def load(cls, save_path, **load_kwargs):
return load_model(save_path, "caffe2")
@ModelExportMethodRegistry.register("torchscript") @ModelExportMethodRegistry.register("torchscript")
@ModelExportMethodRegistry.register("torchscript@tracing") @ModelExportMethodRegistry.register("torchscript@tracing")
...@@ -264,16 +351,20 @@ class DefaultCaffe2Export(object): ...@@ -264,16 +351,20 @@ class DefaultCaffe2Export(object):
@ModelExportMethodRegistry.register("torchscript_int8") @ModelExportMethodRegistry.register("torchscript_int8")
@ModelExportMethodRegistry.register("torchscript_int8@tracing") @ModelExportMethodRegistry.register("torchscript_int8@tracing")
@ModelExportMethodRegistry.register("torchscript_int8@scripting") @ModelExportMethodRegistry.register("torchscript_int8@scripting")
class DefaultTorchscriptExport(object): class DefaultTorchscriptExport(ModelExportMethod):
@classmethod @classmethod
def export(cls, model, input_args, save_path, **export_kwargs): def export(cls, model, input_args, save_path, **export_kwargs):
trace_and_save_torchscript(model, input_args, save_path, **export_kwargs) trace_and_save_torchscript(model, input_args, save_path, **export_kwargs)
return {} return {}
@classmethod
def load(cls, save_path, **load_kwargs):
return load_model(save_path, "torchscript")
@ModelExportMethodRegistry.register("torchscript_mobile") @ModelExportMethodRegistry.register("torchscript_mobile")
@ModelExportMethodRegistry.register("torchscript_mobile_int8") @ModelExportMethodRegistry.register("torchscript_mobile_int8")
class DefaultTorchscriptMobileExport(object): class DefaultTorchscriptMobileExport(ModelExportMethod):
@classmethod @classmethod
def export(cls, model, input_args, save_path, **export_kwargs): def export(cls, model, input_args, save_path, **export_kwargs):
trace_and_save_torchscript( trace_and_save_torchscript(
...@@ -284,3 +375,7 @@ class DefaultTorchscriptMobileExport(object): ...@@ -284,3 +375,7 @@ class DefaultTorchscriptMobileExport(object):
**export_kwargs, **export_kwargs,
) )
return {} return {}
@classmethod
def load(cls, save_path, **load_kwargs):
return load_model(save_path, "torchscript")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment