Unverified Commit d293c4c5 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Adds video reading / saving functionalities (#1039)

* WIP

* WIP

* Add some documentation

* Improve tests and add GC collection

* [WIP] add timestamp getter

* Bugfixes

* Improvements and travis

* Add audio fine-grained alignment

* More doc

* Remove unecessary file

* Remove comment

* Lazy import av

* Remove hard-coded constants for the test

* Return info stats from read

* Fix for Python-2
parent 4c56f429
...@@ -35,6 +35,7 @@ before_install: ...@@ -35,6 +35,7 @@ before_install:
- pip install future - pip install future
- pip install pytest pytest-cov codecov - pip install pytest pytest-cov codecov
- pip install mock - pip install mock
- conda install av -c conda-forge
install: install:
......
import os
import tempfile
import torch
import torchvision.io as io
import unittest
try:
import av
except ImportError:
av = None
class Tester(unittest.TestCase):
# compression adds artifacts, thus we add a tolerance of
# 6 in 0-255 range
TOLERANCE = 6
def _create_video_frames(self, num_frames, height, width):
y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
data = []
for i in range(num_frames):
xc = float(i) / num_frames
yc = 1 - float(i) / (2 * num_frames)
d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255
data.append(d.unsqueeze(2).repeat(1, 1, 3).byte())
return torch.stack(data, 0)
@unittest.skipIf(av is None, "PyAV unavailable")
def test_write_read_video(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
data = self._create_video_frames(10, 300, 300)
io.write_video(f.name, data, fps=5)
lv, _, info = io.read_video(f.name)
self.assertTrue((data.float() - lv.float()).abs().max() < self.TOLERANCE)
self.assertEqual(info["video_fps"], 5)
@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_timestamps(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
data = self._create_video_frames(10, 300, 300)
io.write_video(f.name, data, fps=5)
pts = io.read_video_timestamps(f.name)
# note: not all formats/codecs provide accurate information for computing the
# timestamps. For the format that we use here, this information is available,
# so we use it as a baseline
container = av.open(f.name)
stream = container.streams[0]
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
expected_pts = [i * pts_step for i in range(num_frames)]
self.assertEqual(pts, expected_pts)
@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_partial_video(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
data = self._create_video_frames(10, 300, 300)
io.write_video(f.name, data, fps=5)
pts = io.read_video_timestamps(f.name)
for start in range(5):
for l in range(1, 4):
lv, _, _ = io.read_video(f.name, pts[start], pts[start + l - 1])
s_data = data[start:(start + l)]
self.assertEqual(len(lv), l)
self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE)
lv, _, _ = io.read_video(f.name, pts[4] + 1, pts[7])
self.assertEqual(len(lv), 4)
self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)
# TODO add tests for audio
if __name__ == '__main__':
unittest.main()
...@@ -3,6 +3,7 @@ from torchvision import datasets ...@@ -3,6 +3,7 @@ from torchvision import datasets
from torchvision import ops from torchvision import ops
from torchvision import transforms from torchvision import transforms
from torchvision import utils from torchvision import utils
from torchvision import io
try: try:
from .version import __version__ # noqa: F401 from .version import __version__ # noqa: F401
......
from .video import write_video, read_video, read_video_timestamps
__all__ = [
'write_video', 'read_video', 'read_video_timestamps'
]
import gc
import torch
import numpy as np
try:
import av
except ImportError:
av = None
def _check_av_available():
if av is None:
raise ImportError("""\
PyAV is not installed, and is necessary for the video operations in torchvision.
See https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
""")
# PyAV has some reference cycles
_CALLED_TIMES = 0
_GC_COLLECTION_INTERVAL = 20
def write_video(filename, video_array, fps):
"""
Writes a 4d tensor in [T, H, W, C] format in a video file
Arguments:
filename (str): path where the video will be saved
video_array (Tensor[T, H, W, C]): tensor containing the individual frames,
as a uint8 tensor in [T, H, W, C] format
fps (Number): frames per second
"""
_check_av_available()
video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()
container = av.open(filename, mode='w')
stream = container.add_stream('mpeg4', rate=fps)
stream.width = video_array.shape[2]
stream.height = video_array.shape[1]
stream.pix_fmt = 'yuv420p'
for img in video_array:
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
for packet in stream.encode(frame):
container.mux(packet)
# Flush stream
for packet in stream.encode():
container.mux(packet)
# Close the file
container.close()
def _read_from_stream(container, start_offset, end_offset, stream, stream_name):
global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
_CALLED_TIMES += 1
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
gc.collect()
container.seek(start_offset, any_frame=False, backward=True, stream=stream)
frames = []
first_frame = None
for idx, frame in enumerate(container.decode(**stream_name)):
if frame.pts < start_offset:
first_frame = frame
continue
if first_frame and first_frame.pts < start_offset:
if frame.pts != start_offset:
frames.append(first_frame)
first_frame = None
frames.append(frame)
if frame.pts >= end_offset:
break
return frames
def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
start, end = audio_frames[0].pts, audio_frames[-1].pts
total_aframes = aframes.shape[1]
step_per_aframe = (end - start + 1) / total_aframes
s_idx = 0
e_idx = total_aframes
if start < ref_start:
s_idx = int((ref_start - start) / step_per_aframe)
if end > ref_end:
e_idx = int((ref_end - end) / step_per_aframe)
return aframes[:, s_idx:e_idx]
def read_video(filename, start_pts=0, end_pts=None):
"""
Reads a video from a file, returning both the video frames as well as
the audio frames
Arguments:
filename (str): path to the video file
start_pts (int, optional): the start presentation time of the video
end_pts (int, optional): the end presentation time
Returns:
vframes (Tensor[T, H, W, C]): the `T` video frames
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels
and `L` is the number of points
info (Dict): metadata for the video and audio. Can contain the fields
- video_fps (float)
- audio_fps (int)
"""
_check_av_available()
if end_pts is None:
end_pts = float("inf")
if end_pts < start_pts:
raise ValueError("end_pts should be larger than start_pts, got "
"start_pts={} and end_pts={}".format(start_pts, end_pts))
container = av.open(filename)
info = {}
video_frames = []
if container.streams.video:
video_frames = _read_from_stream(container, start_pts, end_pts,
container.streams.video[0], {'video': 0})
info["video_fps"] = float(container.streams.video[0].average_rate)
audio_frames = []
if container.streams.audio:
audio_frames = _read_from_stream(container, start_pts, end_pts,
container.streams.audio[0], {'audio': 0})
info["audio_fps"] = container.streams.audio[0].rate
container.close()
vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
aframes = [frame.to_ndarray() for frame in audio_frames]
vframes = torch.as_tensor(np.stack(vframes))
if aframes:
aframes = np.concatenate(aframes, 1)
aframes = torch.as_tensor(aframes)
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
else:
aframes = torch.empty((1, 0), dtype=torch.float32)
return vframes, aframes, info
def read_video_timestamps(filename):
"""
List the video frames timestamps.
Note that the function decodes the whole video frame-by-frame.
Arguments:
filename (str): path to the video file
Returns:
pts (List[int]): presentation timestamps for each one of the frames
in the video.
"""
_check_av_available()
container = av.open(filename)
video_frames = []
if container.streams.video:
video_frames = _read_from_stream(container, 0, float("inf"),
container.streams.video[0], {'video': 0})
container.close()
return [x.pts for x in video_frames]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment