Commit a9dce74e authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

support scripting for torchscript ExportMethod

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/118

This diff adds the proper support for using scripting when exporting model.

Rename tracing-related code:
- Previously `trace_and_save_torchscript` is the primary function to export model, replace it with `export_optimize_and_save_torchscript`.
- Also rename `D2 (https://github.com/facebookresearch/d2go/commit/7992f91324aee6ae59795063a007c6837e60cdb8)TorchscriptTracingExport` to `TracingAdaptedTorchscriptExport` since it's not only for tracing now.

Introduce `jit_mode`:
- Add `jit_mode` option as the `export_kwargs` of ExportMethod.
- Add `scripting` and `tracing` trigger words to overwrite `jit_mode`. Please note that the `tracing` now applies to all models, which is different from the previous meaning (using `TracingAdapter` for RCNN).
- Therefore there're two ways of using scripting mode, 1) setting `jit_mode` in prepare_for_export; 2) using `script` trigger words. Add unit test as examples to illustrate two ways.
- Don't use `TracingAdapter` when scripting since it's not scriptable.

Consolidate triggering words logic.
- Group logic of handling trigger words (eg. `_mobile`, `_int8`, `scripting`, `tracing`) into a single decorator `update_export_kwargs_from_export_method` for better structuring and readability.

Reviewed By: zhanghang1989

Differential Revision: D31181624

fbshipit-source-id: 5fbb0d4fa4c29ffa4a761af8ea8f93b4bad4cef9
parent 8adb146e
......@@ -5,7 +5,7 @@
import contextlib
import logging
import os
from typing import Tuple, Optional, Dict, NamedTuple, List, AnyStr, Set
from typing import Any, Tuple, Optional, Dict, NamedTuple, List, AnyStr, Set
import torch
from d2go.export.api import ModelExportMethodRegistry, ModelExportMethod
......@@ -24,6 +24,7 @@ from torch.utils.mobile_optimizer import optimize_for_mobile
logger = logging.getLogger(__name__)
TORCHSCRIPT_FILENAME_KEY: str = "torchscript_filename"
DEFAULT_JIT_MODE = "trace"
class MobileOptimizationConfig(NamedTuple):
......@@ -34,23 +35,54 @@ class MobileOptimizationConfig(NamedTuple):
torchscript_filename: str = "mobile_optimized.ptl"
def trace_and_save_torchscript(
def export_optimize_and_save_torchscript(
model: nn.Module,
inputs: Tuple[torch.Tensor],
inputs: Optional[Tuple[Any]],
output_path: str,
*,
jit_mode: Optional[str] = DEFAULT_JIT_MODE,
torchscript_filename: str = "model.jit",
mobile_optimization: Optional[MobileOptimizationConfig] = None,
_extra_files: Optional[Dict[str, bytes]] = None,
):
logger.info("Tracing and saving TorchScript to {} ...".format(output_path))
) -> str:
"""
The primary function for exporting PyTorch model to TorchScript.
Args:
model (nn.Module): the model to export. When given a ScriptModule, skip the export
and only optimize and save model.
inputs (tuple or None): input arguments of model, can be called as model(*inputs).
Will not be used when scripting the model.
output_path (str): directory that the model will be saved.
jit_mode (str): trace/script or None if the model is already a ScriptModule.
torchscript_filename (str): the filename of non-mobile-optimized model.
mobile_optimization (MobileOptimizationConfig): when provided, the mobile optimization
will be applied.
_extra_files (Dict[str, bytes]): when provided, extra files will be saved.
Returns:
(str): filename of the final model no matter optmized or not.
"""
logger.info("Export, optimize and saving TorchScript to {} ...".format(output_path))
PathManager.mkdirs(output_path)
if _extra_files is None:
_extra_files = {}
if isinstance(model, torch.jit.ScriptModule):
if jit_mode is not None:
logger.info("The input model is already a ScriptModule, skip the jit step")
elif jit_mode == "trace":
logger.info("Tracing the model ...")
with torch.no_grad():
script_model = torch.jit.trace(model, inputs)
elif jit_mode == "script":
logger.info("Scripting the model ...")
script_model = torch.jit.script(model)
else:
raise ValueError("Unsupported jit_mode: {}".format(jit_mode))
with make_temp_directory("trace_and_save_torchscript") as tmp_dir:
with make_temp_directory("export_optimize_and_save_torchscript") as tmp_dir:
@contextlib.contextmanager
def _synced_local_file(rel_path):
......@@ -100,6 +132,26 @@ def trace_and_save_torchscript(
return torchscript_filename
# For backward compatibility, TODO: remove this function.
def trace_and_save_torchscript(
model: nn.Module,
inputs: Optional[Tuple[Any]],
output_path: str,
torchscript_filename: str = "model.jit",
mobile_optimization: Optional[MobileOptimizationConfig] = None,
_extra_files: Optional[Dict[str, bytes]] = None,
):
return export_optimize_and_save_torchscript(
model,
inputs,
output_path,
jit_mode="trace",
torchscript_filename=torchscript_filename,
mobile_optimization=mobile_optimization,
_extra_files=_extra_files,
)
class TorchscriptWrapper(nn.Module):
""" """
......@@ -140,14 +192,14 @@ def tracing_adapter_wrap_export(old_f):
force_disable_tracing_adapter = export_kwargs.pop(
"force_disable_tracing_adapter", False
)
if force_disable_tracing_adapter:
is_trace_mode = export_kwargs.get("jit_mode", "trace") == "trace"
if force_disable_tracing_adapter or not is_trace_mode:
logger.info("Not trace mode, export normally")
return old_f(
cls, model, input_args, save_path, export_method, **export_kwargs
)
if _is_data_flattened_tensors(input_args):
# TODO: only dry-run for traceing
logger.info("Dry run the model to check if TracingAdapter is needed ...")
outputs = model(*input_args)
if _is_data_flattened_tensors(outputs):
......@@ -178,7 +230,7 @@ def tracing_adapter_wrap_export(old_f):
adapter.flattened_inputs,
save_path,
export_method,
**export_kwargs
**export_kwargs,
)
inputs_schema = dump_dataclass(adapter.inputs_schema)
outputs_schema = dump_dataclass(adapter.outputs_schema)
......@@ -231,19 +283,20 @@ def tracing_adapter_wrap_load(old_f):
return new_f
class DefaultTorchscriptExport(ModelExportMethod):
@classmethod
def export(
cls,
model: nn.Module,
input_args: Tuple[Tuple[torch.Tensor]],
save_path: str,
export_method: Optional[str],
**export_kwargs
):
def update_export_kwargs_from_export_method(old_f):
"""
Provide some convenient way of updating export_kwargs by adding trigger words in
`export_method`. For example, instead of setting `mobile_optimization` in the
model_export_kwargs of the PredictorExportConfig, user can simply put the `_mobile`
trigger word in the --predictor-type (which will then be forwarded as `export_method`
in most cases) to enable mobile optimizaiton.
"""
def new_f(cls, model, input_args, save_path, export_method, **export_kwargs):
if export_method is not None:
# update export_kwargs based on export_method
assert isinstance(export_method, str)
original_export_method = export_method
if "_mobile" in export_method:
if "mobile_optimization" in export_kwargs:
logger.warning(
......@@ -251,8 +304,56 @@ class DefaultTorchscriptExport(ModelExportMethod):
)
else:
export_kwargs["mobile_optimization"] = MobileOptimizationConfig()
export_method = export_method.replace("_mobile", "", 1)
torchscript_filename = trace_and_save_torchscript(
if "@scripting" in export_method:
jit_mode = export_kwargs.get("jit_mode", None)
if jit_mode and jit_mode != "script":
logger.warning(
"`jit_mode` is already specified as {}, overwrite it to `script`"
" since @scripting appears in export_method".format(jit_mode)
)
export_kwargs["jit_mode"] = "script"
export_method = export_method.replace("@scripting", "", 1)
if "@tracing" in export_method:
jit_mode = export_kwargs.get("jit_mode", None)
if jit_mode and jit_mode != "trace":
logger.warning(
"`jit_mode` is already specified as {}, overwrite it to `trace`"
" since @tracing appears in export_method".format(jit_mode)
)
export_kwargs["jit_mode"] = "trace"
export_method = export_method.replace("@tracing", "", 1)
if "_int8" in export_method:
export_method = export_method.replace("_int8", "", 1)
if export_method != "torchscript":
logger.warning(
"Suspcious export_method after removing triggering words,"
" original export_method: {}, remaining: {}".format(
original_export_method, export_method
)
)
return old_f(cls, model, input_args, save_path, export_method, **export_kwargs)
return new_f
class DefaultTorchscriptExport(ModelExportMethod):
@classmethod
@update_export_kwargs_from_export_method
def export(
cls,
model: nn.Module,
input_args: Tuple[Tuple[torch.Tensor]],
save_path: str,
export_method: Optional[str],
**export_kwargs,
):
torchscript_filename = export_optimize_and_save_torchscript(
model, input_args, save_path, **export_kwargs
)
return {TORCHSCRIPT_FILENAME_KEY: torchscript_filename}
......@@ -267,8 +368,17 @@ class DefaultTorchscriptExport(ModelExportMethod):
@ModelExportMethodRegistry.register("torchscript_int8")
@ModelExportMethodRegistry.register("torchscript_mobile")
@ModelExportMethodRegistry.register("torchscript_mobile_int8")
class D2TorchscriptTracingExport(DefaultTorchscriptExport):
@ModelExportMethodRegistry.register("torchscript@scripting")
@ModelExportMethodRegistry.register("torchscript_int8@scripting")
@ModelExportMethodRegistry.register("torchscript_mobile@scripting")
@ModelExportMethodRegistry.register("torchscript_mobile_int8@scripting")
@ModelExportMethodRegistry.register("torchscript@tracing")
@ModelExportMethodRegistry.register("torchscript_int8@tracing")
@ModelExportMethodRegistry.register("torchscript_mobile@tracing")
@ModelExportMethodRegistry.register("torchscript_mobile_int8@tracing")
class TracingAdaptedTorchscriptExport(DefaultTorchscriptExport):
@classmethod
@update_export_kwargs_from_export_method
@tracing_adapter_wrap_export
def export(cls, model, input_args, save_path, export_method, **export_kwargs):
with patch_builtin_len():
......
......@@ -369,7 +369,7 @@ class D2RCNNInferenceWrapper(nn.Module):
def forward(self, image):
"""
This function describes what happends during the tracing. Note that the output
contains non-tensor, therefore the D2TorchscriptTracingExport must be used in
contains non-tensor, therefore the TracingAdaptedTorchscriptExport must be used in
order to convert the output back from flattened tensors.
"""
inputs = [{"image": image}]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment