Unverified Commit 16b9a40c authored by Vincent Moens's avatar Vincent Moens Committed by GitHub
Browse files

Pytest for test_videoapi.py and test_video_reader.py (#4233)



* test_video_reader pytest refactoring

* pytest refactoring of test_videoapi.py

* test_video_reader pytest refactoring

* pytest refactoring of test_videoapi.py

* using pytest.approx for test_video_reader.py

* using pytest.approx for test_videoapi.py

* Fixing minor comments

* linting fixes

* minor comments
Co-authored-by: default avatarVincent Moens <vmoens@fb.com>
parent a8397963
...@@ -3,6 +3,7 @@ import shutil ...@@ -3,6 +3,7 @@ import shutil
import tempfile import tempfile
import contextlib import contextlib
import unittest import unittest
import pytest
import argparse import argparse
import sys import sys
import torch import torch
...@@ -20,7 +21,7 @@ from PIL import Image ...@@ -20,7 +21,7 @@ from PIL import Image
IS_PY39 = sys.version_info.major == 3 and sys.version_info.minor == 9 IS_PY39 = sys.version_info.major == 3 and sys.version_info.minor == 9
PY39_SEGFAULT_SKIP_MSG = "Segmentation fault with Python 3.9, see https://github.com/pytorch/vision/issues/3367" PY39_SEGFAULT_SKIP_MSG = "Segmentation fault with Python 3.9, see https://github.com/pytorch/vision/issues/3367"
PY39_SKIP = unittest.skipIf(IS_PY39, PY39_SEGFAULT_SKIP_MSG) PY39_SKIP = pytest.mark.skipif(IS_PY39, reason=PY39_SEGFAULT_SKIP_MSG)
IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == 'true' IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == 'true'
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
...@@ -83,129 +84,6 @@ def is_iterable(obj): ...@@ -83,129 +84,6 @@ def is_iterable(obj):
return False return False
# adapted from TestCase in torch/test/common_utils to accept non-string
# inputs and set maximum binary size
class TestCase(unittest.TestCase):
precision = 1e-5
def assertEqual(self, x, y, prec=None, message='', allow_inf=False):
"""
This is copied from pytorch/test/common_utils.py's TestCase.assertEqual
"""
if isinstance(prec, str) and message == '':
message = prec
prec = None
if prec is None:
prec = self.precision
if isinstance(x, torch.Tensor) and isinstance(y, Number):
self.assertEqual(x.item(), y, prec=prec, message=message,
allow_inf=allow_inf)
elif isinstance(y, torch.Tensor) and isinstance(x, Number):
self.assertEqual(x, y.item(), prec=prec, message=message,
allow_inf=allow_inf)
elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
def assertTensorsEqual(a, b):
super(TestCase, self).assertEqual(a.size(), b.size(), message)
if a.numel() > 0:
if (a.device.type == 'cpu' and (a.dtype == torch.float16 or a.dtype == torch.bfloat16)):
# CPU half and bfloat16 tensors don't have the methods we need below
a = a.to(torch.float32)
b = b.to(a)
if (a.dtype == torch.bool) != (b.dtype == torch.bool):
raise TypeError("Was expecting both tensors to be bool type.")
else:
if a.dtype == torch.bool and b.dtype == torch.bool:
# we want to respect precision but as bool doesn't support substraction,
# boolean tensor has to be converted to int
a = a.to(torch.int)
b = b.to(torch.int)
diff = a - b
if a.is_floating_point():
# check that NaNs are in the same locations
nan_mask = torch.isnan(a)
self.assertTrue(torch.equal(nan_mask, torch.isnan(b)), message)
diff[nan_mask] = 0
# inf check if allow_inf=True
if allow_inf:
inf_mask = torch.isinf(a)
inf_sign = inf_mask.sign()
self.assertTrue(torch.equal(inf_sign, torch.isinf(b).sign()), message)
diff[inf_mask] = 0
# TODO: implement abs on CharTensor (int8)
if diff.is_signed() and diff.dtype != torch.int8:
diff = diff.abs()
max_err = diff.max()
tolerance = prec + prec * abs(a.max())
self.assertLessEqual(max_err, tolerance, message)
super(TestCase, self).assertEqual(x.is_sparse, y.is_sparse, message)
super(TestCase, self).assertEqual(x.is_quantized, y.is_quantized, message)
if x.is_sparse:
x = self.safeCoalesce(x)
y = self.safeCoalesce(y)
assertTensorsEqual(x._indices(), y._indices())
assertTensorsEqual(x._values(), y._values())
elif x.is_quantized and y.is_quantized:
self.assertEqual(x.qscheme(), y.qscheme(), prec=prec,
message=message, allow_inf=allow_inf)
if x.qscheme() == torch.per_tensor_affine:
self.assertEqual(x.q_scale(), y.q_scale(), prec=prec,
message=message, allow_inf=allow_inf)
self.assertEqual(x.q_zero_point(), y.q_zero_point(),
prec=prec, message=message,
allow_inf=allow_inf)
elif x.qscheme() == torch.per_channel_affine:
self.assertEqual(x.q_per_channel_scales(), y.q_per_channel_scales(), prec=prec,
message=message, allow_inf=allow_inf)
self.assertEqual(x.q_per_channel_zero_points(), y.q_per_channel_zero_points(),
prec=prec, message=message,
allow_inf=allow_inf)
self.assertEqual(x.q_per_channel_axis(), y.q_per_channel_axis(),
prec=prec, message=message)
self.assertEqual(x.dtype, y.dtype)
self.assertEqual(x.int_repr().to(torch.int32),
y.int_repr().to(torch.int32), prec=prec,
message=message, allow_inf=allow_inf)
else:
assertTensorsEqual(x, y)
elif isinstance(x, string_classes) and isinstance(y, string_classes):
super(TestCase, self).assertEqual(x, y, message)
elif type(x) == set and type(y) == set:
super(TestCase, self).assertEqual(x, y, message)
elif isinstance(x, dict) and isinstance(y, dict):
if isinstance(x, OrderedDict) and isinstance(y, OrderedDict):
self.assertEqual(x.items(), y.items(), prec=prec,
message=message, allow_inf=allow_inf)
else:
self.assertEqual(set(x.keys()), set(y.keys()), prec=prec,
message=message, allow_inf=allow_inf)
key_list = list(x.keys())
self.assertEqual([x[k] for k in key_list],
[y[k] for k in key_list],
prec=prec, message=message,
allow_inf=allow_inf)
elif is_iterable(x) and is_iterable(y):
super(TestCase, self).assertEqual(len(x), len(y), message)
for x_, y_ in zip(x, y):
self.assertEqual(x_, y_, prec=prec, message=message,
allow_inf=allow_inf)
elif isinstance(x, bool) and isinstance(y, bool):
super(TestCase, self).assertEqual(x, y, message)
elif isinstance(x, Number) and isinstance(y, Number):
inf = float("inf")
if abs(x) == inf or abs(y) == inf:
if allow_inf:
super(TestCase, self).assertEqual(x, y, message)
else:
self.fail("Expected finite numeric values - x={}, y={}".format(x, y))
return
super(TestCase, self).assertLessEqual(abs(x - y), prec, message)
else:
super(TestCase, self).assertEqual(x, y, message)
@contextlib.contextmanager @contextlib.contextmanager
def freeze_rng_state(): def freeze_rng_state():
rng_state = torch.get_rng_state() rng_state = torch.get_rng_state()
......
...@@ -2,7 +2,8 @@ import collections ...@@ -2,7 +2,8 @@ import collections
import itertools import itertools
import math import math
import os import os
import unittest import pytest
from pytest import approx
from fractions import Fraction from fractions import Fraction
import numpy as np import numpy as np
...@@ -268,9 +269,9 @@ def _get_video_tensor(video_dir, video_file): ...@@ -268,9 +269,9 @@ def _get_video_tensor(video_dir, video_file):
return full_path, video_tensor return full_path, video_tensor
@unittest.skipIf(av is None, "PyAV unavailable") @pytest.mark.skipif(av is None, reason="PyAV unavailable")
@unittest.skipIf(_HAS_VIDEO_OPT is False, "Didn't compile with ffmpeg") @pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg")
class TestVideoReader(unittest.TestCase): class TestVideoReader:
def check_separate_decoding_result(self, tv_result, config): def check_separate_decoding_result(self, tv_result, config):
"""check the decoding results from TorchVision decoder """check the decoding results from TorchVision decoder
""" """
...@@ -282,45 +283,46 @@ class TestVideoReader(unittest.TestCase): ...@@ -282,45 +283,46 @@ class TestVideoReader(unittest.TestCase):
video_duration = vduration.item() * Fraction( video_duration = vduration.item() * Fraction(
vtimebase[0].item(), vtimebase[1].item() vtimebase[0].item(), vtimebase[1].item()
) )
self.assertAlmostEqual(video_duration, config.duration, delta=0.5) assert video_duration == approx(config.duration, abs=0.5)
assert vfps.item() == approx(config.video_fps, abs=0.5)
self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5)
if asample_rate.numel() > 0: if asample_rate.numel() > 0:
self.assertEqual(asample_rate.item(), config.audio_sample_rate) assert asample_rate.item() == config.audio_sample_rate
audio_duration = aduration.item() * Fraction( audio_duration = aduration.item() * Fraction(
atimebase[0].item(), atimebase[1].item() atimebase[0].item(), atimebase[1].item()
) )
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5) assert audio_duration == approx(config.duration, abs=0.5)
# check if pts of video frames are sorted in ascending order # check if pts of video frames are sorted in ascending order
for i in range(len(vframe_pts) - 1): for i in range(len(vframe_pts) - 1):
self.assertEqual(vframe_pts[i] < vframe_pts[i + 1], True) assert vframe_pts[i] < vframe_pts[i + 1]
if len(aframe_pts) > 1: if len(aframe_pts) > 1:
# check if pts of audio frames are sorted in ascending order # check if pts of audio frames are sorted in ascending order
for i in range(len(aframe_pts) - 1): for i in range(len(aframe_pts) - 1):
self.assertEqual(aframe_pts[i] < aframe_pts[i + 1], True) assert aframe_pts[i] < aframe_pts[i + 1]
def check_probe_result(self, result, config): def check_probe_result(self, result, config):
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
video_duration = vduration.item() * Fraction( video_duration = vduration.item() * Fraction(
vtimebase[0].item(), vtimebase[1].item() vtimebase[0].item(), vtimebase[1].item()
) )
self.assertAlmostEqual(video_duration, config.duration, delta=0.5) assert video_duration == approx(config.duration, abs=0.5)
self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5) assert vfps.item() == approx(config.video_fps, abs=0.5)
if asample_rate.numel() > 0: if asample_rate.numel() > 0:
self.assertEqual(asample_rate.item(), config.audio_sample_rate) assert asample_rate.item() == config.audio_sample_rate
audio_duration = aduration.item() * Fraction( audio_duration = aduration.item() * Fraction(
atimebase[0].item(), atimebase[1].item() atimebase[0].item(), atimebase[1].item()
) )
self.assertAlmostEqual(audio_duration, config.duration, delta=0.5) assert audio_duration == approx(config.duration, abs=0.5)
def check_meta_result(self, result, config): def check_meta_result(self, result, config):
self.assertAlmostEqual(result.video_duration, config.duration, delta=0.5) assert result.video_duration == approx(config.duration, abs=0.5)
self.assertAlmostEqual(result.video_fps, config.video_fps, delta=0.5) assert result.video_fps == approx(config.video_fps, abs=0.5)
if result.has_audio > 0: if result.has_audio > 0:
self.assertEqual(result.audio_sample_rate, config.audio_sample_rate) assert result.audio_sample_rate == config.audio_sample_rate
self.assertAlmostEqual(result.audio_duration, config.duration, delta=0.5) assert result.audio_duration == approx(config.duration, abs=0.5)
def compare_decoding_result(self, tv_result, ref_result, config=all_check_config): def compare_decoding_result(self, tv_result, ref_result, config=all_check_config):
""" """
...@@ -350,12 +352,12 @@ class TestVideoReader(unittest.TestCase): ...@@ -350,12 +352,12 @@ class TestVideoReader(unittest.TestCase):
mean_delta = torch.mean( mean_delta = torch.mean(
torch.abs(vframes.float() - ref_result.vframes.float()) torch.abs(vframes.float() - ref_result.vframes.float())
) )
self.assertAlmostEqual(mean_delta, 0, delta=8.0) assert mean_delta == approx(0.0, abs=8.0)
mean_delta = torch.mean( mean_delta = torch.mean(
torch.abs(vframe_pts.float() - ref_result.vframe_pts.float()) torch.abs(vframe_pts.float() - ref_result.vframe_pts.float())
) )
self.assertAlmostEqual(mean_delta, 0, delta=1.0) assert mean_delta == approx(0.0, abs=1.0)
assert_equal(vtimebase, ref_result.vtimebase) assert_equal(vtimebase, ref_result.vtimebase)
...@@ -378,12 +380,12 @@ class TestVideoReader(unittest.TestCase): ...@@ -378,12 +380,12 @@ class TestVideoReader(unittest.TestCase):
assert_equal(atimebase, ref_result.atimebase) assert_equal(atimebase, ref_result.atimebase)
@unittest.skip( def test_stress_test_read_video_from_file(self):
pytest.skip(
"This stress test will iteratively decode the same set of videos." "This stress test will iteratively decode the same set of videos."
"It helps to detect memory leak but it takes lots of time to run." "It helps to detect memory leak but it takes lots of time to run."
"By default, it is disabled" "By default, it is disabled"
) )
def test_stress_test_read_video_from_file(self):
num_iter = 10000 num_iter = 10000
# video related # video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0 width, height, min_dimension, max_dimension = 0, 0, 0, 0
...@@ -513,18 +515,18 @@ class TestVideoReader(unittest.TestCase): ...@@ -513,18 +515,18 @@ class TestVideoReader(unittest.TestCase):
tv_result tv_result
) )
self.assertEqual(vframes.numel() > 0, readVideoStream) assert (vframes.numel() > 0) is bool(readVideoStream)
self.assertEqual(vframe_pts.numel() > 0, readVideoStream) assert (vframe_pts.numel() > 0) is bool(readVideoStream)
self.assertEqual(vtimebase.numel() > 0, readVideoStream) assert (vtimebase.numel() > 0) is bool(readVideoStream)
self.assertEqual(vfps.numel() > 0, readVideoStream) assert (vfps.numel() > 0) is bool(readVideoStream)
expect_audio_data = ( expect_audio_data = (
readAudioStream == 1 and config.audio_sample_rate is not None readAudioStream == 1 and config.audio_sample_rate is not None
) )
self.assertEqual(aframes.numel() > 0, expect_audio_data) assert (aframes.numel() > 0) is bool(expect_audio_data)
self.assertEqual(aframe_pts.numel() > 0, expect_audio_data) assert (aframe_pts.numel() > 0) is bool(expect_audio_data)
self.assertEqual(atimebase.numel() > 0, expect_audio_data) assert (atimebase.numel() > 0) is bool(expect_audio_data)
self.assertEqual(asample_rate.numel() > 0, expect_audio_data) assert (asample_rate.numel() > 0) is bool(expect_audio_data)
def test_read_video_from_file_rescale_min_dimension(self): def test_read_video_from_file_rescale_min_dimension(self):
""" """
...@@ -564,9 +566,7 @@ class TestVideoReader(unittest.TestCase): ...@@ -564,9 +566,7 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num, audio_timebase_num,
audio_timebase_den, audio_timebase_den,
) )
self.assertEqual( assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2))
min_dimension, min(tv_result[0].size(1), tv_result[0].size(2))
)
def test_read_video_from_file_rescale_max_dimension(self): def test_read_video_from_file_rescale_max_dimension(self):
""" """
...@@ -606,9 +606,7 @@ class TestVideoReader(unittest.TestCase): ...@@ -606,9 +606,7 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num, audio_timebase_num,
audio_timebase_den, audio_timebase_den,
) )
self.assertEqual( assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2))
max_dimension, max(tv_result[0].size(1), tv_result[0].size(2))
)
def test_read_video_from_file_rescale_both_min_max_dimension(self): def test_read_video_from_file_rescale_both_min_max_dimension(self):
""" """
...@@ -648,12 +646,8 @@ class TestVideoReader(unittest.TestCase): ...@@ -648,12 +646,8 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num, audio_timebase_num,
audio_timebase_den, audio_timebase_den,
) )
self.assertEqual( assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2))
min_dimension, min(tv_result[0].size(1), tv_result[0].size(2)) assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2))
)
self.assertEqual(
max_dimension, max(tv_result[0].size(1), tv_result[0].size(2))
)
def test_read_video_from_file_rescale_width(self): def test_read_video_from_file_rescale_width(self):
""" """
...@@ -693,7 +687,7 @@ class TestVideoReader(unittest.TestCase): ...@@ -693,7 +687,7 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num, audio_timebase_num,
audio_timebase_den, audio_timebase_den,
) )
self.assertEqual(tv_result[0].size(2), width) assert tv_result[0].size(2) == width
def test_read_video_from_file_rescale_height(self): def test_read_video_from_file_rescale_height(self):
""" """
...@@ -733,7 +727,7 @@ class TestVideoReader(unittest.TestCase): ...@@ -733,7 +727,7 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num, audio_timebase_num,
audio_timebase_den, audio_timebase_den,
) )
self.assertEqual(tv_result[0].size(1), height) assert tv_result[0].size(1) == height
def test_read_video_from_file_rescale_width_and_height(self): def test_read_video_from_file_rescale_width_and_height(self):
""" """
...@@ -773,8 +767,8 @@ class TestVideoReader(unittest.TestCase): ...@@ -773,8 +767,8 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num, audio_timebase_num,
audio_timebase_den, audio_timebase_den,
) )
self.assertEqual(tv_result[0].size(1), height) assert tv_result[0].size(1) == height
self.assertEqual(tv_result[0].size(2), width) assert tv_result[0].size(2) == width
@PY39_SKIP @PY39_SKIP
def test_read_video_from_file_audio_resampling(self): def test_read_video_from_file_audio_resampling(self):
...@@ -822,19 +816,15 @@ class TestVideoReader(unittest.TestCase): ...@@ -822,19 +816,15 @@ class TestVideoReader(unittest.TestCase):
tv_result tv_result
) )
if aframes.numel() > 0: if aframes.numel() > 0:
self.assertEqual(samples, asample_rate.item()) assert samples == asample_rate.item()
self.assertEqual(1, aframes.size(1)) assert 1 == aframes.size(1)
# when audio stream is found # when audio stream is found
duration = ( duration = (
float(aframe_pts[-1]) float(aframe_pts[-1])
* float(atimebase[0]) * float(atimebase[0])
/ float(atimebase[1]) / float(atimebase[1])
) )
self.assertAlmostEqual( assert aframes.size(0) == approx(int(duration * asample_rate.item()), abs=0.1 * asample_rate.item())
aframes.size(0),
int(duration * asample_rate.item()),
delta=0.1 * asample_rate.item(),
)
@PY39_SKIP @PY39_SKIP
def test_compare_read_video_from_memory_and_file(self): def test_compare_read_video_from_memory_and_file(self):
...@@ -989,7 +979,7 @@ class TestVideoReader(unittest.TestCase): ...@@ -989,7 +979,7 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num, audio_timebase_num,
audio_timebase_den, audio_timebase_den,
) )
self.assertAlmostEqual(config.video_fps, tv_result[3].item(), delta=0.01) assert abs(config.video_fps - tv_result[3].item()) < 0.01
# pass 2: decode all frames to get PTS only using cpp decoder # pass 2: decode all frames to get PTS only using cpp decoder
tv_result_pts_only = torch.ops.video_reader.read_video_from_memory( tv_result_pts_only = torch.ops.video_reader.read_video_from_memory(
...@@ -1014,8 +1004,8 @@ class TestVideoReader(unittest.TestCase): ...@@ -1014,8 +1004,8 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_den, audio_timebase_den,
) )
self.assertEqual(tv_result_pts_only[0].numel(), 0) assert not tv_result_pts_only[0].numel()
self.assertEqual(tv_result_pts_only[5].numel(), 0) assert not tv_result_pts_only[5].numel()
self.compare_decoding_result(tv_result, tv_result_pts_only) self.compare_decoding_result(tv_result, tv_result_pts_only)
@PY39_SKIP @PY39_SKIP
...@@ -1061,7 +1051,7 @@ class TestVideoReader(unittest.TestCase): ...@@ -1061,7 +1051,7 @@ class TestVideoReader(unittest.TestCase):
aframes, aframe_pts, atimebase, asample_rate, aduration = ( aframes, aframe_pts, atimebase, asample_rate, aduration = (
tv_result tv_result
) )
self.assertAlmostEqual(config.video_fps, vfps.item(), delta=0.01) assert abs(config.video_fps - vfps.item()) < 0.01
for num_frames in [4, 8, 16, 32, 64, 128]: for num_frames in [4, 8, 16, 32, 64, 128]:
start_pts_ind_max = vframe_pts.size(0) - num_frames start_pts_ind_max = vframe_pts.size(0) - num_frames
...@@ -1160,7 +1150,7 @@ class TestVideoReader(unittest.TestCase): ...@@ -1160,7 +1150,7 @@ class TestVideoReader(unittest.TestCase):
audio_end_pts, audio_end_pts,
) )
self.assertEqual(tv_result[0].size(0), num_frames) assert tv_result[0].size(0) == num_frames
if pyav_result.vframes.size(0) == num_frames: if pyav_result.vframes.size(0) == num_frames:
# if PyAv decodes a different number of video frames, skip # if PyAv decodes a different number of video frames, skip
# comparing the decoding results between Torchvision video reader # comparing the decoding results between Torchvision video reader
...@@ -1187,7 +1177,7 @@ class TestVideoReader(unittest.TestCase): ...@@ -1187,7 +1177,7 @@ class TestVideoReader(unittest.TestCase):
def test_probe_video_from_memory_script(self): def test_probe_video_from_memory_script(self):
scripted_fun = torch.jit.script(io._probe_video_from_memory) scripted_fun = torch.jit.script(io._probe_video_from_memory)
self.assertIsNotNone(scripted_fun) assert scripted_fun is not None
for test_video, config in test_videos.items(): for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video) full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
...@@ -1209,7 +1199,7 @@ class TestVideoReader(unittest.TestCase): ...@@ -1209,7 +1199,7 @@ class TestVideoReader(unittest.TestCase):
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
scripted_fun = torch.jit.script(io._read_video_from_memory) scripted_fun = torch.jit.script(io._read_video_from_memory)
self.assertIsNotNone(scripted_fun) assert scripted_fun is not None
for test_video, _config in test_videos.items(): for test_video, _config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video) full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
...@@ -1237,11 +1227,11 @@ class TestVideoReader(unittest.TestCase): ...@@ -1237,11 +1227,11 @@ class TestVideoReader(unittest.TestCase):
def test_invalid_file(self): def test_invalid_file(self):
set_video_backend('video_reader') set_video_backend('video_reader')
with self.assertRaises(RuntimeError): with pytest.raises(RuntimeError):
io.read_video('foo.mp4') io.read_video('foo.mp4')
set_video_backend('pyav') set_video_backend('pyav')
with self.assertRaises(RuntimeError): with pytest.raises(RuntimeError):
io.read_video('foo.mp4') io.read_video('foo.mp4')
def test_audio_present_pts(self): def test_audio_present_pts(self):
...@@ -1258,8 +1248,7 @@ class TestVideoReader(unittest.TestCase): ...@@ -1258,8 +1248,7 @@ class TestVideoReader(unittest.TestCase):
set_video_backend(backend) set_video_backend(backend)
_, audio, _ = io.read_video( _, audio, _ = io.read_video(
full_path, start_offset, end_offset, pts_unit='pts') full_path, start_offset, end_offset, pts_unit='pts')
self.assertGreaterEqual(audio.shape[0], 1) assert all([dimension > 0 for dimension in audio.shape[:2]])
self.assertGreaterEqual(audio.shape[1], 1)
def test_audio_present_sec(self): def test_audio_present_sec(self):
"""Test if audio frames are returned with sec unit.""" """Test if audio frames are returned with sec unit."""
...@@ -1275,9 +1264,8 @@ class TestVideoReader(unittest.TestCase): ...@@ -1275,9 +1264,8 @@ class TestVideoReader(unittest.TestCase):
set_video_backend(backend) set_video_backend(backend)
_, audio, _ = io.read_video( _, audio, _ = io.read_video(
full_path, start_offset, end_offset, pts_unit='sec') full_path, start_offset, end_offset, pts_unit='sec')
self.assertGreaterEqual(audio.shape[0], 1) assert all([dimension > 0 for dimension in audio.shape[:2]])
self.assertGreaterEqual(audio.shape[1], 1)
if __name__ == "__main__": if __name__ == '__main__':
unittest.main() pytest.main([__file__])
import collections import collections
import os import os
import unittest import pytest
from pytest import approx
import torch import torch
import torchvision import torchvision
...@@ -62,10 +63,10 @@ test_videos = { ...@@ -62,10 +63,10 @@ test_videos = {
} }
@unittest.skipIf(_HAS_VIDEO_OPT is False, "Didn't compile with ffmpeg") @pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg")
@PY39_SKIP @PY39_SKIP
class TestVideoApi(unittest.TestCase): class TestVideoApi:
@unittest.skipIf(av is None, "PyAV unavailable") @pytest.mark.skipif(av is None, reason="PyAV unavailable")
def test_frame_reading(self): def test_frame_reading(self):
for test_video, config in test_videos.items(): for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
...@@ -77,11 +78,7 @@ class TestVideoApi(unittest.TestCase): ...@@ -77,11 +78,7 @@ class TestVideoApi(unittest.TestCase):
for av_frame in av_reader.decode(av_reader.streams.video[0]): for av_frame in av_reader.decode(av_reader.streams.video[0]):
vr_frame = next(video_reader) vr_frame = next(video_reader)
self.assertAlmostEqual( assert float(av_frame.pts * av_frame.time_base) == approx(vr_frame["pts"], abs=0.1)
float(av_frame.pts * av_frame.time_base),
vr_frame["pts"],
delta=0.1,
)
av_array = torch.tensor(av_frame.to_rgb().to_ndarray()).permute( av_array = torch.tensor(av_frame.to_rgb().to_ndarray()).permute(
2, 0, 1 2, 0, 1
...@@ -94,18 +91,14 @@ class TestVideoApi(unittest.TestCase): ...@@ -94,18 +91,14 @@ class TestVideoApi(unittest.TestCase):
# by decoding (around 1%) # by decoding (around 1%)
# TODO: asses empirically how to set this? atm it's 1% # TODO: asses empirically how to set this? atm it's 1%
# averaged over all frames # averaged over all frames
self.assertTrue(mean_delta.item() < 2.5) assert mean_delta.item() < 2.5
av_reader = av.open(full_path) av_reader = av.open(full_path)
if av_reader.streams.audio: if av_reader.streams.audio:
video_reader = VideoReader(full_path, "audio") video_reader = VideoReader(full_path, "audio")
for av_frame in av_reader.decode(av_reader.streams.audio[0]): for av_frame in av_reader.decode(av_reader.streams.audio[0]):
vr_frame = next(video_reader) vr_frame = next(video_reader)
self.assertAlmostEqual( assert float(av_frame.pts * av_frame.time_base) == approx(vr_frame["pts"], abs=0.1)
float(av_frame.pts * av_frame.time_base),
vr_frame["pts"],
delta=0.1,
)
av_array = torch.tensor(av_frame.to_ndarray()).permute(1, 0) av_array = torch.tensor(av_frame.to_ndarray()).permute(1, 0)
vr_array = vr_frame["data"] vr_array = vr_frame["data"]
...@@ -114,7 +107,7 @@ class TestVideoApi(unittest.TestCase): ...@@ -114,7 +107,7 @@ class TestVideoApi(unittest.TestCase):
torch.abs(av_array.float() - vr_array.float()) torch.abs(av_array.float() - vr_array.float())
) )
# we assure that there is never more than 1% difference in signal # we assure that there is never more than 1% difference in signal
self.assertTrue(max_delta.item() < 0.001) assert max_delta.item() < 0.001
def test_metadata(self): def test_metadata(self):
""" """
...@@ -125,12 +118,8 @@ class TestVideoApi(unittest.TestCase): ...@@ -125,12 +118,8 @@ class TestVideoApi(unittest.TestCase):
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
reader = VideoReader(full_path, "video") reader = VideoReader(full_path, "video")
reader_md = reader.get_metadata() reader_md = reader.get_metadata()
self.assertAlmostEqual( assert config.video_fps == approx(reader_md["video"]["fps"][0], abs=0.0001)
config.video_fps, reader_md["video"]["fps"][0], delta=0.0001 assert config.duration == approx(reader_md["video"]["duration"][0], abs=0.5)
)
self.assertAlmostEqual(
config.duration, reader_md["video"]["duration"][0], delta=0.5
)
def test_seek_start(self): def test_seek_start(self):
for test_video, config in test_videos.items(): for test_video, config in test_videos.items():
...@@ -149,7 +138,7 @@ class TestVideoApi(unittest.TestCase): ...@@ -149,7 +138,7 @@ class TestVideoApi(unittest.TestCase):
for frame in video_reader: for frame in video_reader:
start_num_frames += 1 start_num_frames += 1
self.assertEqual(start_num_frames, num_frames) assert start_num_frames == num_frames
# now seek the container to < 0 to check for unexpected behaviour # now seek the container to < 0 to check for unexpected behaviour
video_reader.seek(-1) video_reader.seek(-1)
...@@ -157,7 +146,7 @@ class TestVideoApi(unittest.TestCase): ...@@ -157,7 +146,7 @@ class TestVideoApi(unittest.TestCase):
for frame in video_reader: for frame in video_reader:
start_num_frames += 1 start_num_frames += 1
self.assertEqual(start_num_frames, num_frames) assert start_num_frames == num_frames
def test_accurateseek_middle(self): def test_accurateseek_middle(self):
for test_video, config in test_videos.items(): for test_video, config in test_videos.items():
...@@ -178,23 +167,23 @@ class TestVideoApi(unittest.TestCase): ...@@ -178,23 +167,23 @@ class TestVideoApi(unittest.TestCase):
for frame in video_reader: for frame in video_reader:
middle_num_frames += 1 middle_num_frames += 1
self.assertTrue(middle_num_frames < num_frames) assert middle_num_frames < num_frames
self.assertAlmostEqual(middle_num_frames, num_frames // 2, delta=1) assert middle_num_frames == approx(num_frames // 2, abs=1)
video_reader.seek(duration / 2) video_reader.seek(duration / 2)
frame = next(video_reader) frame = next(video_reader)
lb = duration / 2 - 1 / md[stream]["fps"][0] lb = duration / 2 - 1 / md[stream]["fps"][0]
ub = duration / 2 + 1 / md[stream]["fps"][0] ub = duration / 2 + 1 / md[stream]["fps"][0]
self.assertTrue((lb <= frame["pts"]) & (ub >= frame["pts"])) assert (lb <= frame["pts"]) and (ub >= frame["pts"])
def test_fate_suite(self): def test_fate_suite(self):
video_path = fate("sub/MovText_capability_tester.mp4", VIDEO_DIR) video_path = fate("sub/MovText_capability_tester.mp4", VIDEO_DIR)
vr = VideoReader(video_path) vr = VideoReader(video_path)
metadata = vr.get_metadata() metadata = vr.get_metadata()
self.assertTrue(metadata["subtitles"]["duration"] is not None) assert metadata["subtitles"]["duration"] is not None
os.remove(video_path) os.remove(video_path)
if __name__ == "__main__": if __name__ == '__main__':
unittest.main() pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment