Unverified Commit d6ee8757 authored by Negin Raoof's avatar Negin Raoof Committed by GitHub
Browse files

[ONNX] Fix for dynamic scale_factor export (#2087)

* Fixing nms on boxes when no detection

* test

* Fix for scale_factor computation

* remove newline

* Fix for mask_rcnn dynanmic axes

* Clean up

* Update transform.py

* Fix for torchscript

* Fix scripting errors

* Fix annotation

* Fix lint

* Fix annotation

* Fix for interpolate scripting

* Fix for scripting

* refactoring

* refactor the code

* Fix annotation

* Fixed annotations

* Added test for resize

* lint

* format

* bump ORT

* ort-nightly version

* Going to ort 1.1.0

* remove version

* install typing-extension
parent 7b60f4db
...@@ -28,7 +28,8 @@ before_install: ...@@ -28,7 +28,8 @@ before_install:
- pip install typing - pip install typing
- | - |
if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then
pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.0.0.dev1123 pip install -q --user typing-extensions==3.6.6
pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.2.0.dev202004141
fi fi
- conda install av -c conda-forge - conda install av -c conda-forge
......
...@@ -129,6 +129,20 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -129,6 +129,20 @@ class ONNXExporterTester(unittest.TestCase):
model = ops.RoIPool((pool_h, pool_w), 2) model = ops.RoIPool((pool_h, pool_w), 2)
self.run_model(model, [(x, rois)]) self.run_model(model, [(x, rois)])
def test_resize_images(self):
class TransformModule(torch.nn.Module):
def __init__(self_module):
super(TransformModule, self_module).__init__()
self_module.transform = self._init_test_generalized_rcnn_transform()
def forward(self_module, images):
return self_module.transform.resize(images, None)[0]
input = torch.rand(3, 10, 20)
input_test = torch.rand(3, 100, 150)
self.run_model(TransformModule(), [(input,), (input_test,)],
input_names=["input1"], dynamic_axes={"input1": [0, 1, 2, 3]})
def test_transform_images(self): def test_transform_images(self):
class TransformModule(torch.nn.Module): class TransformModule(torch.nn.Module):
...@@ -321,10 +335,10 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -321,10 +335,10 @@ class ONNXExporterTester(unittest.TestCase):
def get_test_images(self): def get_test_images(self):
image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg" image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg"
image = self.get_image_from_url(url=image_url, size=(200, 300)) image = self.get_image_from_url(url=image_url, size=(100, 320))
image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png" image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png"
image2 = self.get_image_from_url(url=image_url2, size=(250, 200)) image2 = self.get_image_from_url(url=image_url2, size=(250, 380))
images = [image] images = [image]
test_images = [image2] test_images = [image2]
...@@ -375,7 +389,6 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -375,7 +389,6 @@ class ONNXExporterTester(unittest.TestCase):
assert torch.all(out2.eq(out_trace2)) assert torch.all(out2.eq(out_trace2))
@unittest.skip("Disable test until export of interpolate script module to ONNX is fixed")
def test_mask_rcnn(self): def test_mask_rcnn(self):
images, test_images = self.get_test_images() images, test_images = self.get_test_images()
...@@ -384,8 +397,9 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -384,8 +397,9 @@ class ONNXExporterTester(unittest.TestCase):
model(images) model(images)
self.run_model(model, [(images,), (test_images,)], self.run_model(model, [(images,), (test_images,)],
input_names=["images_tensors"], input_names=["images_tensors"],
output_names=["outputs"], output_names=["boxes", "labels", "scores"],
dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]}, dynamic_axes={"images_tensors": [0, 1, 2, 3], "boxes": [0, 1], "labels": [0],
"scores": [0], "masks": [0, 1, 2, 3]},
tolerate_small_mismatch=True) tolerate_small_mismatch=True)
# Verify that heatmaps_to_keypoints behaves the same in tracing. # Verify that heatmaps_to_keypoints behaves the same in tracing.
...@@ -416,7 +430,6 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -416,7 +430,6 @@ class ONNXExporterTester(unittest.TestCase):
assert torch.all(out2[0].eq(out_trace2[0])) assert torch.all(out2[0].eq(out_trace2[0]))
assert torch.all(out2[1].eq(out_trace2[1])) assert torch.all(out2[1].eq(out_trace2[1]))
@unittest.skip("Disable test until export of interpolate script module to ONNX is fixed")
def test_keypoint_rcnn(self): def test_keypoint_rcnn(self):
class KeyPointRCNN(torch.nn.Module): class KeyPointRCNN(torch.nn.Module):
def __init__(self): def __init__(self):
......
...@@ -10,6 +10,53 @@ from .image_list import ImageList ...@@ -10,6 +10,53 @@ from .image_list import ImageList
from .roi_heads import paste_masks_in_image from .roi_heads import paste_masks_in_image
@torch.jit.unused
def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target):
# type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
from torch.onnx import operators
im_shape = operators.shape_as_tensor(image)[-2:]
min_size = torch.min(im_shape).to(dtype=torch.float32)
max_size = torch.max(im_shape).to(dtype=torch.float32)
scale_factor = self_min_size / min_size
if max_size * scale_factor > self_max_size:
scale_factor = self_max_size / max_size
image = torch.nn.functional.interpolate(
image[None], scale_factor=scale_factor, mode='bilinear',
align_corners=False)[0]
if target is None:
return image, target
if "masks" in target:
mask = target["masks"]
mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte()
target["masks"] = mask
return image, target
def _resize_image_and_masks(image, self_min_size, self_max_size, target):
# type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
im_shape = torch.tensor(image.shape[-2:])
min_size = float(torch.min(im_shape))
max_size = float(torch.max(im_shape))
scale_factor = self_min_size / min_size
if max_size * scale_factor > self_max_size:
scale_factor = self_max_size / max_size
image = torch.nn.functional.interpolate(
image[None], scale_factor=scale_factor, mode='bilinear',
align_corners=False)[0]
if target is None:
return image, target
if "masks" in target:
mask = target["masks"]
mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte()
target["masks"] = mask
return image, target
class GeneralizedRCNNTransform(nn.Module): class GeneralizedRCNNTransform(nn.Module):
""" """
Performs input / target transformation before feeding the data to a GeneralizedRCNN Performs input / target transformation before feeding the data to a GeneralizedRCNN
...@@ -76,20 +123,15 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -76,20 +123,15 @@ class GeneralizedRCNNTransform(nn.Module):
def resize(self, image, target): def resize(self, image, target):
# type: (Tensor, Optional[Dict[str, Tensor]]) # type: (Tensor, Optional[Dict[str, Tensor]])
h, w = image.shape[-2:] h, w = image.shape[-2:]
im_shape = torch.tensor(image.shape[-2:])
min_size = float(torch.min(im_shape))
max_size = float(torch.max(im_shape))
if self.training: if self.training:
size = float(self.torch_choice(self.min_size)) size = float(self.torch_choice(self.min_size))
else: else:
# FIXME assume for now that testing uses the largest scale # FIXME assume for now that testing uses the largest scale
size = float(self.min_size[-1]) size = float(self.min_size[-1])
scale_factor = size / min_size if torchvision._is_tracing():
if max_size * scale_factor > self.max_size: image, target = _resize_image_and_masks_onnx(image, size, float(self.max_size), target)
scale_factor = self.max_size / max_size else:
image = torch.nn.functional.interpolate( image, target = _resize_image_and_masks(image, size, float(self.max_size), target)
image[None], scale_factor=scale_factor, mode='bilinear',
align_corners=False)[0]
if target is None: if target is None:
return image, target return image, target
...@@ -98,11 +140,6 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -98,11 +140,6 @@ class GeneralizedRCNNTransform(nn.Module):
bbox = resize_boxes(bbox, (h, w), image.shape[-2:]) bbox = resize_boxes(bbox, (h, w), image.shape[-2:])
target["boxes"] = bbox target["boxes"] = bbox
if "masks" in target:
mask = target["masks"]
mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte()
target["masks"] = mask
if "keypoints" in target: if "keypoints" in target:
keypoints = target["keypoints"] keypoints = target["keypoints"]
keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:]) keypoints = resize_keypoints(keypoints, (h, w), image.shape[-2:])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment