Unverified Commit c50d4884 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Improve test_video_reader (#5498)

* Improve test_video_reader

* Fix linter error
parent e3f1a822
import collections
import itertools
import math
import os
from fractions import Fraction
......@@ -112,7 +111,7 @@ DecoderResult = collections.namedtuple("DecoderResult", "vframes vframe_pts vtim
# av_seek_frame is imprecise so seek to a timestamp earlier by a margin
# The unit of margin is second
seek_frame_margin = 0.25
SEEK_FRAME_MARGIN = 0.25
def _read_from_stream(container, start_pts, end_pts, stream, stream_name, buffer_size=4):
......@@ -369,7 +368,8 @@ class TestVideoReader:
assert_equal(atimebase, ref_result.atimebase)
def test_stress_test_read_video_from_file(self):
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_stress_test_read_video_from_file(self, test_video):
pytest.skip(
"This stress test will iteratively decode the same set of videos."
"It helps to detect memory leak but it takes lots of time to run."
......@@ -386,13 +386,12 @@ class TestVideoReader:
audio_timebase_num, audio_timebase_den = 0, 1
for _i in range(num_iter):
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
# pass 1: decode all frames using new decoder
torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
......@@ -412,7 +411,8 @@ class TestVideoReader:
audio_timebase_den,
)
def test_read_video_from_file(self):
@pytest.mark.parametrize("test_video,config", test_videos.items())
def test_read_video_from_file(self, test_video, config):
"""
Test the case when decoder starts with a video file to decode frames.
"""
......@@ -425,13 +425,12 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
# pass 1: decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
......@@ -457,7 +456,11 @@ class TestVideoReader:
# compare decoding results
self.compare_decoding_result(tv_result, pyav_result, config)
def test_read_video_from_file_read_single_stream_only(self):
@pytest.mark.parametrize("test_video,config", test_videos.items())
@pytest.mark.parametrize("read_video_stream,read_audio_stream", [(1, 0), (0, 1)])
def test_read_video_from_file_read_single_stream_only(
self, test_video, config, read_video_stream, read_audio_stream
):
"""
Test the case when decoder starts with a video file to decode frames, and
only reads video stream and ignores audio stream
......@@ -471,15 +474,13 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
for readVideoStream, readAudioStream in [(1, 0), (0, 1)]:
# decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
readVideoStream,
read_video_stream,
width,
height,
min_dimension,
......@@ -488,7 +489,7 @@ class TestVideoReader:
video_end_pts,
video_timebase_num,
video_timebase_den,
readAudioStream,
read_audio_stream,
samples,
channels,
audio_start_pts,
......@@ -510,18 +511,19 @@ class TestVideoReader:
aduration,
) = tv_result
assert (vframes.numel() > 0) is bool(readVideoStream)
assert (vframe_pts.numel() > 0) is bool(readVideoStream)
assert (vtimebase.numel() > 0) is bool(readVideoStream)
assert (vfps.numel() > 0) is bool(readVideoStream)
assert (vframes.numel() > 0) is bool(read_video_stream)
assert (vframe_pts.numel() > 0) is bool(read_video_stream)
assert (vtimebase.numel() > 0) is bool(read_video_stream)
assert (vfps.numel() > 0) is bool(read_video_stream)
expect_audio_data = readAudioStream == 1 and config.audio_sample_rate is not None
expect_audio_data = read_audio_stream == 1 and config.audio_sample_rate is not None
assert (aframes.numel() > 0) is bool(expect_audio_data)
assert (aframe_pts.numel() > 0) is bool(expect_audio_data)
assert (atimebase.numel() > 0) is bool(expect_audio_data)
assert (asample_rate.numel() > 0) is bool(expect_audio_data)
def test_read_video_from_file_rescale_min_dimension(self):
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_min_dimension(self, test_video):
"""
Test the case when decoder starts with a video file to decode frames, and
video min dimension between height and width is set.
......@@ -535,12 +537,11 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
......@@ -561,7 +562,8 @@ class TestVideoReader:
)
assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2))
def test_read_video_from_file_rescale_max_dimension(self):
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_max_dimension(self, test_video):
"""
Test the case when decoder starts with a video file to decode frames, and
video min dimension between height and width is set.
......@@ -575,12 +577,11 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
......@@ -601,7 +602,8 @@ class TestVideoReader:
)
assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2))
def test_read_video_from_file_rescale_both_min_max_dimension(self):
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_both_min_max_dimension(self, test_video):
"""
Test the case when decoder starts with a video file to decode frames, and
video min dimension between height and width is set.
......@@ -615,12 +617,11 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
......@@ -642,7 +643,8 @@ class TestVideoReader:
assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2))
assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2))
def test_read_video_from_file_rescale_width(self):
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_width(self, test_video):
"""
Test the case when decoder starts with a video file to decode frames, and
video width is set.
......@@ -656,12 +658,11 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
......@@ -682,7 +683,8 @@ class TestVideoReader:
)
assert tv_result[0].size(2) == width
def test_read_video_from_file_rescale_height(self):
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_height(self, test_video):
"""
Test the case when decoder starts with a video file to decode frames, and
video height is set.
......@@ -696,12 +698,11 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
......@@ -722,7 +723,8 @@ class TestVideoReader:
)
assert tv_result[0].size(1) == height
def test_read_video_from_file_rescale_width_and_height(self):
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_width_and_height(self, test_video):
"""
Test the case when decoder starts with a video file to decode frames, and
both video height and width are set.
......@@ -736,12 +738,11 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
......@@ -763,13 +764,13 @@ class TestVideoReader:
assert tv_result[0].size(1) == height
assert tv_result[0].size(2) == width
def test_read_video_from_file_audio_resampling(self):
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("samples", [9600, 96000])
def test_read_video_from_file_audio_resampling(self, test_video, samples):
"""
Test the case when decoder starts with a video file to decode frames, and
audio waveform are resampled
"""
for samples in [9600, 96000]: # downsampling # upsampling
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1
......@@ -779,12 +780,11 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
......@@ -822,7 +822,8 @@ class TestVideoReader:
duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1])
assert aframes.size(0) == approx(int(duration * asample_rate.item()), abs=0.1 * asample_rate.item())
def test_compare_read_video_from_memory_and_file(self):
@pytest.mark.parametrize("test_video,config", test_videos.items())
def test_compare_read_video_from_memory_and_file(self, test_video, config):
"""
Test the case when video is already in memory, and decoder reads data in memory
"""
......@@ -835,13 +836,12 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder
tv_result_memory = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
......@@ -864,7 +864,7 @@ class TestVideoReader:
# pass 2: decode all frames from file
tv_result_file = torch.ops.video_reader.read_video_from_file(
full_path,
seek_frame_margin,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
......@@ -888,7 +888,8 @@ class TestVideoReader:
# finally, compare results decoded from memory and file
self.compare_decoding_result(tv_result_memory, tv_result_file)
def test_read_video_from_memory(self):
@pytest.mark.parametrize("test_video,config", test_videos.items())
def test_read_video_from_memory(self, test_video, config):
"""
Test the case when video is already in memory, and decoder reads data in memory
"""
......@@ -901,13 +902,12 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
......@@ -932,7 +932,8 @@ class TestVideoReader:
self.check_separate_decoding_result(tv_result, config)
self.compare_decoding_result(tv_result, pyav_result, config)
def test_read_video_from_memory_get_pts_only(self):
@pytest.mark.parametrize("test_video,config", test_videos.items())
def test_read_video_from_memory_get_pts_only(self, test_video, config):
"""
Test the case when video is already in memory, and decoder reads data in memory.
Compare frame pts between decoding for pts only and full decoding
......@@ -947,13 +948,12 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
_, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
......@@ -977,7 +977,7 @@ class TestVideoReader:
# pass 2: decode all frames to get PTS only using cpp decoder
tv_result_pts_only = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
SEEK_FRAME_MARGIN,
1, # getPtsOnly
1, # readVideoStream
width,
......@@ -1001,13 +1001,14 @@ class TestVideoReader:
assert not tv_result_pts_only[5].numel()
self.compare_decoding_result(tv_result, tv_result_pts_only)
def test_read_video_in_range_from_memory(self):
@pytest.mark.parametrize("test_video,config", test_videos.items())
@pytest.mark.parametrize("num_frames", [4, 8, 16, 32, 64, 128])
def test_read_video_in_range_from_memory(self, test_video, config, num_frames):
"""
Test the case when video is already in memory, and decoder reads data in memory.
In addition, decoder takes meaningful start- and end PTS as input, and decode
frames within that interval
"""
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0
......@@ -1020,7 +1021,7 @@ class TestVideoReader:
# pass 1: decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
......@@ -1053,10 +1054,9 @@ class TestVideoReader:
) = tv_result
assert abs(config.video_fps - vfps.item()) < 0.01
for num_frames in [4, 8, 16, 32, 64, 128]:
start_pts_ind_max = vframe_pts.size(0) - num_frames
if start_pts_ind_max <= 0:
continue
return
# randomly pick start pts
start_pts_ind = randint(0, start_pts_ind_max)
end_pts_ind = start_pts_ind + num_frames - 1
......@@ -1083,7 +1083,7 @@ class TestVideoReader:
# pass 2: decode frames in the randomly generated range
tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor,
seek_frame_margin,
SEEK_FRAME_MARGIN,
0, # getPtsOnly
1, # readVideoStream
width,
......@@ -1147,34 +1147,35 @@ class TestVideoReader:
# and PyAv
self.compare_decoding_result(tv_result, pyav_result, config)
def test_probe_video_from_file(self):
@pytest.mark.parametrize("test_video,config", test_videos.items())
def test_probe_video_from_file(self, test_video, config):
"""
Test the case when decoder probes a video file
"""
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
probe_result = torch.ops.video_reader.probe_video_from_file(full_path)
self.check_probe_result(probe_result, config)
def test_probe_video_from_memory(self):
@pytest.mark.parametrize("test_video,config", test_videos.items())
def test_probe_video_from_memory(self, test_video, config):
"""
Test the case when decoder probes a video in memory
"""
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
_, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor)
self.check_probe_result(probe_result, config)
def test_probe_video_from_memory_script(self):
@pytest.mark.parametrize("test_video,config", test_videos.items())
def test_probe_video_from_memory_script(self, test_video, config):
scripted_fun = torch.jit.script(io._probe_video_from_memory)
assert scripted_fun is not None
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
_, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
probe_result = scripted_fun(video_tensor)
self.check_meta_result(probe_result, config)
def test_read_video_from_memory_scripted(self):
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_memory_scripted(self, test_video):
"""
Test the case when video is already in memory, and decoder reads data in memory
"""
......@@ -1190,13 +1191,12 @@ class TestVideoReader:
scripted_fun = torch.jit.script(io._read_video_from_memory)
assert scripted_fun is not None
for test_video, _config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
_, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# decode all frames using cpp decoder
scripted_fun(
video_tensor,
seek_frame_margin,
SEEK_FRAME_MARGIN,
1, # readVideoStream
width,
height,
......@@ -1223,30 +1223,28 @@ class TestVideoReader:
with pytest.raises(RuntimeError):
io.read_video("foo.mp4")
def test_audio_present_pts(self):
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
@pytest.mark.parametrize("start_offset", [0, 1000])
@pytest.mark.parametrize("end_offset", [3000, None])
def test_audio_present_pts(self, test_video, backend, start_offset, end_offset):
"""Test if audio frames are returned with pts unit."""
backends = ["video_reader", "pyav"]
start_offsets = [0, 1000]
end_offsets = [3000, None]
for test_video, _ in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
container = av.open(full_path)
if container.streams.audio:
for backend, start_offset, end_offset in itertools.product(backends, start_offsets, end_offsets):
set_video_backend(backend)
_, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="pts")
assert all([dimension > 0 for dimension in audio.shape[:2]])
def test_audio_present_sec(self):
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
@pytest.mark.parametrize("start_offset", [0, 0.1])
@pytest.mark.parametrize("end_offset", [0.3, None])
def test_audio_present_sec(self, test_video, backend, start_offset, end_offset):
"""Test if audio frames are returned with sec unit."""
backends = ["video_reader", "pyav"]
start_offsets = [0, 0.1]
end_offsets = [0.3, None]
for test_video, _ in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
container = av.open(full_path)
if container.streams.audio:
for backend, start_offset, end_offset in itertools.product(backends, start_offsets, end_offsets):
set_video_backend(backend)
_, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="sec")
assert all([dimension > 0 for dimension in audio.shape[:2]])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment