Commit be6f398c authored by Lara Haidar's avatar Lara Haidar Committed by Francisco Massa
Browse files

Enable ONNX Test for FasterRcnn (#1555)

* enable faster rcnn test

* flake8

* smaller image size

* set min/max
parent af225a8a
...@@ -19,6 +19,7 @@ except ImportError: ...@@ -19,6 +19,7 @@ except ImportError:
onnxruntime = None onnxruntime = None
import unittest import unittest
from torchvision.ops._register_onnx_ops import _onnx_opset_version
@unittest.skipIf(onnxruntime is None, 'ONNX Runtime unavailable') @unittest.skipIf(onnxruntime is None, 'ONNX Runtime unavailable')
...@@ -32,7 +33,8 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -32,7 +33,8 @@ class ONNXExporterTester(unittest.TestCase):
onnx_io = io.BytesIO() onnx_io = io.BytesIO()
# export to onnx with the first input # export to onnx with the first input
torch.onnx.export(model, inputs_list[0], onnx_io, do_constant_folding=True, opset_version=10) torch.onnx.export(model, inputs_list[0], onnx_io,
do_constant_folding=True, opset_version=_onnx_opset_version)
# validate the exported model with onnx runtime # validate the exported model with onnx runtime
for test_inputs in inputs_list: for test_inputs in inputs_list:
...@@ -97,7 +99,6 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -97,7 +99,6 @@ class ONNXExporterTester(unittest.TestCase):
model = ops.RoIPool((pool_h, pool_w), 2) model = ops.RoIPool((pool_h, pool_w), 2)
self.run_model(model, [(x, rois)]) self.run_model(model, [(x, rois)])
@unittest.skip("Disable test until Resize opset 11 is implemented in ONNX Runtime")
def test_transform_images(self): def test_transform_images(self):
class TransformModule(torch.nn.Module): class TransformModule(torch.nn.Module):
...@@ -108,13 +109,13 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -108,13 +109,13 @@ class ONNXExporterTester(unittest.TestCase):
def forward(self_module, images): def forward(self_module, images):
return self_module.transform(images)[0].tensors return self_module.transform(images)[0].tensors
input = [torch.rand(3, 800, 1280), torch.rand(3, 800, 800)] input = [torch.rand(3, 100, 200), torch.rand(3, 200, 200)]
input_test = [torch.rand(3, 800, 1280), torch.rand(3, 800, 800)] input_test = [torch.rand(3, 100, 200), torch.rand(3, 200, 200)]
self.run_model(TransformModule(), [input, input_test]) self.run_model(TransformModule(), [input, input_test])
def _init_test_generalized_rcnn_transform(self): def _init_test_generalized_rcnn_transform(self):
min_size = 800 min_size = 100
max_size = 1333 max_size = 200
image_mean = [0.485, 0.456, 0.406] image_mean = [0.485, 0.456, 0.406]
image_std = [0.229, 0.224, 0.225] image_std = [0.229, 0.224, 0.225]
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
...@@ -234,7 +235,6 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -234,7 +235,6 @@ class ONNXExporterTester(unittest.TestCase):
self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)]) self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)])
@unittest.skipIf(torch.__version__ < "1.4.", "Disable test if torch version is less than 1.4")
def test_roi_heads(self): def test_roi_heads(self):
class RoiHeadsModule(torch.nn.Module): class RoiHeadsModule(torch.nn.Module):
def __init__(self_module, images): def __init__(self_module, images):
...@@ -271,7 +271,7 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -271,7 +271,7 @@ class ONNXExporterTester(unittest.TestCase):
data = requests.get(url) data = requests.get(url)
image = Image.open(BytesIO(data.content)).convert("RGB") image = Image.open(BytesIO(data.content)).convert("RGB")
image = image.resize((800, 1280), Image.BILINEAR) image = image.resize((300, 200), Image.BILINEAR)
to_tensor = transforms.ToTensor() to_tensor = transforms.ToTensor()
return to_tensor(image) return to_tensor(image)
...@@ -285,12 +285,12 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -285,12 +285,12 @@ class ONNXExporterTester(unittest.TestCase):
test_images = [image2] test_images = [image2]
return images, test_images return images, test_images
@unittest.skip("Disable test until Resize opset 11 is implemented in ONNX Runtime")
@unittest.skipIf(torch.__version__ < "1.4.", "Disable test if torch version is less than 1.4")
def test_faster_rcnn(self): def test_faster_rcnn(self):
images, test_images = self.get_test_images() images, test_images = self.get_test_images()
model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True) model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True,
min_size=200,
max_size=300)
model.eval() model.eval()
model(images) model(images)
self.run_model(model, [(images,), (test_images,)]) self.run_model(model, [(images,), (test_images,)])
......
...@@ -110,13 +110,7 @@ class AnchorGenerator(nn.Module): ...@@ -110,13 +110,7 @@ class AnchorGenerator(nn.Module):
shifts_y = torch.arange( shifts_y = torch.arange(
0, grid_height, dtype=torch.float32, device=device 0, grid_height, dtype=torch.float32, device=device
) * stride_height ) * stride_height
# TODO: remove tracing pass when exporting torch.meshgrid() shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
# is suported in ONNX
if torchvision._is_tracing():
shift_y = shifts_y.view(-1, 1).expand(grid_height, grid_width)
shift_x = shifts_x.view(1, -1).expand(grid_height, grid_width)
else:
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1) shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1) shift_y = shift_y.reshape(-1)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
......
...@@ -89,15 +89,6 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -89,15 +89,6 @@ class GeneralizedRCNNTransform(nn.Module):
target["keypoints"] = keypoints target["keypoints"] = keypoints
return image, target return image, target
# _onnx_dynamic_img_pad() creates a dynamic padding
# for an image supported in ONNx tracing.
# it is used to process the images in _onnx_batch_images().
def _onnx_dynamic_img_pad(self, img, padding):
concat_0 = torch.cat((img, torch.zeros(padding[0], img.shape[1], img.shape[2])), 0)
concat_1 = torch.cat((concat_0, torch.zeros(concat_0.shape[0], padding[1], concat_0.shape[2])), 1)
padded_img = torch.cat((concat_1, torch.zeros(concat_1.shape[0], concat_1.shape[1], padding[2])), 2)
return padded_img
# _onnx_batch_images() is an implementation of # _onnx_batch_images() is an implementation of
# batch_images() that is supported by ONNX tracing. # batch_images() that is supported by ONNX tracing.
def _onnx_batch_images(self, images, size_divisible=32): def _onnx_batch_images(self, images, size_divisible=32):
...@@ -116,7 +107,7 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -116,7 +107,7 @@ class GeneralizedRCNNTransform(nn.Module):
padded_imgs = [] padded_imgs = []
for img in images: for img in images:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padded_img = self._onnx_dynamic_img_pad(img, padding) padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img) padded_imgs.append(padded_img)
return torch.stack(padded_imgs) return torch.stack(padded_imgs)
......
import sys import sys
import torch import torch
_onnx_opset_version = 11
def _register_custom_op(): def _register_custom_op():
from torch.onnx.symbolic_helper import parse_args, scalar_type_to_onnx from torch.onnx.symbolic_helper import parse_args, scalar_type_to_onnx
...@@ -30,6 +32,6 @@ def _register_custom_op(): ...@@ -30,6 +32,6 @@ def _register_custom_op():
return roi_pool, None return roi_pool, None
from torch.onnx import register_custom_op_symbolic from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic('torchvision::nms', symbolic_multi_label_nms, 10) register_custom_op_symbolic('torchvision::nms', symbolic_multi_label_nms, _onnx_opset_version)
register_custom_op_symbolic('torchvision::roi_align', roi_align, 10) register_custom_op_symbolic('torchvision::roi_align', roi_align, _onnx_opset_version)
register_custom_op_symbolic('torchvision::roi_pool', roi_pool, 10) register_custom_op_symbolic('torchvision::roi_pool', roi_pool, _onnx_opset_version)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment