Commit 0ab6d3f1 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

update RCNN model test base

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/75

Refactor the base test case
- make test_dir valid throughout the test (rather than under local context), so individual test can load back the export model
- refactor the `custom_setup_test` for easier override.
- move parameterized into base class to avoid copying naming function

Reviewed By: zhanghang1989

Differential Revision: D28651067

fbshipit-source-id: c59a311564f6114039e20ed3a23e5dd9c84f4ae4
parent 29b57165
...@@ -168,6 +168,7 @@ def _export_single_model( ...@@ -168,6 +168,7 @@ def _export_single_model(
model_export_method_str = model_export_method model_export_method_str = model_export_method
model_export_method = ModelExportMethodRegistry.get(model_export_method) model_export_method = ModelExportMethodRegistry.get(model_export_method)
assert issubclass(model_export_method, ModelExportMethod), model_export_method assert issubclass(model_export_method, ModelExportMethod), model_export_method
logger.info("Using model export method: {}".format(model_export_method))
load_kwargs = model_export_method.export( load_kwargs = model_export_method.export(
model=model, model=model,
......
...@@ -3,12 +3,15 @@ ...@@ -3,12 +3,15 @@
import copy import copy
import shutil
import tempfile
import unittest import unittest
from typing import Optional from typing import Optional
import d2go.data.transforms.box_utils as bu import d2go.data.transforms.box_utils as bu
import torch import torch
from d2go.export.api import convert_and_export_predictor from d2go.export.api import convert_and_export_predictor
from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.runner import GeneralizedRCNNRunner from d2go.runner import GeneralizedRCNNRunner
from d2go.utils.testing.data_loader_helper import create_fake_detection_data_loader from d2go.utils.testing.data_loader_helper import create_fake_detection_data_loader
from detectron2.structures import ( from detectron2.structures import (
...@@ -16,12 +19,12 @@ from detectron2.structures import ( ...@@ -16,12 +19,12 @@ from detectron2.structures import (
Instances, Instances,
) )
from detectron2.utils.testing import assert_instances_allclose from detectron2.utils.testing import assert_instances_allclose
from mobile_cv.common.misc.file_utils import make_temp_directory
from mobile_cv.predictor.api import create_predictor from mobile_cv.predictor.api import create_predictor
from parameterized import parameterized
def _get_image_with_box(image_size, boxes: Optional[Boxes] = None): def _get_image_with_box(image_size, boxes: Optional[Boxes] = None):
""" Draw boxes on the image, one box per channel, use values 10, 20, ... """ """Draw boxes on the image, one box per channel, use values 10, 20, ..."""
ret = torch.zeros((3, image_size[0], image_size[1])) ret = torch.zeros((3, image_size[0], image_size[1]))
if boxes is None: if boxes is None:
return ret return ret
...@@ -80,7 +83,7 @@ def get_batched_inputs( ...@@ -80,7 +83,7 @@ def get_batched_inputs(
def _get_keypoints_from_boxes(boxes: Boxes, num_keypoints: int): def _get_keypoints_from_boxes(boxes: Boxes, num_keypoints: int):
""" Use box center as keypoints """ """Use box center as keypoints"""
centers = boxes.get_centers() centers = boxes.get_centers()
kpts = torch.cat((centers, torch.ones(centers.shape[0], 1)), dim=1) kpts = torch.cat((centers, torch.ones(centers.shape[0], 1)), dim=1)
kpts = kpts.repeat(1, num_keypoints).reshape(len(boxes), num_keypoints, 3) kpts = kpts.repeat(1, num_keypoints).reshape(len(boxes), num_keypoints, 3)
...@@ -235,22 +238,62 @@ def get_quick_test_config_opts( ...@@ -235,22 +238,62 @@ def get_quick_test_config_opts(
return [str(x) for x in ret] return [str(x) for x in ret]
def get_export_test_name(testcase_func, param_num, param):
predictor_type, compare_match = param.args
assert isinstance(predictor_type, str)
assert isinstance(compare_match, bool)
return "{}_{}".format(
testcase_func.__name__, parameterized.to_safe_name(predictor_type)
)
class RCNNBaseTestCases: class RCNNBaseTestCases:
class TemplateTestCase(unittest.TestCase): # TODO: maybe subclass from TestMetaArch @staticmethod
def setup_custom_test(self): def expand_parameterized_test_export(*args, **kwargs):
raise NotImplementedError() if "name_func" not in kwargs:
kwargs["name_func"] = get_export_test_name
return parameterized.expand(*args, **kwargs)
class TemplateTestCase(unittest.TestCase): # TODO: maybe subclass from TestMetaArch
def setUp(self): def setUp(self):
runner = GeneralizedRCNNRunner() # Add APIs to D2's meta arch, this is usually called in runner's setup,
self.cfg = runner.get_default_cfg() # however in unittest it needs to be called sperarately.
self.is_mcs = False # TODO: maybe we should apply this by default
patch_d2_meta_arch()
self.setup_test_dir()
assert hasattr(self, "test_dir")
self.setup_custom_test() self.setup_custom_test()
assert hasattr(self, "runner")
assert hasattr(self, "cfg")
self.force_apply_overwrite_opts()
# NOTE: change some config to make the model run fast self.test_model = self.runner.build_model(self.cfg, eval_only=True)
self.cfg.merge_from_list(get_quick_test_config_opts())
def setup_test_dir(self):
self.test_dir = tempfile.mkdtemp(prefix="test_export_")
self.addCleanup(shutil.rmtree, self.test_dir)
def setup_custom_test(self):
"""
Override this when using different runner, using different base config file,
or setting specific config for certain test.
"""
self.runner = GeneralizedRCNNRunner()
self.cfg = self.runner.get_default_cfg()
# subclass can call: self.cfg.merge_from_file(...)
def force_apply_overwrite_opts(self):
"""
Recommend only overriding this for a group of tests, while indivisual test
should have its own `setup_custom_test`.
"""
# update config to make the model run fast
self.cfg.merge_from_list(get_quick_test_config_opts())
# forcing test on CPU
self.cfg.merge_from_list(["MODEL.DEVICE", "cpu"]) self.cfg.merge_from_list(["MODEL.DEVICE", "cpu"])
self.test_model = runner.build_model(self.cfg, eval_only=True)
def _test_export(self, predictor_type, compare_match=True): def _test_export(self, predictor_type, compare_match=True):
size_divisibility = max(self.test_model.backbone.size_divisibility, 10) size_divisibility = max(self.test_model.backbone.size_divisibility, 10)
...@@ -258,27 +301,32 @@ class RCNNBaseTestCases: ...@@ -258,27 +301,32 @@ class RCNNBaseTestCases:
with create_fake_detection_data_loader(h, w, is_train=False) as data_loader: with create_fake_detection_data_loader(h, w, is_train=False) as data_loader:
inputs = next(iter(data_loader)) inputs = next(iter(data_loader))
with make_temp_directory( # TODO: the export may change model it self, need to fix this
"test_export_{}".format(predictor_type) model_to_export = copy.deepcopy(self.test_model)
) as tmp_dir: predictor_path = convert_and_export_predictor(
# TODO: the export may change model it self, need to fix this self.cfg,
model_to_export = copy.deepcopy(self.test_model) model_to_export,
predictor_path = convert_and_export_predictor( predictor_type,
self.cfg, model_to_export, predictor_type, tmp_dir, data_loader self.test_dir,
data_loader,
)
predictor = create_predictor(predictor_path)
predicotr_outputs = predictor(inputs)
_validate_outputs(inputs, predicotr_outputs)
if compare_match:
with torch.no_grad():
pytorch_outputs = self.test_model(inputs)
assert_instances_allclose(
predicotr_outputs[0]["instances"],
pytorch_outputs[0]["instances"],
) )
predictor = create_predictor(predictor_path) return predictor_path
predicotr_outputs = predictor(inputs)
_validate_outputs(inputs, predicotr_outputs)
if compare_match:
with torch.no_grad():
pytorch_outputs = self.test_model(inputs)
assert_instances_allclose( # TODO: add test_train
predicotr_outputs[0]["instances"],
pytorch_outputs[0]["instances"],
)
def _test_inference(self): def _test_inference(self):
size_divisibility = max(self.test_model.backbone.size_divisibility, 10) size_divisibility = max(self.test_model.backbone.size_divisibility, 10)
......
...@@ -30,6 +30,7 @@ requirements = [ ...@@ -30,6 +30,7 @@ requirements = [
'torch', 'torch',
'pytorch_lightning', 'pytorch_lightning',
'opencv-python', 'opencv-python',
'parameterized',
] ]
def d2go_gather_files(dst_module, file_path, extension="*") -> List[str]: def d2go_gather_files(dst_module, file_path, extension="*") -> List[str]:
......
...@@ -21,23 +21,26 @@ patch_d2_meta_arch() ...@@ -21,23 +21,26 @@ patch_d2_meta_arch()
class TestFBNetV3MaskRCNNNormal(RCNNBaseTestCases.TemplateTestCase): class TestFBNetV3MaskRCNNNormal(RCNNBaseTestCases.TemplateTestCase):
def setup_custom_test(self): def setup_custom_test(self):
super().setup_custom_test()
self.cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3a_dsmask_C4.yaml") self.cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3a_dsmask_C4.yaml")
def test_inference(self): def test_inference(self):
self._test_inference() self._test_inference()
def test_export_torchscript_tracing(self): @RCNNBaseTestCases.expand_parameterized_test_export(
self._test_export("torchscript@tracing", compare_match=True) [
["torchscript@tracing", True],
def test_export_torchscript_int8(self): ["torchscript_int8", False],
self._test_export("torchscript_int8", compare_match=False) ["torchscript_int8@tracing", False],
]
def test_export_torchscript_int8_tracing(self): )
self._test_export("torchscript_int8@tracing", compare_match=False) def test_export(self, predictor_type, compare_match):
self._test_export(predictor_type, compare_match=compare_match)
class TestFBNetV3MaskRCNNQATEager(RCNNBaseTestCases.TemplateTestCase): class TestFBNetV3MaskRCNNQATEager(RCNNBaseTestCases.TemplateTestCase):
def setup_custom_test(self): def setup_custom_test(self):
super().setup_custom_test()
self.cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3a_dsmask_C4.yaml") self.cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3a_dsmask_C4.yaml")
# enable QAT # enable QAT
self.cfg.merge_from_list( self.cfg.merge_from_list(
...@@ -54,12 +57,18 @@ class TestFBNetV3MaskRCNNQATEager(RCNNBaseTestCases.TemplateTestCase): ...@@ -54,12 +57,18 @@ class TestFBNetV3MaskRCNNQATEager(RCNNBaseTestCases.TemplateTestCase):
def test_inference(self): def test_inference(self):
self._test_inference() self._test_inference()
def test_export_torchscript_int8(self): @RCNNBaseTestCases.expand_parameterized_test_export(
self._test_export("torchscript_int8", compare_match=False) # TODO: fix mismatch [
["torchscript_int8", False], # TODO: fix mismatch
]
)
def test_export(self, predictor_type, compare_match):
self._test_export(predictor_type, compare_match=compare_match)
class TestFBNetV3KeypointRCNNNormal(RCNNBaseTestCases.TemplateTestCase): class TestFBNetV3KeypointRCNNNormal(RCNNBaseTestCases.TemplateTestCase):
def setup_custom_test(self): def setup_custom_test(self):
super().setup_custom_test()
self.cfg.merge_from_file("detectron2go://keypoint_rcnn_fbnetv3a_dsmask_C4.yaml") self.cfg.merge_from_file("detectron2go://keypoint_rcnn_fbnetv3a_dsmask_C4.yaml")
# FIXME: have to use qnnpack due to follow error: # FIXME: have to use qnnpack due to follow error:
...@@ -74,8 +83,13 @@ class TestFBNetV3KeypointRCNNNormal(RCNNBaseTestCases.TemplateTestCase): ...@@ -74,8 +83,13 @@ class TestFBNetV3KeypointRCNNNormal(RCNNBaseTestCases.TemplateTestCase):
def test_inference(self): def test_inference(self):
self._test_inference() self._test_inference()
def test_export_torchscript_int8(self): @RCNNBaseTestCases.expand_parameterized_test_export(
self._test_export("torchscript_int8", compare_match=False) [
["torchscript_int8", False], # TODO: fix mismatch
]
)
def test_export(self, predictor_type, compare_match):
self._test_export(predictor_type, compare_match=compare_match)
class TestTorchVisionExport(unittest.TestCase): class TestTorchVisionExport(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment