Commit 0ab6d3f1 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

update RCNN model test base

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/75

Refactor the base test case
- make test_dir valid throughout the test (rather than under local context), so individual test can load back the export model
- refactor the `custom_setup_test` for easier override.
- move parameterized into base class to avoid copying naming function

Reviewed By: zhanghang1989

Differential Revision: D28651067

fbshipit-source-id: c59a311564f6114039e20ed3a23e5dd9c84f4ae4
parent 29b57165
......@@ -168,6 +168,7 @@ def _export_single_model(
model_export_method_str = model_export_method
model_export_method = ModelExportMethodRegistry.get(model_export_method)
assert issubclass(model_export_method, ModelExportMethod), model_export_method
logger.info("Using model export method: {}".format(model_export_method))
load_kwargs = model_export_method.export(
model=model,
......
......@@ -3,12 +3,15 @@
import copy
import shutil
import tempfile
import unittest
from typing import Optional
import d2go.data.transforms.box_utils as bu
import torch
from d2go.export.api import convert_and_export_predictor
from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.runner import GeneralizedRCNNRunner
from d2go.utils.testing.data_loader_helper import create_fake_detection_data_loader
from detectron2.structures import (
......@@ -16,12 +19,12 @@ from detectron2.structures import (
Instances,
)
from detectron2.utils.testing import assert_instances_allclose
from mobile_cv.common.misc.file_utils import make_temp_directory
from mobile_cv.predictor.api import create_predictor
from parameterized import parameterized
def _get_image_with_box(image_size, boxes: Optional[Boxes] = None):
""" Draw boxes on the image, one box per channel, use values 10, 20, ... """
"""Draw boxes on the image, one box per channel, use values 10, 20, ..."""
ret = torch.zeros((3, image_size[0], image_size[1]))
if boxes is None:
return ret
......@@ -80,7 +83,7 @@ def get_batched_inputs(
def _get_keypoints_from_boxes(boxes: Boxes, num_keypoints: int):
""" Use box center as keypoints """
"""Use box center as keypoints"""
centers = boxes.get_centers()
kpts = torch.cat((centers, torch.ones(centers.shape[0], 1)), dim=1)
kpts = kpts.repeat(1, num_keypoints).reshape(len(boxes), num_keypoints, 3)
......@@ -235,22 +238,62 @@ def get_quick_test_config_opts(
return [str(x) for x in ret]
def get_export_test_name(testcase_func, param_num, param):
predictor_type, compare_match = param.args
assert isinstance(predictor_type, str)
assert isinstance(compare_match, bool)
return "{}_{}".format(
testcase_func.__name__, parameterized.to_safe_name(predictor_type)
)
class RCNNBaseTestCases:
class TemplateTestCase(unittest.TestCase): # TODO: maybe subclass from TestMetaArch
def setup_custom_test(self):
raise NotImplementedError()
@staticmethod
def expand_parameterized_test_export(*args, **kwargs):
if "name_func" not in kwargs:
kwargs["name_func"] = get_export_test_name
return parameterized.expand(*args, **kwargs)
class TemplateTestCase(unittest.TestCase): # TODO: maybe subclass from TestMetaArch
def setUp(self):
runner = GeneralizedRCNNRunner()
self.cfg = runner.get_default_cfg()
self.is_mcs = False
# Add APIs to D2's meta arch, this is usually called in runner's setup,
# however in unittest it needs to be called sperarately.
# TODO: maybe we should apply this by default
patch_d2_meta_arch()
self.setup_test_dir()
assert hasattr(self, "test_dir")
self.setup_custom_test()
assert hasattr(self, "runner")
assert hasattr(self, "cfg")
self.force_apply_overwrite_opts()
# NOTE: change some config to make the model run fast
self.cfg.merge_from_list(get_quick_test_config_opts())
self.test_model = self.runner.build_model(self.cfg, eval_only=True)
def setup_test_dir(self):
self.test_dir = tempfile.mkdtemp(prefix="test_export_")
self.addCleanup(shutil.rmtree, self.test_dir)
def setup_custom_test(self):
"""
Override this when using different runner, using different base config file,
or setting specific config for certain test.
"""
self.runner = GeneralizedRCNNRunner()
self.cfg = self.runner.get_default_cfg()
# subclass can call: self.cfg.merge_from_file(...)
def force_apply_overwrite_opts(self):
"""
Recommend only overriding this for a group of tests, while indivisual test
should have its own `setup_custom_test`.
"""
# update config to make the model run fast
self.cfg.merge_from_list(get_quick_test_config_opts())
# forcing test on CPU
self.cfg.merge_from_list(["MODEL.DEVICE", "cpu"])
self.test_model = runner.build_model(self.cfg, eval_only=True)
def _test_export(self, predictor_type, compare_match=True):
size_divisibility = max(self.test_model.backbone.size_divisibility, 10)
......@@ -258,27 +301,32 @@ class RCNNBaseTestCases:
with create_fake_detection_data_loader(h, w, is_train=False) as data_loader:
inputs = next(iter(data_loader))
with make_temp_directory(
"test_export_{}".format(predictor_type)
) as tmp_dir:
# TODO: the export may change model it self, need to fix this
model_to_export = copy.deepcopy(self.test_model)
predictor_path = convert_and_export_predictor(
self.cfg, model_to_export, predictor_type, tmp_dir, data_loader
# TODO: the export may change model it self, need to fix this
model_to_export = copy.deepcopy(self.test_model)
predictor_path = convert_and_export_predictor(
self.cfg,
model_to_export,
predictor_type,
self.test_dir,
data_loader,
)
predictor = create_predictor(predictor_path)
predicotr_outputs = predictor(inputs)
_validate_outputs(inputs, predicotr_outputs)
if compare_match:
with torch.no_grad():
pytorch_outputs = self.test_model(inputs)
assert_instances_allclose(
predicotr_outputs[0]["instances"],
pytorch_outputs[0]["instances"],
)
predictor = create_predictor(predictor_path)
predicotr_outputs = predictor(inputs)
_validate_outputs(inputs, predicotr_outputs)
if compare_match:
with torch.no_grad():
pytorch_outputs = self.test_model(inputs)
return predictor_path
assert_instances_allclose(
predicotr_outputs[0]["instances"],
pytorch_outputs[0]["instances"],
)
# TODO: add test_train
def _test_inference(self):
size_divisibility = max(self.test_model.backbone.size_divisibility, 10)
......
......@@ -30,6 +30,7 @@ requirements = [
'torch',
'pytorch_lightning',
'opencv-python',
'parameterized',
]
def d2go_gather_files(dst_module, file_path, extension="*") -> List[str]:
......
......@@ -21,23 +21,26 @@ patch_d2_meta_arch()
class TestFBNetV3MaskRCNNNormal(RCNNBaseTestCases.TemplateTestCase):
def setup_custom_test(self):
super().setup_custom_test()
self.cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3a_dsmask_C4.yaml")
def test_inference(self):
self._test_inference()
def test_export_torchscript_tracing(self):
self._test_export("torchscript@tracing", compare_match=True)
def test_export_torchscript_int8(self):
self._test_export("torchscript_int8", compare_match=False)
def test_export_torchscript_int8_tracing(self):
self._test_export("torchscript_int8@tracing", compare_match=False)
@RCNNBaseTestCases.expand_parameterized_test_export(
[
["torchscript@tracing", True],
["torchscript_int8", False],
["torchscript_int8@tracing", False],
]
)
def test_export(self, predictor_type, compare_match):
self._test_export(predictor_type, compare_match=compare_match)
class TestFBNetV3MaskRCNNQATEager(RCNNBaseTestCases.TemplateTestCase):
def setup_custom_test(self):
super().setup_custom_test()
self.cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3a_dsmask_C4.yaml")
# enable QAT
self.cfg.merge_from_list(
......@@ -54,12 +57,18 @@ class TestFBNetV3MaskRCNNQATEager(RCNNBaseTestCases.TemplateTestCase):
def test_inference(self):
self._test_inference()
def test_export_torchscript_int8(self):
self._test_export("torchscript_int8", compare_match=False) # TODO: fix mismatch
@RCNNBaseTestCases.expand_parameterized_test_export(
[
["torchscript_int8", False], # TODO: fix mismatch
]
)
def test_export(self, predictor_type, compare_match):
self._test_export(predictor_type, compare_match=compare_match)
class TestFBNetV3KeypointRCNNNormal(RCNNBaseTestCases.TemplateTestCase):
def setup_custom_test(self):
super().setup_custom_test()
self.cfg.merge_from_file("detectron2go://keypoint_rcnn_fbnetv3a_dsmask_C4.yaml")
# FIXME: have to use qnnpack due to follow error:
......@@ -74,8 +83,13 @@ class TestFBNetV3KeypointRCNNNormal(RCNNBaseTestCases.TemplateTestCase):
def test_inference(self):
self._test_inference()
def test_export_torchscript_int8(self):
self._test_export("torchscript_int8", compare_match=False)
@RCNNBaseTestCases.expand_parameterized_test_export(
[
["torchscript_int8", False], # TODO: fix mismatch
]
)
def test_export(self, predictor_type, compare_match):
self._test_export(predictor_type, compare_match=compare_match)
class TestTorchVisionExport(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment