Unverified Commit d993ce56 authored by Gary Miguel's avatar Gary Miguel Committed by GitHub
Browse files

Use test images from repo rather than internet for ONNX tests (#4176)

parent c332a7f5
...@@ -18,6 +18,7 @@ from torchvision.models.detection.roi_heads import RoIHeads ...@@ -18,6 +18,7 @@ from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from collections import OrderedDict from collections import OrderedDict
from typing import List, Tuple
import pytest import pytest
from torchvision.ops._register_onnx_ops import _onnx_opset_version from torchvision.ops._register_onnx_ops import _onnx_opset_version
...@@ -365,32 +366,20 @@ class TestONNXExporter: ...@@ -365,32 +366,20 @@ class TestONNXExporter:
dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3], dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3],
"input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]}) "input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]})
def get_image_from_url(self, url, size=None): def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor:
import requests import os
from PIL import Image from PIL import Image
from io import BytesIO
from torchvision import transforms from torchvision import transforms
data = requests.get(url) data_dir = os.path.join(os.path.dirname(__file__), "assets")
image = Image.open(BytesIO(data.content)).convert("RGB") path = os.path.join(data_dir, *rel_path.split("/"))
image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR)
if size is None: return transforms.ToTensor()(image)
size = (300, 200)
image = image.resize(size, Image.BILINEAR)
to_tensor = transforms.ToTensor() def get_test_images(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
return to_tensor(image) return ([self.get_image("encode_jpeg/grace_hopper_517x606.jpg", (100, 320))],
[self.get_image("fakedata/logos/rgb_pytorch.png", (250, 380))])
def get_test_images(self):
image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg"
image = self.get_image_from_url(url=image_url, size=(100, 320))
image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png"
image2 = self.get_image_from_url(url=image_url2, size=(250, 380))
images = [image]
test_images = [image2]
return images, test_images
def test_faster_rcnn(self): def test_faster_rcnn(self):
images, test_images = self.get_test_images() images, test_images = self.get_test_images()
...@@ -456,7 +445,6 @@ class TestONNXExporter: ...@@ -456,7 +445,6 @@ class TestONNXExporter:
dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0], dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0],
"scores": [0], "masks": [0, 1, 2]}, "scores": [0], "masks": [0, 1, 2]},
tolerate_small_mismatch=True) tolerate_small_mismatch=True)
# TODO: enable this test once dynamic model export is fixed
# Test exported model for an image with no detections on other images # Test exported model for an image with no detections on other images
self.run_model(model, [(dummy_image,), (images,)], self.run_model(model, [(dummy_image,), (images,)],
input_names=["images_tensors"], input_names=["images_tensors"],
...@@ -468,7 +456,6 @@ class TestONNXExporter: ...@@ -468,7 +456,6 @@ class TestONNXExporter:
# Verify that heatmaps_to_keypoints behaves the same in tracing. # Verify that heatmaps_to_keypoints behaves the same in tracing.
# This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints # This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints
# (since jit_trace witll call _heatmaps_to_keypoints). # (since jit_trace witll call _heatmaps_to_keypoints).
# @unittest.skip("Disable test until Resize bug fixed in ORT")
def test_heatmaps_to_keypoints(self): def test_heatmaps_to_keypoints(self):
# disable profiling # disable profiling
torch._C._jit_set_profiling_executor(False) torch._C._jit_set_profiling_executor(False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment