Commit be6f398c authored by Lara Haidar's avatar Lara Haidar Committed by Francisco Massa
Browse files

Enable ONNX Test for FasterRcnn (#1555)

* enable faster rcnn test

* flake8

* smaller image size

* set min/max
parent af225a8a
......@@ -19,6 +19,7 @@ except ImportError:
onnxruntime = None
import unittest
from torchvision.ops._register_onnx_ops import _onnx_opset_version
@unittest.skipIf(onnxruntime is None, 'ONNX Runtime unavailable')
......@@ -32,7 +33,8 @@ class ONNXExporterTester(unittest.TestCase):
onnx_io = io.BytesIO()
# export to onnx with the first input
torch.onnx.export(model, inputs_list[0], onnx_io, do_constant_folding=True, opset_version=10)
torch.onnx.export(model, inputs_list[0], onnx_io,
do_constant_folding=True, opset_version=_onnx_opset_version)
# validate the exported model with onnx runtime
for test_inputs in inputs_list:
......@@ -97,7 +99,6 @@ class ONNXExporterTester(unittest.TestCase):
model = ops.RoIPool((pool_h, pool_w), 2)
self.run_model(model, [(x, rois)])
@unittest.skip("Disable test until Resize opset 11 is implemented in ONNX Runtime")
def test_transform_images(self):
class TransformModule(torch.nn.Module):
......@@ -108,13 +109,13 @@ class ONNXExporterTester(unittest.TestCase):
def forward(self_module, images):
return self_module.transform(images)[0].tensors
input = [torch.rand(3, 800, 1280), torch.rand(3, 800, 800)]
input_test = [torch.rand(3, 800, 1280), torch.rand(3, 800, 800)]
input = [torch.rand(3, 100, 200), torch.rand(3, 200, 200)]
input_test = [torch.rand(3, 100, 200), torch.rand(3, 200, 200)]
self.run_model(TransformModule(), [input, input_test])
def _init_test_generalized_rcnn_transform(self):
min_size = 800
max_size = 1333
min_size = 100
max_size = 200
image_mean = [0.485, 0.456, 0.406]
image_std = [0.229, 0.224, 0.225]
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
......@@ -234,7 +235,6 @@ class ONNXExporterTester(unittest.TestCase):
self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)])
@unittest.skipIf(torch.__version__ < "1.4.", "Disable test if torch version is less than 1.4")
def test_roi_heads(self):
class RoiHeadsModule(torch.nn.Module):
def __init__(self_module, images):
......@@ -271,7 +271,7 @@ class ONNXExporterTester(unittest.TestCase):
data = requests.get(url)
image = Image.open(BytesIO(data.content)).convert("RGB")
image = image.resize((800, 1280), Image.BILINEAR)
image = image.resize((300, 200), Image.BILINEAR)
to_tensor = transforms.ToTensor()
return to_tensor(image)
......@@ -285,12 +285,12 @@ class ONNXExporterTester(unittest.TestCase):
test_images = [image2]
return images, test_images
@unittest.skip("Disable test until Resize opset 11 is implemented in ONNX Runtime")
@unittest.skipIf(torch.__version__ < "1.4.", "Disable test if torch version is less than 1.4")
def test_faster_rcnn(self):
images, test_images = self.get_test_images()
model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True)
model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True,
min_size=200,
max_size=300)
model.eval()
model(images)
self.run_model(model, [(images,), (test_images,)])
......
......@@ -110,13 +110,7 @@ class AnchorGenerator(nn.Module):
shifts_y = torch.arange(
0, grid_height, dtype=torch.float32, device=device
) * stride_height
# TODO: remove tracing pass when exporting torch.meshgrid()
# is suported in ONNX
if torchvision._is_tracing():
shift_y = shifts_y.view(-1, 1).expand(grid_height, grid_width)
shift_x = shifts_x.view(1, -1).expand(grid_height, grid_width)
else:
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
......
......@@ -89,15 +89,6 @@ class GeneralizedRCNNTransform(nn.Module):
target["keypoints"] = keypoints
return image, target
# _onnx_dynamic_img_pad() creates a dynamic padding
# for an image supported in ONNx tracing.
# it is used to process the images in _onnx_batch_images().
def _onnx_dynamic_img_pad(self, img, padding):
concat_0 = torch.cat((img, torch.zeros(padding[0], img.shape[1], img.shape[2])), 0)
concat_1 = torch.cat((concat_0, torch.zeros(concat_0.shape[0], padding[1], concat_0.shape[2])), 1)
padded_img = torch.cat((concat_1, torch.zeros(concat_1.shape[0], concat_1.shape[1], padding[2])), 2)
return padded_img
# _onnx_batch_images() is an implementation of
# batch_images() that is supported by ONNX tracing.
def _onnx_batch_images(self, images, size_divisible=32):
......@@ -116,7 +107,7 @@ class GeneralizedRCNNTransform(nn.Module):
padded_imgs = []
for img in images:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padded_img = self._onnx_dynamic_img_pad(img, padding)
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)
return torch.stack(padded_imgs)
......
import sys
import torch
_onnx_opset_version = 11
def _register_custom_op():
from torch.onnx.symbolic_helper import parse_args, scalar_type_to_onnx
......@@ -30,6 +32,6 @@ def _register_custom_op():
return roi_pool, None
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic('torchvision::nms', symbolic_multi_label_nms, 10)
register_custom_op_symbolic('torchvision::roi_align', roi_align, 10)
register_custom_op_symbolic('torchvision::roi_pool', roi_pool, 10)
register_custom_op_symbolic('torchvision::nms', symbolic_multi_label_nms, _onnx_opset_version)
register_custom_op_symbolic('torchvision::roi_align', roi_align, _onnx_opset_version)
register_custom_op_symbolic('torchvision::roi_pool', roi_pool, _onnx_opset_version)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment