Unverified Commit ea34cd1e authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use torch.testing.assert_close in test_onnx.py (#3882)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent c307db4b
...@@ -7,6 +7,7 @@ except ImportError: ...@@ -7,6 +7,7 @@ except ImportError:
onnxruntime = None onnxruntime = None
from common_utils import set_rng_seed from common_utils import set_rng_seed
from _assert_utils import assert_equal
import io import io
import torch import torch
from torchvision import ops from torchvision import ops
...@@ -483,8 +484,8 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -483,8 +484,8 @@ class ONNXExporterTester(unittest.TestCase):
jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois)) jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois))
out_trace = jit_trace(maps, rois) out_trace = jit_trace(maps, rois)
assert torch.all(out[0].eq(out_trace[0])) assert_equal(out[0], out_trace[0])
assert torch.all(out[1].eq(out_trace[1])) assert_equal(out[1], out_trace[1])
maps2 = torch.rand(20, 2, 21, 21) maps2 = torch.rand(20, 2, 21, 21)
rois2 = torch.rand(20, 4) rois2 = torch.rand(20, 4)
...@@ -492,8 +493,8 @@ class ONNXExporterTester(unittest.TestCase): ...@@ -492,8 +493,8 @@ class ONNXExporterTester(unittest.TestCase):
out2 = heatmaps_to_keypoints(maps2, rois2) out2 = heatmaps_to_keypoints(maps2, rois2)
out_trace2 = jit_trace(maps2, rois2) out_trace2 = jit_trace(maps2, rois2)
assert torch.all(out2[0].eq(out_trace2[0])) assert_equal(out2[0], out_trace2[0])
assert torch.all(out2[1].eq(out_trace2[1])) assert_equal(out2[1], out_trace2[1])
def test_keypoint_rcnn(self): def test_keypoint_rcnn(self):
images, test_images = self.get_test_images() images, test_images = self.get_test_images()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment