Commit e62c0e4c authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

enable (fake) inference for bolt exported model

Summary:
Enable the inference for boltnn (via running torchscript).
- merge rcnn's boltnn test with other export types.
- misc fixes.

Differential Revision: D30610386

fbshipit-source-id: 7b78136f8ca640b5fc179cb47e3218e709418d71
parent ad2973b2
...@@ -23,6 +23,8 @@ from torch.utils.mobile_optimizer import optimize_for_mobile ...@@ -23,6 +23,8 @@ from torch.utils.mobile_optimizer import optimize_for_mobile
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TORCHSCRIPT_FILENAME_KEY: str = "torchscript_filename"
class MobileOptimizationConfig(NamedTuple): class MobileOptimizationConfig(NamedTuple):
# optimize_for_mobile # optimize_for_mobile
...@@ -191,7 +193,7 @@ class DefaultTorchscriptExport(ModelExportMethod): ...@@ -191,7 +193,7 @@ class DefaultTorchscriptExport(ModelExportMethod):
torchscript_filename = trace_and_save_torchscript( torchscript_filename = trace_and_save_torchscript(
model, input_args, save_path, **export_kwargs model, input_args, save_path, **export_kwargs
) )
return {"torchscript_filename": torchscript_filename} return {TORCHSCRIPT_FILENAME_KEY: torchscript_filename}
@classmethod @classmethod
def load(cls, save_path, *, torchscript_filename="model.jit"): def load(cls, save_path, *, torchscript_filename="model.jit"):
......
...@@ -310,21 +310,17 @@ class RCNNBaseTestCases: ...@@ -310,21 +310,17 @@ class RCNNBaseTestCases:
) )
predictor = create_predictor(predictor_path) predictor = create_predictor(predictor_path)
# This check is needed for models with unsupported backends. predictor_outputs = predictor(inputs)
# For these, predictor.model_or_models will be None -- as the _validate_outputs(inputs, predictor_outputs)
# model can't be loaded or run. (e.g. BoltNN).
if predictor.model_or_models is not None: if compare_match:
predictor_outputs = predictor(inputs) with torch.no_grad():
_validate_outputs(inputs, predictor_outputs) pytorch_outputs = self.test_model(inputs)
if compare_match: assert_instances_allclose(
with torch.no_grad(): predictor_outputs[0]["instances"],
pytorch_outputs = self.test_model(inputs) pytorch_outputs[0]["instances"],
)
assert_instances_allclose(
predictor_outputs[0]["instances"],
pytorch_outputs[0]["instances"],
)
return predictor_path return predictor_path
......
...@@ -19,7 +19,7 @@ from mobile_cv.common.misc.file_utils import make_temp_directory ...@@ -19,7 +19,7 @@ from mobile_cv.common.misc.file_utils import make_temp_directory
patch_d2_meta_arch() patch_d2_meta_arch()
class TestFBNetV3MaskRCNNNormal(RCNNBaseTestCases.TemplateTestCase): class TestFBNetV3MaskRCNNFP32(RCNNBaseTestCases.TemplateTestCase):
def setup_custom_test(self): def setup_custom_test(self):
super().setup_custom_test() super().setup_custom_test()
self.cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3a_dsmask_C4.yaml") self.cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3a_dsmask_C4.yaml")
...@@ -66,7 +66,7 @@ class TestFBNetV3MaskRCNNQATEager(RCNNBaseTestCases.TemplateTestCase): ...@@ -66,7 +66,7 @@ class TestFBNetV3MaskRCNNQATEager(RCNNBaseTestCases.TemplateTestCase):
self._test_export(predictor_type, compare_match=compare_match) self._test_export(predictor_type, compare_match=compare_match)
class TestFBNetV3KeypointRCNNNormal(RCNNBaseTestCases.TemplateTestCase): class TestFBNetV3KeypointRCNNFP32(RCNNBaseTestCases.TemplateTestCase):
def setup_custom_test(self): def setup_custom_test(self):
super().setup_custom_test() super().setup_custom_test()
self.cfg.merge_from_file("detectron2go://keypoint_rcnn_fbnetv3a_dsmask_C4.yaml") self.cfg.merge_from_file("detectron2go://keypoint_rcnn_fbnetv3a_dsmask_C4.yaml")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment