Commit d4392a85 authored by Lara Haidar's avatar Lara Haidar Committed by Francisco Massa
Browse files

Support Exporting GeneralizedRCNNTransform to ONNX (#1325)

* Support Exporting GeneralizedRCNNTransform

* refactor code to address comments

* update tests

* address comments

* revert min_size to test CI

* re-revert min_size
parent 3c1ab2c1
import io import io
import torch import torch
from torchvision import ops from torchvision import ops
from torchvision.models.detection.transform import GeneralizedRCNNTransform
# onnxruntime requires python 3.5 or above # onnxruntime requires python 3.5 or above
try: try:
...@@ -17,23 +18,23 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -17,23 +18,23 @@ class ONNXExporterTester(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
torch.manual_seed(123) torch.manual_seed(123)
def run_model(self, model, inputs): def run_model(self, model, inputs_list):
model.eval() model.eval()
# run pytorch model
with torch.no_grad():
if isinstance(inputs, torch.Tensor):
inputs = (inputs,)
outputs = model(*inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
onnx_io = io.BytesIO() onnx_io = io.BytesIO()
# export to onnx # export to onnx with the first input
torch.onnx.export(model, inputs, onnx_io, do_constant_folding=True, opset_version=10) torch.onnx.export(model, inputs_list[0], onnx_io, do_constant_folding=True, opset_version=10)
# validate the exported model with onnx runtime # validate the exported model with onnx runtime
self.ort_validate(onnx_io, inputs, outputs) for test_inputs in inputs_list:
with torch.no_grad():
if isinstance(test_inputs, torch.Tensor) or \
isinstance(test_inputs, list):
test_inputs = (test_inputs,)
test_ouputs = model(*test_inputs)
if isinstance(test_ouputs, torch.Tensor):
test_ouputs = (test_ouputs,)
self.ort_validate(onnx_io, test_inputs, test_ouputs)
def ort_validate(self, onnx_io, inputs, outputs): def ort_validate(self, onnx_io, inputs, outputs):
...@@ -66,13 +67,13 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -66,13 +67,13 @@ class ONNXExporterTester(unittest.TestCase):
def forward(self, boxes, scores): def forward(self, boxes, scores):
return ops.nms(boxes, scores, 0.5) return ops.nms(boxes, scores, 0.5)
self.run_model(Module(), (boxes, scores)) self.run_model(Module(), [(boxes, scores)])
def test_roi_pool(self): def test_roi_pool(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32) x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32) single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, 2) model = ops.RoIAlign((5, 5), 1, 2)
self.run_model(model, (x, single_roi)) self.run_model(model, [(x, single_roi)])
def test_roi_align(self): def test_roi_align(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32) x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
...@@ -81,7 +82,26 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -81,7 +82,26 @@ class ONNXExporterTester(unittest.TestCase):
pool_w = 5 pool_w = 5
model = ops.RoIPool((pool_h, pool_w), 2) model = ops.RoIPool((pool_h, pool_w), 2)
model.eval() model.eval()
self.run_model(model, (x, rois)) self.run_model(model, [(x, rois)])
@unittest.skip("Disable test until Resize opset 11 is implemented in ONNX Runtime")
def test_transform_images(self):
class TransformModule(torch.nn.Module):
def __init__(self_module):
super(TransformModule, self_module).__init__()
min_size = 800
max_size = 1333
image_mean = [0.485, 0.456, 0.406]
image_std = [0.229, 0.224, 0.225]
self_module.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
def forward(self_module, images):
return self_module.transform(images)[0].tensors
input = [torch.rand(3, 800, 1280), torch.rand(3, 800, 800)]
input_test = [torch.rand(3, 800, 1280), torch.rand(3, 800, 800)]
self.run_model(TransformModule(), [input, input_test])
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -34,3 +34,8 @@ def get_image_backend(): ...@@ -34,3 +34,8 @@ def get_image_backend():
Gets the name of the package used to load images Gets the name of the package used to load images
""" """
return _image_backend return _image_backend
def _is_tracing():
import torch
return torch._C._get_tracing_state()
...@@ -2,6 +2,7 @@ import random ...@@ -2,6 +2,7 @@ import random
import math import math
import torch import torch
from torch import nn from torch import nn
import torchvision
from torchvision.ops import misc as misc_nn_ops from torchvision.ops import misc as misc_nn_ops
from .image_list import ImageList from .image_list import ImageList
...@@ -56,8 +57,9 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -56,8 +57,9 @@ class GeneralizedRCNNTransform(nn.Module):
def resize(self, image, target): def resize(self, image, target):
h, w = image.shape[-2:] h, w = image.shape[-2:]
min_size = float(min(image.shape[-2:])) im_shape = torch.tensor(image.shape[-2:])
max_size = float(max(image.shape[-2:])) min_size = float(torch.min(im_shape))
max_size = float(torch.max(im_shape))
if self.training: if self.training:
size = random.choice(self.min_size) size = random.choice(self.min_size)
else: else:
...@@ -87,10 +89,45 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -87,10 +89,45 @@ class GeneralizedRCNNTransform(nn.Module):
target["keypoints"] = keypoints target["keypoints"] = keypoints
return image, target return image, target
# _onnx_dynamic_img_pad() creates a dynamic padding
# for an image supported in ONNx tracing.
# it is used to process the images in _onnx_batch_images().
def _onnx_dynamic_img_pad(self, img, padding):
concat_0 = torch.cat((img, torch.zeros(padding[0], img.shape[1], img.shape[2])), 0)
concat_1 = torch.cat((concat_0, torch.zeros(concat_0.shape[0], padding[1], concat_0.shape[2])), 1)
padded_img = torch.cat((concat_1, torch.zeros(concat_1.shape[0], concat_1.shape[1], padding[2])), 2)
return padded_img
# _onnx_batch_images() is an implementation of
# batch_images() that is supported by ONNX tracing.
def _onnx_batch_images(self, images, size_divisible=32):
max_size = []
for i in range(images[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
max_size.append(max_size_i)
stride = size_divisible
max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
max_size = tuple(max_size)
# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
# which is not yet supported in onnx
padded_imgs = []
for img in images:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padded_img = self._onnx_dynamic_img_pad(img, padding)
padded_imgs.append(padded_img)
return torch.stack(padded_imgs)
def batch_images(self, images, size_divisible=32): def batch_images(self, images, size_divisible=32):
# concatenate if torchvision._is_tracing():
max_size = tuple(max(s) for s in zip(*[img.shape for img in images])) # batch_images() does not export well to ONNX
# call _onnx_batch_images() instead
return self._onnx_batch_images(images, size_divisible)
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
stride = size_divisible stride = size_divisible
max_size = list(max_size) max_size = list(max_size)
max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride) max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment