Unverified Commit f20177b7 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

[FBcode->GH] Revert "Pyav backend for VideoReader API (#6598)" (#6908)

This reverts commit 2e833520.
parent dc11b1f6
......@@ -3,9 +3,7 @@ import os
import pytest
import torch
import torchvision
from torchvision import _HAS_GPU_VIDEO_DECODER
from torchvision.io import VideoReader
from torchvision.io import _HAS_GPU_VIDEO_DECODER, VideoReader
try:
import av
......@@ -31,9 +29,8 @@ class TestVideoGPUDecoder:
],
)
def test_frame_reading(self, video_file):
torchvision.set_video_backend("cuda")
full_path = os.path.join(VIDEO_DIR, video_file)
decoder = VideoReader(full_path)
decoder = VideoReader(full_path, device="cuda")
with av.open(full_path) as container:
for av_frame in container.decode(container.streams.video[0]):
av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
......@@ -57,8 +54,7 @@ class TestVideoGPUDecoder:
],
)
def test_seek_reading(self, keyframes, full_path, duration):
torchvision.set_video_backend("cuda")
decoder = VideoReader(full_path)
decoder = VideoReader(full_path, device="cuda")
time = duration / 2
decoder.seek(time, keyframes_only=keyframes)
with av.open(full_path) as container:
......@@ -83,9 +79,8 @@ class TestVideoGPUDecoder:
],
)
def test_metadata(self, video_file):
torchvision.set_video_backend("cuda")
full_path = os.path.join(VIDEO_DIR, video_file)
decoder = VideoReader(full_path)
decoder = VideoReader(full_path, device="cuda")
video_metadata = decoder.get_metadata()["video"]
with av.open(full_path) as container:
video = container.streams.video[0]
......
......@@ -53,9 +53,7 @@ test_videos = {
class TestVideoApi:
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
def test_frame_reading(self, test_video, backend):
torchvision.set_video_backend(backend)
def test_frame_reading(self, test_video):
full_path = os.path.join(VIDEO_DIR, test_video)
with av.open(full_path) as av_reader:
if av_reader.streams.video:
......@@ -119,60 +117,50 @@ class TestVideoApi:
@pytest.mark.parametrize("stream", ["video", "audio"])
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
def test_frame_reading_mem_vs_file(self, test_video, stream, backend):
torchvision.set_video_backend(backend)
def test_frame_reading_mem_vs_file(self, test_video, stream):
full_path = os.path.join(VIDEO_DIR, test_video)
reader = VideoReader(full_path)
reader_md = reader.get_metadata()
if stream in reader_md:
# Test video reading from file vs from memory
vr_frames, vr_frames_mem = [], []
vr_pts, vr_pts_mem = [], []
# get vr frames
video_reader = VideoReader(full_path, stream)
for vr_frame in video_reader:
vr_frames.append(vr_frame["data"])
vr_pts.append(vr_frame["pts"])
# get vr frames = read from memory
f = open(full_path, "rb")
fbytes = f.read()
f.close()
video_reader_from_mem = VideoReader(fbytes, stream)
for vr_frame_from_mem in video_reader_from_mem:
vr_frames_mem.append(vr_frame_from_mem["data"])
vr_pts_mem.append(vr_frame_from_mem["pts"])
# same number of frames
assert len(vr_frames) == len(vr_frames_mem)
assert len(vr_pts) == len(vr_pts_mem)
# compare the frames and ptss
for i in range(len(vr_frames)):
assert vr_pts[i] == vr_pts_mem[i]
mean_delta = torch.mean(torch.abs(vr_frames[i].float() - vr_frames_mem[i].float()))
# on average the difference is very small and caused
# by decoding (around 1%)
# TODO: asses empirically how to set this? atm it's 1%
# averaged over all frames
assert mean_delta.item() < 2.55
del vr_frames, vr_pts, vr_frames_mem, vr_pts_mem
else:
del reader, reader_md
# Test video reading from file vs from memory
vr_frames, vr_frames_mem = [], []
vr_pts, vr_pts_mem = [], []
# get vr frames
video_reader = VideoReader(full_path, stream)
for vr_frame in video_reader:
vr_frames.append(vr_frame["data"])
vr_pts.append(vr_frame["pts"])
# get vr frames = read from memory
f = open(full_path, "rb")
fbytes = f.read()
f.close()
video_reader_from_mem = VideoReader(fbytes, stream)
for vr_frame_from_mem in video_reader_from_mem:
vr_frames_mem.append(vr_frame_from_mem["data"])
vr_pts_mem.append(vr_frame_from_mem["pts"])
# same number of frames
assert len(vr_frames) == len(vr_frames_mem)
assert len(vr_pts) == len(vr_pts_mem)
# compare the frames and ptss
for i in range(len(vr_frames)):
assert vr_pts[i] == vr_pts_mem[i]
mean_delta = torch.mean(torch.abs(vr_frames[i].float() - vr_frames_mem[i].float()))
# on average the difference is very small and caused
# by decoding (around 1%)
# TODO: asses empirically how to set this? atm it's 1%
# averaged over all frames
assert mean_delta.item() < 2.55
del vr_frames, vr_pts, vr_frames_mem, vr_pts_mem
@pytest.mark.parametrize("test_video,config", test_videos.items())
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
def test_metadata(self, test_video, config, backend):
def test_metadata(self, test_video, config):
"""
Test that the metadata returned via pyav corresponds to the one returned
by the new video decoder API
"""
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video)
reader = VideoReader(full_path, "video")
reader_md = reader.get_metadata()
......@@ -180,9 +168,7 @@ class TestVideoApi:
assert config.duration == approx(reader_md["video"]["duration"][0], abs=0.5)
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
def test_seek_start(self, test_video, backend):
torchvision.set_video_backend(backend)
def test_seek_start(self, test_video):
full_path = os.path.join(VIDEO_DIR, test_video)
video_reader = VideoReader(full_path, "video")
num_frames = 0
......@@ -208,9 +194,7 @@ class TestVideoApi:
assert start_num_frames == num_frames
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", ["video_reader"])
def test_accurateseek_middle(self, test_video, backend):
torchvision.set_video_backend(backend)
def test_accurateseek_middle(self, test_video):
full_path = os.path.join(VIDEO_DIR, test_video)
stream = "video"
video_reader = VideoReader(full_path, stream)
......@@ -249,9 +233,7 @@ class TestVideoApi:
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize("test_video,config", test_videos.items())
@pytest.mark.parametrize("backend", ["pyav", "video_reader"])
def test_keyframe_reading(self, test_video, config, backend):
torchvision.set_video_backend(backend)
def test_keyframe_reading(self, test_video, config):
full_path = os.path.join(VIDEO_DIR, test_video)
av_reader = av.open(full_path)
......
import os
import warnings
from modulefinder import Module
import torch
from torchvision import datasets, io, models, ops, transforms, utils
from .extension import _HAS_OPS, _load_library
from .extension import _HAS_OPS
try:
from .version import __version__ # noqa: F401
except ImportError:
pass
try:
_load_library("Decoder")
_HAS_GPU_VIDEO_DECODER = True
except (ImportError, OSError, ModuleNotFoundError):
_HAS_GPU_VIDEO_DECODER = False
# Check if torchvision is being imported within the root folder
if not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == os.path.join(
os.path.realpath(os.getcwd()), "torchvision"
......@@ -74,16 +66,11 @@ def set_video_backend(backend):
backend, please compile torchvision from source.
"""
global _video_backend
if backend not in ["pyav", "video_reader", "cuda"]:
raise ValueError("Invalid video backend '%s'. Options are 'pyav', 'video_reader' and 'cuda'" % backend)
if backend not in ["pyav", "video_reader"]:
raise ValueError("Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend)
if backend == "video_reader" and not io._HAS_VIDEO_OPT:
# TODO: better messages
message = "video_reader video backend is not available. Please compile torchvision from source and try again"
raise RuntimeError(message)
elif backend == "cuda" and not _HAS_GPU_VIDEO_DECODER:
# TODO: better messages
message = "cuda video backend is not available."
raise RuntimeError(message)
warnings.warn(message)
else:
_video_backend = backend
......
......@@ -4,6 +4,10 @@ import torch
from ..utils import _log_api_usage_once
try:
from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER
except ModuleNotFoundError:
_HAS_GPU_VIDEO_DECODER = False
from ._video_opt import (
_HAS_VIDEO_OPT,
_probe_video_from_file,
......@@ -43,6 +47,7 @@ __all__ = [
"_read_video_timestamps_from_memory",
"_probe_video_from_memory",
"_HAS_VIDEO_OPT",
"_HAS_GPU_VIDEO_DECODER",
"_read_video_clip_from_memory",
"_read_video_meta_data",
"VideoMetaData",
......
from ..extension import _load_library
try:
_load_library("Decoder")
_HAS_GPU_VIDEO_DECODER = True
except (ImportError, OSError):
_HAS_GPU_VIDEO_DECODER = False
import io
import warnings
from typing import Any, Dict, Iterator, Optional
import torch
from ..utils import _log_api_usage_once
try:
from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER
except ModuleNotFoundError:
_HAS_GPU_VIDEO_DECODER = False
from ._video_opt import _HAS_VIDEO_OPT
if _HAS_VIDEO_OPT:
......@@ -20,37 +22,11 @@ else:
return False
try:
import av
av.logging.set_level(av.logging.ERROR)
if not hasattr(av.video.frame.VideoFrame, "pict_type"):
av = ImportError(
"""\
Your version of PyAV is too old for the necessary video operations in torchvision.
If you are on Python 3.5, you will have to build from source (the conda-forge
packages are not up-to-date). See
https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
"""
)
except ImportError:
av = ImportError(
"""\
PyAV is not installed, and is necessary for the video operations in torchvision.
See https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
"""
)
class VideoReader:
"""
Fine-grained video-reading API.
Supports frame-by-frame reading of various streams from a single video
container. Much like previous video_reader API it supports the following
backends: video_reader, pyav, and cuda.
Backends can be set via `torchvision.set_video_backend` function.
container.
.. betastatus:: VideoReader class
......@@ -112,11 +88,16 @@ class VideoReader:
Default value (0) enables multithreading with codec-dependent heuristic. The performance
will depend on the version of FFMPEG codecs supported.
device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``.
To use GPU decoding, pass ``device="cuda"``.
path (str, optional):
.. warning:
This parameter was deprecated in ``0.15`` and will be removed in ``0.17``.
Please use ``src`` instead.
"""
def __init__(
......@@ -124,59 +105,45 @@ class VideoReader:
src: str = "",
stream: str = "video",
num_threads: int = 0,
device: str = "cpu",
path: Optional[str] = None,
) -> None:
_log_api_usage_once(self)
from .. import get_video_backend
self.is_cuda = False
device = torch.device(device)
if device.type == "cuda":
if not _HAS_GPU_VIDEO_DECODER:
raise RuntimeError("Not compiled with GPU decoder support.")
self.is_cuda = True
self._c = torch.classes.torchvision.GPUDecoder(src, device)
return
if not _has_video_opt():
raise RuntimeError(
"Not compiled with video_reader support, "
+ "to enable video_reader support, please install "
+ "ffmpeg (version 4.2 is currently supported) and "
+ "build torchvision from source."
)
if src == "":
if path is None:
raise TypeError("src cannot be empty")
src = path
warnings.warn("path is deprecated and will be removed in 0.17. Please use src instead")
self.backend = get_video_backend()
if isinstance(src, str):
if src == "":
if path is None:
raise TypeError("src cannot be empty")
src = path
warnings.warn("path is deprecated and will be removed in 0.17. Please use src instead")
elif isinstance(src, bytes):
if self.backend in ["cuda"]:
raise RuntimeError(
"VideoReader cannot be initialized from bytes object when using cuda or pyav backend."
)
elif self.backend == "pyav":
src = io.BytesIO(src)
else:
src = torch.frombuffer(src, dtype=torch.uint8)
src = torch.frombuffer(src, dtype=torch.uint8)
if isinstance(src, str):
self._c = torch.classes.torchvision.Video(src, stream, num_threads)
elif isinstance(src, torch.Tensor):
if self.backend in ["cuda", "pyav"]:
raise RuntimeError(
"VideoReader cannot be initialized from Tensor object when using cuda or pyav backend."
)
if self.is_cuda:
raise RuntimeError("GPU VideoReader cannot be initialized from Tensor or bytes object.")
self._c = torch.classes.torchvision.Video("", "", 0)
self._c.init_from_memory(src, stream, num_threads)
else:
raise TypeError("`src` must be either string, Tensor or bytes object.")
if self.backend == "cuda":
device = torch.device("cuda")
self._c = torch.classes.torchvision.GPUDecoder(src, device)
elif self.backend == "video_reader":
if isinstance(src, str):
self._c = torch.classes.torchvision.Video(src, stream, num_threads)
elif isinstance(src, torch.Tensor):
self._c = torch.classes.torchvision.Video("", "", 0)
self._c.init_from_memory(src, stream, num_threads)
elif self.backend == "pyav":
self.container = av.open(src, metadata_errors="ignore")
# TODO: load metadata
stream_type = stream.split(":")[0]
stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
self.pyav_stream = {stream_type: stream_id}
self._c = self.container.decode(**self.pyav_stream)
# TODO: add extradata exception
else:
raise RuntimeError("Unknown video backend: {}".format(self.backend))
def __next__(self) -> Dict[str, Any]:
"""Decodes and returns the next frame of the current stream.
Frames are encoded as a dict with mandatory
......@@ -189,29 +156,14 @@ class VideoReader:
and corresponding timestamp (``pts``) in seconds
"""
if self.backend == "cuda":
if self.is_cuda:
frame = self._c.next()
if frame.numel() == 0:
raise StopIteration
return {"data": frame, "pts": None}
elif self.backend == "video_reader":
frame, pts = self._c.next()
else:
try:
frame = next(self._c)
pts = float(frame.pts * frame.time_base)
if "video" in self.pyav_stream:
frame = torch.tensor(frame.to_rgb().to_ndarray()).permute(2, 0, 1)
elif "audio" in self.pyav_stream:
frame = torch.tensor(frame.to_ndarray()).permute(1, 0)
else:
frame = None
except av.error.EOFError:
raise StopIteration
return {"data": frame}
frame, pts = self._c.next()
if frame.numel() == 0:
raise StopIteration
return {"data": frame, "pts": pts}
def __iter__(self) -> Iterator[Dict[str, Any]]:
......@@ -230,18 +182,7 @@ class VideoReader:
frame with the exact timestamp if it exists or
the first frame with timestamp larger than ``time_s``.
"""
if self.backend in ["cuda", "video_reader"]:
self._c.seek(time_s, keyframes_only)
else:
# handle special case as pyav doesn't catch it
if time_s < 0:
time_s = 0
temp_str = self.container.streams.get(**self.pyav_stream)[0]
offset = int(round(time_s / temp_str.time_base))
if not keyframes_only:
warnings.warn("Accurate seek is not implemented for pyav backend")
self.container.seek(offset, backward=True, any_frame=False, stream=temp_str)
self._c = self.container.decode(**self.pyav_stream)
self._c.seek(time_s, keyframes_only)
return self
def get_metadata(self) -> Dict[str, Any]:
......@@ -250,21 +191,6 @@ class VideoReader:
Returns:
(dict): dictionary containing duration and frame rate for every stream
"""
if self.backend == "pyav":
metadata = {} # type: Dict[str, Any]
for stream in self.container.streams:
if stream.type not in metadata:
if stream.type == "video":
rate_n = "fps"
else:
rate_n = "framerate"
metadata[stream.type] = {rate_n: [], "duration": []}
rate = stream.average_rate if stream.average_rate is not None else stream.sample_rate
metadata[stream.type]["duration"].append(float(stream.duration * stream.time_base))
metadata[stream.type][rate_n].append(float(rate))
return metadata
return self._c.get_metadata()
def set_current_stream(self, stream: str) -> bool:
......@@ -284,12 +210,6 @@ class VideoReader:
Returns:
(bool): True on succes, False otherwise
"""
if self.backend == "cuda":
warnings.warn("GPU decoding only works with video stream.")
if self.backend == "pyav":
stream_type = stream.split(":")[0]
stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
self.pyav_stream = {stream_type: stream_id}
self._c = self.container.decode(**self.pyav_stream)
return True
if self.is_cuda:
print("GPU decoding only works with video stream.")
return self._c.set_current_stream(stream)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment