Unverified Commit 039c9288 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Minor improvements to VideoReader (#2781)



* Minor improvements to VideoReader

* update jupyter notebook with new naming
Co-authored-by: default avatarBruno Korbar <bjuncek@gmail.com>
parent 610c9d2a
......@@ -24,7 +24,7 @@ In addition to the :mod:`read_video` function, we provide a high-performance
lower-level API for more fine-grained control compared to the :mod:`read_video` function.
It does all this whilst fully supporting torchscript.
.. autoclass:: Video
.. autoclass:: VideoReader
:members: next, get_metadata, set_current_stream, seek
......@@ -37,7 +37,7 @@ Example of usage:
# Constructor allocates memory and a threaded decoder
# instance per video. At the momet it takes two arguments:
# path to the video file, and a wanted stream.
reader = torchvision.io.Video(video_path, "video")
reader = torchvision.io.VideoReader(video_path, "video")
# The information about the video can be retrieved using the
# `get_metadata()` method. It returns a dictionary for every stream, with
......
......@@ -16,16 +16,16 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('1.7.0a0+f5c95d5', '0.8.0a0+6eff0a4')"
"('1.7.0a0+f5c95d5', '0.8.0a0+a2f405d')"
]
},
"execution_count": 36,
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
......@@ -37,14 +37,21 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using downloaded and verified file: ./WUzgd7C1pWA.mp4\n"
"Downloading https://github.com/pytorch/vision/blob/master/test/assets/videos/WUzgd7C1pWA.mp4?raw=true to ./WUzgd7C1pWA.mp4\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100.4%"
]
}
],
......@@ -65,7 +72,7 @@
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
......@@ -91,7 +98,7 @@
"\n",
"\n",
"\n",
"video = torch.classes.torchvision.Video(video_path, stream)"
"video = torchvision.io.VideoReader(video_path, stream)"
]
},
{
......@@ -103,7 +110,7 @@
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": 6,
"metadata": {},
"outputs": [
{
......@@ -113,7 +120,7 @@
" 'audio': {'duration': [10.9], 'framerate': [48000.0]}}"
]
},
"execution_count": 39,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
......@@ -133,7 +140,7 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 8,
"metadata": {},
"outputs": [
{
......@@ -152,11 +159,8 @@
"video.set_current_stream(\"video:0\")\n",
"\n",
"frames = [] # we are going to save the frames here.\n",
"frame, pts = video.next()\n",
"# note that next will return emptyframe at the end of the video stream\n",
"while frame.numel() != 0:\n",
"for frame, pts in video:\n",
" frames.append(frame)\n",
" frame, pts = video.next()\n",
" \n",
"print(\"Total number of frames: \", len(frames))\n",
"approx_nf = metadata['video']['duration'][0] * metadata['video']['fps'][0]\n",
......@@ -175,7 +179,7 @@
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 9,
"metadata": {},
"outputs": [
{
......@@ -193,11 +197,8 @@
"video.set_current_stream(\"audio\")\n",
"\n",
"frames = [] # we are going to save the frames here.\n",
"frame, pts = video.next()\n",
"# note that next will return emptyframe at the end of the audio stream\n",
"while frame.numel() != 0:\n",
"for frame, pts in video:\n",
" frames.append(frame)\n",
" frame, pts = video.next()\n",
" \n",
"print(\"Total number of frames: \", len(frames))\n",
"approx_nf = metadata['audio']['duration'][0] * metadata['audio']['framerate'][0]\n",
......@@ -211,14 +212,48 @@
"source": [
"But what if we only want to read certain time segment of the video?\n",
"\n",
"That can be done easily using the combination of our seek function, and the fact that each call to next returns the presentation timestamp of the returned frame in seconds.\n",
"That can be done easily using the combination of our seek function, and the fact that each call to next returns the presentation timestamp of the returned frame in seconds. Given that our implementation relies on python iterators, we can leverage `itertools` to simplify the process and make it more pythonic. \n",
"\n",
"For example, if we wanted to read video from second to fifth second:"
"For example, if we wanted to read ten frames from second second:"
]
},
{
"cell_type": "code",
"execution_count": 42,
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of frames: 10\n"
]
}
],
"source": [
"import itertools\n",
"video.set_current_stream(\"video\")\n",
"\n",
"frames = [] # we are going to save the frames here.\n",
"\n",
"# we seek into a second second of the video\n",
"# and use islice to get 10 frames since\n",
"for frame, pts in itertools.islice(video.seek(2), 10):\n",
" frames.append(frame)\n",
" \n",
"print(\"Total number of frames: \", len(frames))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or if we wanted to read from 2nd to 5th second:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
......@@ -236,15 +271,13 @@
"\n",
"frames = [] # we are going to save the frames here.\n",
"\n",
"# we seek into a second second of the video \n",
"# the following call to next returns the first following frame\n",
"video.seek(2) \n",
"frame, pts = video.next()\n",
"# note that we add exit condition\n",
"while pts < 5 and frame.numel() != 0:\n",
"# we seek into a second second of the video\n",
"video = video.seek(2)\n",
"# then we utilize the itertools takewhile to get the \n",
"# correct number of frames\n",
"for frame, pts in itertools.takewhile(lambda x: x[1] <= 5, video):\n",
" frames.append(frame)\n",
" frame, pts = video.next()\n",
" \n",
"\n",
"print(\"Total number of frames: \", len(frames))\n",
"approx_nf = (5-2) * video.get_metadata()['video']['fps'][0]\n",
"print(\"We can expect approx: \", approx_nf)\n",
......@@ -262,7 +295,7 @@
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
......@@ -280,13 +313,10 @@
" video_pts = []\n",
" if read_video:\n",
" video_object.set_current_stream(\"video\")\n",
" video_object.seek(start)\n",
" frames = []\n",
" t, pts = video_object.next()\n",
" while t.numel() > 0 and (pts >= start and pts <= end):\n",
" for t, pts in itertools.takewhile(lambda x: x[1] <= end, video_object.seek(start)):\n",
" frames.append(t)\n",
" video_pts.append(pts)\n",
" t, pts = video_object.next()\n",
" if len(frames) > 0:\n",
" video_frames = torch.stack(frames, 0)\n",
"\n",
......@@ -294,13 +324,10 @@
" audio_pts = []\n",
" if read_audio:\n",
" video_object.set_current_stream(\"audio\")\n",
" video_object.seek(start)\n",
" frames = []\n",
" t, pts = video_object.next()\n",
" while t.numel() > 0 and (pts >= start and pts <= end):\n",
" for t, pts in itertools.takewhile(lambda x: x[1] <= end, video_object.seek(start)):\n",
" frames.append(t)\n",
" audio_pts.append(pts)\n",
" t, pts = video_object.next()\n",
" video_pts.append(pts)\n",
" if len(frames) > 0:\n",
" audio_frames = torch.cat(frames, 0)\n",
"\n",
......@@ -309,7 +336,7 @@
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": 19,
"metadata": {},
"outputs": [
{
......@@ -322,13 +349,13 @@
],
"source": [
"vf, af, info, meta = example_read_video(video)\n",
"# total number of frames should be 327\n",
"# total number of frames should be 327 for video and 523264 datapoints for audio\n",
"print(vf.size(), af.size())"
]
},
{
"cell_type": "code",
"execution_count": 45,
"execution_count": 20,
"metadata": {},
"outputs": [
{
......@@ -337,7 +364,7 @@
"torch.Size([523264, 1])"
]
},
"execution_count": 45,
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
......@@ -362,7 +389,7 @@
},
{
"cell_type": "code",
"execution_count": 46,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
......@@ -375,7 +402,7 @@
},
{
"cell_type": "code",
"execution_count": 47,
"execution_count": 22,
"metadata": {},
"outputs": [
{
......@@ -461,7 +488,7 @@
},
{
"cell_type": "code",
"execution_count": 49,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
......@@ -499,7 +526,7 @@
},
{
"cell_type": "code",
"execution_count": 50,
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
......@@ -523,16 +550,14 @@
" # get random sample\n",
" path, target = random.choice(self.samples)\n",
" # get video object\n",
" vid = torch.classes.torchvision.Video(path, \"video\")\n",
" vid = torchvision.io.VideoReader(path, \"video\")\n",
" metadata = vid.get_metadata()\n",
" video_frames = [] # video frame buffer \n",
" # seek and return frames\n",
" \n",
" max_seek = metadata[\"video\"]['duration'][0] - (self.clip_len / metadata[\"video\"]['fps'][0])\n",
" start = random.uniform(0., max_seek)\n",
" vid.seek(start)\n",
" while len(video_frames) < self.clip_len:\n",
" frame, current_pts = vid.next()\n",
" for frame, current_pts in itertools.islice(vid.seek(start), self.clip_len):\n",
" video_frames.append(self.frame_transform(frame))\n",
" # stack it into a tensor\n",
" video = torch.stack(video_frames, 0)\n",
......@@ -570,7 +595,7 @@
},
{
"cell_type": "code",
"execution_count": 51,
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
......@@ -583,42 +608,48 @@
},
{
"cell_type": "code",
"execution_count": 52,
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"loader = DataLoader(ds, batch_size=12)\n",
"d = {\"video\":[], 'start':[], 'end':[]}\n",
"d = {\"video\":[], 'start':[], 'end':[], 'tensorsize':[]}\n",
"for b in loader:\n",
" for i in range(len(b['path'])):\n",
" d['video'].append(b['path'][i])\n",
" d['start'].append(b['start'][i].item())\n",
" d['end'].append(b['end'][i].item())"
" d['end'].append(b['end'][i].item())\n",
" d['tensorsize'].append(b['video'][i].size())"
]
},
{
"cell_type": "code",
"execution_count": 53,
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'video': ['./dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi',\n",
"{'video': ['./dataset/1/WUzgd7C1pWA.mp4',\n",
" './dataset/1/WUzgd7C1pWA.mp4',\n",
" './dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi',\n",
" './dataset/2/SOX5yA1l24A.mp4',\n",
" './dataset/2/v_SoccerJuggling_g23_c01.avi'],\n",
" 'start': [0.029482554081669773,\n",
" 3.439334232470971,\n",
" 1.1823159301599728,\n",
" 4.470027811314425,\n",
" 3.3126303902318432],\n",
" 'end': [0.5666669999999999, 3.970633, 1.7, 4.971633, 3.837167]}"
" './dataset/2/v_SoccerJuggling_g23_c01.avi',\n",
" './dataset/2/v_SoccerJuggling_g23_c01.avi',\n",
" './dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi'],\n",
" 'start': [8.97932147319667,\n",
" 9.421856461438313,\n",
" 2.1301381796579437,\n",
" 5.514273689529127,\n",
" 0.31979853297913124],\n",
" 'end': [9.5095, 9.943266999999999, 2.635967, 6.0393669999999995, 0.833333],\n",
" 'tensorsize': [torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112])]}"
]
},
"execution_count": 53,
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
......@@ -629,7 +660,7 @@
},
{
"cell_type": "code",
"execution_count": 54,
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
......@@ -638,6 +669,13 @@
"os.remove(\"./WUzgd7C1pWA.mp4\")\n",
"shutil.rmtree(\"./dataset\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
......
......@@ -5,12 +5,14 @@ import tempfile
import unittest
import random
import itertools
import numpy as np
import torch
import torchvision
from torchvision.io import _HAS_VIDEO_OPT, Video
from torchvision.io import _HAS_VIDEO_OPT, VideoReader
try:
import av
......@@ -242,11 +244,11 @@ def _template_read_video(video_object, s=0, e=None):
video_frames = torch.empty(0)
frames = []
video_pts = []
t, pts = video_object.next()
while t.numel() > 0 and (pts >= s and pts <= e):
for t, pts in itertools.takewhile(lambda x: x[1] <= e, video_object):
if pts < s:
continue
frames.append(t)
video_pts.append(pts)
t, pts = video_object.next()
if len(frames) > 0:
video_frames = torch.stack(frames, 0)
......@@ -255,11 +257,11 @@ def _template_read_video(video_object, s=0, e=None):
audio_frames = torch.empty(0)
frames = []
audio_pts = []
t, pts = video_object.next()
while t.numel() > 0 and (pts > s and pts <= e):
for t, pts in itertools.takewhile(lambda x: x[1] <= e, video_object):
if pts < s:
continue
frames.append(t)
audio_pts.append(pts)
t, pts = video_object.next()
if len(frames) > 0:
audio_frames = torch.stack(frames, 0)
......@@ -289,12 +291,10 @@ class TestVideo(unittest.TestCase):
tv_result, _, _ = torchvision.io.read_video(full_path, pts_unit="sec")
tv_result = tv_result.permute(0, 3, 1, 2)
# pass 2: decode all frames using new api
reader = Video(full_path, "video")
reader = VideoReader(full_path, "video")
frames = []
t, _ = reader.next()
while t.numel() > 0:
for t, _ in reader:
frames.append(t)
t, _ = reader.next()
new_api = torch.stack(frames, 0)
self.assertEqual(tv_result.size(), new_api.size())
......@@ -310,7 +310,7 @@ class TestVideo(unittest.TestCase):
# s = min(r)
# e = max(r)
# reader = Video(full_path, "video")
# reader = VideoReader(full_path, "video")
# results = _template_read_video(reader, s, e)
# tv_video, tv_audio, info = torchvision.io.read_video(
# full_path, start_pts=s, end_pts=e, pts_unit="sec"
......@@ -329,12 +329,12 @@ class TestVideo(unittest.TestCase):
# full_path, pts_unit="sec"
# )
# # pass 2: decode all frames using new api
# reader = Video(full_path, "video")
# reader = VideoReader(full_path, "video")
# pts = []
# t, p = reader.next()
# while t.numel() > 0:
# t, p = next(reader)
# while t.numel() > 0: # THIS NEEDS TO BE FIXED
# pts.append(p)
# t, p = reader.next()
# t, p = next(reader)
# tv_timestamps = [float(p) for p in tv_timestamps]
# napi_pts = [float(p) for p in pts]
......@@ -353,7 +353,7 @@ class TestVideo(unittest.TestCase):
torchvision.set_video_backend("pyav")
for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video)
reader = Video(full_path, "video")
reader = VideoReader(full_path, "video")
reader_md = reader.get_metadata()
self.assertAlmostEqual(
config.video_fps, reader_md["video"]["fps"][0], delta=0.0001
......@@ -372,7 +372,7 @@ class TestVideo(unittest.TestCase):
ref_result = _decode_frames_by_av_module(full_path)
reader = Video(full_path, "video")
reader = VideoReader(full_path, "video")
newapi_result = _template_read_video(reader)
# First we check if the frames are approximately the same
......
......@@ -27,42 +27,56 @@ from .image import (
if _HAS_VIDEO_OPT:
def _has_video_opt():
return True
else:
def _has_video_opt():
return False
class Video:
class VideoReader:
"""
Fine-grained video-reading API.
Supports frame-by-frame reading of various streams from a single video
container.
Args:
path (string): Path to the video file in supported format
stream (string, optional): descriptor of the required stream. Defaults to "video:0"
Currently available options include :mod:`['video', 'audio', 'cc', 'sub']`
Example:
The following examples creates :mod:`Video` object, seeks into 2s
point, and returns a single frame::
import torchvision
video_path = "path_to_a_test_video"
reader = torchvision.io.Video(video_path, "video")
reader = torchvision.io.VideoReader(video_path, "video")
reader.seek(2.0)
frame, timestamp = reader.next()
frame, timestamp = next(reader)
Args:
path (string): Path to the video file in supported format
stream (string, optional): descriptor of the required stream. Defaults to "video:0"
Currently available options include :mod:`['video', 'audio', 'cc', 'sub']`
"""
def __init__(self, path, stream="video"):
if not _has_video_opt():
raise RuntimeError("Not compiled with video_reader support")
self._c = torch.classes.torchvision.Video(path, stream)
def next(self):
"""Iterator that decodes the next frame of the current stream
def __next__(self):
"""Decodes and returns the next frame of the current stream
Returns:
([torch.Tensor, float]): list containing decoded frame and corresponding timestamp
"""
return self._c.next()
frame, pts = self._c.next()
if frame.numel() == 0:
raise StopIteration
return frame, pts
def __iter__(self):
return self
def seek(self, time_s: float):
"""Seek within current stream.
......@@ -77,6 +91,7 @@ if _HAS_VIDEO_OPT:
the first frame with timestamp larger than time_s.
"""
self._c.seek(time_s)
return self
def get_metadata(self):
"""Returns video metadata
......@@ -106,10 +121,6 @@ if _HAS_VIDEO_OPT:
return self._c.set_current_stream(stream)
else:
Video = None
__all__ = [
"write_video",
"read_video",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment