Unverified Commit 010984d4 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Test videos with B-Frames (#1157)

Also extend video saving to support different codecs and options. Notably, we can now save with lossless compression
parent b25f81e0
import os
import contextlib
import tempfile
import torch
import torchvision.io as io
......@@ -11,12 +12,7 @@ except ImportError:
av = None
class Tester(unittest.TestCase):
# compression adds artifacts, thus we add a tolerance of
# 6 in 0-255 range
TOLERANCE = 6
def _create_video_frames(self, num_frames, height, width):
def _create_video_frames(num_frames, height, width):
y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
data = []
for i in range(num_frames):
......@@ -27,29 +23,48 @@ class Tester(unittest.TestCase):
return torch.stack(data, 0)
@unittest.skipIf(av is None, "PyAV unavailable")
def test_write_read_video(self):
@contextlib.contextmanager
def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, options=None):
if lossless:
assert video_codec is None, "video_codec can't be specified together with lossless"
assert options is None, "options can't be specified together with lossless"
video_codec = 'libx264rgb'
options = {'crf': '0'}
if video_codec is None:
video_codec = 'libx264'
if options is None:
options = {}
data = _create_video_frames(num_frames, height, width)
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
data = self._create_video_frames(10, 300, 300)
io.write_video(f.name, data, fps=5)
io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options)
yield f.name, data
class Tester(unittest.TestCase):
# compression adds artifacts, thus we add a tolerance of
# 6 in 0-255 range
TOLERANCE = 6
lv, _, info = io.read_video(f.name)
@unittest.skipIf(av is None, "PyAV unavailable")
def test_write_read_video(self):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
lv, _, info = io.read_video(f_name)
self.assertTrue((data.float() - lv.float()).abs().max() < self.TOLERANCE)
self.assertTrue(data.equal(lv))
self.assertEqual(info["video_fps"], 5)
@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_timestamps(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
data = self._create_video_frames(10, 300, 300)
io.write_video(f.name, data, fps=5)
pts, _ = io.read_video_timestamps(f.name)
with temp_video(10, 300, 300, 5) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
# note: not all formats/codecs provide accurate information for computing the
# timestamps. For the format that we use here, this information is available,
# so we use it as a baseline
container = av.open(f.name)
container = av.open(f_name)
stream = container.streams[0]
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
......@@ -59,20 +74,33 @@ class Tester(unittest.TestCase):
@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_partial_video(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
data = self._create_video_frames(10, 300, 300)
io.write_video(f.name, data, fps=5)
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
for start in range(5):
for l in range(1, 4):
lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
s_data = data[start:(start + l)]
self.assertEqual(len(lv), l)
self.assertTrue(s_data.equal(lv))
pts, _ = io.read_video_timestamps(f.name)
lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
self.assertEqual(len(lv), 4)
self.assertTrue(data[4:8].equal(lv))
for start in range(5):
@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_partial_video_bframes(self):
# do not use lossless encoding, to test the presence of B-frames
options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'}
with temp_video(100, 300, 300, 5, options=options) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
for start in range(0, 80, 20):
for l in range(1, 4):
lv, _, _ = io.read_video(f.name, pts[start], pts[start + l - 1])
lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
s_data = data[start:(start + l)]
self.assertEqual(len(lv), l)
self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE)
lv, _, _ = io.read_video(f.name, pts[4] + 1, pts[7])
lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
self.assertEqual(len(lv), 4)
self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)
......
......@@ -23,7 +23,7 @@ _CALLED_TIMES = 0
_GC_COLLECTION_INTERVAL = 20
def write_video(filename, video_array, fps):
def write_video(filename, video_array, fps, video_codec='libx264', options=None):
"""
Writes a 4d tensor in [T, H, W, C] format in a video file
......@@ -38,13 +38,15 @@ def write_video(filename, video_array, fps):
container = av.open(filename, mode='w')
stream = container.add_stream('mpeg4', rate=fps)
stream = container.add_stream(video_codec, rate=fps)
stream.width = video_array.shape[2]
stream.height = video_array.shape[1]
stream.pix_fmt = 'yuv420p'
stream.pix_fmt = 'yuv420p' if video_codec != 'libx264rgb' else 'rgb24'
stream.options = options or {}
for img in video_array:
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame.pict_type = 'NONE'
for packet in stream.encode(frame):
container.mux(packet)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment