"vscode:/vscode.git/clone" did not exist on "b0e6ccaf57aced24b2ccf444aa09cbcee81ec5e0"
Unverified Commit 2287c8f2 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Optimize read_video_timestamps for some formats (#1168)

* Optimize read_video_timestamps for some formats

* Add some tests
parent 59c97d77
...@@ -52,12 +52,12 @@ def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, ...@@ -52,12 +52,12 @@ def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None,
yield f.name, data yield f.name, data
@unittest.skipIf(av is None, "PyAV unavailable")
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
# compression adds artifacts, thus we add a tolerance of # compression adds artifacts, thus we add a tolerance of
# 6 in 0-255 range # 6 in 0-255 range
TOLERANCE = 6 TOLERANCE = 6
@unittest.skipIf(av is None, "PyAV unavailable")
def test_write_read_video(self): def test_write_read_video(self):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
lv, _, info = io.read_video(f_name) lv, _, info = io.read_video(f_name)
...@@ -65,7 +65,6 @@ class Tester(unittest.TestCase): ...@@ -65,7 +65,6 @@ class Tester(unittest.TestCase):
self.assertTrue(data.equal(lv)) self.assertTrue(data.equal(lv))
self.assertEqual(info["video_fps"], 5) self.assertEqual(info["video_fps"], 5)
@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_timestamps(self): def test_read_timestamps(self):
with temp_video(10, 300, 300, 5) as (f_name, data): with temp_video(10, 300, 300, 5) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name) pts, _ = io.read_video_timestamps(f_name)
...@@ -81,7 +80,6 @@ class Tester(unittest.TestCase): ...@@ -81,7 +80,6 @@ class Tester(unittest.TestCase):
self.assertEqual(pts, expected_pts) self.assertEqual(pts, expected_pts)
@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_partial_video(self): def test_read_partial_video(self):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name) pts, _ = io.read_video_timestamps(f_name)
...@@ -96,7 +94,6 @@ class Tester(unittest.TestCase): ...@@ -96,7 +94,6 @@ class Tester(unittest.TestCase):
self.assertEqual(len(lv), 4) self.assertEqual(len(lv), 4)
self.assertTrue(data[4:8].equal(lv)) self.assertTrue(data[4:8].equal(lv))
@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_partial_video_bframes(self): def test_read_partial_video_bframes(self):
# do not use lossless encoding, to test the presence of B-frames # do not use lossless encoding, to test the presence of B-frames
options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'} options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'}
...@@ -113,7 +110,6 @@ class Tester(unittest.TestCase): ...@@ -113,7 +110,6 @@ class Tester(unittest.TestCase):
self.assertEqual(len(lv), 4) self.assertEqual(len(lv), 4)
self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE) self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)
@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_packed_b_frames_divx_file(self): def test_read_packed_b_frames_divx_file(self):
with get_tmp_dir() as temp_dir: with get_tmp_dir() as temp_dir:
name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi" name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi"
...@@ -129,6 +125,23 @@ class Tester(unittest.TestCase): ...@@ -129,6 +125,23 @@ class Tester(unittest.TestCase):
warnings.warn(msg, RuntimeWarning) warnings.warn(msg, RuntimeWarning)
raise unittest.SkipTest(msg) raise unittest.SkipTest(msg)
def test_read_timestamps_from_packet(self):
with temp_video(10, 300, 300, 5, video_codec='mpeg4') as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
# note: not all formats/codecs provide accurate information for computing the
# timestamps. For the format that we use here, this information is available,
# so we use it as a baseline
container = av.open(f_name)
stream = container.streams[0]
# make sure we went through the optimized codepath
self.assertIn(b'Lavc', stream.codec_context.extradata)
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
expected_pts = [i * pts_step for i in range(num_frames)]
self.assertEqual(pts, expected_pts)
# TODO add tests for audio # TODO add tests for audio
......
...@@ -185,6 +185,15 @@ def read_video(filename, start_pts=0, end_pts=None): ...@@ -185,6 +185,15 @@ def read_video(filename, start_pts=0, end_pts=None):
return vframes, aframes, info return vframes, aframes, info
def _can_read_timestamps_from_packets(container):
extradata = container.streams[0].codec_context.extradata
if extradata is None:
return False
if b"Lavc" in extradata:
return True
return False
def read_video_timestamps(filename): def read_video_timestamps(filename):
""" """
List the video frames timestamps. List the video frames timestamps.
...@@ -205,6 +214,10 @@ def read_video_timestamps(filename): ...@@ -205,6 +214,10 @@ def read_video_timestamps(filename):
video_frames = [] video_frames = []
video_fps = None video_fps = None
if container.streams.video: if container.streams.video:
if _can_read_timestamps_from_packets(container):
# fast path
video_frames = [x for x in container.demux(video=0) if x.pts is not None]
else:
video_frames = _read_from_stream(container, 0, float("inf"), video_frames = _read_from_stream(container, 0, float("inf"),
container.streams.video[0], {'video': 0}) container.streams.video[0], {'video': 0})
video_fps = float(container.streams.video[0].average_rate) video_fps = float(container.streams.video[0].average_rate)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment