Unverified Commit 039c9288 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Minor improvements to VideoReader (#2781)



* Minor improvements to VideoReader

* update jupyter notebook with new naming
Co-authored-by: default avatarBruno Korbar <bjuncek@gmail.com>
parent 610c9d2a
...@@ -24,7 +24,7 @@ In addition to the :mod:`read_video` function, we provide a high-performance ...@@ -24,7 +24,7 @@ In addition to the :mod:`read_video` function, we provide a high-performance
lower-level API for more fine-grained control compared to the :mod:`read_video` function. lower-level API for more fine-grained control compared to the :mod:`read_video` function.
It does all this whilst fully supporting torchscript. It does all this whilst fully supporting torchscript.
.. autoclass:: Video .. autoclass:: VideoReader
:members: next, get_metadata, set_current_stream, seek :members: next, get_metadata, set_current_stream, seek
...@@ -37,7 +37,7 @@ Example of usage: ...@@ -37,7 +37,7 @@ Example of usage:
# Constructor allocates memory and a threaded decoder # Constructor allocates memory and a threaded decoder
# instance per video. At the momet it takes two arguments: # instance per video. At the momet it takes two arguments:
# path to the video file, and a wanted stream. # path to the video file, and a wanted stream.
reader = torchvision.io.Video(video_path, "video") reader = torchvision.io.VideoReader(video_path, "video")
# The information about the video can be retrieved using the # The information about the video can be retrieved using the
# `get_metadata()` method. It returns a dictionary for every stream, with # `get_metadata()` method. It returns a dictionary for every stream, with
......
...@@ -16,16 +16,16 @@ ...@@ -16,16 +16,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 36, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"('1.7.0a0+f5c95d5', '0.8.0a0+6eff0a4')" "('1.7.0a0+f5c95d5', '0.8.0a0+a2f405d')"
] ]
}, },
"execution_count": 36, "execution_count": 1,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -37,14 +37,21 @@ ...@@ -37,14 +37,21 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 37, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Using downloaded and verified file: ./WUzgd7C1pWA.mp4\n" "Downloading https://github.com/pytorch/vision/blob/master/test/assets/videos/WUzgd7C1pWA.mp4?raw=true to ./WUzgd7C1pWA.mp4\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100.4%"
] ]
} }
], ],
...@@ -65,7 +72,7 @@ ...@@ -65,7 +72,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 38, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -91,7 +98,7 @@ ...@@ -91,7 +98,7 @@
"\n", "\n",
"\n", "\n",
"\n", "\n",
"video = torch.classes.torchvision.Video(video_path, stream)" "video = torchvision.io.VideoReader(video_path, stream)"
] ]
}, },
{ {
...@@ -103,7 +110,7 @@ ...@@ -103,7 +110,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 39, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -113,7 +120,7 @@ ...@@ -113,7 +120,7 @@
" 'audio': {'duration': [10.9], 'framerate': [48000.0]}}" " 'audio': {'duration': [10.9], 'framerate': [48000.0]}}"
] ]
}, },
"execution_count": 39, "execution_count": 6,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -133,7 +140,7 @@ ...@@ -133,7 +140,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 40, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -152,11 +159,8 @@ ...@@ -152,11 +159,8 @@
"video.set_current_stream(\"video:0\")\n", "video.set_current_stream(\"video:0\")\n",
"\n", "\n",
"frames = [] # we are going to save the frames here.\n", "frames = [] # we are going to save the frames here.\n",
"frame, pts = video.next()\n", "for frame, pts in video:\n",
"# note that next will return emptyframe at the end of the video stream\n",
"while frame.numel() != 0:\n",
" frames.append(frame)\n", " frames.append(frame)\n",
" frame, pts = video.next()\n",
" \n", " \n",
"print(\"Total number of frames: \", len(frames))\n", "print(\"Total number of frames: \", len(frames))\n",
"approx_nf = metadata['video']['duration'][0] * metadata['video']['fps'][0]\n", "approx_nf = metadata['video']['duration'][0] * metadata['video']['fps'][0]\n",
...@@ -175,7 +179,7 @@ ...@@ -175,7 +179,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 41, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -193,11 +197,8 @@ ...@@ -193,11 +197,8 @@
"video.set_current_stream(\"audio\")\n", "video.set_current_stream(\"audio\")\n",
"\n", "\n",
"frames = [] # we are going to save the frames here.\n", "frames = [] # we are going to save the frames here.\n",
"frame, pts = video.next()\n", "for frame, pts in video:\n",
"# note that next will return emptyframe at the end of the audio stream\n",
"while frame.numel() != 0:\n",
" frames.append(frame)\n", " frames.append(frame)\n",
" frame, pts = video.next()\n",
" \n", " \n",
"print(\"Total number of frames: \", len(frames))\n", "print(\"Total number of frames: \", len(frames))\n",
"approx_nf = metadata['audio']['duration'][0] * metadata['audio']['framerate'][0]\n", "approx_nf = metadata['audio']['duration'][0] * metadata['audio']['framerate'][0]\n",
...@@ -211,14 +212,48 @@ ...@@ -211,14 +212,48 @@
"source": [ "source": [
"But what if we only want to read certain time segment of the video?\n", "But what if we only want to read certain time segment of the video?\n",
"\n", "\n",
"That can be done easily using the combination of our seek function, and the fact that each call to next returns the presentation timestamp of the returned frame in seconds.\n", "That can be done easily using the combination of our seek function, and the fact that each call to next returns the presentation timestamp of the returned frame in seconds. Given that our implementation relies on python iterators, we can leverage `itertools` to simplify the process and make it more pythonic. \n",
"\n", "\n",
"For example, if we wanted to read video from second to fifth second:" "For example, if we wanted to read ten frames from second second:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 42, "execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of frames: 10\n"
]
}
],
"source": [
"import itertools\n",
"video.set_current_stream(\"video\")\n",
"\n",
"frames = [] # we are going to save the frames here.\n",
"\n",
"# we seek into a second second of the video\n",
"# and use islice to get 10 frames since\n",
"for frame, pts in itertools.islice(video.seek(2), 10):\n",
" frames.append(frame)\n",
" \n",
"print(\"Total number of frames: \", len(frames))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or if we wanted to read from 2nd to 5th second:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -236,15 +271,13 @@ ...@@ -236,15 +271,13 @@
"\n", "\n",
"frames = [] # we are going to save the frames here.\n", "frames = [] # we are going to save the frames here.\n",
"\n", "\n",
"# we seek into a second second of the video \n", "# we seek into a second second of the video\n",
"# the following call to next returns the first following frame\n", "video = video.seek(2)\n",
"video.seek(2) \n", "# then we utilize the itertools takewhile to get the \n",
"frame, pts = video.next()\n", "# correct number of frames\n",
"# note that we add exit condition\n", "for frame, pts in itertools.takewhile(lambda x: x[1] <= 5, video):\n",
"while pts < 5 and frame.numel() != 0:\n",
" frames.append(frame)\n", " frames.append(frame)\n",
" frame, pts = video.next()\n", "\n",
" \n",
"print(\"Total number of frames: \", len(frames))\n", "print(\"Total number of frames: \", len(frames))\n",
"approx_nf = (5-2) * video.get_metadata()['video']['fps'][0]\n", "approx_nf = (5-2) * video.get_metadata()['video']['fps'][0]\n",
"print(\"We can expect approx: \", approx_nf)\n", "print(\"We can expect approx: \", approx_nf)\n",
...@@ -262,7 +295,7 @@ ...@@ -262,7 +295,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 43, "execution_count": 17,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -280,13 +313,10 @@ ...@@ -280,13 +313,10 @@
" video_pts = []\n", " video_pts = []\n",
" if read_video:\n", " if read_video:\n",
" video_object.set_current_stream(\"video\")\n", " video_object.set_current_stream(\"video\")\n",
" video_object.seek(start)\n",
" frames = []\n", " frames = []\n",
" t, pts = video_object.next()\n", " for t, pts in itertools.takewhile(lambda x: x[1] <= end, video_object.seek(start)):\n",
" while t.numel() > 0 and (pts >= start and pts <= end):\n",
" frames.append(t)\n", " frames.append(t)\n",
" video_pts.append(pts)\n", " video_pts.append(pts)\n",
" t, pts = video_object.next()\n",
" if len(frames) > 0:\n", " if len(frames) > 0:\n",
" video_frames = torch.stack(frames, 0)\n", " video_frames = torch.stack(frames, 0)\n",
"\n", "\n",
...@@ -294,13 +324,10 @@ ...@@ -294,13 +324,10 @@
" audio_pts = []\n", " audio_pts = []\n",
" if read_audio:\n", " if read_audio:\n",
" video_object.set_current_stream(\"audio\")\n", " video_object.set_current_stream(\"audio\")\n",
" video_object.seek(start)\n",
" frames = []\n", " frames = []\n",
" t, pts = video_object.next()\n", " for t, pts in itertools.takewhile(lambda x: x[1] <= end, video_object.seek(start)):\n",
" while t.numel() > 0 and (pts >= start and pts <= end):\n",
" frames.append(t)\n", " frames.append(t)\n",
" audio_pts.append(pts)\n", " video_pts.append(pts)\n",
" t, pts = video_object.next()\n",
" if len(frames) > 0:\n", " if len(frames) > 0:\n",
" audio_frames = torch.cat(frames, 0)\n", " audio_frames = torch.cat(frames, 0)\n",
"\n", "\n",
...@@ -309,7 +336,7 @@ ...@@ -309,7 +336,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 44, "execution_count": 19,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -322,13 +349,13 @@ ...@@ -322,13 +349,13 @@
], ],
"source": [ "source": [
"vf, af, info, meta = example_read_video(video)\n", "vf, af, info, meta = example_read_video(video)\n",
"# total number of frames should be 327\n", "# total number of frames should be 327 for video and 523264 datapoints for audio\n",
"print(vf.size(), af.size())" "print(vf.size(), af.size())"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 45, "execution_count": 20,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -337,7 +364,7 @@ ...@@ -337,7 +364,7 @@
"torch.Size([523264, 1])" "torch.Size([523264, 1])"
] ]
}, },
"execution_count": 45, "execution_count": 20,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -362,7 +389,7 @@ ...@@ -362,7 +389,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 46, "execution_count": 21,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -375,7 +402,7 @@ ...@@ -375,7 +402,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 47, "execution_count": 22,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -461,7 +488,7 @@ ...@@ -461,7 +488,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 49, "execution_count": 23,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -499,7 +526,7 @@ ...@@ -499,7 +526,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 50, "execution_count": 33,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -523,16 +550,14 @@ ...@@ -523,16 +550,14 @@
" # get random sample\n", " # get random sample\n",
" path, target = random.choice(self.samples)\n", " path, target = random.choice(self.samples)\n",
" # get video object\n", " # get video object\n",
" vid = torch.classes.torchvision.Video(path, \"video\")\n", " vid = torchvision.io.VideoReader(path, \"video\")\n",
" metadata = vid.get_metadata()\n", " metadata = vid.get_metadata()\n",
" video_frames = [] # video frame buffer \n", " video_frames = [] # video frame buffer \n",
" # seek and return frames\n", " # seek and return frames\n",
" \n", " \n",
" max_seek = metadata[\"video\"]['duration'][0] - (self.clip_len / metadata[\"video\"]['fps'][0])\n", " max_seek = metadata[\"video\"]['duration'][0] - (self.clip_len / metadata[\"video\"]['fps'][0])\n",
" start = random.uniform(0., max_seek)\n", " start = random.uniform(0., max_seek)\n",
" vid.seek(start)\n", " for frame, current_pts in itertools.islice(vid.seek(start), self.clip_len):\n",
" while len(video_frames) < self.clip_len:\n",
" frame, current_pts = vid.next()\n",
" video_frames.append(self.frame_transform(frame))\n", " video_frames.append(self.frame_transform(frame))\n",
" # stack it into a tensor\n", " # stack it into a tensor\n",
" video = torch.stack(video_frames, 0)\n", " video = torch.stack(video_frames, 0)\n",
...@@ -570,7 +595,7 @@ ...@@ -570,7 +595,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 51, "execution_count": 34,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -583,42 +608,48 @@ ...@@ -583,42 +608,48 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 52, "execution_count": 39,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from torch.utils.data import DataLoader\n", "from torch.utils.data import DataLoader\n",
"loader = DataLoader(ds, batch_size=12)\n", "loader = DataLoader(ds, batch_size=12)\n",
"d = {\"video\":[], 'start':[], 'end':[]}\n", "d = {\"video\":[], 'start':[], 'end':[], 'tensorsize':[]}\n",
"for b in loader:\n", "for b in loader:\n",
" for i in range(len(b['path'])):\n", " for i in range(len(b['path'])):\n",
" d['video'].append(b['path'][i])\n", " d['video'].append(b['path'][i])\n",
" d['start'].append(b['start'][i].item())\n", " d['start'].append(b['start'][i].item())\n",
" d['end'].append(b['end'][i].item())" " d['end'].append(b['end'][i].item())\n",
" d['tensorsize'].append(b['video'][i].size())"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 53, "execution_count": 40,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'video': ['./dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi',\n", "{'video': ['./dataset/1/WUzgd7C1pWA.mp4',\n",
" './dataset/1/WUzgd7C1pWA.mp4',\n", " './dataset/1/WUzgd7C1pWA.mp4',\n",
" './dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi',\n", " './dataset/2/v_SoccerJuggling_g23_c01.avi',\n",
" './dataset/2/SOX5yA1l24A.mp4',\n", " './dataset/2/v_SoccerJuggling_g23_c01.avi',\n",
" './dataset/2/v_SoccerJuggling_g23_c01.avi'],\n", " './dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi'],\n",
" 'start': [0.029482554081669773,\n", " 'start': [8.97932147319667,\n",
" 3.439334232470971,\n", " 9.421856461438313,\n",
" 1.1823159301599728,\n", " 2.1301381796579437,\n",
" 4.470027811314425,\n", " 5.514273689529127,\n",
" 3.3126303902318432],\n", " 0.31979853297913124],\n",
" 'end': [0.5666669999999999, 3.970633, 1.7, 4.971633, 3.837167]}" " 'end': [9.5095, 9.943266999999999, 2.635967, 6.0393669999999995, 0.833333],\n",
" 'tensorsize': [torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112])]}"
] ]
}, },
"execution_count": 53, "execution_count": 40,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -629,7 +660,7 @@ ...@@ -629,7 +660,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 54, "execution_count": 41,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -638,6 +669,13 @@ ...@@ -638,6 +669,13 @@
"os.remove(\"./WUzgd7C1pWA.mp4\")\n", "os.remove(\"./WUzgd7C1pWA.mp4\")\n",
"shutil.rmtree(\"./dataset\")" "shutil.rmtree(\"./dataset\")"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {
......
...@@ -5,12 +5,14 @@ import tempfile ...@@ -5,12 +5,14 @@ import tempfile
import unittest import unittest
import random import random
import itertools
import numpy as np import numpy as np
import torch import torch
import torchvision import torchvision
from torchvision.io import _HAS_VIDEO_OPT, Video from torchvision.io import _HAS_VIDEO_OPT, VideoReader
try: try:
import av import av
...@@ -242,11 +244,11 @@ def _template_read_video(video_object, s=0, e=None): ...@@ -242,11 +244,11 @@ def _template_read_video(video_object, s=0, e=None):
video_frames = torch.empty(0) video_frames = torch.empty(0)
frames = [] frames = []
video_pts = [] video_pts = []
t, pts = video_object.next() for t, pts in itertools.takewhile(lambda x: x[1] <= e, video_object):
while t.numel() > 0 and (pts >= s and pts <= e): if pts < s:
continue
frames.append(t) frames.append(t)
video_pts.append(pts) video_pts.append(pts)
t, pts = video_object.next()
if len(frames) > 0: if len(frames) > 0:
video_frames = torch.stack(frames, 0) video_frames = torch.stack(frames, 0)
...@@ -255,11 +257,11 @@ def _template_read_video(video_object, s=0, e=None): ...@@ -255,11 +257,11 @@ def _template_read_video(video_object, s=0, e=None):
audio_frames = torch.empty(0) audio_frames = torch.empty(0)
frames = [] frames = []
audio_pts = [] audio_pts = []
t, pts = video_object.next() for t, pts in itertools.takewhile(lambda x: x[1] <= e, video_object):
while t.numel() > 0 and (pts > s and pts <= e): if pts < s:
continue
frames.append(t) frames.append(t)
audio_pts.append(pts) audio_pts.append(pts)
t, pts = video_object.next()
if len(frames) > 0: if len(frames) > 0:
audio_frames = torch.stack(frames, 0) audio_frames = torch.stack(frames, 0)
...@@ -289,12 +291,10 @@ class TestVideo(unittest.TestCase): ...@@ -289,12 +291,10 @@ class TestVideo(unittest.TestCase):
tv_result, _, _ = torchvision.io.read_video(full_path, pts_unit="sec") tv_result, _, _ = torchvision.io.read_video(full_path, pts_unit="sec")
tv_result = tv_result.permute(0, 3, 1, 2) tv_result = tv_result.permute(0, 3, 1, 2)
# pass 2: decode all frames using new api # pass 2: decode all frames using new api
reader = Video(full_path, "video") reader = VideoReader(full_path, "video")
frames = [] frames = []
t, _ = reader.next() for t, _ in reader:
while t.numel() > 0:
frames.append(t) frames.append(t)
t, _ = reader.next()
new_api = torch.stack(frames, 0) new_api = torch.stack(frames, 0)
self.assertEqual(tv_result.size(), new_api.size()) self.assertEqual(tv_result.size(), new_api.size())
...@@ -310,7 +310,7 @@ class TestVideo(unittest.TestCase): ...@@ -310,7 +310,7 @@ class TestVideo(unittest.TestCase):
# s = min(r) # s = min(r)
# e = max(r) # e = max(r)
# reader = Video(full_path, "video") # reader = VideoReader(full_path, "video")
# results = _template_read_video(reader, s, e) # results = _template_read_video(reader, s, e)
# tv_video, tv_audio, info = torchvision.io.read_video( # tv_video, tv_audio, info = torchvision.io.read_video(
# full_path, start_pts=s, end_pts=e, pts_unit="sec" # full_path, start_pts=s, end_pts=e, pts_unit="sec"
...@@ -329,12 +329,12 @@ class TestVideo(unittest.TestCase): ...@@ -329,12 +329,12 @@ class TestVideo(unittest.TestCase):
# full_path, pts_unit="sec" # full_path, pts_unit="sec"
# ) # )
# # pass 2: decode all frames using new api # # pass 2: decode all frames using new api
# reader = Video(full_path, "video") # reader = VideoReader(full_path, "video")
# pts = [] # pts = []
# t, p = reader.next() # t, p = next(reader)
# while t.numel() > 0: # while t.numel() > 0: # THIS NEEDS TO BE FIXED
# pts.append(p) # pts.append(p)
# t, p = reader.next() # t, p = next(reader)
# tv_timestamps = [float(p) for p in tv_timestamps] # tv_timestamps = [float(p) for p in tv_timestamps]
# napi_pts = [float(p) for p in pts] # napi_pts = [float(p) for p in pts]
...@@ -353,7 +353,7 @@ class TestVideo(unittest.TestCase): ...@@ -353,7 +353,7 @@ class TestVideo(unittest.TestCase):
torchvision.set_video_backend("pyav") torchvision.set_video_backend("pyav")
for test_video, config in test_videos.items(): for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
reader = Video(full_path, "video") reader = VideoReader(full_path, "video")
reader_md = reader.get_metadata() reader_md = reader.get_metadata()
self.assertAlmostEqual( self.assertAlmostEqual(
config.video_fps, reader_md["video"]["fps"][0], delta=0.0001 config.video_fps, reader_md["video"]["fps"][0], delta=0.0001
...@@ -372,7 +372,7 @@ class TestVideo(unittest.TestCase): ...@@ -372,7 +372,7 @@ class TestVideo(unittest.TestCase):
ref_result = _decode_frames_by_av_module(full_path) ref_result = _decode_frames_by_av_module(full_path)
reader = Video(full_path, "video") reader = VideoReader(full_path, "video")
newapi_result = _template_read_video(reader) newapi_result = _template_read_video(reader)
# First we check if the frames are approximately the same # First we check if the frames are approximately the same
......
...@@ -27,87 +27,98 @@ from .image import ( ...@@ -27,87 +27,98 @@ from .image import (
if _HAS_VIDEO_OPT: if _HAS_VIDEO_OPT:
def _has_video_opt():
return True
else:
def _has_video_opt():
return False
class Video:
"""
Fine-grained video-reading API.
Supports frame-by-frame reading of various streams from a single video
container.
Args:
path (string): Path to the video file in supported format class VideoReader:
"""
Fine-grained video-reading API.
Supports frame-by-frame reading of various streams from a single video
container.
stream (string, optional): descriptor of the required stream. Defaults to "video:0" Example:
Currently available options include :mod:`['video', 'audio', 'cc', 'sub']` The following examples creates :mod:`Video` object, seeks into 2s
point, and returns a single frame::
import torchvision
video_path = "path_to_a_test_video"
Example: reader = torchvision.io.VideoReader(video_path, "video")
The following examples creates :mod:`Video` object, seeks into 2s reader.seek(2.0)
point, and returns a single frame:: frame, timestamp = next(reader)
import torchvision
video_path = "path_to_a_test_video"
reader = torchvision.io.Video(video_path, "video") Args:
reader.seek(2.0)
frame, timestamp = reader.next()
"""
def __init__(self, path, stream="video"): path (string): Path to the video file in supported format
self._c = torch.classes.torchvision.Video(path, stream)
def next(self): stream (string, optional): descriptor of the required stream. Defaults to "video:0"
"""Iterator that decodes the next frame of the current stream Currently available options include :mod:`['video', 'audio', 'cc', 'sub']`
"""
Returns: def __init__(self, path, stream="video"):
([torch.Tensor, float]): list containing decoded frame and corresponding timestamp if not _has_video_opt():
raise RuntimeError("Not compiled with video_reader support")
self._c = torch.classes.torchvision.Video(path, stream)
""" def __next__(self):
return self._c.next() """Decodes and returns the next frame of the current stream
def seek(self, time_s: float): Returns:
"""Seek within current stream. ([torch.Tensor, float]): list containing decoded frame and corresponding timestamp
Args: """
time_s (float): seek time in seconds frame, pts = self._c.next()
if frame.numel() == 0:
raise StopIteration
return frame, pts
.. note:: def __iter__(self):
Current implementation is the so-called precise seek. This return self
means following seek, call to :mod:`next()` will return the
frame with the exact timestamp if it exists or
the first frame with timestamp larger than time_s.
"""
self._c.seek(time_s)
def get_metadata(self): def seek(self, time_s: float):
"""Returns video metadata """Seek within current stream.
Returns: Args:
(dict): dictionary containing duration and frame rate for every stream time_s (float): seek time in seconds
"""
return self._c.get_metadata()
def set_current_stream(self, stream: str): .. note::
"""Set current stream. Current implementation is the so-called precise seek. This
Explicitly define the stream we are operating on. means following seek, call to :mod:`next()` will return the
frame with the exact timestamp if it exists or
the first frame with timestamp larger than time_s.
"""
self._c.seek(time_s)
return self
Args: def get_metadata(self):
stream (string): descriptor of the required stream. Defaults to "video:0" """Returns video metadata
Currently available stream types include :mod:`['video', 'audio', 'cc', 'sub']`.
Each descriptor consists of two parts: stream type (e.g. 'video') and
a unique stream id (which are determined by video encoding).
In this way, if the video contaner contains multiple
streams of the same type, users can acces the one they want.
If only stream type is passed, the decoder auto-detects first stream
of that type and returns it.
Returns: Returns:
(bool): True on succes, False otherwise (dict): dictionary containing duration and frame rate for every stream
""" """
return self._c.set_current_stream(stream) return self._c.get_metadata()
def set_current_stream(self, stream: str):
"""Set current stream.
Explicitly define the stream we are operating on.
else: Args:
Video = None stream (string): descriptor of the required stream. Defaults to "video:0"
Currently available stream types include :mod:`['video', 'audio', 'cc', 'sub']`.
Each descriptor consists of two parts: stream type (e.g. 'video') and
a unique stream id (which are determined by video encoding).
In this way, if the video contaner contains multiple
streams of the same type, users can acces the one they want.
If only stream type is passed, the decoder auto-detects first stream
of that type and returns it.
Returns:
(bool): True on succes, False otherwise
"""
return self._c.set_current_stream(stream)
__all__ = [ __all__ = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment