Unverified Commit d293c4c5 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Adds video reading / saving functionalities (#1039)

* WIP

* WIP

* Add some documentation

* Improve tests and add GC collection

* [WIP] add timestamp getter

* Bugfixes

* Improvements and travis

* Add audio fine-grained alignment

* More doc

* Remove unecessary file

* Remove comment

* Lazy import av

* Remove hard-coded constants for the test

* Return info stats from read

* Fix for Python-2
parent 4c56f429
......@@ -35,6 +35,7 @@ before_install:
- pip install future
- pip install pytest pytest-cov codecov
- pip install mock
- conda install av -c conda-forge
install:
......
import os
import tempfile
import torch
import torchvision.io as io
import unittest
try:
import av
except ImportError:
av = None
class Tester(unittest.TestCase):
# compression adds artifacts, thus we add a tolerance of
# 6 in 0-255 range
TOLERANCE = 6
def _create_video_frames(self, num_frames, height, width):
y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
data = []
for i in range(num_frames):
xc = float(i) / num_frames
yc = 1 - float(i) / (2 * num_frames)
d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255
data.append(d.unsqueeze(2).repeat(1, 1, 3).byte())
return torch.stack(data, 0)
@unittest.skipIf(av is None, "PyAV unavailable")
def test_write_read_video(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
data = self._create_video_frames(10, 300, 300)
io.write_video(f.name, data, fps=5)
lv, _, info = io.read_video(f.name)
self.assertTrue((data.float() - lv.float()).abs().max() < self.TOLERANCE)
self.assertEqual(info["video_fps"], 5)
@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_timestamps(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
data = self._create_video_frames(10, 300, 300)
io.write_video(f.name, data, fps=5)
pts = io.read_video_timestamps(f.name)
# note: not all formats/codecs provide accurate information for computing the
# timestamps. For the format that we use here, this information is available,
# so we use it as a baseline
container = av.open(f.name)
stream = container.streams[0]
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
expected_pts = [i * pts_step for i in range(num_frames)]
self.assertEqual(pts, expected_pts)
@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_partial_video(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
data = self._create_video_frames(10, 300, 300)
io.write_video(f.name, data, fps=5)
pts = io.read_video_timestamps(f.name)
for start in range(5):
for l in range(1, 4):
lv, _, _ = io.read_video(f.name, pts[start], pts[start + l - 1])
s_data = data[start:(start + l)]
self.assertEqual(len(lv), l)
self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE)
lv, _, _ = io.read_video(f.name, pts[4] + 1, pts[7])
self.assertEqual(len(lv), 4)
self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)
# TODO add tests for audio
if __name__ == '__main__':
unittest.main()
......@@ -3,6 +3,7 @@ from torchvision import datasets
from torchvision import ops
from torchvision import transforms
from torchvision import utils
from torchvision import io
try:
from .version import __version__ # noqa: F401
......
from .video import write_video, read_video, read_video_timestamps
__all__ = [
'write_video', 'read_video', 'read_video_timestamps'
]
import gc
import torch
import numpy as np
try:
import av
except ImportError:
av = None
def _check_av_available():
if av is None:
raise ImportError("""\
PyAV is not installed, and is necessary for the video operations in torchvision.
See https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
""")
# PyAV has some reference cycles
_CALLED_TIMES = 0
_GC_COLLECTION_INTERVAL = 20
def write_video(filename, video_array, fps):
"""
Writes a 4d tensor in [T, H, W, C] format in a video file
Arguments:
filename (str): path where the video will be saved
video_array (Tensor[T, H, W, C]): tensor containing the individual frames,
as a uint8 tensor in [T, H, W, C] format
fps (Number): frames per second
"""
_check_av_available()
video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()
container = av.open(filename, mode='w')
stream = container.add_stream('mpeg4', rate=fps)
stream.width = video_array.shape[2]
stream.height = video_array.shape[1]
stream.pix_fmt = 'yuv420p'
for img in video_array:
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
for packet in stream.encode(frame):
container.mux(packet)
# Flush stream
for packet in stream.encode():
container.mux(packet)
# Close the file
container.close()
def _read_from_stream(container, start_offset, end_offset, stream, stream_name):
global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
_CALLED_TIMES += 1
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
gc.collect()
container.seek(start_offset, any_frame=False, backward=True, stream=stream)
frames = []
first_frame = None
for idx, frame in enumerate(container.decode(**stream_name)):
if frame.pts < start_offset:
first_frame = frame
continue
if first_frame and first_frame.pts < start_offset:
if frame.pts != start_offset:
frames.append(first_frame)
first_frame = None
frames.append(frame)
if frame.pts >= end_offset:
break
return frames
def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
start, end = audio_frames[0].pts, audio_frames[-1].pts
total_aframes = aframes.shape[1]
step_per_aframe = (end - start + 1) / total_aframes
s_idx = 0
e_idx = total_aframes
if start < ref_start:
s_idx = int((ref_start - start) / step_per_aframe)
if end > ref_end:
e_idx = int((ref_end - end) / step_per_aframe)
return aframes[:, s_idx:e_idx]
def read_video(filename, start_pts=0, end_pts=None):
"""
Reads a video from a file, returning both the video frames as well as
the audio frames
Arguments:
filename (str): path to the video file
start_pts (int, optional): the start presentation time of the video
end_pts (int, optional): the end presentation time
Returns:
vframes (Tensor[T, H, W, C]): the `T` video frames
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels
and `L` is the number of points
info (Dict): metadata for the video and audio. Can contain the fields
- video_fps (float)
- audio_fps (int)
"""
_check_av_available()
if end_pts is None:
end_pts = float("inf")
if end_pts < start_pts:
raise ValueError("end_pts should be larger than start_pts, got "
"start_pts={} and end_pts={}".format(start_pts, end_pts))
container = av.open(filename)
info = {}
video_frames = []
if container.streams.video:
video_frames = _read_from_stream(container, start_pts, end_pts,
container.streams.video[0], {'video': 0})
info["video_fps"] = float(container.streams.video[0].average_rate)
audio_frames = []
if container.streams.audio:
audio_frames = _read_from_stream(container, start_pts, end_pts,
container.streams.audio[0], {'audio': 0})
info["audio_fps"] = container.streams.audio[0].rate
container.close()
vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
aframes = [frame.to_ndarray() for frame in audio_frames]
vframes = torch.as_tensor(np.stack(vframes))
if aframes:
aframes = np.concatenate(aframes, 1)
aframes = torch.as_tensor(aframes)
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
else:
aframes = torch.empty((1, 0), dtype=torch.float32)
return vframes, aframes, info
def read_video_timestamps(filename):
"""
List the video frames timestamps.
Note that the function decodes the whole video frame-by-frame.
Arguments:
filename (str): path to the video file
Returns:
pts (List[int]): presentation timestamps for each one of the frames
in the video.
"""
_check_av_available()
container = av.open(filename)
video_frames = []
if container.streams.video:
video_frames = _read_from_stream(container, 0, float("inf"),
container.streams.video[0], {'video': 0})
container.close()
return [x.pts for x in video_frames]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment