Commit 90aff5da authored by Sanjeev Kumar's avatar Sanjeev Kumar Committed by Facebook GitHub Bot
Browse files

Enable inference config in export step

Summary:
- Enable sdk inference config specification in export step. This enables adding the sdk configuration as part of model file in the export step. The sdk config can be specified as infernece_config.yaml and is zipped together with torchscript model. The main goal of sdk configuration is to control the model inference behavior with model.
- SDK inference config design doc: https://docs.google.com/document/d/1j5qx8IrnFg1DJFzTnu4W8WmXFYJ-AgCDfSQHb2ACJsk/edit
- One click fblearner pipeline is in next diff on the stack

Differential Revision: D27881742

fbshipit-source-id: 34a3ab7a88f456b74841cf671ea1b3f678cdb733
parent 27bef8e3
......@@ -4,6 +4,7 @@
import contextlib
import logging
from typing import Optional, Any, Dict
import mock
import yaml
......@@ -94,6 +95,60 @@ class CfgNode(_CfgNode):
# dump follows alphabetical order, thus good for hash use
return hash(self.dump())
def get_field_or_none(self, field_path: str) -> Optional[Any]:
"""
Reads out a value from the cfg node is set otherwise returns None
The field path is the name of the parameter with all config groups
names prefixed with a "." separator.
e.g. if the config is
MODEL:
TEST:
SCORE_THRESHOLD: 0.7
Then to access the value of SCORE_THRESHOLD, this API should be called
>> score_threshold = cfg.get_field_or_none("MODEL.TEST.SCORE_THRESHOLD")
"""
attributes = field_path.split(".")
path_to_last, last_attribute = attributes[:-1], attributes[-1]
cfg_node = self
for attribute in path_to_last:
if not isinstance(cfg_node, _CfgNode) or attribute not in cfg_node:
return None
cfg_node = cfg_node[attribute]
return (
cfg_node[last_attribute]
if isinstance(cfg_node, _CfgNode) and last_attribute in cfg_node
else None
)
def as_flattened_dict(self) -> Dict[str, Any]:
"""
Returns all keys from config object as a flattened dict.
For example if the config is
MODEL:
TEST:
SCORE_THRESHOLD: 0.7
MIN_DIM_SIZE: 360
The returned dict would be
{
"MODEL.TEST.SCORE_THRESHOLD": 0.7,
"MODEL.TEST.MIN_DIM_SIZE": 360
}
"""
return self._as_flattened_dict()
def _as_flattened_dict(self, prefix: str = "") -> Dict[str, Any]:
ret = {}
for key in sorted(self.keys()):
value = self[key]
key_path = f"{prefix}.{key}" if prefix else key
if isinstance(value, CfgNode):
ret.update(value._as_flattened_dict(key_path))
else:
ret[key_path] = value
return ret
@contextlib.contextmanager
def reroute_load_yaml_with_base():
......
import unittest
from d2go.config import CfgNode
class TestConfigNode(unittest.TestCase):
@staticmethod
def _get_default_config():
cfg = CfgNode()
cfg.INPUT = CfgNode()
cfg.INPUT.CROP = CfgNode()
cfg.INPUT.CROP.ENABLED = False
cfg.INPUT.CROP.SIZE = (0.9, 0.9)
cfg.INPUT.CROP.TYPE = "relative_range"
cfg.MODEL = CfgNode()
cfg.MODEL.MIN_DIM_SIZE = 360
cfg.INFERENCE_SDK = CfgNode()
cfg.INFERENCE_SDK.MODEL = CfgNode()
cfg.INFERENCE_SDK.MODEL.SCORE_THRESHOLD = 0.8
cfg.INFERENCE_SDK.IOU_TRACKER = CfgNode()
cfg.INFERENCE_SDK.IOU_TRACKER.IOU_THRESHOLD = 0.15
cfg.INFERENCE_SDK.ENABLE_ID_TRACKING = True
return cfg
def test_get_field_or_none(self):
cfg = self._get_default_config()
self.assertEqual(cfg.get_field_or_none("MODEL.MIN_DIM_SIZE"), 360)
self.assertEqual(
cfg.get_field_or_none("INFERENCE_SDK.ENABLE_ID_TRACKING"), True
)
self.assertEqual(cfg.get_field_or_none("MODEL.INPUT_SIZE"), None)
self.assertEqual(cfg.get_field_or_none("MODEL.INPUT_SIZE.HEIGHT"), None)
def test_as_flattened_dict(self):
cfg = self._get_default_config()
cfg_dict = cfg.as_flattened_dict()
target_cfg_dict = {
"INPUT.CROP.ENABLED": False,
"INPUT.CROP.SIZE": (0.9, 0.9),
"INPUT.CROP.TYPE": "relative_range",
"MODEL.MIN_DIM_SIZE": 360,
"INFERENCE_SDK.MODEL.SCORE_THRESHOLD": 0.8,
"INFERENCE_SDK.IOU_TRACKER.IOU_THRESHOLD": 0.15,
"INFERENCE_SDK.ENABLE_ID_TRACKING": True,
}
self.assertEqual(target_cfg_dict, cfg_dict)
......@@ -7,22 +7,72 @@ deployable format (such as torchscript, caffe2, ...)
"""
import copy
import json
import logging
import os
import tempfile
import typing
from typing import Optional
import mobile_cv.lut.lib.pt.flops_utils as flops_utils
from d2go.config import temp_defrost
import torch
from d2go.config import temp_defrost, CfgNode
from d2go.export.api import convert_and_export_predictor
from d2go.setup import (
basic_argument_parser,
prepare_for_launch,
setup_after_launch,
)
from iopath.common.file_io import PathManager
from iopath.fb.manifold import ManifoldPathHandler
from mobile_cv.common.misc.py import post_mortem_if_fail
path_manager = PathManager()
path_manager.register_handler(ManifoldPathHandler())
logger = logging.getLogger("d2go.tools.export")
INFERNCE_CONFIG_FILENAME = "inference_config.json"
MOBILE_OPTIMIZED_BUNDLE_FILENAME = "mobile_optimized_bundled.ptl"
def write_model_with_config(
output_dir: str, model_jit_path: str, inference_config: Optional[CfgNode] = None
):
"""
Writes the sdk inference config along with model file and saves the model
with configuration at ${output_dir}/mobile_optimized_bundled.ptl
"""
model_jit_local_path = path_manager.get_local_path(model_jit_path)
model = torch.jit.load(model_jit_local_path)
extra_files = {}
if inference_config:
extra_files = {
INFERNCE_CONFIG_FILENAME: json.dumps(inference_config.as_flattened_dict())
}
bundled_model_path = os.path.join(output_dir, MOBILE_OPTIMIZED_BUNDLE_FILENAME)
with tempfile.NamedTemporaryFile() as temp_file:
model._save_for_lite_interpreter(temp_file.name, _extra_files=extra_files)
path_manager.copy_from_local(temp_file.name, bundled_model_path, overwrite=True)
logger.info(f"Saved bundled model to: {bundled_model_path}")
def _add_inference_config(
predictor_paths: typing.Dict[str, str],
inference_config: Optional[CfgNode],
):
"""Adds inference config in _extra_files as json and writes the bundled model"""
if inference_config is None:
return
for _, export_dir in predictor_paths.items():
model_jit_path = os.path.join(export_dir, "model.jit")
write_model_with_config(export_dir, model_jit_path, inference_config)
def main(
cfg,
......@@ -32,6 +82,7 @@ def main(
predictor_types: typing.List[str],
compare_accuracy: bool = False,
skip_if_fail: bool = False,
inference_config: Optional[CfgNode] = None,
):
cfg = copy.deepcopy(cfg)
setup_after_launch(cfg, output_dir, runner)
......@@ -66,6 +117,8 @@ def main(
if not skip_if_fail:
raise e
_add_inference_config(predictor_paths, inference_config)
ret = {"predictor_paths": predictor_paths, "accuracy_comparison": {}}
if compare_accuracy:
raise NotImplementedError()
......@@ -78,6 +131,12 @@ def main(
@post_mortem_if_fail()
def run_with_cmdline_args(args):
cfg, output_dir, runner = prepare_for_launch(args)
inference_config = None
if args.inference_config_file:
inference_config = CfgNode(
CfgNode.load_yaml_with_base(args.inference_config_file)
)
return main(
cfg,
output_dir,
......@@ -86,6 +145,7 @@ def run_with_cmdline_args(args):
predictor_types=args.predictor_types,
compare_accuracy=args.compare_accuracy,
skip_if_fail=args.skip_if_fail,
inference_config=inference_config,
)
......@@ -110,10 +170,18 @@ def get_parser():
help="If set, suppress the exception for failed exporting and continue to"
" export the next type of model",
)
parser.add_argument(
"--inference-config-file",
type=str,
default=None,
help="Inference config file containing the model parameters for c++ sdk pipeline",
)
return parser
def cli():
run_with_cmdline_args(get_parser().parse_args())
if __name__ == "__main__":
cli()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment