Unverified Commit d6ee8757 authored by Negin Raoof's avatar Negin Raoof Committed by GitHub
Browse files

[ONNX] Fix for dynamic scale_factor export (#2087)

* Fixing nms on boxes when no detection

* test

* Fix for scale_factor computation

* remove newline

* Fix for mask_rcnn dynanmic axes

* Clean up

* Update transform.py

* Fix for torchscript

* Fix scripting errors

* Fix annotation

* Fix lint

* Fix annotation

* Fix for interpolate scripting

* Fix for scripting

* refactoring

* refactor the code

* Fix annotation

* Fixed annotations

* Added test for resize

* lint

* format

* bump ORT

* ort-nightly version

* Going to ort 1.1.0

* remove version

* install typing-extension
parent 7b60f4db
......@@ -28,7 +28,8 @@ before_install:
- pip install typing
- |
if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then
pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.0.0.dev1123
pip install -q --user typing-extensions==3.6.6
pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.2.0.dev202004141
fi
- conda install av -c conda-forge
......
......@@ -129,6 +129,20 @@ class ONNXExporterTester(unittest.TestCase):
model = ops.RoIPool((pool_h, pool_w), 2)
self.run_model(model, [(x, rois)])
def test_resize_images(self):
class TransformModule(torch.nn.Module):
def __init__(self_module):
super(TransformModule, self_module).__init__()
self_module.transform = self._init_test_generalized_rcnn_transform()
def forward(self_module, images):
return self_module.transform.resize(images, None)[0]
input = torch.rand(3, 10, 20)
input_test = torch.rand(3, 100, 150)
self.run_model(TransformModule(), [(input,), (input_test,)],
input_names=["input1"], dynamic_axes={"input1": [0, 1, 2, 3]})
def test_transform_images(self):
class TransformModule(torch.nn.Module):
......@@ -321,10 +335,10 @@ class ONNXExporterTester(unittest.TestCase):
def get_test_images(self):
image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg"
image = self.get_image_from_url(url=image_url, size=(200, 300))
image = self.get_image_from_url(url=image_url, size=(100, 320))
image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png"
image2 = self.get_image_from_url(url=image_url2, size=(250, 200))
image2 = self.get_image_from_url(url=image_url2, size=(250, 380))
images = [image]
test_images = [image2]
......@@ -375,7 +389,6 @@ class ONNXExporterTester(unittest.TestCase):
assert torch.all(out2.eq(out_trace2))
@unittest.skip("Disable test until export of interpolate script module to ONNX is fixed")
def test_mask_rcnn(self):
images, test_images = self.get_test_images()
......@@ -384,8 +397,9 @@ class ONNXExporterTester(unittest.TestCase):
model(images)
self.run_model(model, [(images,), (test_images,)],
input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]},
output_names=["boxes", "labels", "scores"],
dynamic_axes={"images_tensors": [0, 1, 2, 3], "boxes": [0, 1], "labels": [0],
"scores": [0], "masks": [0, 1, 2, 3]},
tolerate_small_mismatch=True)
# Verify that heatmaps_to_keypoints behaves the same in tracing.
......@@ -416,7 +430,6 @@ class ONNXExporterTester(unittest.TestCase):
assert torch.all(out2[0].eq(out_trace2[0]))
assert torch.all(out2[1].eq(out_trace2[1]))
@unittest.skip("Disable test until export of interpolate script module to ONNX is fixed")
def test_keypoint_rcnn(self):
class KeyPointRCNN(torch.nn.Module):
def __init__(self):
......
......@@ -10,6 +10,53 @@ from .image_list import ImageList
from .roi_heads import paste_masks_in_image
@torch.jit.unused
def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target):
# type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
from torch.onnx import operators
im_shape = operators.shape_as_tensor(image)[-2:]
min_size = torch.min(im_shape).to(dtype=torch.float32)
max_size = torch.max(im_shape).to(dtype=torch.float32)
scale_factor = self_min_size / min_size
if max_size * scale_factor > self_max_size:
scale_factor = self_max_size / max_size
image = torch.nn.functional.interpolate(
image[None], scale_factor=scale_factor, mode='bilinear',
align_corners=False)[0]
if target is None:
return image, target
if "masks" in target:
mask = target["masks"]
mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte()
target["masks"] = mask
return image, target
def _resize_image_and_masks(image, self_min_size, self_max_size, target):
# type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
im_shape = torch.tensor(image.shape[-2:])
min_size = float(torch.min(im_shape))
max_size = float(torch.max(im_shape))
scale_factor = self_min_size / min_size
if max_size * scale_factor > self_max_size:
scale_factor = self_max_size / max_size
image = torch.nn.functional.interpolate(
image[None], scale_factor=scale_factor, mode='bilinear',
align_corners=False)[0]
if target is None:
return image, target
if "masks" in target:
mask = target["masks"]
mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte()
target["masks"] = mask
return image, target
class GeneralizedRCNNTransform(nn.Module):
"""
Performs input / target transformation before feeding the data to a GeneralizedRCNN
......@@ -76,20 +123,15 @@ class GeneralizedRCNNTransform(nn.Module):
def resize(self, image, target):
# type: (Tensor, Optional[Dict[str, Tensor]])
h, w = image.shape[-2:]
im_shape = torch.tensor(image.shape[-2:])
min_size = float(torch.min(im_shape))
max_size = float(torch.max(im_shape))
if self.training:
size = float(self.torch_choice(self.min_size))
else:
# FIXME assume for now that testing uses the largest scale
size = float(self.min_size[-1])
scale_factor = size / min_size
if max_size * scale_factor > self.max_size:
scale_factor = self.max_size / max_size
image = torch.nn.functional.interpolate(
image[None], scale_factor=scale_factor, mode='bilinear',
align_corners=False)[0]
if torchvision._is_tracing():
image, target = _resize_image_and_masks_onnx(image, size, float(self.max_size), target)
else:
image, target = _resize_image_and_masks(image, size, float(self.max_size), target)
if target is None:
return image, target
......@@ -98,11 +140,6 @@ class GeneralizedRCNNTransform(nn.Module):
bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
target["boxes"] = bbox
if "masks" in target:
mask = target["masks"]
mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte()
target["masks"] = mask
if "keypoints" in target:
keypoints = target["keypoints"]
keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment