"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "355b2788252fd5a43eafde94967aa2c26598fa25"
Unverified Commit d993ce56 authored by Gary Miguel's avatar Gary Miguel Committed by GitHub
Browse files

Use test images from repo rather than internet for ONNX tests (#4176)

parent c332a7f5
...@@ -18,6 +18,7 @@ from torchvision.models.detection.roi_heads import RoIHeads ...@@ -18,6 +18,7 @@ from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from collections import OrderedDict from collections import OrderedDict
from typing import List, Tuple
import pytest import pytest
from torchvision.ops._register_onnx_ops import _onnx_opset_version from torchvision.ops._register_onnx_ops import _onnx_opset_version
...@@ -365,32 +366,20 @@ class TestONNXExporter: ...@@ -365,32 +366,20 @@ class TestONNXExporter:
dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3], dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3],
"input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]}) "input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]})
def get_image_from_url(self, url, size=None): def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor:
import requests import os
from PIL import Image from PIL import Image
from io import BytesIO
from torchvision import transforms from torchvision import transforms
data = requests.get(url) data_dir = os.path.join(os.path.dirname(__file__), "assets")
image = Image.open(BytesIO(data.content)).convert("RGB") path = os.path.join(data_dir, *rel_path.split("/"))
image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR)
if size is None: return transforms.ToTensor()(image)
size = (300, 200)
image = image.resize(size, Image.BILINEAR)
to_tensor = transforms.ToTensor() def get_test_images(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
return to_tensor(image) return ([self.get_image("encode_jpeg/grace_hopper_517x606.jpg", (100, 320))],
[self.get_image("fakedata/logos/rgb_pytorch.png", (250, 380))])
def get_test_images(self):
image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg"
image = self.get_image_from_url(url=image_url, size=(100, 320))
image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png"
image2 = self.get_image_from_url(url=image_url2, size=(250, 380))
images = [image]
test_images = [image2]
return images, test_images
def test_faster_rcnn(self): def test_faster_rcnn(self):
images, test_images = self.get_test_images() images, test_images = self.get_test_images()
...@@ -456,7 +445,6 @@ class TestONNXExporter: ...@@ -456,7 +445,6 @@ class TestONNXExporter:
dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0], dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0],
"scores": [0], "masks": [0, 1, 2]}, "scores": [0], "masks": [0, 1, 2]},
tolerate_small_mismatch=True) tolerate_small_mismatch=True)
# TODO: enable this test once dynamic model export is fixed
# Test exported model for an image with no detections on other images # Test exported model for an image with no detections on other images
self.run_model(model, [(dummy_image,), (images,)], self.run_model(model, [(dummy_image,), (images,)],
input_names=["images_tensors"], input_names=["images_tensors"],
...@@ -468,7 +456,6 @@ class TestONNXExporter: ...@@ -468,7 +456,6 @@ class TestONNXExporter:
# Verify that heatmaps_to_keypoints behaves the same in tracing. # Verify that heatmaps_to_keypoints behaves the same in tracing.
# This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints # This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints
# (since jit_trace witll call _heatmaps_to_keypoints). # (since jit_trace witll call _heatmaps_to_keypoints).
# @unittest.skip("Disable test until Resize bug fixed in ORT")
def test_heatmaps_to_keypoints(self): def test_heatmaps_to_keypoints(self):
# disable profiling # disable profiling
torch._C._jit_set_profiling_executor(False) torch._C._jit_set_profiling_executor(False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment