Commit d86ecc92 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

add option for exporting torchscript mobile

Reviewed By: zhanghang1989

Differential Revision: D27805428

fbshipit-source-id: c588bdb456e606ca333c2f99eb5c3668edddcbfa
parent 387020a9
...@@ -32,6 +32,10 @@ from typing import Callable, Dict, NamedTuple, Optional, Union ...@@ -32,6 +32,10 @@ from typing import Callable, Dict, NamedTuple, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.quantization.quantize_fx import torch.quantization.quantize_fx
from d2go.export.torchscript import (
trace_and_save_torchscript,
MobileOptimizationConfig,
)
from d2go.modeling.quantization import post_training_quantize from d2go.modeling.quantization import post_training_quantize
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from mobile_cv.arch.utils import fuse_utils from mobile_cv.arch.utils import fuse_utils
...@@ -263,7 +267,20 @@ class DefaultCaffe2Export(object): ...@@ -263,7 +267,20 @@ class DefaultCaffe2Export(object):
class DefaultTorchscriptExport(object): class DefaultTorchscriptExport(object):
@classmethod @classmethod
def export(cls, model, input_args, save_path, **export_kwargs): def export(cls, model, input_args, save_path, **export_kwargs):
from d2go.export.torchscript import trace_and_save_torchscript
trace_and_save_torchscript(model, input_args, save_path, **export_kwargs) trace_and_save_torchscript(model, input_args, save_path, **export_kwargs)
return {} return {}
@ModelExportMethodRegistry.register("torchscript_mobile")
@ModelExportMethodRegistry.register("torchscript_mobile_int8")
class DefaultTorchscriptMobileExport(object):
@classmethod
def export(cls, model, input_args, save_path, **export_kwargs):
trace_and_save_torchscript(
model,
input_args,
save_path,
mobile_optimization=MobileOptimizationConfig(),
**export_kwargs,
)
return {}
...@@ -2,43 +2,84 @@ ...@@ -2,43 +2,84 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import contextlib
import logging import logging
import os import os
from typing import Tuple, Optional, Dict from typing import Tuple, Optional, Dict, NamedTuple, List, AnyStr, Set
import torch import torch
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from mobile_cv.common.misc.file_utils import make_temp_directory
from torch import nn from torch import nn
from torch._C import MobileOptimizerType
from torch.utils.bundled_inputs import augment_model_with_bundled_inputs
from torch.utils.mobile_optimizer import optimize_for_mobile
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MobileOptimizationConfig(NamedTuple):
# optimize_for_mobile
optimization_blocklist: Set[MobileOptimizerType] = None
preserved_methods: List[AnyStr] = None
backend: str = "CPU"
methods_to_optimize: List[AnyStr] = None
def trace_and_save_torchscript( def trace_and_save_torchscript(
model: nn.Module, model: nn.Module,
inputs: Tuple[torch.Tensor], inputs: Tuple[torch.Tensor],
output_path: str, output_path: str,
mobile_optimization: Optional[MobileOptimizationConfig] = None,
_extra_files: Optional[Dict[str, bytes]] = None, _extra_files: Optional[Dict[str, bytes]] = None,
): ):
logger.info("Tracing and saving TorchScript to {} ...".format(output_path)) logger.info("Tracing and saving TorchScript to {} ...".format(output_path))
PathManager.mkdirs(output_path)
if _extra_files is None:
_extra_files = {}
# TODO: patch_builtin_len depends on D2, we should either copy the function or # TODO: patch_builtin_len depends on D2, we should either copy the function or
# dynamically registering the D2's version. # dynamically registering the D2's version.
from detectron2.export.torchscript_patch import patch_builtin_len from detectron2.export.torchscript_patch import patch_builtin_len
with torch.no_grad(), patch_builtin_len(): with torch.no_grad(), patch_builtin_len():
script_model = torch.jit.trace(model, inputs) script_model = torch.jit.trace(model, inputs)
if _extra_files is None: with make_temp_directory("trace_and_save_torchscript") as tmp_dir:
_extra_files = {}
model_file = os.path.join(output_path, "model.jit")
PathManager.mkdirs(output_path) @contextlib.contextmanager
with PathManager.open(model_file, "wb") as f: def _synced_local_file(rel_path):
torch.jit.save(script_model, f, _extra_files=_extra_files) remote_file = os.path.join(output_path, rel_path)
local_file = os.path.join(tmp_dir, rel_path)
yield local_file
PathManager.copy_from_local(local_file, remote_file, overwrite=True)
with _synced_local_file("model.jit") as model_file:
torch.jit.save(script_model, model_file, _extra_files=_extra_files)
with _synced_local_file("data.pth") as data_file:
torch.save(inputs, data_file)
data_file = os.path.join(output_path, "data.pth") if mobile_optimization is not None:
with PathManager.open(data_file, "wb") as f: logger.info("Applying optimize_for_mobile ...")
torch.save(inputs, f) liteopt_model = optimize_for_mobile(
script_model,
optimization_blocklist=mobile_optimization.optimization_blocklist,
preserved_methods=mobile_optimization.preserved_methods,
backend=mobile_optimization.backend,
methods_to_optimize=mobile_optimization.methods_to_optimize,
)
with _synced_local_file("mobile_optimized.ptl") as lite_path:
liteopt_model._save_for_lite_interpreter(lite_path)
# liteopt_model(*inputs) # sanity check
op_names = torch.jit.export_opnames(liteopt_model)
logger.info(
"Operator names from lite interpreter:\n{}".format("\n".join(op_names))
)
# NOTE: new API doesn't require return logger.info("Applying augment_model_with_bundled_inputs ...")
return model_file augment_model_with_bundled_inputs(liteopt_model, [inputs])
liteopt_model.run_on_bundled_input(0) # sanity check
with _synced_local_file("mobile_optimized_bundled.ptl") as lite_path:
liteopt_model._save_for_lite_interpreter(lite_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment