Unverified Commit 3ac864dc authored by Negin Raoof's avatar Negin Raoof Committed by GitHub
Browse files

[ONNX] Fix model export for images w/ no detection (#2126)

* Fixing nms on boxes when no detection

* test

* Fix for scale_factor computation

* remove newline

* Fix for mask_rcnn dynanmic axes

* Clean up

* Update transform.py

* Fix for torchscript

* Fix scripting errors

* Fix annotation

* Fix lint

* Fix annotation

* Fix for interpolate scripting

* Fix for scripting

* refactoring

* refactor the code

* Fix annotation

* Fixed annotations

* Added test for resize

* lint

* format

* bump ORT

* ort-nightly version

* Going to ort 1.1.0

* remove version

* install typing-extension

* Export model for images with no detection

* Upgrade ort nightly

* update ORT

* Update test_onnx.py

* updated tests

* Updated tests

* merge

* Update transforms.py

* Update cityscapes.py

* Update celeba.py

* Update caltech.py

* Update pkg_helpers.bash

* Clean up

* Clean up for dynamic split

* Remove extra casts

* flake8
parent 14af9de6
...@@ -29,7 +29,7 @@ before_install: ...@@ -29,7 +29,7 @@ before_install:
- | - |
if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then
pip install -q --user typing-extensions==3.6.6 pip install -q --user typing-extensions==3.6.6
pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.2.0.dev202004141 pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.2.0.dev202005021
fi fi
- conda install av -c conda-forge - conda install av -c conda-forge
......
...@@ -346,11 +346,17 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -346,11 +346,17 @@ class ONNXExporterTester(unittest.TestCase):
def test_faster_rcnn(self): def test_faster_rcnn(self):
images, test_images = self.get_test_images() images, test_images = self.get_test_images()
dummy_image = [torch.ones(3, 100, 100) * 0.3]
model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
model.eval() model.eval()
model(images) model(images)
self.run_model(model, [(images,), (test_images,)], input_names=["images_tensors"], # Test exported model on images of different size, or dummy input
self.run_model(model, [(images,), (test_images,), (dummy_image,)], input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]},
tolerate_small_mismatch=True)
# Test exported model for an image with no detections on other images
self.run_model(model, [(dummy_image,), (images,)], input_names=["images_tensors"],
output_names=["outputs"], output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]}, dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]},
tolerate_small_mismatch=True) tolerate_small_mismatch=True)
...@@ -391,16 +397,25 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -391,16 +397,25 @@ class ONNXExporterTester(unittest.TestCase):
def test_mask_rcnn(self): def test_mask_rcnn(self):
images, test_images = self.get_test_images() images, test_images = self.get_test_images()
dummy_image = [torch.ones(3, 100, 320) * 0.3]
model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
model.eval() model.eval()
model(images) model(images)
self.run_model(model, [(images,), (test_images,)], # Test exported model on images of different size, or dummy input
self.run_model(model, [(images,), (test_images,), (dummy_image,)],
input_names=["images_tensors"], input_names=["images_tensors"],
output_names=["boxes", "labels", "scores"], output_names=["boxes", "labels", "scores"],
dynamic_axes={"images_tensors": [0, 1, 2, 3], "boxes": [0, 1], "labels": [0], dynamic_axes={"images_tensors": [0, 1, 2, 3], "boxes": [0, 1], "labels": [0],
"scores": [0], "masks": [0, 1, 2, 3]}, "scores": [0], "masks": [0, 1, 2, 3]},
tolerate_small_mismatch=True) tolerate_small_mismatch=True)
# TODO: enable this test once dynamic model export is fixed
# Test exported model for an image with no detections on other images
# self.run_model(model, [(images,),(test_images,)],
# input_names=["images_tensors"],
# output_names=["boxes", "labels", "scores"],
# dynamic_axes={"images_tensors": [0, 1, 2, 3], "boxes": [0, 1], "labels": [0],
# "scores": [0], "masks": [0, 1, 2, 3]},
# tolerate_small_mismatch=True)
# Verify that heatmaps_to_keypoints behaves the same in tracing. # Verify that heatmaps_to_keypoints behaves the same in tracing.
# This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints # This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints
...@@ -445,6 +460,10 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -445,6 +460,10 @@ class ONNXExporterTester(unittest.TestCase):
return output[0]['boxes'], output[0]['labels'], output[0]['scores'], output[0]['keypoints'] return output[0]['boxes'], output[0]['labels'], output[0]['scores'], output[0]['keypoints']
images, test_images = self.get_test_images() images, test_images = self.get_test_images()
# TODO:
# Enable test for dummy_image (no detection) once issue is
# _onnx_heatmaps_to_keypoints_loop for empty heatmaps is fixed
# dummy_images = [torch.ones(3, 100, 100) * 0.3]
model = KeyPointRCNN() model = KeyPointRCNN()
model.eval() model.eval()
model(images) model(images)
...@@ -453,6 +472,13 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -453,6 +472,13 @@ class ONNXExporterTester(unittest.TestCase):
output_names=["outputs1", "outputs2", "outputs3", "outputs4"], output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
dynamic_axes={"images_tensors": [0, 1, 2, 3]}, dynamic_axes={"images_tensors": [0, 1, 2, 3]},
tolerate_small_mismatch=True) tolerate_small_mismatch=True)
# TODO: enable this test once dynamic model export is fixed
# Test exported model for an image with no detections on other images
# self.run_model(model, [(dummy_images,), (test_images,)],
# input_names=["images_tensors"],
# output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
# dynamic_axes={"images_tensors": [0, 1, 2, 3]},
# tolerate_small_mismatch=True)
if __name__ == '__main__': if __name__ == '__main__':
......
import warnings import warnings
from .extension import _HAS_OPS
from torchvision import models from torchvision import models
from torchvision import datasets from torchvision import datasets
from torchvision import ops from torchvision import ops
...@@ -7,7 +9,6 @@ from torchvision import transforms ...@@ -7,7 +9,6 @@ from torchvision import transforms
from torchvision import utils from torchvision import utils
from torchvision import io from torchvision import io
from .extension import _HAS_OPS
import torch import torch
try: try:
......
...@@ -75,19 +75,13 @@ def maskrcnn_inference(x, labels): ...@@ -75,19 +75,13 @@ def maskrcnn_inference(x, labels):
# select masks coresponding to the predicted classes # select masks coresponding to the predicted classes
num_masks = x.shape[0] num_masks = x.shape[0]
boxes_per_image = [len(l) for l in labels] boxes_per_image = [l.shape[0] for l in labels]
labels = torch.cat(labels) labels = torch.cat(labels)
index = torch.arange(num_masks, device=labels.device) index = torch.arange(num_masks, device=labels.device)
mask_prob = mask_prob[index, labels][:, None] mask_prob = mask_prob[index, labels][:, None]
mask_prob = mask_prob.split(boxes_per_image, dim=0)
if len(boxes_per_image) == 1: return mask_prob
# TODO : remove when dynamic split supported in ONNX
# and remove assignment to mask_prob_list, just assign to mask_prob
mask_prob_list = [mask_prob]
else:
mask_prob_list = mask_prob.split(boxes_per_image, dim=0)
return mask_prob_list
def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M): def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
...@@ -318,12 +312,6 @@ def keypointrcnn_inference(x, boxes): ...@@ -318,12 +312,6 @@ def keypointrcnn_inference(x, boxes):
kp_scores = [] kp_scores = []
boxes_per_image = [box.size(0) for box in boxes] boxes_per_image = [box.size(0) for box in boxes]
if len(boxes_per_image) == 1:
# TODO : remove when dynamic split supported in ONNX
kp_prob, scores = heatmaps_to_keypoints(x, boxes[0])
return [kp_prob], [scores]
x2 = x.split(boxes_per_image, dim=0) x2 = x.split(boxes_per_image, dim=0)
for xx, bb in zip(x2, boxes): for xx, bb in zip(x2, boxes):
......
...@@ -17,11 +17,9 @@ def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): ...@@ -17,11 +17,9 @@ def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
# type: (Tensor, int) -> Tuple[int, int] # type: (Tensor, int) -> Tuple[int, int]
from torch.onnx import operators from torch.onnx import operators
num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0) num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
# TODO : remove cast to IntTensor/num_anchors.dtype when
# ONNX Runtime version is updated with ReduceMin int64 support
pre_nms_top_n = torch.min(torch.cat( pre_nms_top_n = torch.min(torch.cat(
(torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), (torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype),
num_anchors), 0).to(torch.int32)).to(num_anchors.dtype) num_anchors), 0))
return num_anchors, pre_nms_top_n return num_anchors, pre_nms_top_n
......
...@@ -17,9 +17,7 @@ def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target): ...@@ -17,9 +17,7 @@ def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target):
im_shape = operators.shape_as_tensor(image)[-2:] im_shape = operators.shape_as_tensor(image)[-2:]
min_size = torch.min(im_shape).to(dtype=torch.float32) min_size = torch.min(im_shape).to(dtype=torch.float32)
max_size = torch.max(im_shape).to(dtype=torch.float32) max_size = torch.max(im_shape).to(dtype=torch.float32)
scale_factor = self_min_size / min_size scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size)
if max_size * scale_factor > self_max_size:
scale_factor = self_max_size / max_size
image = torch.nn.functional.interpolate( image = torch.nn.functional.interpolate(
image[None], scale_factor=scale_factor, mode='bilinear', image[None], scale_factor=scale_factor, mode='bilinear',
......
...@@ -4,6 +4,7 @@ from torch import Tensor ...@@ -4,6 +4,7 @@ from torch import Tensor
import torchvision import torchvision
@torch.jit.script
def nms(boxes, scores, iou_threshold): def nms(boxes, scores, iou_threshold):
# type: (Tensor, Tensor, float) # type: (Tensor, Tensor, float)
""" """
...@@ -40,6 +41,7 @@ def nms(boxes, scores, iou_threshold): ...@@ -40,6 +41,7 @@ def nms(boxes, scores, iou_threshold):
return torch.ops.torchvision.nms(boxes, scores, iou_threshold) return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
@torch.jit.script
def batched_nms(boxes, scores, idxs, iou_threshold): def batched_nms(boxes, scores, idxs, iou_threshold):
# type: (Tensor, Tensor, Tensor, float) # type: (Tensor, Tensor, Tensor, float)
""" """
...@@ -74,11 +76,12 @@ def batched_nms(boxes, scores, idxs, iou_threshold): ...@@ -74,11 +76,12 @@ def batched_nms(boxes, scores, idxs, iou_threshold):
# we add an offset to all the boxes. The offset is dependent # we add an offset to all the boxes. The offset is dependent
# only on the class idx, and is large enough so that boxes # only on the class idx, and is large enough so that boxes
# from different classes do not overlap # from different classes do not overlap
max_coordinate = boxes.max() else:
offsets = idxs.to(boxes) * (max_coordinate + 1) max_coordinate = boxes.max()
boxes_for_nms = boxes + offsets[:, None] offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
keep = nms(boxes_for_nms, scores, iou_threshold) boxes_for_nms = boxes + offsets[:, None]
return keep keep = nms(boxes_for_nms, scores, iou_threshold)
return keep
def remove_small_boxes(boxes, min_size): def remove_small_boxes(boxes, min_size):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment