"vscode:/vscode.git/clone" did not exist on "b0b92e70479069776b1143cc47dd8ba48f790fc3"
Unverified Commit d710f3d1 authored by Bruno Korbar's avatar Bruno Korbar Committed by GitHub
Browse files

Bkorbar/pyavapi (#6943)



* Test: add backend parameter

* VideoReader object now works on backend

* Frame reading now passes

* Keyframe seek now passes

* Pyav backend now supports metadata

* changes in test to reflect GPU decoder change

* Linter?

* Test GPU output

* Addressing Joao's comments

* lint

* lint

* Revert "Test GPU output"

This reverts commit f62e955d7dc81bcb23b40d58ea75413b9b62e76d.

* lint?

* lint

* lint

* Address issues in build?

* hopefully doc fix

* Arrgh

* arrgh

* fix typos

* fix input options

* remove read from memory option in pyav

* skip read from mem test for gpu and pyab be

* fix test

* remove unused import

* Hack to get reading from memory work with pyav

* patch audio test

* gallery change in a hope that docs won't break

* check video decoder inside io

* adding missing lib loading code

* remove unused input
Co-authored-by: default avatarBruno Korbar <bkorbar@quansight.com>
Co-authored-by: default avatarJoao Gomes <jdsgomes@fb.com>
parent b1054cbb
...@@ -32,6 +32,7 @@ videos, together with the examples on how to build datasets and more. ...@@ -32,6 +32,7 @@ videos, together with the examples on how to build datasets and more.
import torch import torch
import torchvision import torchvision
from torchvision.datasets.utils import download_url from torchvision.datasets.utils import download_url
torchvision.set_video_backend("video_reader")
# Download the sample video # Download the sample video
download_url( download_url(
......
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import pytest import pytest
import torch import torch
import torchvision
from torchvision.io import _HAS_GPU_VIDEO_DECODER, VideoReader from torchvision.io import _HAS_GPU_VIDEO_DECODER, VideoReader
try: try:
...@@ -29,8 +30,9 @@ class TestVideoGPUDecoder: ...@@ -29,8 +30,9 @@ class TestVideoGPUDecoder:
], ],
) )
def test_frame_reading(self, video_file): def test_frame_reading(self, video_file):
torchvision.set_video_backend("cuda")
full_path = os.path.join(VIDEO_DIR, video_file) full_path = os.path.join(VIDEO_DIR, video_file)
decoder = VideoReader(full_path, device="cuda") decoder = VideoReader(full_path)
with av.open(full_path) as container: with av.open(full_path) as container:
for av_frame in container.decode(container.streams.video[0]): for av_frame in container.decode(container.streams.video[0]):
av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray()) av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
...@@ -54,7 +56,8 @@ class TestVideoGPUDecoder: ...@@ -54,7 +56,8 @@ class TestVideoGPUDecoder:
], ],
) )
def test_seek_reading(self, keyframes, full_path, duration): def test_seek_reading(self, keyframes, full_path, duration):
decoder = VideoReader(full_path, device="cuda") torchvision.set_video_backend("cuda")
decoder = VideoReader(full_path)
time = duration / 2 time = duration / 2
decoder.seek(time, keyframes_only=keyframes) decoder.seek(time, keyframes_only=keyframes)
with av.open(full_path) as container: with av.open(full_path) as container:
...@@ -79,8 +82,9 @@ class TestVideoGPUDecoder: ...@@ -79,8 +82,9 @@ class TestVideoGPUDecoder:
], ],
) )
def test_metadata(self, video_file): def test_metadata(self, video_file):
torchvision.set_video_backend("cuda")
full_path = os.path.join(VIDEO_DIR, video_file) full_path = os.path.join(VIDEO_DIR, video_file)
decoder = VideoReader(full_path, device="cuda") decoder = VideoReader(full_path)
video_metadata = decoder.get_metadata()["video"] video_metadata = decoder.get_metadata()["video"]
with av.open(full_path) as container: with av.open(full_path) as container:
video = container.streams.video[0] video = container.streams.video[0]
......
...@@ -53,7 +53,9 @@ test_videos = { ...@@ -53,7 +53,9 @@ test_videos = {
class TestVideoApi: class TestVideoApi:
@pytest.mark.skipif(av is None, reason="PyAV unavailable") @pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize("test_video", test_videos.keys()) @pytest.mark.parametrize("test_video", test_videos.keys())
def test_frame_reading(self, test_video): @pytest.mark.parametrize("backend", ["video_reader", "pyav"])
def test_frame_reading(self, test_video, backend):
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
with av.open(full_path) as av_reader: with av.open(full_path) as av_reader:
if av_reader.streams.video: if av_reader.streams.video:
...@@ -117,9 +119,15 @@ class TestVideoApi: ...@@ -117,9 +119,15 @@ class TestVideoApi:
@pytest.mark.parametrize("stream", ["video", "audio"]) @pytest.mark.parametrize("stream", ["video", "audio"])
@pytest.mark.parametrize("test_video", test_videos.keys()) @pytest.mark.parametrize("test_video", test_videos.keys())
def test_frame_reading_mem_vs_file(self, test_video, stream): @pytest.mark.parametrize("backend", ["video_reader", "pyav"])
def test_frame_reading_mem_vs_file(self, test_video, stream, backend):
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
reader = VideoReader(full_path)
reader_md = reader.get_metadata()
if stream in reader_md:
# Test video reading from file vs from memory # Test video reading from file vs from memory
vr_frames, vr_frames_mem = [], [] vr_frames, vr_frames_mem = [], []
vr_pts, vr_pts_mem = [], [] vr_pts, vr_pts_mem = [], []
...@@ -154,13 +162,17 @@ class TestVideoApi: ...@@ -154,13 +162,17 @@ class TestVideoApi:
assert mean_delta.item() < 2.55 assert mean_delta.item() < 2.55
del vr_frames, vr_pts, vr_frames_mem, vr_pts_mem del vr_frames, vr_pts, vr_frames_mem, vr_pts_mem
else:
del reader, reader_md
@pytest.mark.parametrize("test_video,config", test_videos.items()) @pytest.mark.parametrize("test_video,config", test_videos.items())
def test_metadata(self, test_video, config): @pytest.mark.parametrize("backend", ["video_reader", "pyav"])
def test_metadata(self, test_video, config, backend):
""" """
Test that the metadata returned via pyav corresponds to the one returned Test that the metadata returned via pyav corresponds to the one returned
by the new video decoder API by the new video decoder API
""" """
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
reader = VideoReader(full_path, "video") reader = VideoReader(full_path, "video")
reader_md = reader.get_metadata() reader_md = reader.get_metadata()
...@@ -168,7 +180,9 @@ class TestVideoApi: ...@@ -168,7 +180,9 @@ class TestVideoApi:
assert config.duration == approx(reader_md["video"]["duration"][0], abs=0.5) assert config.duration == approx(reader_md["video"]["duration"][0], abs=0.5)
@pytest.mark.parametrize("test_video", test_videos.keys()) @pytest.mark.parametrize("test_video", test_videos.keys())
def test_seek_start(self, test_video): @pytest.mark.parametrize("backend", ["video_reader", "pyav"])
def test_seek_start(self, test_video, backend):
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
video_reader = VideoReader(full_path, "video") video_reader = VideoReader(full_path, "video")
num_frames = 0 num_frames = 0
...@@ -194,7 +208,9 @@ class TestVideoApi: ...@@ -194,7 +208,9 @@ class TestVideoApi:
assert start_num_frames == num_frames assert start_num_frames == num_frames
@pytest.mark.parametrize("test_video", test_videos.keys()) @pytest.mark.parametrize("test_video", test_videos.keys())
def test_accurateseek_middle(self, test_video): @pytest.mark.parametrize("backend", ["video_reader"])
def test_accurateseek_middle(self, test_video, backend):
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
stream = "video" stream = "video"
video_reader = VideoReader(full_path, stream) video_reader = VideoReader(full_path, stream)
...@@ -233,7 +249,9 @@ class TestVideoApi: ...@@ -233,7 +249,9 @@ class TestVideoApi:
@pytest.mark.skipif(av is None, reason="PyAV unavailable") @pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize("test_video,config", test_videos.items()) @pytest.mark.parametrize("test_video,config", test_videos.items())
def test_keyframe_reading(self, test_video, config): @pytest.mark.parametrize("backend", ["pyav", "video_reader"])
def test_keyframe_reading(self, test_video, config, backend):
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video) full_path = os.path.join(VIDEO_DIR, test_video)
av_reader = av.open(full_path) av_reader = av.open(full_path)
......
import os import os
import warnings import warnings
from modulefinder import Module
import torch import torch
from torchvision import datasets, io, models, ops, transforms, utils from torchvision import datasets, io, models, ops, transforms, utils
...@@ -11,6 +12,7 @@ try: ...@@ -11,6 +12,7 @@ try:
except ImportError: except ImportError:
pass pass
# Check if torchvision is being imported within the root folder # Check if torchvision is being imported within the root folder
if not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == os.path.join( if not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == os.path.join(
os.path.realpath(os.getcwd()), "torchvision" os.path.realpath(os.getcwd()), "torchvision"
...@@ -66,11 +68,16 @@ def set_video_backend(backend): ...@@ -66,11 +68,16 @@ def set_video_backend(backend):
backend, please compile torchvision from source. backend, please compile torchvision from source.
""" """
global _video_backend global _video_backend
if backend not in ["pyav", "video_reader"]: if backend not in ["pyav", "video_reader", "cuda"]:
raise ValueError("Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend) raise ValueError("Invalid video backend '%s'. Options are 'pyav', 'video_reader' and 'cuda'" % backend)
if backend == "video_reader" and not io._HAS_VIDEO_OPT: if backend == "video_reader" and not io._HAS_VIDEO_OPT:
# TODO: better messages
message = "video_reader video backend is not available. Please compile torchvision from source and try again" message = "video_reader video backend is not available. Please compile torchvision from source and try again"
warnings.warn(message) raise RuntimeError(message)
elif backend == "cuda" and not io._HAS_GPU_VIDEO_DECODER:
# TODO: better messages
message = "cuda video backend is not available."
raise RuntimeError(message)
else: else:
_video_backend = backend _video_backend = backend
......
...@@ -4,10 +4,6 @@ import torch ...@@ -4,10 +4,6 @@ import torch
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
try:
from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER
except ModuleNotFoundError:
_HAS_GPU_VIDEO_DECODER = False
from ._video_opt import ( from ._video_opt import (
_HAS_VIDEO_OPT, _HAS_VIDEO_OPT,
_probe_video_from_file, _probe_video_from_file,
...@@ -32,7 +28,7 @@ from .image import ( ...@@ -32,7 +28,7 @@ from .image import (
write_jpeg, write_jpeg,
write_png, write_png,
) )
from .video import read_video, read_video_timestamps, write_video from .video import _HAS_GPU_VIDEO_DECODER, read_video, read_video_timestamps, write_video
from .video_reader import VideoReader from .video_reader import VideoReader
......
from ..extension import _load_library
try:
_load_library("Decoder")
_HAS_GPU_VIDEO_DECODER = True
except (ImportError, OSError):
_HAS_GPU_VIDEO_DECODER = False
...@@ -9,9 +9,16 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -9,9 +9,16 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from ..extension import _load_library
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from . import _video_opt from . import _video_opt
try:
_load_library("Decoder")
_HAS_GPU_VIDEO_DECODER = True
except (ImportError, OSError, ModuleNotFoundError):
_HAS_GPU_VIDEO_DECODER = False
try: try:
import av import av
......
import io
import warnings import warnings
from typing import Any, Dict, Iterator, Optional from typing import Any, Dict, Iterator, Optional
import torch import torch
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
try:
from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER
except ModuleNotFoundError:
_HAS_GPU_VIDEO_DECODER = False
from ._video_opt import _HAS_VIDEO_OPT from ._video_opt import _HAS_VIDEO_OPT
if _HAS_VIDEO_OPT: if _HAS_VIDEO_OPT:
...@@ -22,11 +20,37 @@ else: ...@@ -22,11 +20,37 @@ else:
return False return False
try:
import av
av.logging.set_level(av.logging.ERROR)
if not hasattr(av.video.frame.VideoFrame, "pict_type"):
av = ImportError(
"""\
Your version of PyAV is too old for the necessary video operations in torchvision.
If you are on Python 3.5, you will have to build from source (the conda-forge
packages are not up-to-date). See
https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
"""
)
except ImportError:
av = ImportError(
"""\
PyAV is not installed, and is necessary for the video operations in torchvision.
See https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
"""
)
class VideoReader: class VideoReader:
""" """
Fine-grained video-reading API. Fine-grained video-reading API.
Supports frame-by-frame reading of various streams from a single video Supports frame-by-frame reading of various streams from a single video
container. container. Much like previous video_reader API it supports the following
backends: video_reader, pyav, and cuda.
Backends can be set via `torchvision.set_video_backend` function.
.. betastatus:: VideoReader class .. betastatus:: VideoReader class
...@@ -88,16 +112,11 @@ class VideoReader: ...@@ -88,16 +112,11 @@ class VideoReader:
Default value (0) enables multithreading with codec-dependent heuristic. The performance Default value (0) enables multithreading with codec-dependent heuristic. The performance
will depend on the version of FFMPEG codecs supported. will depend on the version of FFMPEG codecs supported.
device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``.
To use GPU decoding, pass ``device="cuda"``.
path (str, optional): path (str, optional):
.. warning: .. warning:
This parameter was deprecated in ``0.15`` and will be removed in ``0.17``. This parameter was deprecated in ``0.15`` and will be removed in ``0.17``.
Please use ``src`` instead. Please use ``src`` instead.
""" """
def __init__( def __init__(
...@@ -105,44 +124,58 @@ class VideoReader: ...@@ -105,44 +124,58 @@ class VideoReader:
src: str = "", src: str = "",
stream: str = "video", stream: str = "video",
num_threads: int = 0, num_threads: int = 0,
device: str = "cpu",
path: Optional[str] = None, path: Optional[str] = None,
) -> None: ) -> None:
_log_api_usage_once(self) _log_api_usage_once(self)
self.is_cuda = False from .. import get_video_backend
device = torch.device(device)
if device.type == "cuda":
if not _HAS_GPU_VIDEO_DECODER:
raise RuntimeError("Not compiled with GPU decoder support.")
self.is_cuda = True
self._c = torch.classes.torchvision.GPUDecoder(src, device)
return
if not _has_video_opt():
raise RuntimeError(
"Not compiled with video_reader support, "
+ "to enable video_reader support, please install "
+ "ffmpeg (version 4.2 is currently supported) and "
+ "build torchvision from source."
)
self.backend = get_video_backend()
if isinstance(src, str):
if src == "": if src == "":
if path is None: if path is None:
raise TypeError("src cannot be empty") raise TypeError("src cannot be empty")
src = path src = path
warnings.warn("path is deprecated and will be removed in 0.17. Please use src instead") warnings.warn("path is deprecated and will be removed in 0.17. Please use src instead")
elif isinstance(src, bytes): elif isinstance(src, bytes):
if self.backend in ["cuda"]:
raise RuntimeError(
"VideoReader cannot be initialized from bytes object when using cuda or pyav backend."
)
elif self.backend == "pyav":
src = io.BytesIO(src)
else:
src = torch.frombuffer(src, dtype=torch.uint8) src = torch.frombuffer(src, dtype=torch.uint8)
elif isinstance(src, torch.Tensor):
if self.backend in ["cuda", "pyav"]:
raise RuntimeError(
"VideoReader cannot be initialized from Tensor object when using cuda or pyav backend."
)
else:
raise TypeError("`src` must be either string, Tensor or bytes object.")
if self.backend == "cuda":
device = torch.device("cuda")
self._c = torch.classes.torchvision.GPUDecoder(src, device)
elif self.backend == "video_reader":
if isinstance(src, str): if isinstance(src, str):
self._c = torch.classes.torchvision.Video(src, stream, num_threads) self._c = torch.classes.torchvision.Video(src, stream, num_threads)
elif isinstance(src, torch.Tensor): elif isinstance(src, torch.Tensor):
if self.is_cuda:
raise RuntimeError("GPU VideoReader cannot be initialized from Tensor or bytes object.")
self._c = torch.classes.torchvision.Video("", "", 0) self._c = torch.classes.torchvision.Video("", "", 0)
self._c.init_from_memory(src, stream, num_threads) self._c.init_from_memory(src, stream, num_threads)
elif self.backend == "pyav":
self.container = av.open(src, metadata_errors="ignore")
# TODO: load metadata
stream_type = stream.split(":")[0]
stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
self.pyav_stream = {stream_type: stream_id}
self._c = self.container.decode(**self.pyav_stream)
# TODO: add extradata exception
else: else:
raise TypeError("`src` must be either string, Tensor or bytes object.") raise RuntimeError("Unknown video backend: {}".format(self.backend))
def __next__(self) -> Dict[str, Any]: def __next__(self) -> Dict[str, Any]:
"""Decodes and returns the next frame of the current stream. """Decodes and returns the next frame of the current stream.
...@@ -156,14 +189,29 @@ class VideoReader: ...@@ -156,14 +189,29 @@ class VideoReader:
and corresponding timestamp (``pts``) in seconds and corresponding timestamp (``pts``) in seconds
""" """
if self.is_cuda: if self.backend == "cuda":
frame = self._c.next() frame = self._c.next()
if frame.numel() == 0: if frame.numel() == 0:
raise StopIteration raise StopIteration
return {"data": frame} return {"data": frame, "pts": None}
elif self.backend == "video_reader":
frame, pts = self._c.next() frame, pts = self._c.next()
else:
try:
frame = next(self._c)
pts = float(frame.pts * frame.time_base)
if "video" in self.pyav_stream:
frame = torch.tensor(frame.to_rgb().to_ndarray()).permute(2, 0, 1)
elif "audio" in self.pyav_stream:
frame = torch.tensor(frame.to_ndarray()).permute(1, 0)
else:
frame = None
except av.error.EOFError:
raise StopIteration
if frame.numel() == 0: if frame.numel() == 0:
raise StopIteration raise StopIteration
return {"data": frame, "pts": pts} return {"data": frame, "pts": pts}
def __iter__(self) -> Iterator[Dict[str, Any]]: def __iter__(self) -> Iterator[Dict[str, Any]]:
...@@ -182,7 +230,18 @@ class VideoReader: ...@@ -182,7 +230,18 @@ class VideoReader:
frame with the exact timestamp if it exists or frame with the exact timestamp if it exists or
the first frame with timestamp larger than ``time_s``. the first frame with timestamp larger than ``time_s``.
""" """
if self.backend in ["cuda", "video_reader"]:
self._c.seek(time_s, keyframes_only) self._c.seek(time_s, keyframes_only)
else:
# handle special case as pyav doesn't catch it
if time_s < 0:
time_s = 0
temp_str = self.container.streams.get(**self.pyav_stream)[0]
offset = int(round(time_s / temp_str.time_base))
if not keyframes_only:
warnings.warn("Accurate seek is not implemented for pyav backend")
self.container.seek(offset, backward=True, any_frame=False, stream=temp_str)
self._c = self.container.decode(**self.pyav_stream)
return self return self
def get_metadata(self) -> Dict[str, Any]: def get_metadata(self) -> Dict[str, Any]:
...@@ -191,6 +250,21 @@ class VideoReader: ...@@ -191,6 +250,21 @@ class VideoReader:
Returns: Returns:
(dict): dictionary containing duration and frame rate for every stream (dict): dictionary containing duration and frame rate for every stream
""" """
if self.backend == "pyav":
metadata = {} # type: Dict[str, Any]
for stream in self.container.streams:
if stream.type not in metadata:
if stream.type == "video":
rate_n = "fps"
else:
rate_n = "framerate"
metadata[stream.type] = {rate_n: [], "duration": []}
rate = stream.average_rate if stream.average_rate is not None else stream.sample_rate
metadata[stream.type]["duration"].append(float(stream.duration * stream.time_base))
metadata[stream.type][rate_n].append(float(rate))
return metadata
return self._c.get_metadata() return self._c.get_metadata()
def set_current_stream(self, stream: str) -> bool: def set_current_stream(self, stream: str) -> bool:
...@@ -210,6 +284,12 @@ class VideoReader: ...@@ -210,6 +284,12 @@ class VideoReader:
Returns: Returns:
(bool): True on succes, False otherwise (bool): True on succes, False otherwise
""" """
if self.is_cuda: if self.backend == "cuda":
print("GPU decoding only works with video stream.") warnings.warn("GPU decoding only works with video stream.")
if self.backend == "pyav":
stream_type = stream.split(":")[0]
stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
self.pyav_stream = {stream_type: stream_id}
self._c = self.container.decode(**self.pyav_stream)
return True
return self._c.set_current_stream(stream) return self._c.set_current_stream(stream)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment