Commit 2a1ed1ec authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

allowing specifying names for torchscript

Reviewed By: zhanghang1989

Differential Revision: D28083131

fbshipit-source-id: 8bad642800d3923db3f42047d1b1d85625c01bd9
parent 715b4f66
......@@ -15,7 +15,6 @@ from detectron2.export.torchscript_patch import patch_builtin_len
from detectron2.utils.file_io import PathManager
from mobile_cv.common.misc.file_utils import make_temp_directory
from mobile_cv.common.misc.iter_utils import recursive_iterate
from mobile_cv.predictor.model_wrappers import load_model
from torch import nn
from torch._C import MobileOptimizerType
from torch.utils.bundled_inputs import augment_model_with_bundled_inputs
......@@ -30,12 +29,14 @@ class MobileOptimizationConfig(NamedTuple):
optimization_blocklist: Set[MobileOptimizerType] = None
preserved_methods: List[AnyStr] = None
backend: str = "CPU"
torchscript_filename: str = "mobile_optimized.ptl"
def trace_and_save_torchscript(
model: nn.Module,
inputs: Tuple[torch.Tensor],
output_path: str,
torchscript_filename: str = "model.jit",
mobile_optimization: Optional[MobileOptimizationConfig] = None,
_extra_files: Optional[Dict[str, bytes]] = None,
):
......@@ -56,7 +57,7 @@ def trace_and_save_torchscript(
yield local_file
PathManager.copy_from_local(local_file, remote_file, overwrite=True)
with _synced_local_file("model.jit") as model_file:
with _synced_local_file(torchscript_filename) as model_file:
torch.jit.save(script_model, model_file, _extra_files=_extra_files)
with _synced_local_file("data.pth") as data_file:
......@@ -70,7 +71,8 @@ def trace_and_save_torchscript(
preserved_methods=mobile_optimization.preserved_methods,
backend=mobile_optimization.backend,
)
with _synced_local_file("mobile_optimized.ptl") as lite_path:
torchscript_filename = mobile_optimization.torchscript_filename
with _synced_local_file(torchscript_filename) as lite_path:
liteopt_model._save_for_lite_interpreter(lite_path)
# liteopt_model(*inputs) # sanity check
op_names = torch.jit.export_opnames(liteopt_model)
......@@ -87,9 +89,36 @@ def trace_and_save_torchscript(
inputs = iters.value
augment_model_with_bundled_inputs(liteopt_model, [inputs])
liteopt_model.run_on_bundled_input(0) # sanity check
with _synced_local_file("mobile_optimized_bundled.ptl") as lite_path:
name, ext = os.path.splitext(torchscript_filename)
with _synced_local_file(name + "_bundled" + ext) as lite_path:
liteopt_model._save_for_lite_interpreter(lite_path)
return torchscript_filename
class TorchscriptWrapper(nn.Module):
""""""
def __init__(self, module, int8_backend=None):
super().__init__()
self.module = module
self.int8_backend = int8_backend
def forward(self, *args, **kwargs):
# TODO: set int8 backend accordingly if needed
return self.module(*args, **kwargs)
def load_torchscript(model_path):
extra_files = {}
# NOTE: may support loading extra_file specified by model_info
# extra_files["predictor_info.json"] = ""
with PathManager.open(model_path, "rb") as f:
ts = torch.jit.load(f, _extra_files=extra_files)
return TorchscriptWrapper(ts)
def tracing_adapter_wrap_export(old_f):
def new_f(cls, model, input_args, *args, **kwargs):
......@@ -150,12 +179,15 @@ class DefaultTorchscriptExport(ModelExportMethod):
else:
export_kwargs["mobile_optimization"] = MobileOptimizationConfig()
trace_and_save_torchscript(model, input_args, save_path, **export_kwargs)
return {}
torchscript_filename = trace_and_save_torchscript(
model, input_args, save_path, **export_kwargs
)
return {"torchscript_filename": torchscript_filename}
@classmethod
def load(cls, save_path, **load_kwargs):
return load_model(save_path, "torchscript")
def load(cls, save_path, *, torchscript_filename="model.jit"):
model_path = os.path.join(save_path, torchscript_filename)
return load_torchscript(model_path)
@ModelExportMethodRegistry.register("torchscript@tracing")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment