Unverified Commit 97e21c10 authored by Negin Raoof's avatar Negin Raoof Committed by GitHub
Browse files

[ONNX] Fix export of images with no detection (#2215)

* Fixing nms on boxes when no detection

* test

* Fix for scale_factor computation

* remove newline

* Fix for mask_rcnn dynanmic axes

* Clean up

* Update transform.py

* Fix for torchscript

* Fix scripting errors

* Fix annotation

* Fix lint

* Fix annotation

* Fix for interpolate scripting

* Fix for scripting

* refactoring

* refactor the code

* Fix annotation

* Fixed annotations

* Added test for resize

* lint

* format

* bump ORT

* ort-nightly version

* Going to ort 1.1.0

* remove version

* install typing-extension

* Export model for images with no detection

* Upgrade ort nightly

* update ORT

* Update test_onnx.py

* updated tests

* Updated tests

* merge

* Update transforms.py

* Update cityscapes.py

* Update celeba.py

* Update caltech.py

* Update pkg_helpers.bash

* Clean up

* Clean up for dynamic split

* Remove extra casts

* flake8

* Fix for mask rcnn no detection export

* clean up

* Enable mask rcnn tests

* Added test

* update ORT

* Update .travis.yml

* fix annotation

* Clean up roi_heads

* clean up

* clean up misc ops
parent ba63fbdb
...@@ -29,7 +29,7 @@ before_install: ...@@ -29,7 +29,7 @@ before_install:
- | - |
if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then
pip install -q --user typing-extensions==3.6.6 pip install -q --user typing-extensions==3.6.6
pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.2.0.dev202005021 pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.3.0.dev202005123
fi fi
- conda install av -c conda-forge - conda install av -c conda-forge
......
...@@ -397,25 +397,25 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -397,25 +397,25 @@ class ONNXExporterTester(unittest.TestCase):
def test_mask_rcnn(self): def test_mask_rcnn(self):
images, test_images = self.get_test_images() images, test_images = self.get_test_images()
dummy_image = [torch.ones(3, 100, 320) * 0.3] dummy_image = [torch.ones(3, 100, 100) * 0.3]
model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
model.eval() model.eval()
model(images) model(images)
# Test exported model on images of different size, or dummy input # Test exported model on images of different size, or dummy input
self.run_model(model, [(images,), (test_images,), (dummy_image,)], self.run_model(model, [(images,), (test_images,), (dummy_image,)],
input_names=["images_tensors"], input_names=["images_tensors"],
output_names=["boxes", "labels", "scores"], output_names=["boxes", "labels", "scores", "masks"],
dynamic_axes={"images_tensors": [0, 1, 2, 3], "boxes": [0, 1], "labels": [0], dynamic_axes={"images_tensors": [0, 1, 2, 3], "boxes": [0, 1], "labels": [0],
"scores": [0], "masks": [0, 1, 2, 3]}, "scores": [0], "masks": [0, 1, 2, 3]},
tolerate_small_mismatch=True) tolerate_small_mismatch=True)
# TODO: enable this test once dynamic model export is fixed # TODO: enable this test once dynamic model export is fixed
# Test exported model for an image with no detections on other images # Test exported model for an image with no detections on other images
# self.run_model(model, [(images,),(test_images,)], self.run_model(model, [(dummy_image,), (images,)],
# input_names=["images_tensors"], input_names=["images_tensors"],
# output_names=["boxes", "labels", "scores"], output_names=["boxes", "labels", "scores", "masks"],
# dynamic_axes={"images_tensors": [0, 1, 2, 3], "boxes": [0, 1], "labels": [0], dynamic_axes={"images_tensors": [0, 1, 2, 3], "boxes": [0, 1], "labels": [0],
# "scores": [0], "masks": [0, 1, 2, 3]}, "scores": [0], "masks": [0, 1, 2, 3]},
# tolerate_small_mismatch=True) tolerate_small_mismatch=True)
# Verify that heatmaps_to_keypoints behaves the same in tracing. # Verify that heatmaps_to_keypoints behaves the same in tracing.
# This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints # This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints
......
...@@ -222,12 +222,12 @@ class KeypointRCNNHeads(nn.Sequential): ...@@ -222,12 +222,12 @@ class KeypointRCNNHeads(nn.Sequential):
d = [] d = []
next_feature = in_channels next_feature = in_channels
for out_channels in layers: for out_channels in layers:
d.append(misc_nn_ops.Conv2d(next_feature, out_channels, 3, stride=1, padding=1)) d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
d.append(nn.ReLU(inplace=True)) d.append(nn.ReLU(inplace=True))
next_feature = out_channels next_feature = out_channels
super(KeypointRCNNHeads, self).__init__(*d) super(KeypointRCNNHeads, self).__init__(*d)
for m in self.children(): for m in self.children():
if isinstance(m, misc_nn_ops.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
...@@ -237,7 +237,7 @@ class KeypointRCNNPredictor(nn.Module): ...@@ -237,7 +237,7 @@ class KeypointRCNNPredictor(nn.Module):
super(KeypointRCNNPredictor, self).__init__() super(KeypointRCNNPredictor, self).__init__()
input_features = in_channels input_features = in_channels
deconv_kernel = 4 deconv_kernel = 4
self.kps_score_lowres = misc_nn_ops.ConvTranspose2d( self.kps_score_lowres = nn.ConvTranspose2d(
input_features, input_features,
num_keypoints, num_keypoints,
deconv_kernel, deconv_kernel,
......
...@@ -229,7 +229,7 @@ class MaskRCNNHeads(nn.Sequential): ...@@ -229,7 +229,7 @@ class MaskRCNNHeads(nn.Sequential):
d = OrderedDict() d = OrderedDict()
next_feature = in_channels next_feature = in_channels
for layer_idx, layer_features in enumerate(layers, 1): for layer_idx, layer_features in enumerate(layers, 1):
d["mask_fcn{}".format(layer_idx)] = misc_nn_ops.Conv2d( d["mask_fcn{}".format(layer_idx)] = nn.Conv2d(
next_feature, layer_features, kernel_size=3, next_feature, layer_features, kernel_size=3,
stride=1, padding=dilation, dilation=dilation) stride=1, padding=dilation, dilation=dilation)
d["relu{}".format(layer_idx)] = nn.ReLU(inplace=True) d["relu{}".format(layer_idx)] = nn.ReLU(inplace=True)
...@@ -246,9 +246,9 @@ class MaskRCNNHeads(nn.Sequential): ...@@ -246,9 +246,9 @@ class MaskRCNNHeads(nn.Sequential):
class MaskRCNNPredictor(nn.Sequential): class MaskRCNNPredictor(nn.Sequential):
def __init__(self, in_channels, dim_reduced, num_classes): def __init__(self, in_channels, dim_reduced, num_classes):
super(MaskRCNNPredictor, self).__init__(OrderedDict([ super(MaskRCNNPredictor, self).__init__(OrderedDict([
("conv5_mask", misc_nn_ops.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)), ("conv5_mask", nn.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)),
("relu", nn.ReLU(inplace=True)), ("relu", nn.ReLU(inplace=True)),
("mask_fcn_logits", misc_nn_ops.Conv2d(dim_reduced, num_classes, 1, 1, 0)), ("mask_fcn_logits", nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)),
])) ]))
for name, param in self.named_parameters(): for name, param in self.named_parameters():
......
...@@ -73,7 +73,7 @@ def maskrcnn_inference(x, labels): ...@@ -73,7 +73,7 @@ def maskrcnn_inference(x, labels):
""" """
mask_prob = x.sigmoid() mask_prob = x.sigmoid()
# select masks coresponding to the predicted classes # select masks corresponding to the predicted classes
num_masks = x.shape[0] num_masks = x.shape[0]
boxes_per_image = [label.shape[0] for label in labels] boxes_per_image = [label.shape[0] for label in labels]
labels = torch.cat(labels) labels = torch.cat(labels)
......
...@@ -16,59 +16,6 @@ import math ...@@ -16,59 +16,6 @@ import math
import warnings import warnings
import torch import torch
from torchvision.ops import _new_empty_tensor from torchvision.ops import _new_empty_tensor
from torch.nn import Module, Conv2d
import torch.nn.functional as F
class ConvTranspose2d(torch.nn.ConvTranspose2d):
"""
Equivalent to nn.ConvTranspose2d, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
def forward(self, x):
if x.numel() > 0:
return self.super_forward(x)
# get output shape
output_shape = [
(i - 1) * d - 2 * p + (di * (k - 1) + 1) + op
for i, p, di, k, d, op in zip(
x.shape[-2:],
list(self.padding),
list(self.dilation),
list(self.kernel_size),
list(self.stride),
list(self.output_padding),
)
]
output_shape = [x.shape[0], self.out_channels] + output_shape
return _new_empty_tensor(x, output_shape)
def super_forward(self, input, output_size=None):
# type: (Tensor, Optional[List[int]]) -> Tensor
if self.padding_mode != 'zeros':
raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')
output_padding = self._output_padding(input, output_size, self.stride, self.padding, self.kernel_size)
return F.conv_transpose2d(
input, self.weight, self.bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
class BatchNorm2d(torch.nn.BatchNorm2d):
"""
Equivalent to nn.BatchNorm2d, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
def forward(self, x):
if x.numel() > 0:
return super(BatchNorm2d, self).forward(x)
# get output shape
output_shape = x.shape
return _new_empty_tensor(x, output_shape)
def _check_size_scale_factor(dim, size, scale_factor): def _check_size_scale_factor(dim, size, scale_factor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment