Commit 12468504 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

use new interface to export RCNN using D2's tracing

Summary:
This diff cleans up the process of exporting RCNN to predictor by tracing.
- Implement a new `D2 (https://github.com/facebookresearch/d2go/commit/d86ecc92eb97f14fcd97d626185f61c6817351e4)TorchscriptTracingExport` which utilizes D2 (https://github.com/facebookresearch/d2go/commit/d86ecc92eb97f14fcd97d626185f61c6817351e4)'s `TracingAdapter`. It's capable to handle more complicated input/output data structures, for example the `MultiDictInMultiDictOut` in unit test. Some duplicated code for serializing can also be removed.
- Later on we'll move `DefaultTorchscriptExport` to `mobile_cv.predictor` which doesn't have D2 (https://github.com/facebookresearch/d2go/commit/d86ecc92eb97f14fcd97d626185f61c6817351e4) dependency, while keep `D2 (https://github.com/facebookresearch/d2go/commit/d86ecc92eb97f14fcd97d626185f61c6817351e4)TorchscriptTracingExport` in D2 (https://github.com/facebookresearch/d2go/commit/d86ecc92eb97f14fcd97d626185f61c6817351e4)Go as a more advanced version.
- Using `D2 (https://github.com/facebookresearch/d2go/commit/d86ecc92eb97f14fcd97d626185f61c6817351e4)TorchscriptTracingExport` we can simply the `prepare_for_export` quite a bit and remove hacky code.

Reviewed By: zhanghang1989

Differential Revision: D27931029

fbshipit-source-id: 4a8d5e5ee3f10e29d98fca63e0e1c68bbda22745
parent 83d9e8f7
...@@ -39,6 +39,8 @@ from d2go.export.torchscript import ( ...@@ -39,6 +39,8 @@ from d2go.export.torchscript import (
MobileOptimizationConfig, MobileOptimizationConfig,
) )
from d2go.modeling.quantization import post_training_quantize from d2go.modeling.quantization import post_training_quantize
from detectron2.config.instantiate import dump_dataclass, instantiate
from detectron2.export.flatten import TracingAdapter, flatten_to_tuple
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from mobile_cv.arch.utils import fuse_utils from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.file_utils import make_temp_directory from mobile_cv.common.misc.file_utils import make_temp_directory
...@@ -346,10 +348,8 @@ class DefaultCaffe2Export(ModelExportMethod): ...@@ -346,10 +348,8 @@ class DefaultCaffe2Export(ModelExportMethod):
@ModelExportMethodRegistry.register("torchscript") @ModelExportMethodRegistry.register("torchscript")
@ModelExportMethodRegistry.register("torchscript@tracing")
@ModelExportMethodRegistry.register("torchscript@scripting") @ModelExportMethodRegistry.register("torchscript@scripting")
@ModelExportMethodRegistry.register("torchscript_int8") @ModelExportMethodRegistry.register("torchscript_int8")
@ModelExportMethodRegistry.register("torchscript_int8@tracing")
@ModelExportMethodRegistry.register("torchscript_int8@scripting") @ModelExportMethodRegistry.register("torchscript_int8@scripting")
class DefaultTorchscriptExport(ModelExportMethod): class DefaultTorchscriptExport(ModelExportMethod):
@classmethod @classmethod
...@@ -362,6 +362,40 @@ class DefaultTorchscriptExport(ModelExportMethod): ...@@ -362,6 +362,40 @@ class DefaultTorchscriptExport(ModelExportMethod):
return load_model(save_path, "torchscript") return load_model(save_path, "torchscript")
@ModelExportMethodRegistry.register("torchscript@tracing")
@ModelExportMethodRegistry.register("torchscript_int8@tracing")
class D2TorchscriptTracingExport(ModelExportMethod):
@classmethod
def export(cls, model, input_args, save_path, **export_kwargs):
adapter = TracingAdapter(model, input_args)
trace_and_save_torchscript(
adapter, adapter.flattened_inputs, save_path, **export_kwargs
)
inputs_schema = dump_dataclass(adapter.inputs_schema)
outputs_schema = dump_dataclass(adapter.outputs_schema)
return {"inputs_schema": inputs_schema, "outputs_schema": outputs_schema}
@classmethod
def load(cls, save_path, inputs_schema, outputs_schema, **load_kwargs):
inputs_schema = instantiate(inputs_schema)
outputs_schema = instantiate(outputs_schema)
traced_model = load_model(save_path, "torchscript")
class TracingAdapterWrapper(nn.Module):
def __init__(self, traced_model, inputs_schema, outputs_schema):
super().__init__()
self.traced_model = traced_model
self.inputs_schema = inputs_schema
self.outputs_schema = outputs_schema
def forward(self, *input_args):
flattened_inputs, _ = flatten_to_tuple(input_args)
flattened_outputs = self.traced_model(*flattened_inputs)
return self.outputs_schema(flattened_outputs)
return TracingAdapterWrapper(traced_model, inputs_schema, outputs_schema)
@ModelExportMethodRegistry.register("torchscript_mobile") @ModelExportMethodRegistry.register("torchscript_mobile")
@ModelExportMethodRegistry.register("torchscript_mobile_int8") @ModelExportMethodRegistry.register("torchscript_mobile_int8")
class DefaultTorchscriptMobileExport(ModelExportMethod): class DefaultTorchscriptMobileExport(ModelExportMethod):
......
...@@ -2,10 +2,7 @@ ...@@ -2,10 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import sys import torch.nn as nn
import json
import importlib
import dataclasses
from caffe2.proto import caffe2_pb2 from caffe2.proto import caffe2_pb2
from detectron2.export.caffe2_modeling import ( from detectron2.export.caffe2_modeling import (
META_ARCH_CAFFE2_EXPORT_TYPE_MAP, META_ARCH_CAFFE2_EXPORT_TYPE_MAP,
...@@ -88,49 +85,38 @@ class D2Caffe2MetaArchPostprocessFunc(object): ...@@ -88,49 +85,38 @@ class D2Caffe2MetaArchPostprocessFunc(object):
} }
def dataclass_object_dump(ob): class D2RCNNTracingWrapper(nn.Module):
datacls = type(ob) def __init__(self, model):
if not dataclasses.is_dataclass(datacls): super().__init__()
raise TypeError(f"Expected dataclass instance, got '{datacls!r}' object") self.model = model
mod = sys.modules.get(datacls.__module__)
if mod is None or not hasattr(mod, datacls.__qualname__):
raise ValueError(f"Can't resolve '{datacls!r}' reference")
ref = f"{datacls.__module__}.{datacls.__qualname__}"
fields = (f.name for f in dataclasses.fields(ob))
return {**{f: getattr(ob, f) for f in fields}, "__dataclass__": ref}
def dataclass_object_load(d):
ref = d.pop("__dataclass__", None)
if ref is None:
return d
try:
modname, hasdot, qualname = ref.rpartition(".")
module = importlib.import_module(modname)
datacls = getattr(module, qualname)
if not dataclasses.is_dataclass(datacls) or not isinstance(datacls, type):
raise ValueError
return datacls(**d)
except (ModuleNotFoundError, ValueError, AttributeError, TypeError):
raise ValueError(f"Invalid dataclass reference {ref!r}") from None
class D2TracingAdapterPreprocessFunc(object):
def __call__(self, inputs):
assert len(inputs) == 1, "only support single batch"
return inputs[0]["image"]
class D2TracingAdapterPostFunc(object):
def __init__(self, outputs_schema_json):
self.outputs_schema = json.loads(
outputs_schema_json, object_hook=dataclass_object_load
)
def __call__(self, inputs, tensor_inputs, tensor_outputs): def forward(self, image):
results_per_image = self.outputs_schema(tensor_outputs) """
This function describes what happends during the tracing. Note that the output
contains non-tensor, therefore the D2TorchscriptTracingExport must be used in
order to convert the output back from flattened tensors.
"""
inputs = [{"image": image}]
return self.model.inference(inputs, do_postprocess=False)[0]
assert len(inputs) == 1, "only support single batch" @staticmethod
width, height = inputs[0]["width"], inputs[0]["height"] def generator_trace_inputs(batch):
r = detector_postprocess(results_per_image, height, width) """
return [{"instances": r}] This function describes how to covert orginal input (from the data loader)
to the inputs used during the tracing (i.e. the inputs of forward function).
"""
return (batch[0]["image"],)
class RunFunc(object):
def __call__(self, tracing_adapter_wrapper, batch):
"""
This function describes how to run the predictor using exported model. Note
that `tracing_adapter_wrapper` runs the traced model under the hood and
behaves exactly the same as the forward function.
"""
assert len(batch) == 1, "only support single batch"
width, height = batch[0]["width"], batch[0]["height"]
inputs = D2RCNNTracingWrapper.generator_trace_inputs(batch)
results_per_image = tracing_adapter_wrapper(inputs)
r = detector_postprocess(results_per_image, height, width)
return [{"instances": r}]
...@@ -2,18 +2,16 @@ ...@@ -2,18 +2,16 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import json
import logging import logging
import torch
from d2go.export.api import PredictorExportConfig from d2go.export.api import PredictorExportConfig
from d2go.utils.export_utils import (
D2Caffe2MetaArchPreprocessFunc,
D2Caffe2MetaArchPostprocessFunc,
D2RCNNTracingWrapper,
)
from detectron2.export.caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP from detectron2.export.caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP
from mobile_cv.predictor.api import FuncInfo from mobile_cv.predictor.api import FuncInfo
from detectron2.export.flatten import TracingAdapter
from detectron2.export.torchscript_patch import patch_builtin_len
from d2go.utils.export_utils import (D2Caffe2MetaArchPreprocessFunc,
D2Caffe2MetaArchPostprocessFunc, D2TracingAdapterPreprocessFunc, D2TracingAdapterPostFunc,
dataclass_object_dump)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -21,37 +19,11 @@ logger = logging.getLogger(__name__) ...@@ -21,37 +19,11 @@ logger = logging.getLogger(__name__)
def d2_meta_arch_prepare_for_export(self, cfg, inputs, predictor_type): def d2_meta_arch_prepare_for_export(self, cfg, inputs, predictor_type):
if "torchscript" in predictor_type and "@tracing" in predictor_type: if "torchscript" in predictor_type and "@tracing" in predictor_type:
def inference_func(model, image):
inputs = [{"image": image}]
return model.inference(inputs, do_postprocess=False)[0]
def data_generator(x):
return (x[0]["image"],)
image = data_generator(inputs)[0]
wrapper = TracingAdapter(self, image, inference_func)
wrapper.eval()
# HACK: outputs_schema can only be obtained after running tracing, but
# PredictorExportConfig requires a pre-defined postprocessing function, this
# causes tracing to run twice.
logger.info("tracing the model to get outputs_schema ...")
with torch.no_grad(), patch_builtin_len():
_ = torch.jit.trace(wrapper, (image,))
outputs_schema_json = json.dumps(
wrapper.outputs_schema, default=dataclass_object_dump
)
return PredictorExportConfig( return PredictorExportConfig(
model=wrapper, model=D2RCNNTracingWrapper(self),
data_generator=data_generator, data_generator=D2RCNNTracingWrapper.generator_trace_inputs,
preprocess_info=FuncInfo.gen_func_info( run_func_info=FuncInfo.gen_func_info(
D2TracingAdapterPreprocessFunc, params={} D2RCNNTracingWrapper.RunFunc, params={}
),
postprocess_info=FuncInfo.gen_func_info(
D2TracingAdapterPostFunc,
params={"outputs_schema_json": outputs_schema_json},
), ),
) )
...@@ -80,4 +52,3 @@ def d2_meta_arch_prepare_for_export(self, cfg, inputs, predictor_type): ...@@ -80,4 +52,3 @@ def d2_meta_arch_prepare_for_export(self, cfg, inputs, predictor_type):
) )
raise NotImplementedError("Can't determine prepare_for_tracing!") raise NotImplementedError("Can't determine prepare_for_tracing!")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment