Unverified Commit 039c9288 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Minor improvements to VideoReader (#2781)



* Minor improvements to VideoReader

* update jupyter notebook with new naming
Co-authored-by: default avatarBruno Korbar <bjuncek@gmail.com>
parent 610c9d2a
...@@ -24,7 +24,7 @@ In addition to the :mod:`read_video` function, we provide a high-performance ...@@ -24,7 +24,7 @@ In addition to the :mod:`read_video` function, we provide a high-performance
lower-level API for more fine-grained control compared to the :mod:`read_video` function. lower-level API for more fine-grained control compared to the :mod:`read_video` function.
It does all this whilst fully supporting torchscript. It does all this whilst fully supporting torchscript.
.. autoclass:: Video .. autoclass:: VideoReader
:members: next, get_metadata, set_current_stream, seek :members: next, get_metadata, set_current_stream, seek
...@@ -37,7 +37,7 @@ Example of usage: ...@@ -37,7 +37,7 @@ Example of usage:
# Constructor allocates memory and a threaded decoder # Constructor allocates memory and a threaded decoder
# instance per video. At the momet it takes two arguments: # instance per video. At the momet it takes two arguments:
# path to the video file, and a wanted stream. # path to the video file, and a wanted stream.
reader = torchvision.io.Video(video_path, "video") reader = torchvision.io.VideoReader(video_path, "video")
# The information about the video can be retrieved using the # The information about the video can be retrieved using the
# `get_metadata()` method. It returns a dictionary for every stream, with # `get_metadata()` method. It returns a dictionary for every stream, with
......
...@@ -16,16 +16,16 @@ ...@@ -16,16 +16,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 36, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"('1.7.0a0+f5c95d5', '0.8.0a0+6eff0a4')" "('1.7.0a0+f5c95d5', '0.8.0a0+a2f405d')"
] ]
}, },
"execution_count": 36, "execution_count": 1,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -37,14 +37,21 @@ ...@@ -37,14 +37,21 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 37, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Using downloaded and verified file: ./WUzgd7C1pWA.mp4\n" "Downloading https://github.com/pytorch/vision/blob/master/test/assets/videos/WUzgd7C1pWA.mp4?raw=true to ./WUzgd7C1pWA.mp4\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100.4%"
] ]
} }
], ],
...@@ -65,7 +72,7 @@ ...@@ -65,7 +72,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 38, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -91,7 +98,7 @@ ...@@ -91,7 +98,7 @@
"\n", "\n",
"\n", "\n",
"\n", "\n",
"video = torch.classes.torchvision.Video(video_path, stream)" "video = torchvision.io.VideoReader(video_path, stream)"
] ]
}, },
{ {
...@@ -103,7 +110,7 @@ ...@@ -103,7 +110,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 39, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -113,7 +120,7 @@ ...@@ -113,7 +120,7 @@
" 'audio': {'duration': [10.9], 'framerate': [48000.0]}}" " 'audio': {'duration': [10.9], 'framerate': [48000.0]}}"
] ]
}, },
"execution_count": 39, "execution_count": 6,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -133,7 +140,7 @@ ...@@ -133,7 +140,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 40, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -152,11 +159,8 @@ ...@@ -152,11 +159,8 @@
"video.set_current_stream(\"video:0\")\n", "video.set_current_stream(\"video:0\")\n",
"\n", "\n",
"frames = [] # we are going to save the frames here.\n", "frames = [] # we are going to save the frames here.\n",
"frame, pts = video.next()\n", "for frame, pts in video:\n",
"# note that next will return emptyframe at the end of the video stream\n",
"while frame.numel() != 0:\n",
" frames.append(frame)\n", " frames.append(frame)\n",
" frame, pts = video.next()\n",
" \n", " \n",
"print(\"Total number of frames: \", len(frames))\n", "print(\"Total number of frames: \", len(frames))\n",
"approx_nf = metadata['video']['duration'][0] * metadata['video']['fps'][0]\n", "approx_nf = metadata['video']['duration'][0] * metadata['video']['fps'][0]\n",
...@@ -175,7 +179,7 @@ ...@@ -175,7 +179,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 41, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -193,11 +197,8 @@ ...@@ -193,11 +197,8 @@
"video.set_current_stream(\"audio\")\n", "video.set_current_stream(\"audio\")\n",
"\n", "\n",
"frames = [] # we are going to save the frames here.\n", "frames = [] # we are going to save the frames here.\n",
"frame, pts = video.next()\n", "for frame, pts in video:\n",
"# note that next will return emptyframe at the end of the audio stream\n",
"while frame.numel() != 0:\n",
" frames.append(frame)\n", " frames.append(frame)\n",
" frame, pts = video.next()\n",
" \n", " \n",
"print(\"Total number of frames: \", len(frames))\n", "print(\"Total number of frames: \", len(frames))\n",
"approx_nf = metadata['audio']['duration'][0] * metadata['audio']['framerate'][0]\n", "approx_nf = metadata['audio']['duration'][0] * metadata['audio']['framerate'][0]\n",
...@@ -211,14 +212,48 @@ ...@@ -211,14 +212,48 @@
"source": [ "source": [
"But what if we only want to read certain time segment of the video?\n", "But what if we only want to read certain time segment of the video?\n",
"\n", "\n",
"That can be done easily using the combination of our seek function, and the fact that each call to next returns the presentation timestamp of the returned frame in seconds.\n", "That can be done easily using the combination of our seek function, and the fact that each call to next returns the presentation timestamp of the returned frame in seconds. Given that our implementation relies on python iterators, we can leverage `itertools` to simplify the process and make it more pythonic. \n",
"\n", "\n",
"For example, if we wanted to read video from second to fifth second:" "For example, if we wanted to read ten frames from second second:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 42, "execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of frames: 10\n"
]
}
],
"source": [
"import itertools\n",
"video.set_current_stream(\"video\")\n",
"\n",
"frames = [] # we are going to save the frames here.\n",
"\n",
"# we seek into a second second of the video\n",
"# and use islice to get 10 frames since\n",
"for frame, pts in itertools.islice(video.seek(2), 10):\n",
" frames.append(frame)\n",
" \n",
"print(\"Total number of frames: \", len(frames))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or if we wanted to read from 2nd to 5th second:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -236,15 +271,13 @@ ...@@ -236,15 +271,13 @@
"\n", "\n",
"frames = [] # we are going to save the frames here.\n", "frames = [] # we are going to save the frames here.\n",
"\n", "\n",
"# we seek into a second second of the video \n", "# we seek into a second second of the video\n",
"# the following call to next returns the first following frame\n", "video = video.seek(2)\n",
"video.seek(2) \n", "# then we utilize the itertools takewhile to get the \n",
"frame, pts = video.next()\n", "# correct number of frames\n",
"# note that we add exit condition\n", "for frame, pts in itertools.takewhile(lambda x: x[1] <= 5, video):\n",
"while pts < 5 and frame.numel() != 0:\n",
" frames.append(frame)\n", " frames.append(frame)\n",
" frame, pts = video.next()\n", "\n",
" \n",
"print(\"Total number of frames: \", len(frames))\n", "print(\"Total number of frames: \", len(frames))\n",
"approx_nf = (5-2) * video.get_metadata()['video']['fps'][0]\n", "approx_nf = (5-2) * video.get_metadata()['video']['fps'][0]\n",
"print(\"We can expect approx: \", approx_nf)\n", "print(\"We can expect approx: \", approx_nf)\n",
...@@ -262,7 +295,7 @@ ...@@ -262,7 +295,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 43, "execution_count": 17,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -280,13 +313,10 @@ ...@@ -280,13 +313,10 @@
" video_pts = []\n", " video_pts = []\n",
" if read_video:\n", " if read_video:\n",
" video_object.set_current_stream(\"video\")\n", " video_object.set_current_stream(\"video\")\n",
" video_object.seek(start)\n",
" frames = []\n", " frames = []\n",
" t, pts = video_object.next()\n", " for t, pts in itertools.takewhile(lambda x: x[1] <= end, video_object.seek(start)):\n",
" while t.numel() > 0 and (pts >= start and pts <= end):\n",
" frames.append(t)\n", " frames.append(t)\n",
" video_pts.append(pts)\n", " video_pts.append(pts)\n",
" t, pts = video_object.next()\n",
" if len(frames) > 0:\n", " if len(frames) > 0:\n",
" video_frames = torch.stack(frames, 0)\n", " video_frames = torch.stack(frames, 0)\n",
"\n", "\n",
...@@ -294,13 +324,10 @@ ...@@ -294,13 +324,10 @@
" audio_pts = []\n", " audio_pts = []\n",
" if read_audio:\n", " if read_audio:\n",
" video_object.set_current_stream(\"audio\")\n", " video_object.set_current_stream(\"audio\")\n",
" video_object.seek(start)\n",
" frames = []\n", " frames = []\n",
" t, pts = video_object.next()\n", " for t, pts in itertools.takewhile(lambda x: x[1] <= end, video_object.seek(start)):\n",
" while t.numel() > 0 and (pts >= start and pts <= end):\n",
" frames.append(t)\n", " frames.append(t)\n",
" audio_pts.append(pts)\n", " video_pts.append(pts)\n",
" t, pts = video_object.next()\n",
" if len(frames) > 0:\n", " if len(frames) > 0:\n",
" audio_frames = torch.cat(frames, 0)\n", " audio_frames = torch.cat(frames, 0)\n",
"\n", "\n",
...@@ -309,7 +336,7 @@ ...@@ -309,7 +336,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 44, "execution_count": 19,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -322,13 +349,13 @@ ...@@ -322,13 +349,13 @@
], ],
"source": [ "source": [
"vf, af, info, meta = example_read_video(video)\n", "vf, af, info, meta = example_read_video(video)\n",
"# total number of frames should be 327\n", "# total number of frames should be 327 for video and 523264 datapoints for audio\n",
"print(vf.size(), af.size())" "print(vf.size(), af.size())"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 45, "execution_count": 20,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -337,7 +364,7 @@ ...@@ -337,7 +364,7 @@
"torch.Size([523264, 1])" "torch.Size([523264, 1])"
] ]
}, },
"execution_count": 45, "execution_count": 20,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -362,7 +389,7 @@ ...@@ -362,7 +389,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 46, "execution_count": 21,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -375,7 +402,7 @@ ...@@ -375,7 +402,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 47, "execution_count": 22,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -461,7 +488,7 @@ ...@@ -461,7 +488,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 49, "execution_count": 23,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -499,7 +526,7 @@ ...@@ -499,7 +526,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 50, "execution_count": 33,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -523,16 +550,14 @@ ...@@ -523,16 +550,14 @@
" # get random sample\n", " # get random sample\n",
" path, target = random.choice(self.samples)\n", " path, target = random.choice(self.samples)\n",
" # get video object\n", " # get video object\n",
" vid = torch.classes.torchvision.Video(path, \"video\")\n", " vid = torchvision.io.VideoReader(path, \"video\")\n",
" metadata = vid.get_metadata()\n", " metadata = vid.get_metadata()\n",
" video_frames = [] # video frame buffer \n", " video_frames = [] # video frame buffer \n",
" # seek and return frames\n", " # seek and return frames\n",
" \n", " \n",
" max_seek = metadata[\"video\"]['duration'][0] - (self.clip_len / metadata[\"video\"]['fps'][0])\n", " max_seek = metadata[\"video\"]['duration'][0] - (self.clip_len / metadata[\"video\"]['fps'][0])\n",
" start = random.uniform(0., max_seek)\n", " start = random.uniform(0., max_seek)\n",
" vid.seek(start)\n", " for frame, current_pts in itertools.islice(vid.seek(start), self.clip_len):\n",
" while len(video_frames) < self.clip_len:\n",
" frame, current_pts = vid.next()\n",
" video_frames.append(self.frame_transform(frame))\n", " video_frames.append(self.frame_transform(frame))\n",
" # stack it into a tensor\n", " # stack it into a tensor\n",
" video = torch.stack(video_frames, 0)\n", " video = torch.stack(video_frames, 0)\n",
...@@ -570,7 +595,7 @@ ...@@ -570,7 +595,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 51, "execution_count": 34,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -583,42 +608,48 @@ ...@@ -583,42 +608,48 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 52, "execution_count": 39,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from torch.utils.data import DataLoader\n", "from torch.utils.data import DataLoader\n",
"loader = DataLoader(ds, batch_size=12)\n", "loader = DataLoader(ds, batch_size=12)\n",
"d = {\"video\":[], 'start':[], 'end':[]}\n", "d = {\"video\":[], 'start':[], 'end':[], 'tensorsize':[]}\n",
"for b in loader:\n", "for b in loader:\n",
" for i in range(len(b['path'])):\n", " for i in range(len(b['path'])):\n",
" d['video'].append(b['path'][i])\n", " d['video'].append(b['path'][i])\n",
" d['start'].append(b['start'][i].item())\n", " d['start'].append(b['start'][i].item())\n",
" d['end'].append(b['end'][i].item())" " d['end'].append(b['end'][i].item())\n",
" d['tensorsize'].append(b['video'][i].size())"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 53, "execution_count": 40,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'video': ['./dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi',\n", "{'video': ['./dataset/1/WUzgd7C1pWA.mp4',\n",
" './dataset/1/WUzgd7C1pWA.mp4',\n", " './dataset/1/WUzgd7C1pWA.mp4',\n",
" './dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi',\n", " './dataset/2/v_SoccerJuggling_g23_c01.avi',\n",
" './dataset/2/SOX5yA1l24A.mp4',\n", " './dataset/2/v_SoccerJuggling_g23_c01.avi',\n",
" './dataset/2/v_SoccerJuggling_g23_c01.avi'],\n", " './dataset/1/RATRACE_wave_f_nm_np1_fr_goo_37.avi'],\n",
" 'start': [0.029482554081669773,\n", " 'start': [8.97932147319667,\n",
" 3.439334232470971,\n", " 9.421856461438313,\n",
" 1.1823159301599728,\n", " 2.1301381796579437,\n",
" 4.470027811314425,\n", " 5.514273689529127,\n",
" 3.3126303902318432],\n", " 0.31979853297913124],\n",
" 'end': [0.5666669999999999, 3.970633, 1.7, 4.971633, 3.837167]}" " 'end': [9.5095, 9.943266999999999, 2.635967, 6.0393669999999995, 0.833333],\n",
" 'tensorsize': [torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112]),\n",
" torch.Size([16, 3, 112, 112])]}"
] ]
}, },
"execution_count": 53, "execution_count": 40,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -629,7 +660,7 @@ ...@@ -629,7 +660,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 54, "execution_count": 41,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -638,6 +669,13 @@ ...@@ -638,6 +669,13 @@
"os.remove(\"./WUzgd7C1pWA.mp4\")\n", "os.remove(\"./WUzgd7C1pWA.mp4\")\n",
"shutil.rmtree(\"./dataset\")" "shutil.rmtree(\"./dataset\")"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {
......
...@@ -5,12 +5,14 @@ import tempfile ...@@ -5,12 +5,14 @@ import tempfile
import unittest import unittest
import random import random
import itertools
import numpy as np import numpy as np
import torch import torch
import torchvision import torchvision
from torchvision.io import _HAS_VIDEO_OPT, Video from torchvision.io import _HAS_VIDEO_OPT, VideoReader
try: try:
import av import av
...@@ -242,11 +244,11 @@ def _template_read_video(video_object, s=0, e=None): ...@@ -242,11 +244,11 @@ def _template_read_video(video_object, s=0, e=None):
video_frames = torch.empty(0) video_frames = torch.empty(0)
frames = [] frames = []
video_pts = [] video_pts = []
t, pts = video_object.next() for t, pts in itertools.takewhile(lambda x: x[1] <= e, video_object):
while t.numel() > 0 and (pts >= s and pts <= e): if pts < s:
continue
frames.append(t) frames.append(t)
video_pts.append(pts) video_pts.append(pts)
t, pts = video_object.next()
if len(frames) > 0: if len(frames) > 0:
video_frames = torch.stack(frames, 0) video_frames = torch.stack(frames, 0)
...@@ -255,11 +257,11 @@ def _template_read_video(video_object, s=0, e=None): ...@@ -255,11 +257,11 @@ def _template_read_video(video_object, s=0, e=None):
audio_frames = torch.empty(0) audio_frames = torch.empty(0)
frames = [] frames = []
audio_pts = [] audio_pts = []
t, pts = video_object.next() for t, pts in itertools.takewhile(lambda x: x[1] <= e, video_object):
while t.numel() > 0 and (pts > s and pts <= e): if pts < s:
continue
frames.append(t) frames.append(t)
audio_pts.append(pts) audio_pts.append(pts)
t, pts = video_object.next()
if len(frames) > 0: if len(frames) > 0:
audio_frames = torch.stack(frames, 0) audio_frames = torch.stack(frames, 0)
...@@ -289,12 +291,10 @@ class TestVideo(unittest.TestCase): ...@@ -289,12 +291,10 @@ class TestVideo(unittest.TestCase):
tv_result, _, _ = torchvision.io.read_video(full_path, pts_unit="sec") tv_result, _, _ = torchvision.io.read_video(full_path, pts_unit="sec")
tv_result = tv_result.permute(0, 3, 1, 2) tv_result = tv_result.permute(0, 3, 1, 2)
# pass 2: decode all frames using new api # pass 2: decode all frames using new api
reader = Video(full_path, "video") reader = VideoReader(full_path, "video")
frames = [] frames = []
t, _ = reader.next() for t, _ in reader:
while t.numel() > 0:
frames.append(t) frames.append(t)
t, _ = reader.next()
new_api = torch.stack(frames, 0) new_api = torch.stack(frames, 0)
self.assertEqual(tv_result.size(), new_api.size()) self.assertEqual(tv_result.size(), new_api.size())
...@@ -310,7 +310,7 @@ class TestVideo(unittest.TestCase): ...@@ -310,7 +310,7 @@ class TestVideo(unittest.TestCase):
# s = min(r) # s = min(r)
# e = max(r) # e = max(r)
# reader = Video(full_path, "video") # reader = VideoReader(full_path, "video")
# results = _template_read_video(reader, s, e) # results = _template_read_video(reader, s, e)
# tv_video, tv_audio, info = torchvision.io.read_video( # tv_video, tv_audio, info = torchvision.io.read_video(
# full_path, start_pts=s, end_pts=e, pts_unit="sec" # full_path, start_pts=s, end_pts=e, pts_unit="sec"
...@@ -329,12 +329,12 @@ class TestVideo(unittest.TestCase): ...@@ -329,12 +329,12 @@ class TestVideo(unittest.TestCase):
# full_path, pts_unit="sec" # full_path, pts_unit="sec"
# ) # )
# # pass 2: decode all frames using new api # # pass 2: decode all frames using new api
# reader = Video(full_path, "video") # reader = VideoReader(full_path, "video")
# pts = [] # pts = []
# t, p = reader.next() # t, p = next(reader)
# while t.numel() > 0: # while t.numel() > 0: # THIS NEEDS TO BE FIXED
# pts.append(p) # pts.append(p)
# t, p = reader.next() # t, p = next(reader)
# tv_timestamps = [float(p) for p in tv_timestamps] # tv_timestamps = [float(p) for p in tv_timestamps]
# napi_pts = [float(p) for p in pts] # napi_pts = [float(p) for p in pts]
...@@ -353,7 +353,7 @@ class TestVideo(unittest.TestCase): ...@@ -353,7 +353,7 @@ class TestVideo(unittest.TestCase):
torchvision.set_video_backend("pyav") torchvision.set_video_backend("pyav")
for test_video, config in test_videos.items(): for test_video, config in test_videos.items():
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
reader = Video(full_path, "video") reader = VideoReader(full_path, "video")
reader_md = reader.get_metadata() reader_md = reader.get_metadata()
self.assertAlmostEqual( self.assertAlmostEqual(
config.video_fps, reader_md["video"]["fps"][0], delta=0.0001 config.video_fps, reader_md["video"]["fps"][0], delta=0.0001
...@@ -372,7 +372,7 @@ class TestVideo(unittest.TestCase): ...@@ -372,7 +372,7 @@ class TestVideo(unittest.TestCase):
ref_result = _decode_frames_by_av_module(full_path) ref_result = _decode_frames_by_av_module(full_path)
reader = Video(full_path, "video") reader = VideoReader(full_path, "video")
newapi_result = _template_read_video(reader) newapi_result = _template_read_video(reader)
# First we check if the frames are approximately the same # First we check if the frames are approximately the same
......
...@@ -27,42 +27,56 @@ from .image import ( ...@@ -27,42 +27,56 @@ from .image import (
if _HAS_VIDEO_OPT: if _HAS_VIDEO_OPT:
def _has_video_opt():
return True
else:
def _has_video_opt():
return False
class Video: class VideoReader:
""" """
Fine-grained video-reading API. Fine-grained video-reading API.
Supports frame-by-frame reading of various streams from a single video Supports frame-by-frame reading of various streams from a single video
container. container.
Args:
path (string): Path to the video file in supported format
stream (string, optional): descriptor of the required stream. Defaults to "video:0"
Currently available options include :mod:`['video', 'audio', 'cc', 'sub']`
Example: Example:
The following examples creates :mod:`Video` object, seeks into 2s The following examples creates :mod:`Video` object, seeks into 2s
point, and returns a single frame:: point, and returns a single frame::
import torchvision import torchvision
video_path = "path_to_a_test_video" video_path = "path_to_a_test_video"
reader = torchvision.io.Video(video_path, "video") reader = torchvision.io.VideoReader(video_path, "video")
reader.seek(2.0) reader.seek(2.0)
frame, timestamp = reader.next() frame, timestamp = next(reader)
Args:
path (string): Path to the video file in supported format
stream (string, optional): descriptor of the required stream. Defaults to "video:0"
Currently available options include :mod:`['video', 'audio', 'cc', 'sub']`
""" """
def __init__(self, path, stream="video"): def __init__(self, path, stream="video"):
if not _has_video_opt():
raise RuntimeError("Not compiled with video_reader support")
self._c = torch.classes.torchvision.Video(path, stream) self._c = torch.classes.torchvision.Video(path, stream)
def next(self): def __next__(self):
"""Iterator that decodes the next frame of the current stream """Decodes and returns the next frame of the current stream
Returns: Returns:
([torch.Tensor, float]): list containing decoded frame and corresponding timestamp ([torch.Tensor, float]): list containing decoded frame and corresponding timestamp
""" """
return self._c.next() frame, pts = self._c.next()
if frame.numel() == 0:
raise StopIteration
return frame, pts
def __iter__(self):
return self
def seek(self, time_s: float): def seek(self, time_s: float):
"""Seek within current stream. """Seek within current stream.
...@@ -77,6 +91,7 @@ if _HAS_VIDEO_OPT: ...@@ -77,6 +91,7 @@ if _HAS_VIDEO_OPT:
the first frame with timestamp larger than time_s. the first frame with timestamp larger than time_s.
""" """
self._c.seek(time_s) self._c.seek(time_s)
return self
def get_metadata(self): def get_metadata(self):
"""Returns video metadata """Returns video metadata
...@@ -106,10 +121,6 @@ if _HAS_VIDEO_OPT: ...@@ -106,10 +121,6 @@ if _HAS_VIDEO_OPT:
return self._c.set_current_stream(stream) return self._c.set_current_stream(stream)
else:
Video = None
__all__ = [ __all__ = [
"write_video", "write_video",
"read_video", "read_video",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment