"vscode:/vscode.git/clone" did not exist on "38a3e4df926c59bc122191c0fc8066755e98b6d2"
Commit e62c0e4c authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

enable (fake) inference for bolt exported model

Summary:
Enable the inference for boltnn (via running torchscript).
- merge rcnn's boltnn test with other export types.
- misc fixes.

Differential Revision: D30610386

fbshipit-source-id: 7b78136f8ca640b5fc179cb47e3218e709418d71
parent ad2973b2
......@@ -23,6 +23,8 @@ from torch.utils.mobile_optimizer import optimize_for_mobile
logger = logging.getLogger(__name__)
TORCHSCRIPT_FILENAME_KEY: str = "torchscript_filename"
class MobileOptimizationConfig(NamedTuple):
# optimize_for_mobile
......@@ -191,7 +193,7 @@ class DefaultTorchscriptExport(ModelExportMethod):
torchscript_filename = trace_and_save_torchscript(
model, input_args, save_path, **export_kwargs
)
return {"torchscript_filename": torchscript_filename}
return {TORCHSCRIPT_FILENAME_KEY: torchscript_filename}
@classmethod
def load(cls, save_path, *, torchscript_filename="model.jit"):
......
......@@ -310,10 +310,6 @@ class RCNNBaseTestCases:
)
predictor = create_predictor(predictor_path)
# This check is needed for models with unsupported backends.
# For these, predictor.model_or_models will be None -- as the
# model can't be loaded or run. (e.g. BoltNN).
if predictor.model_or_models is not None:
predictor_outputs = predictor(inputs)
_validate_outputs(inputs, predictor_outputs)
......
......@@ -19,7 +19,7 @@ from mobile_cv.common.misc.file_utils import make_temp_directory
patch_d2_meta_arch()
class TestFBNetV3MaskRCNNNormal(RCNNBaseTestCases.TemplateTestCase):
class TestFBNetV3MaskRCNNFP32(RCNNBaseTestCases.TemplateTestCase):
def setup_custom_test(self):
super().setup_custom_test()
self.cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3a_dsmask_C4.yaml")
......@@ -66,7 +66,7 @@ class TestFBNetV3MaskRCNNQATEager(RCNNBaseTestCases.TemplateTestCase):
self._test_export(predictor_type, compare_match=compare_match)
class TestFBNetV3KeypointRCNNNormal(RCNNBaseTestCases.TemplateTestCase):
class TestFBNetV3KeypointRCNNFP32(RCNNBaseTestCases.TemplateTestCase):
def setup_custom_test(self):
super().setup_custom_test()
self.cfg.merge_from_file("detectron2go://keypoint_rcnn_fbnetv3a_dsmask_C4.yaml")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment