Commit a9dce74e authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

support scripting for torchscript ExportMethod

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/118

This diff adds the proper support for using scripting when exporting model.

Rename tracing-related code:
- Previously `trace_and_save_torchscript` is the primary function to export model, replace it with `export_optimize_and_save_torchscript`.
- Also rename `D2 (https://github.com/facebookresearch/d2go/commit/7992f91324aee6ae59795063a007c6837e60cdb8)TorchscriptTracingExport` to `TracingAdaptedTorchscriptExport` since it's not only for tracing now.

Introduce `jit_mode`:
- Add `jit_mode` option as the `export_kwargs` of ExportMethod.
- Add `scripting` and `tracing` trigger words to overwrite `jit_mode`. Please note that the `tracing` now applies to all models, which is different from the previous meaning (using `TracingAdapter` for RCNN).
- Therefore there're two ways of using scripting mode, 1) setting `jit_mode` in prepare_for_export; 2) using `script` trigger words. Add unit test as examples to illustrate two ways.
- Don't use `TracingAdapter` when scripting since it's not scriptable.

Consolidate triggering words logic.
- Group logic of handling trigger words (eg. `_mobile`, `_int8`, `scripting`, `tracing`) into a single decorator `update_export_kwargs_from_export_method` for better structuring and readability.

Reviewed By: zhanghang1989

Differential Revision: D31181624

fbshipit-source-id: 5fbb0d4fa4c29ffa4a761af8ea8f93b4bad4cef9
parent 8adb146e
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import contextlib import contextlib
import logging import logging
import os import os
from typing import Tuple, Optional, Dict, NamedTuple, List, AnyStr, Set from typing import Any, Tuple, Optional, Dict, NamedTuple, List, AnyStr, Set
import torch import torch
from d2go.export.api import ModelExportMethodRegistry, ModelExportMethod from d2go.export.api import ModelExportMethodRegistry, ModelExportMethod
...@@ -24,6 +24,7 @@ from torch.utils.mobile_optimizer import optimize_for_mobile ...@@ -24,6 +24,7 @@ from torch.utils.mobile_optimizer import optimize_for_mobile
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TORCHSCRIPT_FILENAME_KEY: str = "torchscript_filename" TORCHSCRIPT_FILENAME_KEY: str = "torchscript_filename"
DEFAULT_JIT_MODE = "trace"
class MobileOptimizationConfig(NamedTuple): class MobileOptimizationConfig(NamedTuple):
...@@ -34,23 +35,54 @@ class MobileOptimizationConfig(NamedTuple): ...@@ -34,23 +35,54 @@ class MobileOptimizationConfig(NamedTuple):
torchscript_filename: str = "mobile_optimized.ptl" torchscript_filename: str = "mobile_optimized.ptl"
def trace_and_save_torchscript( def export_optimize_and_save_torchscript(
model: nn.Module, model: nn.Module,
inputs: Tuple[torch.Tensor], inputs: Optional[Tuple[Any]],
output_path: str, output_path: str,
*,
jit_mode: Optional[str] = DEFAULT_JIT_MODE,
torchscript_filename: str = "model.jit", torchscript_filename: str = "model.jit",
mobile_optimization: Optional[MobileOptimizationConfig] = None, mobile_optimization: Optional[MobileOptimizationConfig] = None,
_extra_files: Optional[Dict[str, bytes]] = None, _extra_files: Optional[Dict[str, bytes]] = None,
): ) -> str:
logger.info("Tracing and saving TorchScript to {} ...".format(output_path)) """
The primary function for exporting PyTorch model to TorchScript.
Args:
model (nn.Module): the model to export. When given a ScriptModule, skip the export
and only optimize and save model.
inputs (tuple or None): input arguments of model, can be called as model(*inputs).
Will not be used when scripting the model.
output_path (str): directory that the model will be saved.
jit_mode (str): trace/script or None if the model is already a ScriptModule.
torchscript_filename (str): the filename of non-mobile-optimized model.
mobile_optimization (MobileOptimizationConfig): when provided, the mobile optimization
will be applied.
_extra_files (Dict[str, bytes]): when provided, extra files will be saved.
Returns:
(str): filename of the final model no matter optmized or not.
"""
logger.info("Export, optimize and saving TorchScript to {} ...".format(output_path))
PathManager.mkdirs(output_path) PathManager.mkdirs(output_path)
if _extra_files is None: if _extra_files is None:
_extra_files = {} _extra_files = {}
if isinstance(model, torch.jit.ScriptModule):
if jit_mode is not None:
logger.info("The input model is already a ScriptModule, skip the jit step")
elif jit_mode == "trace":
logger.info("Tracing the model ...")
with torch.no_grad(): with torch.no_grad():
script_model = torch.jit.trace(model, inputs) script_model = torch.jit.trace(model, inputs)
elif jit_mode == "script":
logger.info("Scripting the model ...")
script_model = torch.jit.script(model)
else:
raise ValueError("Unsupported jit_mode: {}".format(jit_mode))
with make_temp_directory("trace_and_save_torchscript") as tmp_dir: with make_temp_directory("export_optimize_and_save_torchscript") as tmp_dir:
@contextlib.contextmanager @contextlib.contextmanager
def _synced_local_file(rel_path): def _synced_local_file(rel_path):
...@@ -100,6 +132,26 @@ def trace_and_save_torchscript( ...@@ -100,6 +132,26 @@ def trace_and_save_torchscript(
return torchscript_filename return torchscript_filename
# For backward compatibility, TODO: remove this function.
def trace_and_save_torchscript(
model: nn.Module,
inputs: Optional[Tuple[Any]],
output_path: str,
torchscript_filename: str = "model.jit",
mobile_optimization: Optional[MobileOptimizationConfig] = None,
_extra_files: Optional[Dict[str, bytes]] = None,
):
return export_optimize_and_save_torchscript(
model,
inputs,
output_path,
jit_mode="trace",
torchscript_filename=torchscript_filename,
mobile_optimization=mobile_optimization,
_extra_files=_extra_files,
)
class TorchscriptWrapper(nn.Module): class TorchscriptWrapper(nn.Module):
""" """ """ """
...@@ -140,14 +192,14 @@ def tracing_adapter_wrap_export(old_f): ...@@ -140,14 +192,14 @@ def tracing_adapter_wrap_export(old_f):
force_disable_tracing_adapter = export_kwargs.pop( force_disable_tracing_adapter = export_kwargs.pop(
"force_disable_tracing_adapter", False "force_disable_tracing_adapter", False
) )
if force_disable_tracing_adapter: is_trace_mode = export_kwargs.get("jit_mode", "trace") == "trace"
if force_disable_tracing_adapter or not is_trace_mode:
logger.info("Not trace mode, export normally") logger.info("Not trace mode, export normally")
return old_f( return old_f(
cls, model, input_args, save_path, export_method, **export_kwargs cls, model, input_args, save_path, export_method, **export_kwargs
) )
if _is_data_flattened_tensors(input_args): if _is_data_flattened_tensors(input_args):
# TODO: only dry-run for traceing
logger.info("Dry run the model to check if TracingAdapter is needed ...") logger.info("Dry run the model to check if TracingAdapter is needed ...")
outputs = model(*input_args) outputs = model(*input_args)
if _is_data_flattened_tensors(outputs): if _is_data_flattened_tensors(outputs):
...@@ -178,7 +230,7 @@ def tracing_adapter_wrap_export(old_f): ...@@ -178,7 +230,7 @@ def tracing_adapter_wrap_export(old_f):
adapter.flattened_inputs, adapter.flattened_inputs,
save_path, save_path,
export_method, export_method,
**export_kwargs **export_kwargs,
) )
inputs_schema = dump_dataclass(adapter.inputs_schema) inputs_schema = dump_dataclass(adapter.inputs_schema)
outputs_schema = dump_dataclass(adapter.outputs_schema) outputs_schema = dump_dataclass(adapter.outputs_schema)
...@@ -231,19 +283,20 @@ def tracing_adapter_wrap_load(old_f): ...@@ -231,19 +283,20 @@ def tracing_adapter_wrap_load(old_f):
return new_f return new_f
class DefaultTorchscriptExport(ModelExportMethod): def update_export_kwargs_from_export_method(old_f):
@classmethod """
def export( Provide some convenient way of updating export_kwargs by adding trigger words in
cls, `export_method`. For example, instead of setting `mobile_optimization` in the
model: nn.Module, model_export_kwargs of the PredictorExportConfig, user can simply put the `_mobile`
input_args: Tuple[Tuple[torch.Tensor]], trigger word in the --predictor-type (which will then be forwarded as `export_method`
save_path: str, in most cases) to enable mobile optimizaiton.
export_method: Optional[str], """
**export_kwargs
): def new_f(cls, model, input_args, save_path, export_method, **export_kwargs):
if export_method is not None: if export_method is not None:
# update export_kwargs based on export_method
assert isinstance(export_method, str) assert isinstance(export_method, str)
original_export_method = export_method
if "_mobile" in export_method: if "_mobile" in export_method:
if "mobile_optimization" in export_kwargs: if "mobile_optimization" in export_kwargs:
logger.warning( logger.warning(
...@@ -251,8 +304,56 @@ class DefaultTorchscriptExport(ModelExportMethod): ...@@ -251,8 +304,56 @@ class DefaultTorchscriptExport(ModelExportMethod):
) )
else: else:
export_kwargs["mobile_optimization"] = MobileOptimizationConfig() export_kwargs["mobile_optimization"] = MobileOptimizationConfig()
export_method = export_method.replace("_mobile", "", 1)
torchscript_filename = trace_and_save_torchscript( if "@scripting" in export_method:
jit_mode = export_kwargs.get("jit_mode", None)
if jit_mode and jit_mode != "script":
logger.warning(
"`jit_mode` is already specified as {}, overwrite it to `script`"
" since @scripting appears in export_method".format(jit_mode)
)
export_kwargs["jit_mode"] = "script"
export_method = export_method.replace("@scripting", "", 1)
if "@tracing" in export_method:
jit_mode = export_kwargs.get("jit_mode", None)
if jit_mode and jit_mode != "trace":
logger.warning(
"`jit_mode` is already specified as {}, overwrite it to `trace`"
" since @tracing appears in export_method".format(jit_mode)
)
export_kwargs["jit_mode"] = "trace"
export_method = export_method.replace("@tracing", "", 1)
if "_int8" in export_method:
export_method = export_method.replace("_int8", "", 1)
if export_method != "torchscript":
logger.warning(
"Suspcious export_method after removing triggering words,"
" original export_method: {}, remaining: {}".format(
original_export_method, export_method
)
)
return old_f(cls, model, input_args, save_path, export_method, **export_kwargs)
return new_f
class DefaultTorchscriptExport(ModelExportMethod):
@classmethod
@update_export_kwargs_from_export_method
def export(
cls,
model: nn.Module,
input_args: Tuple[Tuple[torch.Tensor]],
save_path: str,
export_method: Optional[str],
**export_kwargs,
):
torchscript_filename = export_optimize_and_save_torchscript(
model, input_args, save_path, **export_kwargs model, input_args, save_path, **export_kwargs
) )
return {TORCHSCRIPT_FILENAME_KEY: torchscript_filename} return {TORCHSCRIPT_FILENAME_KEY: torchscript_filename}
...@@ -267,8 +368,17 @@ class DefaultTorchscriptExport(ModelExportMethod): ...@@ -267,8 +368,17 @@ class DefaultTorchscriptExport(ModelExportMethod):
@ModelExportMethodRegistry.register("torchscript_int8") @ModelExportMethodRegistry.register("torchscript_int8")
@ModelExportMethodRegistry.register("torchscript_mobile") @ModelExportMethodRegistry.register("torchscript_mobile")
@ModelExportMethodRegistry.register("torchscript_mobile_int8") @ModelExportMethodRegistry.register("torchscript_mobile_int8")
class D2TorchscriptTracingExport(DefaultTorchscriptExport): @ModelExportMethodRegistry.register("torchscript@scripting")
@ModelExportMethodRegistry.register("torchscript_int8@scripting")
@ModelExportMethodRegistry.register("torchscript_mobile@scripting")
@ModelExportMethodRegistry.register("torchscript_mobile_int8@scripting")
@ModelExportMethodRegistry.register("torchscript@tracing")
@ModelExportMethodRegistry.register("torchscript_int8@tracing")
@ModelExportMethodRegistry.register("torchscript_mobile@tracing")
@ModelExportMethodRegistry.register("torchscript_mobile_int8@tracing")
class TracingAdaptedTorchscriptExport(DefaultTorchscriptExport):
@classmethod @classmethod
@update_export_kwargs_from_export_method
@tracing_adapter_wrap_export @tracing_adapter_wrap_export
def export(cls, model, input_args, save_path, export_method, **export_kwargs): def export(cls, model, input_args, save_path, export_method, **export_kwargs):
with patch_builtin_len(): with patch_builtin_len():
......
...@@ -369,7 +369,7 @@ class D2RCNNInferenceWrapper(nn.Module): ...@@ -369,7 +369,7 @@ class D2RCNNInferenceWrapper(nn.Module):
def forward(self, image): def forward(self, image):
""" """
This function describes what happends during the tracing. Note that the output This function describes what happends during the tracing. Note that the output
contains non-tensor, therefore the D2TorchscriptTracingExport must be used in contains non-tensor, therefore the TracingAdaptedTorchscriptExport must be used in
order to convert the output back from flattened tensors. order to convert the output back from flattened tensors.
""" """
inputs = [{"image": image}] inputs = [{"image": image}]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment