Unverified Commit 039c9288 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Minor improvements to VideoReader (#2781)



* Minor improvements to VideoReader

* update jupyter notebook with new naming
Co-authored-by: default avatarBruno Korbar <bjuncek@gmail.com>
parent 610c9d2a
......@@ -24,7 +24,7 @@ In addition to the :mod:`read_video` function, we provide a high-performance
lower-level API for more fine-grained control compared to the :mod:`read_video` function.
It does all this whilst fully supporting torchscript.
.. autoclass:: Video
.. autoclass:: VideoReader
:members: next, get_metadata, set_current_stream, seek
......@@ -37,7 +37,7 @@ Example of usage:
# Constructor allocates memory and a threaded decoder
# instance per video. At the momet it takes two arguments:
# path to the video file, and a wanted stream.
reader = torchvision.io.Video(video_path, "video")
reader = torchvision.io.VideoReader(video_path, "video")
# The information about the video can be retrieved using the
# `get_metadata()` method. It returns a dictionary for every stream, with
......
......@@ -16,16 +16,16 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('1.7.0a0+f5c95d5', '0.8.0a0+6eff0a4')"
"('1.7.0a0+f5c95d5', '0.8.0a0+a2f405d')"
]
},
"execution_count": 36,
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
......@@ -37,14 +37,21 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using downloaded and verified file: ./WUzgd7C1pWA.mp4\n"
"Downloading https://github.com/pytorch/vision/blob/master/test/assets/videos/WUzgd7C1pWA.mp4?raw=true to ./WUzgd7C1pWA.mp4\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100.4%"
]
}
],
......@@ -65,7 +72,7 @@
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
......@@ -91,7 +98,7 @@
"\n",
"\n",
"\n",
"video = torch.classes.torchvision.Video(video_path, stream)"
"video = torchvision.io.VideoReader(video_path, stream)"
]
},
{
......@@ -103,7 +110,7 @@
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": 6,
"metadata": {},
"outputs": [
{
......@@ -113,7 +120,7 @@
" 'audio': {'duration': [10.9], 'framerate': [48000.0]}}"
]
},
"execution_count": 39,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
......@@ -133,7 +140,7 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 8,
"metadata": {},
"outputs": [
{
......@@ -152,11 +159,8 @@
"video.set_current_stream(\"video:0\")\n",
"\n",
"frames = [] # we are going to save the frames here.\n",
"frame, pts = video.next()\n",
"# note that next will return emptyframe at the end of the video stream\n",
"while frame.numel() != 0:\n",
"for frame, pts in video:\n",
" frames.append(frame)\n",
" frame, pts = video.next()\n",
" \n",
"print(\"Total number of frames: \", len(frames))\n",
"approx_nf = metadata['video']['duration'][0] * metadata['video']['fps'][0]\n",
......@@ -175,7 +179,7 @@
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 9,
"metadata": {},
"outputs": [
{
......@@ -193,11 +197,8 @@
"video.set_current_stream(\"audio\")\n",
"\n",
"frames = [] # we are going to save the frames here.\n",
"frame, pts = video.next()\n",
"# note that next will return emptyframe at the end of the audio stream\n",
"while frame.numel() != 0:\n",
"for frame, pts in video:\n",
" frames.append(frame)\n",
" frame, pts = video.next()\n",
" \n",
"print(\"Total number of frames: \", len(frames))\n",
"approx_nf = metadata['audio']['duration'][0] * metadata['audio']['framerate'][0]\n",
......@@ -211,14 +212,48 @@
"source": [
"But what if we only want to read certain time segment of the video?\n",
"\n",
"That can be done easily using the combination of our seek function, and the fact that each call to next returns the presentation timestamp of the returned frame in seconds.\n",
"That can be done easily using the combination of our seek function, and the fact that each call to next returns the presentation timestamp of the returned frame in seconds. Given that our implementation relies on python iterators, we can leverage `itertools` to simplify the process and make it more pythonic. \n",
"\n",
"For example, if we wanted to read video from second to fifth second:"
"For example, if we wanted to read ten frames from second second:"
]
},
{
"cell_type": "code",
"execution_count": 42,
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of frames: 10\n"
]
}
],
"source": [
"import itertools\n",
"video.set_current_stream(\"video\")\n",
"\n",
"frames = [] # we are going to save the frames here.\n",
"\n",
"# we seek into a second second of the video\n",
"# and use islice to get 10 frames since\n",
"for frame, pts in itertools.islice(video.seek(2), 10):\n",
" frames.append(frame)\n",
" \n",
"print(\"Total number of frames: \", len(frames))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or if we wanted to read from 2nd to 5th second:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
......@@ -236,15 +271,13 @@
"\n",
"frames = [] # we are going to save the frames here.\n",
"\n",
"# we seek into a second second of the video \n",
"# the following call to next returns the first following frame\n",
"video.seek(2) \n",
"frame, pts = video.next()\n",
"# note that we add exit condition\n",
"while pts < 5 and frame.numel() != 0:\n",
"# we seek into a second second of the video\n",
"video = video.seek(2)\n",
"# then we utilize the itertools takewhile to get the \n",
"# correct number of frames\n",
"for frame, pts in itertools.takewhile(lambda x: x[1] <= 5, video):\n",
" frames.append(frame)\n",
" frame, pts = video.next()\n",
" \n",
"\n",
"print(\"Total number of frames: \", len(frames))\n",
"approx_nf = (5-2) * video.get_metadata()['video']['fps'][0]\n",
"print(\"We can expect approx: \", approx_nf)\n",
......@@ -262,7 +295,7 @@
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
......@@ -280,13 +313,10 @@
" video_pts = []\n",
" if read_video:\n",
" video_object.set_current_stream(\"video\")\n",
" video_object.seek(start)\n",
" frames = []\n",
" t, pts = video_object.next()\n",
" while t.numel() > 0 and (pts >= start and pts <= end):\n",
" for t, pts in itertools.takewhile(lambda x: x[1] <= end, video_object.seek(start)):\n",
" frames.append(t)\n",
" video_pts.append(pts)\n",
" t, pts = video_object.next()\n",
" if len(frames) > 0:\n",
" video_frames = torch.stack(frames, 0)\n",
"\n",
......@@ -294,13 +324,10 @@
" audio_pts = []\n",
" if read_audio:\n",
" video_object.set_current_stream(\"audio\")\n",
" video_object.seek(start)\n",
" frames = []\n",
" t, pts = video_object.next()\n",
" while t.numel() > 0 and (pts >= start and pts <= end):\n",
" for t, pts in itertools.takewhile(lambda x: x[1] <= end, video_object.seek(start)):\n",
" frames.append(t)\n",
" audio_pts.append(pts)\n",
" t, pts = video_object.next()\n",
" video_pts.append(pts)\n",
" if len(frames) > 0:\n",
" audio_frames = torch.cat(frames, 0)\n",
"\n",
......@@ -309,7 +336,7 @@
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": 19,
"metadata": {},
"outputs": [
{
......@@ -322,13 +349,13 @@
],
"source": [
"vf, af, info, meta = example_read_video(video)\n",
"# total number of frames should be 327\n",
"# total number of frames should be 327 for video and 523264 datapoints for audio\n",
"print(vf.size(), af.size())"
]
},
{
"cell_type": "code",
"execution_count": 45,
"execution_count": 20,
"metadata": {},
"outputs": [
{
......@@ -337,7 +364,7 @@
"torch.Size([523264, 1])"
]
},
"execution_count": 45,
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
......@@ -362,7 +389,7 @@
},
{
"cell_type": "code",
"execution_count": 46,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
......@@ -375,7 +402,7 @@
},
{
"cell_type": "code",
"execution_count": 47,
"execution_count": 22,
"metadata": {},
"outputs": [
{
......@@ -461,7 +488,7 @@
},
{
"cell_type": "code",
"execution_count": 49,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
......@@ -499,7 +526,7 @@
},
{
"cell_type": "code",
"execution_count": 50,
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
......@@ -523,16 +550,14 @@
" # get random sample\n",
" path, target = random.choice(self.samples)\n",
" # get video object\n",
" vid = torch.classes.torchvision.Video(path, \"video\")\n",
" vid = torchvision.io.VideoReader(path, \"video\")\n",
" metadata = vid.get_metadata()\n",
" video_frames = [] # video frame buffer \n",
" # seek and return frames\n",
" \n",
" max_seek = metadata[\"video\"]['duration'][0] - (self.clip_len / metadata[\"video\"]['fps'][0])\n",
" start = random.uniform(0., max_seek)\n",
" vid.seek(start)\n",
" while len(video_frames) < self.clip_len:\n",
" frame, current_pts = vid.next()\n",
" for frame, current_pts in itertools.islice(vid.seek(start), self.clip_len):\n",
" video_frames.append(self.frame_transform(frame))\n",
" # stack it into a tensor\n",
" video = torch.stack(video_frames, 0)\n",
......@@ -570,7 +595,7 @@
},
{
"cell_type": "code",
"execution_count": 51,
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
......@@ -583,42 +608,48 @@
},
{
"cell_type": "code",
"execution_count": 52,
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"loader = DataLoader(ds, batch_size=12)\n",
"d = {\"video\":[], 'start':[], 'end':[]}\n",
"d = {\"video\":[], 'start':[], 'end':[], 'tensorsize':[]}\n",
"for b in loader:\n",
" for i in range(len(b['path'])):\n",
" d['video'].append(b['path'][i])\n",
" d['start'].append(b['start'][i].item())\n",
" d['end'].append(b['end'][i].item())"
" d['end'].append(b['end'][i].item())\n",
" d['tensorsize'].append(b['video'][i].size())"
]
},
{
"cell_type": "code",
"execution_count": 53,
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'video': ['./dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi',\n",
"{'video': ['./dataset/1/WUzgd7C1pWA.mp4',\n",
" './dataset/1/WUzgd7C1pWA.mp4',\n",
" './dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi',\n",
" './dataset/2/SOX5yA1l24A.mp4',\n",
" './dataset/2/v_SoccerJuggling_g23_c01.avi'],\n",
" 'start': [0.029482554081669773,\n",
" 3.439334232470971,\n",
" 1.1823159301599728,\n",
" 4.470027811314425,\n",
" 3.3126303902318432],\n",
" 'end': [0.5666669999999999, 3.970633, 1.7, 4.971633, 3.837167]}"
" './dataset/2/v_SoccerJuggling_g23_c01.avi',\n",
" './dataset/2/v_SoccerJuggling_g23_c01.avi',\n",
" './dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi'],\n",
" 'start': [8.97932147319667,\n",
" 9.421856461438313,\n",
" 2.1301381796579437,\n",
" 5.514273689529127,\n",
" 0.31979853297913124],\n",
" 'end': [9.5095, 9.943266999999999, 2.635967, 6.0393669999999995, 0.833333],\n",
" 'tensorsize': [torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112])]}"
]
},
"execution_count": 53,
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
......@@ -629,7 +660,7 @@
},
{
"cell_type": "code",
"execution_count": 54,
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
......@@ -638,6 +669,13 @@
"os.remove(\"./WUzgd7C1pWA.mp4\")\n",
"shutil.rmtree(\"./dataset\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
......
......@@ -5,12 +5,14 @@ import tempfile
import unittest
import random
import itertools
import numpy as np
import torch
import torchvision
from torchvision.io import _HAS_VIDEO_OPT, Video
from torchvision.io import _HAS_VIDEO_OPT, VideoReader
try:
import av
......@@ -242,11 +244,11 @@ def _template_read_video(video_object, s=0, e=None):
video_frames = torch.empty(0)
frames = []
video_pts = []
t, pts = video_object.next()
while t.numel() > 0 and (pts >= s and pts <= e):
for t, pts in itertools.takewhile(lambda x: x[1] <= e, video_object):
if pts < s:
continue
frames.append(t)
video_pts.append(pts)
t, pts = video_object.next()
if len(frames) > 0:
video_frames = torch.stack(frames, 0)
......@@ -255,11 +257,11 @@ def _template_read_video(video_object, s=0, e=None):
audio_frames = torch.empty(0)
frames = []
audio_pts = []
t, pts = video_object.next()
while t.numel() > 0 and (pts > s and pts <= e):
for t, pts in itertools.takewhile(lambda x: x[1] <= e, video_object):
if pts < s:
continue
frames.append(t)
audio_pts.append(pts)
t, pts = video_object.next()
if len(frames) > 0:
audio_frames = torch.stack(frames, 0)
......@@ -289,12 +291,10 @@ class TestVideo(unittest.TestCase):
tv_result, _, _ = torchvision.io.read_video(full_path, pts_unit="sec")
tv_result = tv_result.permute(0, 3, 1, 2)
# pass 2: decode all frames using new api
reader = Video(full_path, "video")
reader = VideoReader(full_path, "video")
frames = []
t, _ = reader.next()
while t.numel() > 0:
for t, _ in reader:
frames.append(t)
t, _ = reader.next()
new_api = torch.stack(frames, 0)
self.assertEqual(tv_result.size(), new_api.size())
......@@ -310,7 +310,7 @@ class TestVideo(unittest.TestCase):
# s = min(r)
# e = max(r)
# reader = Video(full_path, "video")
# reader = VideoReader(full_path, "video")
# results = _template_read_video(reader, s, e)
# tv_video, tv_audio, info = torchvision.io.read_video(
# full_path, start_pts=s, end_pts=e, pts_unit="sec"
......@@ -329,12 +329,12 @@ class TestVideo(unittest.TestCase):
# full_path, pts_unit="sec"
# )
# # pass 2: decode all frames using new api
# reader = Video(full_path, "video")
# reader = VideoReader(full_path, "video")
# pts = []
# t, p = reader.next()
# while t.numel() > 0:
# t, p = next(reader)
# while t.numel() > 0: # THIS NEEDS TO BE FIXED
# pts.append(p)
# t, p = reader.next()
# t, p = next(reader)
# tv_timestamps = [float(p) for p in tv_timestamps]
# napi_pts = [float(p) for p in pts]
......@@ -353,7 +353,7 @@ class TestVideo(unittest.TestCase):
torchvision.set_video_backend("pyav")
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
reader = Video(full_path, "video")
reader = VideoReader(full_path, "video")
reader_md = reader.get_metadata()
self.assertAlmostEqual(
config.video_fps, reader_md["video"]["fps"][0], delta=0.0001
......@@ -372,7 +372,7 @@ class TestVideo(unittest.TestCase):
ref_result = _decode_frames_by_av_module(full_path)
reader = Video(full_path, "video")
reader = VideoReader(full_path, "video")
newapi_result = _template_read_video(reader)
# First we check if the frames are approximately the same
......
......@@ -27,87 +27,98 @@ from .image import (
if _HAS_VIDEO_OPT:
def _has_video_opt():
return True
else:
def _has_video_opt():
return False
class Video:
"""
Fine-grained video-reading API.
Supports frame-by-frame reading of various streams from a single video
container.
Args:
path (string): Path to the video file in supported format
class VideoReader:
"""
Fine-grained video-reading API.
Supports frame-by-frame reading of various streams from a single video
container.
stream (string, optional): descriptor of the required stream. Defaults to "video:0"
Currently available options include :mod:`['video', 'audio', 'cc', 'sub']`
Example:
The following examples creates :mod:`Video` object, seeks into 2s
point, and returns a single frame::
import torchvision
video_path = "path_to_a_test_video"
Example:
The following examples creates :mod:`Video` object, seeks into 2s
point, and returns a single frame::
import torchvision
video_path = "path_to_a_test_video"
reader = torchvision.io.VideoReader(video_path, "video")
reader.seek(2.0)
frame, timestamp = next(reader)
reader = torchvision.io.Video(video_path, "video")
reader.seek(2.0)
frame, timestamp = reader.next()
"""
Args:
def __init__(self, path, stream="video"):
self._c = torch.classes.torchvision.Video(path, stream)
path (string): Path to the video file in supported format
def next(self):
"""Iterator that decodes the next frame of the current stream
stream (string, optional): descriptor of the required stream. Defaults to "video:0"
Currently available options include :mod:`['video', 'audio', 'cc', 'sub']`
"""
Returns:
([torch.Tensor, float]): list containing decoded frame and corresponding timestamp
def __init__(self, path, stream="video"):
if not _has_video_opt():
raise RuntimeError("Not compiled with video_reader support")
self._c = torch.classes.torchvision.Video(path, stream)
"""
return self._c.next()
def __next__(self):
"""Decodes and returns the next frame of the current stream
def seek(self, time_s: float):
"""Seek within current stream.
Returns:
([torch.Tensor, float]): list containing decoded frame and corresponding timestamp
Args:
time_s (float): seek time in seconds
"""
frame, pts = self._c.next()
if frame.numel() == 0:
raise StopIteration
return frame, pts
.. note::
Current implementation is the so-called precise seek. This
means following seek, call to :mod:`next()` will return the
frame with the exact timestamp if it exists or
the first frame with timestamp larger than time_s.
"""
self._c.seek(time_s)
def __iter__(self):
return self
def get_metadata(self):
"""Returns video metadata
def seek(self, time_s: float):
"""Seek within current stream.
Returns:
(dict): dictionary containing duration and frame rate for every stream
"""
return self._c.get_metadata()
Args:
time_s (float): seek time in seconds
def set_current_stream(self, stream: str):
"""Set current stream.
Explicitly define the stream we are operating on.
.. note::
Current implementation is the so-called precise seek. This
means following seek, call to :mod:`next()` will return the
frame with the exact timestamp if it exists or
the first frame with timestamp larger than time_s.
"""
self._c.seek(time_s)
return self
Args:
stream (string): descriptor of the required stream. Defaults to "video:0"
Currently available stream types include :mod:`['video', 'audio', 'cc', 'sub']`.
Each descriptor consists of two parts: stream type (e.g. 'video') and
a unique stream id (which are determined by video encoding).
In this way, if the video contaner contains multiple
streams of the same type, users can acces the one they want.
If only stream type is passed, the decoder auto-detects first stream
of that type and returns it.
def get_metadata(self):
"""Returns video metadata
Returns:
(bool): True on succes, False otherwise
"""
return self._c.set_current_stream(stream)
Returns:
(dict): dictionary containing duration and frame rate for every stream
"""
return self._c.get_metadata()
def set_current_stream(self, stream: str):
"""Set current stream.
Explicitly define the stream we are operating on.
else:
Video = None
Args:
stream (string): descriptor of the required stream. Defaults to "video:0"
Currently available stream types include :mod:`['video', 'audio', 'cc', 'sub']`.
Each descriptor consists of two parts: stream type (e.g. 'video') and
a unique stream id (which are determined by video encoding).
In this way, if the video contaner contains multiple
streams of the same type, users can acces the one they want.
If only stream type is passed, the decoder auto-detects first stream
of that type and returns it.
Returns:
(bool): True on succes, False otherwise
"""
return self._c.set_current_stream(stream)
__all__ = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment