Unverified Commit c50d4884 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Improve test_video_reader (#5498)

* Improve test_video_reader

* Fix linter error
parent e3f1a822
import collections import collections
import itertools
import math import math
import os import os
from fractions import Fraction from fractions import Fraction
...@@ -112,7 +111,7 @@ DecoderResult = collections.namedtuple("DecoderResult", "vframes vframe_pts vtim ...@@ -112,7 +111,7 @@ DecoderResult = collections.namedtuple("DecoderResult", "vframes vframe_pts vtim
# av_seek_frame is imprecise so seek to a timestamp earlier by a margin # av_seek_frame is imprecise so seek to a timestamp earlier by a margin
# The unit of margin is second # The unit of margin is second
seek_frame_margin = 0.25 SEEK_FRAME_MARGIN = 0.25
def _read_from_stream(container, start_pts, end_pts, stream, stream_name, buffer_size=4): def _read_from_stream(container, start_pts, end_pts, stream, stream_name, buffer_size=4):
...@@ -369,7 +368,8 @@ class TestVideoReader: ...@@ -369,7 +368,8 @@ class TestVideoReader:
assert_equal(atimebase, ref_result.atimebase) assert_equal(atimebase, ref_result.atimebase)
def test_stress_test_read_video_from_file(self): @pytest.mark.parametrize("test_video", test_videos.keys())
def test_stress_test_read_video_from_file(self, test_video):
pytest.skip( pytest.skip(
"This stress test will iteratively decode the same set of videos." "This stress test will iteratively decode the same set of videos."
"It helps to detect memory leak but it takes lots of time to run." "It helps to detect memory leak but it takes lots of time to run."
...@@ -386,13 +386,12 @@ class TestVideoReader: ...@@ -386,13 +386,12 @@ class TestVideoReader:
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for _i in range(num_iter): for _i in range(num_iter):
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
# pass 1: decode all frames using new decoder # pass 1: decode all frames using new decoder
torch.ops.video_reader.read_video_from_file( torch.ops.video_reader.read_video_from_file(
full_path, full_path,
seek_frame_margin, SEEK_FRAME_MARGIN,
0, # getPtsOnly 0, # getPtsOnly
1, # readVideoStream 1, # readVideoStream
width, width,
...@@ -412,7 +411,8 @@ class TestVideoReader: ...@@ -412,7 +411,8 @@ class TestVideoReader:
audio_timebase_den, audio_timebase_den,
) )
def test_read_video_from_file(self): @pytest.mark.parametrize("test_video,config", test_videos.items())
def test_read_video_from_file(self, test_video, config):
""" """
Test the case when decoder starts with a video file to decode frames. Test the case when decoder starts with a video file to decode frames.
""" """
...@@ -425,13 +425,12 @@ class TestVideoReader: ...@@ -425,13 +425,12 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
# pass 1: decode all frames using new decoder # pass 1: decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_file( tv_result = torch.ops.video_reader.read_video_from_file(
full_path, full_path,
seek_frame_margin, SEEK_FRAME_MARGIN,
0, # getPtsOnly 0, # getPtsOnly
1, # readVideoStream 1, # readVideoStream
width, width,
...@@ -457,7 +456,11 @@ class TestVideoReader: ...@@ -457,7 +456,11 @@ class TestVideoReader:
# compare decoding results # compare decoding results
self.compare_decoding_result(tv_result, pyav_result, config) self.compare_decoding_result(tv_result, pyav_result, config)
def test_read_video_from_file_read_single_stream_only(self): @pytest.mark.parametrize("test_video,config", test_videos.items())
@pytest.mark.parametrize("read_video_stream,read_audio_stream", [(1, 0), (0, 1)])
def test_read_video_from_file_read_single_stream_only(
self, test_video, config, read_video_stream, read_audio_stream
):
""" """
Test the case when decoder starts with a video file to decode frames, and Test the case when decoder starts with a video file to decode frames, and
only reads video stream and ignores audio stream only reads video stream and ignores audio stream
...@@ -471,15 +474,13 @@ class TestVideoReader: ...@@ -471,15 +474,13 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
for readVideoStream, readAudioStream in [(1, 0), (0, 1)]:
# decode all frames using new decoder # decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_file( tv_result = torch.ops.video_reader.read_video_from_file(
full_path, full_path,
seek_frame_margin, SEEK_FRAME_MARGIN,
0, # getPtsOnly 0, # getPtsOnly
readVideoStream, read_video_stream,
width, width,
height, height,
min_dimension, min_dimension,
...@@ -488,7 +489,7 @@ class TestVideoReader: ...@@ -488,7 +489,7 @@ class TestVideoReader:
video_end_pts, video_end_pts,
video_timebase_num, video_timebase_num,
video_timebase_den, video_timebase_den,
readAudioStream, read_audio_stream,
samples, samples,
channels, channels,
audio_start_pts, audio_start_pts,
...@@ -510,18 +511,19 @@ class TestVideoReader: ...@@ -510,18 +511,19 @@ class TestVideoReader:
aduration, aduration,
) = tv_result ) = tv_result
assert (vframes.numel() > 0) is bool(readVideoStream) assert (vframes.numel() > 0) is bool(read_video_stream)
assert (vframe_pts.numel() > 0) is bool(readVideoStream) assert (vframe_pts.numel() > 0) is bool(read_video_stream)
assert (vtimebase.numel() > 0) is bool(readVideoStream) assert (vtimebase.numel() > 0) is bool(read_video_stream)
assert (vfps.numel() > 0) is bool(readVideoStream) assert (vfps.numel() > 0) is bool(read_video_stream)
expect_audio_data = readAudioStream == 1 and config.audio_sample_rate is not None expect_audio_data = read_audio_stream == 1 and config.audio_sample_rate is not None
assert (aframes.numel() > 0) is bool(expect_audio_data) assert (aframes.numel() > 0) is bool(expect_audio_data)
assert (aframe_pts.numel() > 0) is bool(expect_audio_data) assert (aframe_pts.numel() > 0) is bool(expect_audio_data)
assert (atimebase.numel() > 0) is bool(expect_audio_data) assert (atimebase.numel() > 0) is bool(expect_audio_data)
assert (asample_rate.numel() > 0) is bool(expect_audio_data) assert (asample_rate.numel() > 0) is bool(expect_audio_data)
def test_read_video_from_file_rescale_min_dimension(self): @pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_min_dimension(self, test_video):
""" """
Test the case when decoder starts with a video file to decode frames, and Test the case when decoder starts with a video file to decode frames, and
video min dimension between height and width is set. video min dimension between height and width is set.
...@@ -535,12 +537,11 @@ class TestVideoReader: ...@@ -535,12 +537,11 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file( tv_result = torch.ops.video_reader.read_video_from_file(
full_path, full_path,
seek_frame_margin, SEEK_FRAME_MARGIN,
0, # getPtsOnly 0, # getPtsOnly
1, # readVideoStream 1, # readVideoStream
width, width,
...@@ -561,7 +562,8 @@ class TestVideoReader: ...@@ -561,7 +562,8 @@ class TestVideoReader:
) )
assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2)) assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2))
def test_read_video_from_file_rescale_max_dimension(self): @pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_max_dimension(self, test_video):
""" """
Test the case when decoder starts with a video file to decode frames, and Test the case when decoder starts with a video file to decode frames, and
video min dimension between height and width is set. video min dimension between height and width is set.
...@@ -575,12 +577,11 @@ class TestVideoReader: ...@@ -575,12 +577,11 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file( tv_result = torch.ops.video_reader.read_video_from_file(
full_path, full_path,
seek_frame_margin, SEEK_FRAME_MARGIN,
0, # getPtsOnly 0, # getPtsOnly
1, # readVideoStream 1, # readVideoStream
width, width,
...@@ -601,7 +602,8 @@ class TestVideoReader: ...@@ -601,7 +602,8 @@ class TestVideoReader:
) )
assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2)) assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2))
def test_read_video_from_file_rescale_both_min_max_dimension(self): @pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_both_min_max_dimension(self, test_video):
""" """
Test the case when decoder starts with a video file to decode frames, and Test the case when decoder starts with a video file to decode frames, and
video min dimension between height and width is set. video min dimension between height and width is set.
...@@ -615,12 +617,11 @@ class TestVideoReader: ...@@ -615,12 +617,11 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file( tv_result = torch.ops.video_reader.read_video_from_file(
full_path, full_path,
seek_frame_margin, SEEK_FRAME_MARGIN,
0, # getPtsOnly 0, # getPtsOnly
1, # readVideoStream 1, # readVideoStream
width, width,
...@@ -642,7 +643,8 @@ class TestVideoReader: ...@@ -642,7 +643,8 @@ class TestVideoReader:
assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2)) assert min_dimension == min(tv_result[0].size(1), tv_result[0].size(2))
assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2)) assert max_dimension == max(tv_result[0].size(1), tv_result[0].size(2))
def test_read_video_from_file_rescale_width(self): @pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_width(self, test_video):
""" """
Test the case when decoder starts with a video file to decode frames, and Test the case when decoder starts with a video file to decode frames, and
video width is set. video width is set.
...@@ -656,12 +658,11 @@ class TestVideoReader: ...@@ -656,12 +658,11 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file( tv_result = torch.ops.video_reader.read_video_from_file(
full_path, full_path,
seek_frame_margin, SEEK_FRAME_MARGIN,
0, # getPtsOnly 0, # getPtsOnly
1, # readVideoStream 1, # readVideoStream
width, width,
...@@ -682,7 +683,8 @@ class TestVideoReader: ...@@ -682,7 +683,8 @@ class TestVideoReader:
) )
assert tv_result[0].size(2) == width assert tv_result[0].size(2) == width
def test_read_video_from_file_rescale_height(self): @pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_height(self, test_video):
""" """
Test the case when decoder starts with a video file to decode frames, and Test the case when decoder starts with a video file to decode frames, and
video height is set. video height is set.
...@@ -696,12 +698,11 @@ class TestVideoReader: ...@@ -696,12 +698,11 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file( tv_result = torch.ops.video_reader.read_video_from_file(
full_path, full_path,
seek_frame_margin, SEEK_FRAME_MARGIN,
0, # getPtsOnly 0, # getPtsOnly
1, # readVideoStream 1, # readVideoStream
width, width,
...@@ -722,7 +723,8 @@ class TestVideoReader: ...@@ -722,7 +723,8 @@ class TestVideoReader:
) )
assert tv_result[0].size(1) == height assert tv_result[0].size(1) == height
def test_read_video_from_file_rescale_width_and_height(self): @pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_file_rescale_width_and_height(self, test_video):
""" """
Test the case when decoder starts with a video file to decode frames, and Test the case when decoder starts with a video file to decode frames, and
both video height and width are set. both video height and width are set.
...@@ -736,12 +738,11 @@ class TestVideoReader: ...@@ -736,12 +738,11 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file( tv_result = torch.ops.video_reader.read_video_from_file(
full_path, full_path,
seek_frame_margin, SEEK_FRAME_MARGIN,
0, # getPtsOnly 0, # getPtsOnly
1, # readVideoStream 1, # readVideoStream
width, width,
...@@ -763,13 +764,13 @@ class TestVideoReader: ...@@ -763,13 +764,13 @@ class TestVideoReader:
assert tv_result[0].size(1) == height assert tv_result[0].size(1) == height
assert tv_result[0].size(2) == width assert tv_result[0].size(2) == width
def test_read_video_from_file_audio_resampling(self): @pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("samples", [9600, 96000])
def test_read_video_from_file_audio_resampling(self, test_video, samples):
""" """
Test the case when decoder starts with a video file to decode frames, and Test the case when decoder starts with a video file to decode frames, and
audio waveform are resampled audio waveform are resampled
""" """
for samples in [9600, 96000]: # downsampling # upsampling
# video related # video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0 width, height, min_dimension, max_dimension = 0, 0, 0, 0
video_start_pts, video_end_pts = 0, -1 video_start_pts, video_end_pts = 0, -1
...@@ -779,12 +780,11 @@ class TestVideoReader: ...@@ -779,12 +780,11 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, _config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
tv_result = torch.ops.video_reader.read_video_from_file( tv_result = torch.ops.video_reader.read_video_from_file(
full_path, full_path,
seek_frame_margin, SEEK_FRAME_MARGIN,
0, # getPtsOnly 0, # getPtsOnly
1, # readVideoStream 1, # readVideoStream
width, width,
...@@ -822,7 +822,8 @@ class TestVideoReader: ...@@ -822,7 +822,8 @@ class TestVideoReader:
duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1]) duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1])
assert aframes.size(0) == approx(int(duration * asample_rate.item()), abs=0.1 * asample_rate.item()) assert aframes.size(0) == approx(int(duration * asample_rate.item()), abs=0.1 * asample_rate.item())
def test_compare_read_video_from_memory_and_file(self): @pytest.mark.parametrize("test_video,config", test_videos.items())
def test_compare_read_video_from_memory_and_file(self, test_video, config):
""" """
Test the case when video is already in memory, and decoder reads data in memory Test the case when video is already in memory, and decoder reads data in memory
""" """
...@@ -835,13 +836,12 @@ class TestVideoReader: ...@@ -835,13 +836,12 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video) full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder # pass 1: decode all frames using cpp decoder
tv_result_memory = torch.ops.video_reader.read_video_from_memory( tv_result_memory = torch.ops.video_reader.read_video_from_memory(
video_tensor, video_tensor,
seek_frame_margin, SEEK_FRAME_MARGIN,
0, # getPtsOnly 0, # getPtsOnly
1, # readVideoStream 1, # readVideoStream
width, width,
...@@ -864,7 +864,7 @@ class TestVideoReader: ...@@ -864,7 +864,7 @@ class TestVideoReader:
# pass 2: decode all frames from file # pass 2: decode all frames from file
tv_result_file = torch.ops.video_reader.read_video_from_file( tv_result_file = torch.ops.video_reader.read_video_from_file(
full_path, full_path,
seek_frame_margin, SEEK_FRAME_MARGIN,
0, # getPtsOnly 0, # getPtsOnly
1, # readVideoStream 1, # readVideoStream
width, width,
...@@ -888,7 +888,8 @@ class TestVideoReader: ...@@ -888,7 +888,8 @@ class TestVideoReader:
# finally, compare results decoded from memory and file # finally, compare results decoded from memory and file
self.compare_decoding_result(tv_result_memory, tv_result_file) self.compare_decoding_result(tv_result_memory, tv_result_file)
def test_read_video_from_memory(self): @pytest.mark.parametrize("test_video,config", test_videos.items())
def test_read_video_from_memory(self, test_video, config):
""" """
Test the case when video is already in memory, and decoder reads data in memory Test the case when video is already in memory, and decoder reads data in memory
""" """
...@@ -901,13 +902,12 @@ class TestVideoReader: ...@@ -901,13 +902,12 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video) full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder # pass 1: decode all frames using cpp decoder
tv_result = torch.ops.video_reader.read_video_from_memory( tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor, video_tensor,
seek_frame_margin, SEEK_FRAME_MARGIN,
0, # getPtsOnly 0, # getPtsOnly
1, # readVideoStream 1, # readVideoStream
width, width,
...@@ -932,7 +932,8 @@ class TestVideoReader: ...@@ -932,7 +932,8 @@ class TestVideoReader:
self.check_separate_decoding_result(tv_result, config) self.check_separate_decoding_result(tv_result, config)
self.compare_decoding_result(tv_result, pyav_result, config) self.compare_decoding_result(tv_result, pyav_result, config)
def test_read_video_from_memory_get_pts_only(self): @pytest.mark.parametrize("test_video,config", test_videos.items())
def test_read_video_from_memory_get_pts_only(self, test_video, config):
""" """
Test the case when video is already in memory, and decoder reads data in memory. Test the case when video is already in memory, and decoder reads data in memory.
Compare frame pts between decoding for pts only and full decoding Compare frame pts between decoding for pts only and full decoding
...@@ -947,13 +948,12 @@ class TestVideoReader: ...@@ -947,13 +948,12 @@ class TestVideoReader:
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase_num, audio_timebase_den = 0, 1 audio_timebase_num, audio_timebase_den = 0, 1
for test_video, config in test_videos.items(): _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# pass 1: decode all frames using cpp decoder # pass 1: decode all frames using cpp decoder
tv_result = torch.ops.video_reader.read_video_from_memory( tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor, video_tensor,
seek_frame_margin, SEEK_FRAME_MARGIN,
0, # getPtsOnly 0, # getPtsOnly
1, # readVideoStream 1, # readVideoStream
width, width,
...@@ -977,7 +977,7 @@ class TestVideoReader: ...@@ -977,7 +977,7 @@ class TestVideoReader:
# pass 2: decode all frames to get PTS only using cpp decoder # pass 2: decode all frames to get PTS only using cpp decoder
tv_result_pts_only = torch.ops.video_reader.read_video_from_memory( tv_result_pts_only = torch.ops.video_reader.read_video_from_memory(
video_tensor, video_tensor,
seek_frame_margin, SEEK_FRAME_MARGIN,
1, # getPtsOnly 1, # getPtsOnly
1, # readVideoStream 1, # readVideoStream
width, width,
...@@ -1001,13 +1001,14 @@ class TestVideoReader: ...@@ -1001,13 +1001,14 @@ class TestVideoReader:
assert not tv_result_pts_only[5].numel() assert not tv_result_pts_only[5].numel()
self.compare_decoding_result(tv_result, tv_result_pts_only) self.compare_decoding_result(tv_result, tv_result_pts_only)
def test_read_video_in_range_from_memory(self): @pytest.mark.parametrize("test_video,config", test_videos.items())
@pytest.mark.parametrize("num_frames", [4, 8, 16, 32, 64, 128])
def test_read_video_in_range_from_memory(self, test_video, config, num_frames):
""" """
Test the case when video is already in memory, and decoder reads data in memory. Test the case when video is already in memory, and decoder reads data in memory.
In addition, decoder takes meaningful start- and end PTS as input, and decode In addition, decoder takes meaningful start- and end PTS as input, and decode
frames within that interval frames within that interval
""" """
for test_video, config in test_videos.items():
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video) full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# video related # video related
width, height, min_dimension, max_dimension = 0, 0, 0, 0 width, height, min_dimension, max_dimension = 0, 0, 0, 0
...@@ -1020,7 +1021,7 @@ class TestVideoReader: ...@@ -1020,7 +1021,7 @@ class TestVideoReader:
# pass 1: decode all frames using new decoder # pass 1: decode all frames using new decoder
tv_result = torch.ops.video_reader.read_video_from_memory( tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor, video_tensor,
seek_frame_margin, SEEK_FRAME_MARGIN,
0, # getPtsOnly 0, # getPtsOnly
1, # readVideoStream 1, # readVideoStream
width, width,
...@@ -1053,10 +1054,9 @@ class TestVideoReader: ...@@ -1053,10 +1054,9 @@ class TestVideoReader:
) = tv_result ) = tv_result
assert abs(config.video_fps - vfps.item()) < 0.01 assert abs(config.video_fps - vfps.item()) < 0.01
for num_frames in [4, 8, 16, 32, 64, 128]:
start_pts_ind_max = vframe_pts.size(0) - num_frames start_pts_ind_max = vframe_pts.size(0) - num_frames
if start_pts_ind_max <= 0: if start_pts_ind_max <= 0:
continue return
# randomly pick start pts # randomly pick start pts
start_pts_ind = randint(0, start_pts_ind_max) start_pts_ind = randint(0, start_pts_ind_max)
end_pts_ind = start_pts_ind + num_frames - 1 end_pts_ind = start_pts_ind + num_frames - 1
...@@ -1083,7 +1083,7 @@ class TestVideoReader: ...@@ -1083,7 +1083,7 @@ class TestVideoReader:
# pass 2: decode frames in the randomly generated range # pass 2: decode frames in the randomly generated range
tv_result = torch.ops.video_reader.read_video_from_memory( tv_result = torch.ops.video_reader.read_video_from_memory(
video_tensor, video_tensor,
seek_frame_margin, SEEK_FRAME_MARGIN,
0, # getPtsOnly 0, # getPtsOnly
1, # readVideoStream 1, # readVideoStream
width, width,
...@@ -1147,34 +1147,35 @@ class TestVideoReader: ...@@ -1147,34 +1147,35 @@ class TestVideoReader:
# and PyAv # and PyAv
self.compare_decoding_result(tv_result, pyav_result, config) self.compare_decoding_result(tv_result, pyav_result, config)
def test_probe_video_from_file(self): @pytest.mark.parametrize("test_video,config", test_videos.items())
def test_probe_video_from_file(self, test_video, config):
""" """
Test the case when decoder probes a video file Test the case when decoder probes a video file
""" """
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
probe_result = torch.ops.video_reader.probe_video_from_file(full_path) probe_result = torch.ops.video_reader.probe_video_from_file(full_path)
self.check_probe_result(probe_result, config) self.check_probe_result(probe_result, config)
def test_probe_video_from_memory(self): @pytest.mark.parametrize("test_video,config", test_videos.items())
def test_probe_video_from_memory(self, test_video, config):
""" """
Test the case when decoder probes a video in memory Test the case when decoder probes a video in memory
""" """
for test_video, config in test_videos.items(): _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor) probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor)
self.check_probe_result(probe_result, config) self.check_probe_result(probe_result, config)
def test_probe_video_from_memory_script(self): @pytest.mark.parametrize("test_video,config", test_videos.items())
def test_probe_video_from_memory_script(self, test_video, config):
scripted_fun = torch.jit.script(io._probe_video_from_memory) scripted_fun = torch.jit.script(io._probe_video_from_memory)
assert scripted_fun is not None assert scripted_fun is not None
for test_video, config in test_videos.items(): _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
probe_result = scripted_fun(video_tensor) probe_result = scripted_fun(video_tensor)
self.check_meta_result(probe_result, config) self.check_meta_result(probe_result, config)
def test_read_video_from_memory_scripted(self): @pytest.mark.parametrize("test_video", test_videos.keys())
def test_read_video_from_memory_scripted(self, test_video):
""" """
Test the case when video is already in memory, and decoder reads data in memory Test the case when video is already in memory, and decoder reads data in memory
""" """
...@@ -1190,13 +1191,12 @@ class TestVideoReader: ...@@ -1190,13 +1191,12 @@ class TestVideoReader:
scripted_fun = torch.jit.script(io._read_video_from_memory) scripted_fun = torch.jit.script(io._read_video_from_memory)
assert scripted_fun is not None assert scripted_fun is not None
for test_video, _config in test_videos.items(): _, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video)
# decode all frames using cpp decoder # decode all frames using cpp decoder
scripted_fun( scripted_fun(
video_tensor, video_tensor,
seek_frame_margin, SEEK_FRAME_MARGIN,
1, # readVideoStream 1, # readVideoStream
width, width,
height, height,
...@@ -1223,30 +1223,28 @@ class TestVideoReader: ...@@ -1223,30 +1223,28 @@ class TestVideoReader:
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
io.read_video("foo.mp4") io.read_video("foo.mp4")
def test_audio_present_pts(self): @pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
@pytest.mark.parametrize("start_offset", [0, 1000])
@pytest.mark.parametrize("end_offset", [3000, None])
def test_audio_present_pts(self, test_video, backend, start_offset, end_offset):
"""Test if audio frames are returned with pts unit.""" """Test if audio frames are returned with pts unit."""
backends = ["video_reader", "pyav"]
start_offsets = [0, 1000]
end_offsets = [3000, None]
for test_video, _ in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
container = av.open(full_path) container = av.open(full_path)
if container.streams.audio: if container.streams.audio:
for backend, start_offset, end_offset in itertools.product(backends, start_offsets, end_offsets):
set_video_backend(backend) set_video_backend(backend)
_, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="pts") _, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="pts")
assert all([dimension > 0 for dimension in audio.shape[:2]]) assert all([dimension > 0 for dimension in audio.shape[:2]])
def test_audio_present_sec(self): @pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
@pytest.mark.parametrize("start_offset", [0, 0.1])
@pytest.mark.parametrize("end_offset", [0.3, None])
def test_audio_present_sec(self, test_video, backend, start_offset, end_offset):
"""Test if audio frames are returned with sec unit.""" """Test if audio frames are returned with sec unit."""
backends = ["video_reader", "pyav"]
start_offsets = [0, 0.1]
end_offsets = [0.3, None]
for test_video, _ in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
container = av.open(full_path) container = av.open(full_path)
if container.streams.audio: if container.streams.audio:
for backend, start_offset, end_offset in itertools.product(backends, start_offsets, end_offsets):
set_video_backend(backend) set_video_backend(backend)
_, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="sec") _, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="sec")
assert all([dimension > 0 for dimension in audio.shape[:2]]) assert all([dimension > 0 for dimension in audio.shape[:2]])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment