Unverified Commit ea34cd1e authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use torch.testing.assert_close in test_onnx.py (#3882)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent c307db4b
......@@ -7,6 +7,7 @@ except ImportError:
onnxruntime = None
from common_utils import set_rng_seed
from _assert_utils import assert_equal
import io
import torch
from torchvision import ops
......@@ -483,8 +484,8 @@ class ONNXExporterTester(unittest.TestCase):
jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois))
out_trace = jit_trace(maps, rois)
assert torch.all(out[0].eq(out_trace[0]))
assert torch.all(out[1].eq(out_trace[1]))
assert_equal(out[0], out_trace[0])
assert_equal(out[1], out_trace[1])
maps2 = torch.rand(20, 2, 21, 21)
rois2 = torch.rand(20, 4)
......@@ -492,8 +493,8 @@ class ONNXExporterTester(unittest.TestCase):
out2 = heatmaps_to_keypoints(maps2, rois2)
out_trace2 = jit_trace(maps2, rois2)
assert torch.all(out2[0].eq(out_trace2[0]))
assert torch.all(out2[1].eq(out_trace2[1]))
assert_equal(out2[0], out_trace2[0])
assert_equal(out2[1], out_trace2[1])
def test_keypoint_rcnn(self):
images, test_images = self.get_test_images()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment