Unverified Commit 986d2423 authored by Negin Raoof's avatar Negin Raoof Committed by GitHub
Browse files

ONNX export for variable input sizes (#1840)



* fixes and tests for variable input size

* transform test fix

* Fix comment

* Dynamic shape for keypoint_rcnn

* Update test_onnx.py

* Update rpn.py

* Fix for split on RPN

* Fixes for feedbacks

* flake8

* topk fix

* Fix build

* branch on tracing

* fix for scalar tensor

* Fixes for script type annotations

* Update rpn.py

* clean up

* clean up

* Update rpn.py

* Updated for feedback

* Fix for comments

* revert to use tensor

* Added test for box clip

* Fixes for feedback

* Fix for feedback

* ORT version revert

* Update ort

* Update .travis.yml

* Update test_onnx.py

* Update test_onnx.py

* Tensor sizes

* Fix for dynamic split

* Try disable tests

* pytest verbose

* revert one test

* enable tests

* Update .travis.yml

* Update .travis.yml

* Update .travis.yml

* Update test_onnx.py

* Update .travis.yml

* Passing device

* Fixes for test

* Fix for boxes datatype

* clean up
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 504d20c6
...@@ -28,14 +28,15 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -28,14 +28,15 @@ class ONNXExporterTester(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
torch.manual_seed(123) torch.manual_seed(123)
def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True): def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True, dynamic_axes=None,
output_names=None, input_names=None):
model.eval() model.eval()
onnx_io = io.BytesIO() onnx_io = io.BytesIO()
# export to onnx with the first input # export to onnx with the first input
torch.onnx.export(model, inputs_list[0], onnx_io, torch.onnx.export(model, inputs_list[0], onnx_io,
do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version) do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version,
dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names)
# validate the exported model with onnx runtime # validate the exported model with onnx runtime
for test_inputs in inputs_list: for test_inputs in inputs_list:
with torch.no_grad(): with torch.no_grad():
...@@ -99,6 +100,21 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -99,6 +100,21 @@ class ONNXExporterTester(unittest.TestCase):
self.run_model(Module(), [(boxes, scores)]) self.run_model(Module(), [(boxes, scores)])
def test_clip_boxes_to_image(self):
boxes = torch.randn(5, 4) * 500
boxes[:, 2:] += boxes[:, :2]
size = torch.randn(200, 300)
size_2 = torch.randn(300, 400)
class Module(torch.nn.Module):
def forward(self, boxes, size):
return ops.boxes.clip_boxes_to_image(boxes, size.shape)
self.run_model(Module(), [(boxes, size), (boxes, size_2)],
input_names=["boxes", "size"],
dynamic_axes={"size": [0, 1]})
def test_roi_align(self): def test_roi_align(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32) x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32) single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
...@@ -123,9 +139,9 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -123,9 +139,9 @@ class ONNXExporterTester(unittest.TestCase):
def forward(self_module, images): def forward(self_module, images):
return self_module.transform(images)[0].tensors return self_module.transform(images)[0].tensors
input = [torch.rand(3, 100, 200), torch.rand(3, 200, 200)] input = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
input_test = [torch.rand(3, 100, 200), torch.rand(3, 200, 200)] input_test = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
self.run_model(TransformModule(), [input, input_test]) self.run_model(TransformModule(), [(input,), (input_test,)])
def _init_test_generalized_rcnn_transform(self): def _init_test_generalized_rcnn_transform(self):
min_size = 100 min_size = 100
...@@ -207,22 +223,28 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -207,22 +223,28 @@ class ONNXExporterTester(unittest.TestCase):
def test_rpn(self): def test_rpn(self):
class RPNModule(torch.nn.Module): class RPNModule(torch.nn.Module):
def __init__(self_module, images): def __init__(self_module):
super(RPNModule, self_module).__init__() super(RPNModule, self_module).__init__()
self_module.rpn = self._init_test_rpn() self_module.rpn = self._init_test_rpn()
self_module.images = ImageList(images, [i.shape[-2:] for i in images])
def forward(self_module, features): def forward(self_module, images, features):
return self_module.rpn(self_module.images, features) images = ImageList(images, [i.shape[-2:] for i in images])
return self_module.rpn(images, features)
images = torch.rand(2, 3, 600, 600) images = torch.rand(2, 3, 150, 150)
features = self.get_features(images) features = self.get_features(images)
test_features = self.get_features(images) images2 = torch.rand(2, 3, 80, 80)
test_features = self.get_features(images2)
model = RPNModule(images) model = RPNModule()
model.eval() model.eval()
model(features) model(images, features)
self.run_model(model, [(features,), (test_features,)], tolerate_small_mismatch=True)
self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True,
input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3],
"input3": [0, 1, 2, 3], "input4": [0, 1, 2, 3],
"input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]})
def test_multi_scale_roi_align(self): def test_multi_scale_roi_align(self):
...@@ -251,50 +273,59 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -251,50 +273,59 @@ class ONNXExporterTester(unittest.TestCase):
def test_roi_heads(self): def test_roi_heads(self):
class RoiHeadsModule(torch.nn.Module): class RoiHeadsModule(torch.nn.Module):
def __init__(self_module, images): def __init__(self_module):
super(RoiHeadsModule, self_module).__init__() super(RoiHeadsModule, self_module).__init__()
self_module.transform = self._init_test_generalized_rcnn_transform() self_module.transform = self._init_test_generalized_rcnn_transform()
self_module.rpn = self._init_test_rpn() self_module.rpn = self._init_test_rpn()
self_module.roi_heads = self._init_test_roi_heads_faster_rcnn() self_module.roi_heads = self._init_test_roi_heads_faster_rcnn()
self_module.original_image_sizes = [img.shape[-2:] for img in images]
self_module.images = ImageList(images, [i.shape[-2:] for i in images])
def forward(self_module, features): def forward(self_module, images, features):
proposals, _ = self_module.rpn(self_module.images, features) original_image_sizes = [img.shape[-2:] for img in images]
detections, _ = self_module.roi_heads(features, proposals, self_module.images.image_sizes) images = ImageList(images, [i.shape[-2:] for i in images])
proposals, _ = self_module.rpn(images, features)
detections, _ = self_module.roi_heads(features, proposals, images.image_sizes)
detections = self_module.transform.postprocess(detections, detections = self_module.transform.postprocess(detections,
self_module.images.image_sizes, images.image_sizes,
self_module.original_image_sizes) original_image_sizes)
return detections return detections
images = torch.rand(2, 3, 600, 600) images = torch.rand(2, 3, 100, 100)
features = self.get_features(images) features = self.get_features(images)
test_features = self.get_features(images) images2 = torch.rand(2, 3, 150, 150)
test_features = self.get_features(images2)
model = RoiHeadsModule(images) model = RoiHeadsModule()
model.eval() model.eval()
model(features) model(images, features)
self.run_model(model, [(features,), (test_features,)])
def get_image_from_url(self, url): self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True,
input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3],
"input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]})
def get_image_from_url(self, url, size=None):
import requests import requests
import numpy
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
from torchvision import transforms from torchvision import transforms
data = requests.get(url) data = requests.get(url)
image = Image.open(BytesIO(data.content)).convert("RGB") image = Image.open(BytesIO(data.content)).convert("RGB")
image = image.resize((300, 200), Image.BILINEAR)
if size is None:
size = (300, 200)
image = image.resize(size, Image.BILINEAR)
to_tensor = transforms.ToTensor() to_tensor = transforms.ToTensor()
return to_tensor(image) return to_tensor(image)
def get_test_images(self): def get_test_images(self):
image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg" image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg"
image = self.get_image_from_url(url=image_url) image = self.get_image_from_url(url=image_url, size=(200, 300))
image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png" image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png"
image2 = self.get_image_from_url(url=image_url2) image2 = self.get_image_from_url(url=image_url2, size=(250, 200))
images = [image] images = [image]
test_images = [image2] test_images = [image2]
return images, test_images return images, test_images
...@@ -302,12 +333,13 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -302,12 +333,13 @@ class ONNXExporterTester(unittest.TestCase):
def test_faster_rcnn(self): def test_faster_rcnn(self):
images, test_images = self.get_test_images() images, test_images = self.get_test_images()
model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
min_size=200,
max_size=300)
model.eval() model.eval()
model(images) model(images)
self.run_model(model, [(images,), (test_images,)]) self.run_model(model, [(images,), (test_images,)], input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]},
tolerate_small_mismatch=True)
# Verify that paste_mask_in_image beahves the same in tracing. # Verify that paste_mask_in_image beahves the same in tracing.
# This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image # This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image
...@@ -350,7 +382,11 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -350,7 +382,11 @@ class ONNXExporterTester(unittest.TestCase):
model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
model.eval() model.eval()
model(images) model(images)
self.run_model(model, [(images,), (test_images,)]) self.run_model(model, [(images,), (test_images,)],
input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]},
tolerate_small_mismatch=True)
# Verify that heatmaps_to_keypoints behaves the same in tracing. # Verify that heatmaps_to_keypoints behaves the same in tracing.
# This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints # This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints
...@@ -385,9 +421,7 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -385,9 +421,7 @@ class ONNXExporterTester(unittest.TestCase):
class KeyPointRCNN(torch.nn.Module): class KeyPointRCNN(torch.nn.Module):
def __init__(self): def __init__(self):
super(KeyPointRCNN, self).__init__() super(KeyPointRCNN, self).__init__()
self.model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, self.model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
min_size=200,
max_size=300)
def forward(self, images): def forward(self, images):
output = self.model(images) output = self.model(images)
...@@ -399,8 +433,12 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -399,8 +433,12 @@ class ONNXExporterTester(unittest.TestCase):
images, test_images = self.get_test_images() images, test_images = self.get_test_images()
model = KeyPointRCNN() model = KeyPointRCNN()
model.eval() model.eval()
model(test_images) model(images)
self.run_model(model, [(images,), (test_images,)]) self.run_model(model, [(images,), (test_images,)],
input_names=["images_tensors"],
output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
dynamic_axes={"images_tensors": [0, 1, 2, 3]},
tolerate_small_mismatch=True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -678,20 +678,13 @@ class RoIHeads(torch.nn.Module): ...@@ -678,20 +678,13 @@ class RoIHeads(torch.nn.Module):
device = class_logits.device device = class_logits.device
num_classes = class_logits.shape[-1] num_classes = class_logits.shape[-1]
boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals] boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
pred_boxes = self.box_coder.decode(box_regression, proposals) pred_boxes = self.box_coder.decode(box_regression, proposals)
pred_scores = F.softmax(class_logits, -1) pred_scores = F.softmax(class_logits, -1)
# split boxes and scores per image pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
if len(boxes_per_image) == 1: pred_scores_list = pred_scores.split(boxes_per_image, 0)
# TODO : remove this when ONNX support dynamic split sizes
# and just assign to pred_boxes instead of pred_boxes_list
pred_boxes_list = [pred_boxes]
pred_scores_list = [pred_scores]
else:
pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
pred_scores_list = pred_scores.split(boxes_per_image, 0)
all_boxes = [] all_boxes = []
all_scores = [] all_scores = []
......
...@@ -114,7 +114,7 @@ class AnchorGenerator(nn.Module): ...@@ -114,7 +114,7 @@ class AnchorGenerator(nn.Module):
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
def grid_anchors(self, grid_sizes, strides): def grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[int]]) # type: (List[List[int]], List[List[Tensor]])
anchors = [] anchors = []
cell_anchors = self.cell_anchors cell_anchors = self.cell_anchors
assert cell_anchors is not None assert cell_anchors is not None
...@@ -124,10 +124,6 @@ class AnchorGenerator(nn.Module): ...@@ -124,10 +124,6 @@ class AnchorGenerator(nn.Module):
): ):
grid_height, grid_width = size grid_height, grid_width = size
stride_height, stride_width = stride stride_height, stride_width = stride
if torchvision._is_tracing():
# required in ONNX export for mult operation with float32
stride_width = torch.tensor(stride_width, dtype=torch.float32)
stride_height = torch.tensor(stride_height, dtype=torch.float32)
device = base_anchors.device device = base_anchors.device
# For output anchor, compute [x_center, y_center, x_center, y_center] # For output anchor, compute [x_center, y_center, x_center, y_center]
...@@ -151,8 +147,8 @@ class AnchorGenerator(nn.Module): ...@@ -151,8 +147,8 @@ class AnchorGenerator(nn.Module):
return anchors return anchors
def cached_grid_anchors(self, grid_sizes, strides): def cached_grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[int]]) # type: (List[List[int]], List[List[Tensor]])
key = str(grid_sizes + strides) key = str(grid_sizes) + str(strides)
if key in self._cache: if key in self._cache:
return self._cache[key] return self._cache[key]
anchors = self.grid_anchors(grid_sizes, strides) anchors = self.grid_anchors(grid_sizes, strides)
...@@ -163,9 +159,9 @@ class AnchorGenerator(nn.Module): ...@@ -163,9 +159,9 @@ class AnchorGenerator(nn.Module):
# type: (ImageList, List[Tensor]) # type: (ImageList, List[Tensor])
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:] image_size = image_list.tensors.shape[-2:]
strides = [[int(image_size[0] / g[0]), int(image_size[1] / g[1])] for g in grid_sizes]
dtype, device = feature_maps[0].dtype, feature_maps[0].device dtype, device = feature_maps[0].dtype, feature_maps[0].device
strides = [[torch.tensor(image_size[0] / g[0], dtype=torch.int64, device=device),
torch.tensor(image_size[1] / g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
self.set_cell_anchors(dtype, device) self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = torch.jit.annotate(List[List[torch.Tensor]], []) anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
...@@ -480,7 +476,8 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -480,7 +476,8 @@ class RegionProposalNetwork(torch.nn.Module):
anchors = self.anchor_generator(images, features) anchors = self.anchor_generator(images, features)
num_images = len(anchors) num_images = len(anchors)
num_anchors_per_level = [o[0].numel() for o in objectness] num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
objectness, pred_bbox_deltas = \ objectness, pred_bbox_deltas = \
concat_box_prediction_layers(objectness, pred_bbox_deltas) concat_box_prediction_layers(objectness, pred_bbox_deltas)
# apply pred_bbox_deltas to anchors to obtain the decoded proposals # apply pred_bbox_deltas to anchors to obtain the decoded proposals
......
...@@ -88,7 +88,8 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -88,7 +88,8 @@ class GeneralizedRCNNTransform(nn.Module):
if max_size * scale_factor > self.max_size: if max_size * scale_factor > self.max_size:
scale_factor = self.max_size / max_size scale_factor = self.max_size / max_size
image = torch.nn.functional.interpolate( image = torch.nn.functional.interpolate(
image[None], scale_factor=scale_factor, mode='bilinear', align_corners=False)[0] image[None], scale_factor=scale_factor, mode='bilinear',
align_corners=False)[0]
if target is None: if target is None:
return image, target return image, target
...@@ -191,7 +192,8 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -191,7 +192,8 @@ class GeneralizedRCNNTransform(nn.Module):
def resize_keypoints(keypoints, original_size, new_size): def resize_keypoints(keypoints, original_size, new_size):
# type: (Tensor, List[int], List[int]) # type: (Tensor, List[int], List[int])
ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)] ratios = [torch.tensor(s, dtype=torch.float32, device=keypoints.device) / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
for s, s_orig in zip(new_size, original_size)]
ratio_h, ratio_w = ratios ratio_h, ratio_w = ratios
resized_data = keypoints.clone() resized_data = keypoints.clone()
if torch._C._get_tracing_state(): if torch._C._get_tracing_state():
...@@ -206,7 +208,8 @@ def resize_keypoints(keypoints, original_size, new_size): ...@@ -206,7 +208,8 @@ def resize_keypoints(keypoints, original_size, new_size):
def resize_boxes(boxes, original_size, new_size): def resize_boxes(boxes, original_size, new_size):
# type: (Tensor, List[int], List[int]) # type: (Tensor, List[int], List[int])
ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)] ratios = [torch.tensor(s, dtype=torch.float32, device=boxes.device) / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
for s, s_orig in zip(new_size, original_size)]
ratio_height, ratio_width = ratios ratio_height, ratio_width = ratios
xmin, ymin, xmax, ymax = boxes.unbind(1) xmin, ymin, xmax, ymax = boxes.unbind(1)
......
import torch import torch
from torch.jit.annotations import Tuple from torch.jit.annotations import Tuple
from torch import Tensor from torch import Tensor
import torchvision
def nms(boxes, scores, iou_threshold): def nms(boxes, scores, iou_threshold):
...@@ -110,8 +111,16 @@ def clip_boxes_to_image(boxes, size): ...@@ -110,8 +111,16 @@ def clip_boxes_to_image(boxes, size):
boxes_x = boxes[..., 0::2] boxes_x = boxes[..., 0::2]
boxes_y = boxes[..., 1::2] boxes_y = boxes[..., 1::2]
height, width = size height, width = size
boxes_x = boxes_x.clamp(min=0, max=width)
boxes_y = boxes_y.clamp(min=0, max=height) if torchvision._is_tracing():
boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
else:
boxes_x = boxes_x.clamp(min=0, max=width)
boxes_y = boxes_y.clamp(min=0, max=height)
clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim) clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
return clipped_boxes.reshape(boxes.shape) return clipped_boxes.reshape(boxes.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment