Commit 715b4f66 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

add export_method to the ModelExportMethod.export

Reviewed By: zhanghang1989

Differential Revision: D28081681

fbshipit-source-id: 3722f5db668c36c4f23c3fd0c10657a3cf14ad3c
parent 3e243c1a
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# enable registry
from . import caffe2 # noqa
from . import torchscript # noqa
......@@ -34,13 +34,7 @@ from typing import final
import torch
import torch.nn as nn
import torch.quantization.quantize_fx
from d2go.export.torchscript import (
trace_and_save_torchscript,
MobileOptimizationConfig,
)
from d2go.modeling.quantization import post_training_quantize
from detectron2.config.instantiate import dump_dataclass, instantiate
from detectron2.export.flatten import TracingAdapter, flatten_to_tuple
from detectron2.utils.file_io import PathManager
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.file_utils import make_temp_directory
......@@ -51,7 +45,6 @@ from mobile_cv.predictor.builtin_functions import (
IdentityPreprocess,
NaiveRunFunc,
)
from mobile_cv.predictor.model_wrappers import load_model
logger = logging.getLogger(__name__)
......@@ -170,7 +163,9 @@ def _export_single_model(
):
assert isinstance(model, nn.Module), model
# model_export_method either inherits ModelExportMethod or is a key in the registry
model_export_method_str = None
if isinstance(model_export_method, str):
model_export_method_str = model_export_method
model_export_method = ModelExportMethodRegistry.get(model_export_method)
assert issubclass(model_export_method, ModelExportMethod), model_export_method
......@@ -178,6 +173,7 @@ def _export_single_model(
model=model,
input_args=input_args,
save_path=save_path,
export_method=model_export_method_str,
**model_export_kwargs,
)
assert isinstance(load_kwargs, dict)
......@@ -268,7 +264,7 @@ class ModelExportMethod(ABC):
@classmethod
@abstractmethod
def export(cls, model, input_args, save_path, **export_kwargs):
def export(cls, model, input_args, save_path, export_method, **export_kwargs):
"""
Export the model to deployable format.
......@@ -276,6 +272,7 @@ class ModelExportMethod(ABC):
model (nn.Module): a pytorch model to export
input_args (Tuple[Any]): inputs of model, called as model(*input_args)
save_path (str): directory where the model will be exported
export_method (str): string name for the export method
export_kwargs (Dict): additional parameters for exporting model defined
by each model export method.
Return:
......@@ -302,7 +299,9 @@ class ModelExportMethod(ABC):
@classmethod
@final
def test_export_and_load(cls, model, input_args, export_kwargs, output_checker):
def test_export_and_load(
cls, model, input_args, export_method, export_kwargs, output_checker
):
"""
Illustrate the life-cycle of export and load, used for testing.
"""
......@@ -313,7 +312,9 @@ class ModelExportMethod(ABC):
original_output = model(*input_args)
# export the model
model.eval() # TODO: decide where eval() should be called
load_kwargs = cls.export(model, input_args, save_path, **export_kwargs)
load_kwargs = cls.export(
model, input_args, save_path, export_method, **export_kwargs
)
# sanity check for load_kwargs
assert isinstance(load_kwargs, dict), load_kwargs
assert json.dumps(load_kwargs), load_kwargs
......@@ -327,89 +328,3 @@ class ModelExportMethod(ABC):
ModelExportMethodRegistry = Registry("ModelExportMethod", allow_override=True)
@ModelExportMethodRegistry.register("caffe2")
class DefaultCaffe2Export(ModelExportMethod):
@classmethod
def export(cls, model, input_args, save_path, **export_kwargs):
from d2go.export.caffe2 import export_caffe2
# HACK: workaround the current caffe2 export API
if not hasattr(model, "encode_additional_info"):
model.encode_additional_info = lambda predict_net, init_net: None
export_caffe2(model, input_args[0], save_path, **export_kwargs)
return {}
@classmethod
def load(cls, save_path, **load_kwargs):
return load_model(save_path, "caffe2")
@ModelExportMethodRegistry.register("torchscript")
@ModelExportMethodRegistry.register("torchscript@scripting")
@ModelExportMethodRegistry.register("torchscript_int8")
@ModelExportMethodRegistry.register("torchscript_int8@scripting")
class DefaultTorchscriptExport(ModelExportMethod):
@classmethod
def export(cls, model, input_args, save_path, **export_kwargs):
trace_and_save_torchscript(model, input_args, save_path, **export_kwargs)
return {}
@classmethod
def load(cls, save_path, **load_kwargs):
return load_model(save_path, "torchscript")
@ModelExportMethodRegistry.register("torchscript@tracing")
@ModelExportMethodRegistry.register("torchscript_int8@tracing")
class D2TorchscriptTracingExport(ModelExportMethod):
@classmethod
def export(cls, model, input_args, save_path, **export_kwargs):
adapter = TracingAdapter(model, input_args)
trace_and_save_torchscript(
adapter, adapter.flattened_inputs, save_path, **export_kwargs
)
inputs_schema = dump_dataclass(adapter.inputs_schema)
outputs_schema = dump_dataclass(adapter.outputs_schema)
return {"inputs_schema": inputs_schema, "outputs_schema": outputs_schema}
@classmethod
def load(cls, save_path, inputs_schema, outputs_schema, **load_kwargs):
inputs_schema = instantiate(inputs_schema)
outputs_schema = instantiate(outputs_schema)
traced_model = load_model(save_path, "torchscript")
class TracingAdapterWrapper(nn.Module):
def __init__(self, traced_model, inputs_schema, outputs_schema):
super().__init__()
self.traced_model = traced_model
self.inputs_schema = inputs_schema
self.outputs_schema = outputs_schema
def forward(self, *input_args):
flattened_inputs, _ = flatten_to_tuple(input_args)
flattened_outputs = self.traced_model(*flattened_inputs)
return self.outputs_schema(flattened_outputs)
return TracingAdapterWrapper(traced_model, inputs_schema, outputs_schema)
@ModelExportMethodRegistry.register("torchscript_mobile")
@ModelExportMethodRegistry.register("torchscript_mobile_int8")
class DefaultTorchscriptMobileExport(ModelExportMethod):
@classmethod
def export(cls, model, input_args, save_path, **export_kwargs):
trace_and_save_torchscript(
model,
input_args,
save_path,
mobile_optimization=MobileOptimizationConfig(),
**export_kwargs,
)
return {}
@classmethod
def load(cls, save_path, **load_kwargs):
return load_model(save_path, "torchscript")
......@@ -3,17 +3,18 @@
import logging
import torch
import os
from torch import nn
from typing import Dict, Tuple
import torch
from d2go.export.api import ModelExportMethodRegistry, ModelExportMethod
from d2go.export.logfiledb import export_to_logfiledb
from detectron2.export.api import Caffe2Model
from detectron2.export.caffe2_export import (
export_caffe2_detection_model,
run_and_save_graph,
)
from d2go.export.logfiledb import export_to_logfiledb
from torch import nn
logger = logging.getLogger(__name__)
......@@ -37,10 +38,12 @@ def export_caffe2(
caffe2_export_paths = {}
if save_pb:
caffe2_model.save_protobuf(output_dir)
caffe2_export_paths.update({
caffe2_export_paths.update(
{
"predict_net_path": os.path.join(output_dir, "model.pb"),
"init_net_path": os.path.join(output_dir, "model_init.pb"),
})
}
)
graph_save_path = os.path.join(output_dir, "model_def.svg")
ws_blobs = run_and_save_graph(
......@@ -49,15 +52,37 @@ def export_caffe2(
tensor_inputs,
graph_save_path=graph_save_path,
)
caffe2_export_paths.update({
caffe2_export_paths.update(
{
"model_def_path": graph_save_path,
})
}
)
if save_logdb:
logfiledb_path = os.path.join(output_dir, "model.logfiledb")
export_to_logfiledb(predict_net, init_net, logfiledb_path, ws_blobs)
caffe2_export_paths.update({
caffe2_export_paths.update(
{
"logfiledb_path": logfiledb_path if save_logdb else None,
})
}
)
return caffe2_model, caffe2_export_paths
@ModelExportMethodRegistry.register("caffe2")
class DefaultCaffe2Export(ModelExportMethod):
@classmethod
def export(cls, model, input_args, save_path, export_method, **export_kwargs):
# HACK: workaround the current caffe2 export API
if not hasattr(model, "encode_additional_info"):
model.encode_additional_info = lambda predict_net, init_net: None
export_caffe2(model, input_args[0], save_path, **export_kwargs)
return {}
@classmethod
def load(cls, save_path, **load_kwargs):
from mobile_cv.predictor.model_wrappers import load_model
return load_model(save_path, "caffe2")
......@@ -8,8 +8,14 @@ import os
from typing import Tuple, Optional, Dict, NamedTuple, List, AnyStr, Set
import torch
from d2go.export.api import ModelExportMethodRegistry, ModelExportMethod
from detectron2.config.instantiate import dump_dataclass, instantiate
from detectron2.export.flatten import TracingAdapter, flatten_to_tuple
from detectron2.export.torchscript_patch import patch_builtin_len
from detectron2.utils.file_io import PathManager
from mobile_cv.common.misc.file_utils import make_temp_directory
from mobile_cv.common.misc.iter_utils import recursive_iterate
from mobile_cv.predictor.model_wrappers import load_model
from torch import nn
from torch._C import MobileOptimizerType
from torch.utils.bundled_inputs import augment_model_with_bundled_inputs
......@@ -38,11 +44,7 @@ def trace_and_save_torchscript(
if _extra_files is None:
_extra_files = {}
# TODO: patch_builtin_len depends on D2, we should either copy the function or
# dynamically registering the D2's version.
from detectron2.export.torchscript_patch import patch_builtin_len
with torch.no_grad(), patch_builtin_len():
with torch.no_grad():
script_model = torch.jit.trace(model, inputs)
with make_temp_directory("trace_and_save_torchscript") as tmp_dir:
......@@ -77,7 +79,99 @@ def trace_and_save_torchscript(
)
logger.info("Applying augment_model_with_bundled_inputs ...")
# make all tensors zero-like to save storage
iters = recursive_iterate(inputs)
for x in iters:
if isinstance(x, torch.Tensor):
iters.send(torch.zeros_like(x))
inputs = iters.value
augment_model_with_bundled_inputs(liteopt_model, [inputs])
liteopt_model.run_on_bundled_input(0) # sanity check
with _synced_local_file("mobile_optimized_bundled.ptl") as lite_path:
liteopt_model._save_for_lite_interpreter(lite_path)
def tracing_adapter_wrap_export(old_f):
def new_f(cls, model, input_args, *args, **kwargs):
adapter = TracingAdapter(model, input_args)
load_kwargs = old_f(cls, adapter, adapter.flattened_inputs, *args, **kwargs)
inputs_schema = dump_dataclass(adapter.inputs_schema)
outputs_schema = dump_dataclass(adapter.outputs_schema)
assert "inputs_schema" not in load_kwargs
assert "outputs_schema" not in load_kwargs
load_kwargs.update(
{"inputs_schema": inputs_schema, "outputs_schema": outputs_schema}
)
return load_kwargs
return new_f
def tracing_adapter_wrap_load(old_f):
def new_f(cls, save_path, **load_kwargs):
assert "inputs_schema" in load_kwargs
assert "outputs_schema" in load_kwargs
inputs_schema = instantiate(load_kwargs.pop("inputs_schema"))
outputs_schema = instantiate(load_kwargs.pop("outputs_schema"))
traced_model = old_f(cls, save_path, **load_kwargs)
class TracingAdapterModelWrapper(nn.Module):
def __init__(self, traced_model, inputs_schema, outputs_schema):
super().__init__()
self.traced_model = traced_model
self.inputs_schema = inputs_schema
self.outputs_schema = outputs_schema
def forward(self, *input_args):
flattened_inputs, _ = flatten_to_tuple(input_args)
flattened_outputs = self.traced_model(*flattened_inputs)
return self.outputs_schema(flattened_outputs)
return TracingAdapterModelWrapper(traced_model, inputs_schema, outputs_schema)
return new_f
@ModelExportMethodRegistry.register("torchscript")
@ModelExportMethodRegistry.register("torchscript_int8")
@ModelExportMethodRegistry.register("torchscript_mobile")
@ModelExportMethodRegistry.register("torchscript_mobile_int8")
class DefaultTorchscriptExport(ModelExportMethod):
@classmethod
def export(cls, model, input_args, save_path, export_method, **export_kwargs):
if export_method is not None:
# update export_kwargs based on export_method
assert isinstance(export_method, str)
if "_mobile" in export_method:
if "mobile_optimization" in export_kwargs:
logger.warning(
"`mobile_optimization` is already specified, keep using it"
)
else:
export_kwargs["mobile_optimization"] = MobileOptimizationConfig()
trace_and_save_torchscript(model, input_args, save_path, **export_kwargs)
return {}
@classmethod
def load(cls, save_path, **load_kwargs):
return load_model(save_path, "torchscript")
@ModelExportMethodRegistry.register("torchscript@tracing")
@ModelExportMethodRegistry.register("torchscript_int8@tracing")
@ModelExportMethodRegistry.register("torchscript_mobile@tracing")
@ModelExportMethodRegistry.register("torchscript_mobile_int8@tracing")
class D2TorchscriptTracingExport(DefaultTorchscriptExport):
@classmethod
@tracing_adapter_wrap_export
def export(cls, model, input_args, save_path, export_method, **export_kwargs):
with patch_builtin_len():
return super().export(
model, input_args, save_path, export_method, **export_kwargs
)
@classmethod
@tracing_adapter_wrap_load
def load(cls, save_path, **load_kwargs):
return super().load(save_path, **load_kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment