Commit 715b4f66 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

add export_method to the ModelExportMethod.export

Reviewed By: zhanghang1989

Differential Revision: D28081681

fbshipit-source-id: 3722f5db668c36c4f23c3fd0c10657a3cf14ad3c
parent 3e243c1a
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# enable registry
from . import caffe2 # noqa
from . import torchscript # noqa
...@@ -34,13 +34,7 @@ from typing import final ...@@ -34,13 +34,7 @@ from typing import final
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.quantization.quantize_fx import torch.quantization.quantize_fx
from d2go.export.torchscript import (
trace_and_save_torchscript,
MobileOptimizationConfig,
)
from d2go.modeling.quantization import post_training_quantize from d2go.modeling.quantization import post_training_quantize
from detectron2.config.instantiate import dump_dataclass, instantiate
from detectron2.export.flatten import TracingAdapter, flatten_to_tuple
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from mobile_cv.arch.utils import fuse_utils from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.file_utils import make_temp_directory from mobile_cv.common.misc.file_utils import make_temp_directory
...@@ -51,7 +45,6 @@ from mobile_cv.predictor.builtin_functions import ( ...@@ -51,7 +45,6 @@ from mobile_cv.predictor.builtin_functions import (
IdentityPreprocess, IdentityPreprocess,
NaiveRunFunc, NaiveRunFunc,
) )
from mobile_cv.predictor.model_wrappers import load_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -170,7 +163,9 @@ def _export_single_model( ...@@ -170,7 +163,9 @@ def _export_single_model(
): ):
assert isinstance(model, nn.Module), model assert isinstance(model, nn.Module), model
# model_export_method either inherits ModelExportMethod or is a key in the registry # model_export_method either inherits ModelExportMethod or is a key in the registry
model_export_method_str = None
if isinstance(model_export_method, str): if isinstance(model_export_method, str):
model_export_method_str = model_export_method
model_export_method = ModelExportMethodRegistry.get(model_export_method) model_export_method = ModelExportMethodRegistry.get(model_export_method)
assert issubclass(model_export_method, ModelExportMethod), model_export_method assert issubclass(model_export_method, ModelExportMethod), model_export_method
...@@ -178,6 +173,7 @@ def _export_single_model( ...@@ -178,6 +173,7 @@ def _export_single_model(
model=model, model=model,
input_args=input_args, input_args=input_args,
save_path=save_path, save_path=save_path,
export_method=model_export_method_str,
**model_export_kwargs, **model_export_kwargs,
) )
assert isinstance(load_kwargs, dict) assert isinstance(load_kwargs, dict)
...@@ -268,7 +264,7 @@ class ModelExportMethod(ABC): ...@@ -268,7 +264,7 @@ class ModelExportMethod(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def export(cls, model, input_args, save_path, **export_kwargs): def export(cls, model, input_args, save_path, export_method, **export_kwargs):
""" """
Export the model to deployable format. Export the model to deployable format.
...@@ -276,6 +272,7 @@ class ModelExportMethod(ABC): ...@@ -276,6 +272,7 @@ class ModelExportMethod(ABC):
model (nn.Module): a pytorch model to export model (nn.Module): a pytorch model to export
input_args (Tuple[Any]): inputs of model, called as model(*input_args) input_args (Tuple[Any]): inputs of model, called as model(*input_args)
save_path (str): directory where the model will be exported save_path (str): directory where the model will be exported
export_method (str): string name for the export method
export_kwargs (Dict): additional parameters for exporting model defined export_kwargs (Dict): additional parameters for exporting model defined
by each model export method. by each model export method.
Return: Return:
...@@ -302,7 +299,9 @@ class ModelExportMethod(ABC): ...@@ -302,7 +299,9 @@ class ModelExportMethod(ABC):
@classmethod @classmethod
@final @final
def test_export_and_load(cls, model, input_args, export_kwargs, output_checker): def test_export_and_load(
cls, model, input_args, export_method, export_kwargs, output_checker
):
""" """
Illustrate the life-cycle of export and load, used for testing. Illustrate the life-cycle of export and load, used for testing.
""" """
...@@ -313,7 +312,9 @@ class ModelExportMethod(ABC): ...@@ -313,7 +312,9 @@ class ModelExportMethod(ABC):
original_output = model(*input_args) original_output = model(*input_args)
# export the model # export the model
model.eval() # TODO: decide where eval() should be called model.eval() # TODO: decide where eval() should be called
load_kwargs = cls.export(model, input_args, save_path, **export_kwargs) load_kwargs = cls.export(
model, input_args, save_path, export_method, **export_kwargs
)
# sanity check for load_kwargs # sanity check for load_kwargs
assert isinstance(load_kwargs, dict), load_kwargs assert isinstance(load_kwargs, dict), load_kwargs
assert json.dumps(load_kwargs), load_kwargs assert json.dumps(load_kwargs), load_kwargs
...@@ -327,89 +328,3 @@ class ModelExportMethod(ABC): ...@@ -327,89 +328,3 @@ class ModelExportMethod(ABC):
ModelExportMethodRegistry = Registry("ModelExportMethod", allow_override=True) ModelExportMethodRegistry = Registry("ModelExportMethod", allow_override=True)
@ModelExportMethodRegistry.register("caffe2")
class DefaultCaffe2Export(ModelExportMethod):
@classmethod
def export(cls, model, input_args, save_path, **export_kwargs):
from d2go.export.caffe2 import export_caffe2
# HACK: workaround the current caffe2 export API
if not hasattr(model, "encode_additional_info"):
model.encode_additional_info = lambda predict_net, init_net: None
export_caffe2(model, input_args[0], save_path, **export_kwargs)
return {}
@classmethod
def load(cls, save_path, **load_kwargs):
return load_model(save_path, "caffe2")
@ModelExportMethodRegistry.register("torchscript")
@ModelExportMethodRegistry.register("torchscript@scripting")
@ModelExportMethodRegistry.register("torchscript_int8")
@ModelExportMethodRegistry.register("torchscript_int8@scripting")
class DefaultTorchscriptExport(ModelExportMethod):
@classmethod
def export(cls, model, input_args, save_path, **export_kwargs):
trace_and_save_torchscript(model, input_args, save_path, **export_kwargs)
return {}
@classmethod
def load(cls, save_path, **load_kwargs):
return load_model(save_path, "torchscript")
@ModelExportMethodRegistry.register("torchscript@tracing")
@ModelExportMethodRegistry.register("torchscript_int8@tracing")
class D2TorchscriptTracingExport(ModelExportMethod):
@classmethod
def export(cls, model, input_args, save_path, **export_kwargs):
adapter = TracingAdapter(model, input_args)
trace_and_save_torchscript(
adapter, adapter.flattened_inputs, save_path, **export_kwargs
)
inputs_schema = dump_dataclass(adapter.inputs_schema)
outputs_schema = dump_dataclass(adapter.outputs_schema)
return {"inputs_schema": inputs_schema, "outputs_schema": outputs_schema}
@classmethod
def load(cls, save_path, inputs_schema, outputs_schema, **load_kwargs):
inputs_schema = instantiate(inputs_schema)
outputs_schema = instantiate(outputs_schema)
traced_model = load_model(save_path, "torchscript")
class TracingAdapterWrapper(nn.Module):
def __init__(self, traced_model, inputs_schema, outputs_schema):
super().__init__()
self.traced_model = traced_model
self.inputs_schema = inputs_schema
self.outputs_schema = outputs_schema
def forward(self, *input_args):
flattened_inputs, _ = flatten_to_tuple(input_args)
flattened_outputs = self.traced_model(*flattened_inputs)
return self.outputs_schema(flattened_outputs)
return TracingAdapterWrapper(traced_model, inputs_schema, outputs_schema)
@ModelExportMethodRegistry.register("torchscript_mobile")
@ModelExportMethodRegistry.register("torchscript_mobile_int8")
class DefaultTorchscriptMobileExport(ModelExportMethod):
@classmethod
def export(cls, model, input_args, save_path, **export_kwargs):
trace_and_save_torchscript(
model,
input_args,
save_path,
mobile_optimization=MobileOptimizationConfig(),
**export_kwargs,
)
return {}
@classmethod
def load(cls, save_path, **load_kwargs):
return load_model(save_path, "torchscript")
...@@ -3,17 +3,18 @@ ...@@ -3,17 +3,18 @@
import logging import logging
import torch
import os import os
from torch import nn
from typing import Dict, Tuple from typing import Dict, Tuple
import torch
from d2go.export.api import ModelExportMethodRegistry, ModelExportMethod
from d2go.export.logfiledb import export_to_logfiledb
from detectron2.export.api import Caffe2Model from detectron2.export.api import Caffe2Model
from detectron2.export.caffe2_export import ( from detectron2.export.caffe2_export import (
export_caffe2_detection_model, export_caffe2_detection_model,
run_and_save_graph, run_and_save_graph,
) )
from d2go.export.logfiledb import export_to_logfiledb from torch import nn
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -37,10 +38,12 @@ def export_caffe2( ...@@ -37,10 +38,12 @@ def export_caffe2(
caffe2_export_paths = {} caffe2_export_paths = {}
if save_pb: if save_pb:
caffe2_model.save_protobuf(output_dir) caffe2_model.save_protobuf(output_dir)
caffe2_export_paths.update({ caffe2_export_paths.update(
{
"predict_net_path": os.path.join(output_dir, "model.pb"), "predict_net_path": os.path.join(output_dir, "model.pb"),
"init_net_path": os.path.join(output_dir, "model_init.pb"), "init_net_path": os.path.join(output_dir, "model_init.pb"),
}) }
)
graph_save_path = os.path.join(output_dir, "model_def.svg") graph_save_path = os.path.join(output_dir, "model_def.svg")
ws_blobs = run_and_save_graph( ws_blobs = run_and_save_graph(
...@@ -49,15 +52,37 @@ def export_caffe2( ...@@ -49,15 +52,37 @@ def export_caffe2(
tensor_inputs, tensor_inputs,
graph_save_path=graph_save_path, graph_save_path=graph_save_path,
) )
caffe2_export_paths.update({ caffe2_export_paths.update(
{
"model_def_path": graph_save_path, "model_def_path": graph_save_path,
}) }
)
if save_logdb: if save_logdb:
logfiledb_path = os.path.join(output_dir, "model.logfiledb") logfiledb_path = os.path.join(output_dir, "model.logfiledb")
export_to_logfiledb(predict_net, init_net, logfiledb_path, ws_blobs) export_to_logfiledb(predict_net, init_net, logfiledb_path, ws_blobs)
caffe2_export_paths.update({ caffe2_export_paths.update(
{
"logfiledb_path": logfiledb_path if save_logdb else None, "logfiledb_path": logfiledb_path if save_logdb else None,
}) }
)
return caffe2_model, caffe2_export_paths return caffe2_model, caffe2_export_paths
@ModelExportMethodRegistry.register("caffe2")
class DefaultCaffe2Export(ModelExportMethod):
@classmethod
def export(cls, model, input_args, save_path, export_method, **export_kwargs):
# HACK: workaround the current caffe2 export API
if not hasattr(model, "encode_additional_info"):
model.encode_additional_info = lambda predict_net, init_net: None
export_caffe2(model, input_args[0], save_path, **export_kwargs)
return {}
@classmethod
def load(cls, save_path, **load_kwargs):
from mobile_cv.predictor.model_wrappers import load_model
return load_model(save_path, "caffe2")
...@@ -8,8 +8,14 @@ import os ...@@ -8,8 +8,14 @@ import os
from typing import Tuple, Optional, Dict, NamedTuple, List, AnyStr, Set from typing import Tuple, Optional, Dict, NamedTuple, List, AnyStr, Set
import torch import torch
from d2go.export.api import ModelExportMethodRegistry, ModelExportMethod
from detectron2.config.instantiate import dump_dataclass, instantiate
from detectron2.export.flatten import TracingAdapter, flatten_to_tuple
from detectron2.export.torchscript_patch import patch_builtin_len
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from mobile_cv.common.misc.file_utils import make_temp_directory from mobile_cv.common.misc.file_utils import make_temp_directory
from mobile_cv.common.misc.iter_utils import recursive_iterate
from mobile_cv.predictor.model_wrappers import load_model
from torch import nn from torch import nn
from torch._C import MobileOptimizerType from torch._C import MobileOptimizerType
from torch.utils.bundled_inputs import augment_model_with_bundled_inputs from torch.utils.bundled_inputs import augment_model_with_bundled_inputs
...@@ -38,11 +44,7 @@ def trace_and_save_torchscript( ...@@ -38,11 +44,7 @@ def trace_and_save_torchscript(
if _extra_files is None: if _extra_files is None:
_extra_files = {} _extra_files = {}
# TODO: patch_builtin_len depends on D2, we should either copy the function or with torch.no_grad():
# dynamically registering the D2's version.
from detectron2.export.torchscript_patch import patch_builtin_len
with torch.no_grad(), patch_builtin_len():
script_model = torch.jit.trace(model, inputs) script_model = torch.jit.trace(model, inputs)
with make_temp_directory("trace_and_save_torchscript") as tmp_dir: with make_temp_directory("trace_and_save_torchscript") as tmp_dir:
...@@ -77,7 +79,99 @@ def trace_and_save_torchscript( ...@@ -77,7 +79,99 @@ def trace_and_save_torchscript(
) )
logger.info("Applying augment_model_with_bundled_inputs ...") logger.info("Applying augment_model_with_bundled_inputs ...")
# make all tensors zero-like to save storage
iters = recursive_iterate(inputs)
for x in iters:
if isinstance(x, torch.Tensor):
iters.send(torch.zeros_like(x))
inputs = iters.value
augment_model_with_bundled_inputs(liteopt_model, [inputs]) augment_model_with_bundled_inputs(liteopt_model, [inputs])
liteopt_model.run_on_bundled_input(0) # sanity check liteopt_model.run_on_bundled_input(0) # sanity check
with _synced_local_file("mobile_optimized_bundled.ptl") as lite_path: with _synced_local_file("mobile_optimized_bundled.ptl") as lite_path:
liteopt_model._save_for_lite_interpreter(lite_path) liteopt_model._save_for_lite_interpreter(lite_path)
def tracing_adapter_wrap_export(old_f):
def new_f(cls, model, input_args, *args, **kwargs):
adapter = TracingAdapter(model, input_args)
load_kwargs = old_f(cls, adapter, adapter.flattened_inputs, *args, **kwargs)
inputs_schema = dump_dataclass(adapter.inputs_schema)
outputs_schema = dump_dataclass(adapter.outputs_schema)
assert "inputs_schema" not in load_kwargs
assert "outputs_schema" not in load_kwargs
load_kwargs.update(
{"inputs_schema": inputs_schema, "outputs_schema": outputs_schema}
)
return load_kwargs
return new_f
def tracing_adapter_wrap_load(old_f):
def new_f(cls, save_path, **load_kwargs):
assert "inputs_schema" in load_kwargs
assert "outputs_schema" in load_kwargs
inputs_schema = instantiate(load_kwargs.pop("inputs_schema"))
outputs_schema = instantiate(load_kwargs.pop("outputs_schema"))
traced_model = old_f(cls, save_path, **load_kwargs)
class TracingAdapterModelWrapper(nn.Module):
def __init__(self, traced_model, inputs_schema, outputs_schema):
super().__init__()
self.traced_model = traced_model
self.inputs_schema = inputs_schema
self.outputs_schema = outputs_schema
def forward(self, *input_args):
flattened_inputs, _ = flatten_to_tuple(input_args)
flattened_outputs = self.traced_model(*flattened_inputs)
return self.outputs_schema(flattened_outputs)
return TracingAdapterModelWrapper(traced_model, inputs_schema, outputs_schema)
return new_f
@ModelExportMethodRegistry.register("torchscript")
@ModelExportMethodRegistry.register("torchscript_int8")
@ModelExportMethodRegistry.register("torchscript_mobile")
@ModelExportMethodRegistry.register("torchscript_mobile_int8")
class DefaultTorchscriptExport(ModelExportMethod):
@classmethod
def export(cls, model, input_args, save_path, export_method, **export_kwargs):
if export_method is not None:
# update export_kwargs based on export_method
assert isinstance(export_method, str)
if "_mobile" in export_method:
if "mobile_optimization" in export_kwargs:
logger.warning(
"`mobile_optimization` is already specified, keep using it"
)
else:
export_kwargs["mobile_optimization"] = MobileOptimizationConfig()
trace_and_save_torchscript(model, input_args, save_path, **export_kwargs)
return {}
@classmethod
def load(cls, save_path, **load_kwargs):
return load_model(save_path, "torchscript")
@ModelExportMethodRegistry.register("torchscript@tracing")
@ModelExportMethodRegistry.register("torchscript_int8@tracing")
@ModelExportMethodRegistry.register("torchscript_mobile@tracing")
@ModelExportMethodRegistry.register("torchscript_mobile_int8@tracing")
class D2TorchscriptTracingExport(DefaultTorchscriptExport):
@classmethod
@tracing_adapter_wrap_export
def export(cls, model, input_args, save_path, export_method, **export_kwargs):
with patch_builtin_len():
return super().export(
model, input_args, save_path, export_method, **export_kwargs
)
@classmethod
@tracing_adapter_wrap_load
def load(cls, save_path, **load_kwargs):
return super().load(save_path, **load_kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment