Unverified Commit ed239b8a authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Use torch.testing.assert_close in test_io.py (#3878)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 195bb86e
......@@ -10,6 +10,7 @@ import warnings
from urllib.error import URLError
from common_utils import get_tmp_dir
from _assert_utils import assert_equal
try:
......@@ -74,7 +75,7 @@ class TestIO(unittest.TestCase):
def test_write_read_video(self):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
lv, _, info = io.read_video(f_name)
self.assertTrue(data.equal(lv))
assert_equal(data, lv)
self.assertEqual(info["video_fps"], 5)
@unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen")
......@@ -116,14 +117,14 @@ class TestIO(unittest.TestCase):
lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1])
s_data = data[start:(start + offset)]
self.assertEqual(len(lv), offset)
self.assertTrue(s_data.equal(lv))
assert_equal(s_data, lv)
if get_video_backend() == "pyav":
# for "video_reader" backend, we don't decode the closest early frame
# when the given start pts is not matching any frame pts
lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
self.assertEqual(len(lv), 4)
self.assertTrue(data[4:8].equal(lv))
assert_equal(data[4:8], lv)
def test_read_partial_video_bframes(self):
# do not use lossless encoding, to test the presence of B-frames
......@@ -135,16 +136,16 @@ class TestIO(unittest.TestCase):
lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1])
s_data = data[start:(start + offset)]
self.assertEqual(len(lv), offset)
self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE)
assert_equal(s_data, lv, rtol=0.0, atol=self.TOLERANCE)
lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
# TODO fix this
if get_video_backend() == 'pyav':
self.assertEqual(len(lv), 4)
self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)
assert_equal(data[4:8], lv, rtol=0.0, atol=self.TOLERANCE)
else:
self.assertEqual(len(lv), 3)
self.assertTrue((data[5:8].float() - lv.float()).abs().max() < self.TOLERANCE)
assert_equal(data[5:8], lv, rtol=0.0, atol=self.TOLERANCE)
def test_read_packed_b_frames_divx_file(self):
name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi"
......@@ -175,7 +176,7 @@ class TestIO(unittest.TestCase):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
lv, _, info = io.read_video(f_name, pts_unit='sec')
self.assertTrue(data.equal(lv))
assert_equal(data, lv)
self.assertEqual(info["video_fps"], 5)
self.assertEqual(info, {"video_fps": 5})
......@@ -201,7 +202,7 @@ class TestIO(unittest.TestCase):
lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1], pts_unit='sec')
s_data = data[start:(start + offset)]
self.assertEqual(len(lv), offset)
self.assertTrue(s_data.equal(lv))
assert_equal(s_data, lv)
container = av.open(f_name)
stream = container.streams[0]
......@@ -212,7 +213,7 @@ class TestIO(unittest.TestCase):
# for "video_reader" backend, we don't decode the closest early frame
# when the given start pts is not matching any frame pts
self.assertEqual(len(lv), 4)
self.assertTrue(data[4:8].equal(lv))
assert_equal(data[4:8], lv)
container.close()
def test_read_video_corrupted_file(self):
......@@ -251,9 +252,10 @@ class TestIO(unittest.TestCase):
else:
self.assertEqual(len(video), 4)
# but the valid decoded content is still correct
self.assertTrue(video[:3].equal(data[:3]))
assert_equal(video[:3], data[:3])
# and the last few frames are wrong
self.assertFalse(video.equal(data))
with self.assertRaises(AssertionError):
assert_equal(video, data)
@unittest.skipIf(sys.platform == 'win32', 'temporarily disabled on Windows')
def test_write_video_with_audio(self):
......@@ -278,7 +280,7 @@ class TestIO(unittest.TestCase):
)
self.assertEqual(info["video_fps"], out_info["video_fps"])
self.assertTrue(video_tensor.equal(out_video_tensor))
assert_equal(video_tensor, out_video_tensor)
audio_stream = av.open(f_name).streams.audio[0]
out_audio_stream = av.open(out_f_name).streams.audio[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment