Commit e84d3414 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

move some of `test_meta_arch_rcnn.py` to oss

Reviewed By: newstzpz

Differential Revision: D27747996

fbshipit-source-id: 6ae3b89c3944098828e246e5a4a89209b8e171a1
parent 77ebe09f
......@@ -28,8 +28,7 @@ MODEL:
POOLER_RESOLUTION: 6
NORM: "naiveSyncBN"
ROI_MASK_HEAD:
NAME: "MaskRCNNConvUpsampleHead"
NUM_CONV: 4
NAME: "FBNetV2RoIMaskHead"
POOLER_RESOLUTION: 14
NORM: "naiveSyncBN"
MODEL_EMA:
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import os
import unittest
import torch
from d2go.export.api import convert_and_export_predictor
from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.runner import GeneralizedRCNNRunner
from d2go.utils.testing.data_loader_helper import create_fake_detection_data_loader
from d2go.utils.testing.rcnn_helper import RCNNBaseTestCases, get_quick_test_config_opts
from mobile_cv.common.misc.file_utils import make_temp_directory
# Add APIs to D2's meta arch, this is usually called in runner's setup, however in
# unittest it needs to be called sperarately. (maybe we should apply this by default)
patch_d2_meta_arch()
class TestFBNetV3MaskRCNNNormal(RCNNBaseTestCases.TemplateTestCase):
def setup_custom_test(self):
self.cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3a_dsmask_C4.yaml")
def test_inference(self):
self._test_inference()
def test_export_torchscript_tracing(self):
self._test_export("torchscript@tracing", compare_match=True)
def test_export_torchscript_int8(self):
self._test_export("torchscript_int8", compare_match=False)
def test_export_torchscript_int8_tracing(self):
self._test_export("torchscript_int8@tracing", compare_match=False)
class TestFBNetV3MaskRCNNQATEager(RCNNBaseTestCases.TemplateTestCase):
def setup_custom_test(self):
self.cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3a_dsmask_C4.yaml")
# enable QAT
self.cfg.merge_from_list(
[
"QUANTIZATION.BACKEND",
"qnnpack",
"QUANTIZATION.QAT.ENABLED",
"True",
]
)
# FIXME: NaiveSyncBN is not supported
self.cfg.merge_from_list(["MODEL.FBNET_V2.NORM", "bn"])
def test_inference(self):
self._test_inference()
def test_export_torchscript_int8(self):
self._test_export("torchscript_int8", compare_match=False) # TODO: fix mismatch
class TestFBNetV3KeypointRCNNNormal(RCNNBaseTestCases.TemplateTestCase):
def setup_custom_test(self):
self.cfg.merge_from_file("detectron2go://keypoint_rcnn_fbnetv3a_dsmask_C4.yaml")
# FIXME: have to use qnnpack due to follow error:
# Per Channel Quantization is currently disabled for transposed conv
self.cfg.merge_from_list(
[
"QUANTIZATION.BACKEND",
"qnnpack",
]
)
def test_inference(self):
self._test_inference()
def test_export_torchscript_int8(self):
self._test_export("torchscript_int8", compare_match=False)
class TestTorchVisionExport(unittest.TestCase):
def test_export_torchvision_format(self):
runner = GeneralizedRCNNRunner()
cfg = runner.get_default_cfg()
cfg.merge_from_file("detectron2go://mask_rcnn_fbnetv3a_dsmask_C4.yaml")
cfg.merge_from_list(get_quick_test_config_opts())
cfg.merge_from_list(["MODEL.DEVICE", "cpu"])
pytorch_model = runner.build_model(cfg, eval_only=True)
from typing import List, Dict
class Wrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, inputs: List[torch.Tensor]):
x = inputs[0].unsqueeze(0) * 255
scale = 320.0 / min(x.shape[-2], x.shape[-1])
x = torch.nn.functional.interpolate(
x,
scale_factor=scale,
mode="bilinear",
align_corners=True,
recompute_scale_factor=True,
)
out = self.model(x[0])
res: Dict[str, torch.Tensor] = {}
res["boxes"] = out[0] / scale
res["labels"] = out[2]
res["scores"] = out[1]
return inputs, [res]
size_divisibility = max(pytorch_model.backbone.size_divisibility, 10)
h, w = size_divisibility, size_divisibility * 2
with create_fake_detection_data_loader(h, w, is_train=False) as data_loader:
with make_temp_directory("test_export_torchvision_format") as tmp_dir:
predictor_path = convert_and_export_predictor(
cfg,
copy.deepcopy(pytorch_model),
"torchscript@tracing",
tmp_dir,
data_loader,
)
orig_model = torch.jit.load(os.path.join(predictor_path, "model.jit"))
wrapped_model = Wrapper(orig_model)
# optionally do a forward
wrapped_model([torch.rand(3, 600, 600)])
scripted_model = torch.jit.script(wrapped_model)
scripted_model.save(os.path.join(tmp_dir, "new_file.pt"))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment