video_reader.py 11.5 KB
Newer Older
Bruno Korbar's avatar
Bruno Korbar committed
1
import io
2
import warnings
Bruno Korbar's avatar
Bruno Korbar committed
3

4
from typing import Any, Dict, Iterator, Optional
5
6
7
8
9

import torch

from ..utils import _log_api_usage_once

10
from ._video_opt import _HAS_VIDEO_OPT
11
12
13
14
15
16
17
18
19
20
21
22

if _HAS_VIDEO_OPT:

    def _has_video_opt() -> bool:
        return True

else:

    def _has_video_opt() -> bool:
        return False


Bruno Korbar's avatar
Bruno Korbar committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
try:
    import av

    av.logging.set_level(av.logging.ERROR)
    if not hasattr(av.video.frame.VideoFrame, "pict_type"):
        av = ImportError(
            """\
Your version of PyAV is too old for the necessary video operations in torchvision.
If you are on Python 3.5, you will have to build from source (the conda-forge
packages are not up-to-date).  See
https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
"""
        )
except ImportError:
    av = ImportError(
        """\
PyAV is not installed, and is necessary for the video operations in torchvision.
See https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
"""
    )


47
48
49
50
class VideoReader:
    """
    Fine-grained video-reading API.
    Supports frame-by-frame reading of various streams from a single video
Bruno Korbar's avatar
Bruno Korbar committed
51
52
53
    container. Much like previous video_reader API it supports the following
    backends: video_reader, pyav, and cuda.
    Backends can be set via `torchvision.set_video_backend` function.
54

55
56
    .. betastatus:: VideoReader class

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    Example:
        The following examples creates a :mod:`VideoReader` object, seeks into 2s
        point, and returns a single frame::

            import torchvision
            video_path = "path_to_a_test_video"
            reader = torchvision.io.VideoReader(video_path, "video")
            reader.seek(2.0)
            frame = next(reader)

        :mod:`VideoReader` implements the iterable API, which makes it suitable to
        using it in conjunction with :mod:`itertools` for more advanced reading.
        As such, we can use a :mod:`VideoReader` instance inside for loops::

            reader.seek(2)
            for frame in reader:
                frames.append(frame['data'])
            # additionally, `seek` implements a fluent API, so we can do
            for frame in reader.seek(2):
                frames.append(frame['data'])

        With :mod:`itertools`, we can read all frames between 2 and 5 seconds with the
        following code::

            for frame in itertools.takewhile(lambda x: x['pts'] <= 5, reader.seek(2)):
                frames.append(frame['data'])

        and similarly, reading 10 frames after the 2s timestamp can be achieved
        as follows::

            for frame in itertools.islice(reader.seek(2), 10):
                frames.append(frame['data'])

    .. note::

        Each stream descriptor consists of two parts: stream type (e.g. 'video') and
        a unique stream id (which are determined by the video encoding).
        In this way, if the video contaner contains multiple
95
        streams of the same type, users can access the one they want.
96
97
98
        If only stream type is passed, the decoder auto-detects first stream of that type.

    Args:
99
100
        src (string, bytes object, or tensor): The media source.
            If string-type, it must be a file path supported by FFMPEG.
101
            If bytes should be an in memory representatin of a file supported by FFMPEG.
102
103
104
            If Tensor, it is interpreted internally as byte buffer.
            It must be one-dimensional, of type ``torch.uint8``.

105
106
107
108
109
110
111
112
113
114
115


        stream (string, optional): descriptor of the required stream, followed by the stream id,
            in the format ``{stream_type}:{stream_id}``. Defaults to ``"video:0"``.
            Currently available options include ``['video', 'audio']``

        num_threads (int, optional): number of threads used by the codec to decode video.
            Default value (0) enables multithreading with codec-dependent heuristic. The performance
            will depend on the version of FFMPEG codecs supported.


116
117
118
119
        path (str, optional):
            .. warning:
                This parameter was deprecated in ``0.15`` and will be removed in ``0.17``.
                Please use ``src`` instead.
120
121
    """

122
123
124
125
126
127
128
    def __init__(
        self,
        src: str = "",
        stream: str = "video",
        num_threads: int = 0,
        path: Optional[str] = None,
    ) -> None:
129
        _log_api_usage_once(self)
Bruno Korbar's avatar
Bruno Korbar committed
130
        from .. import get_video_backend
131

Bruno Korbar's avatar
Bruno Korbar committed
132
        self.backend = get_video_backend()
133
        if isinstance(src, str):
Bruno Korbar's avatar
Bruno Korbar committed
134
135
136
137
138
139
140
141
142
143
144
145
146
            if src == "":
                if path is None:
                    raise TypeError("src cannot be empty")
                src = path
                warnings.warn("path is deprecated and will be removed in 0.17. Please use src instead")
        elif isinstance(src, bytes):
            if self.backend in ["cuda"]:
                raise RuntimeError(
                    "VideoReader cannot be initialized from bytes object when using cuda or pyav backend."
                )
            elif self.backend == "pyav":
                src = io.BytesIO(src)
            else:
147
                with warnings.catch_warnings():
148
                    # Ignore the warning because we actually don't modify the buffer in this function
149
150
                    warnings.filterwarnings("ignore", message="The given buffer is not writable")
                    src = torch.frombuffer(src, dtype=torch.uint8)
151
        elif isinstance(src, torch.Tensor):
Bruno Korbar's avatar
Bruno Korbar committed
152
153
154
155
            if self.backend in ["cuda", "pyav"]:
                raise RuntimeError(
                    "VideoReader cannot be initialized from Tensor object when using cuda or pyav backend."
                )
156
157
        else:
            raise TypeError("`src` must be either string, Tensor or bytes object.")
158

Bruno Korbar's avatar
Bruno Korbar committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        if self.backend == "cuda":
            device = torch.device("cuda")
            self._c = torch.classes.torchvision.GPUDecoder(src, device)

        elif self.backend == "video_reader":
            if isinstance(src, str):
                self._c = torch.classes.torchvision.Video(src, stream, num_threads)
            elif isinstance(src, torch.Tensor):
                self._c = torch.classes.torchvision.Video("", "", 0)
                self._c.init_from_memory(src, stream, num_threads)

        elif self.backend == "pyav":
            self.container = av.open(src, metadata_errors="ignore")
            # TODO: load metadata
            stream_type = stream.split(":")[0]
            stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
            self.pyav_stream = {stream_type: stream_id}
            self._c = self.container.decode(**self.pyav_stream)

            # TODO: add extradata exception

        else:
            raise RuntimeError("Unknown video backend: {}".format(self.backend))

183
184
185
186
187
188
189
190
191
192
193
194
    def __next__(self) -> Dict[str, Any]:
        """Decodes and returns the next frame of the current stream.
        Frames are encoded as a dict with mandatory
        data and pts fields, where data is a tensor, and pts is a
        presentation timestamp of the frame expressed in seconds
        as a float.

        Returns:
            (dict): a dictionary and containing decoded frame (``data``)
            and corresponding timestamp (``pts``) in seconds

        """
Bruno Korbar's avatar
Bruno Korbar committed
195
        if self.backend == "cuda":
196
197
198
            frame = self._c.next()
            if frame.numel() == 0:
                raise StopIteration
Bruno Korbar's avatar
Bruno Korbar committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
            return {"data": frame, "pts": None}
        elif self.backend == "video_reader":
            frame, pts = self._c.next()
        else:
            try:
                frame = next(self._c)
                pts = float(frame.pts * frame.time_base)
                if "video" in self.pyav_stream:
                    frame = torch.tensor(frame.to_rgb().to_ndarray()).permute(2, 0, 1)
                elif "audio" in self.pyav_stream:
                    frame = torch.tensor(frame.to_ndarray()).permute(1, 0)
                else:
                    frame = None
            except av.error.EOFError:
                raise StopIteration

215
216
        if frame.numel() == 0:
            raise StopIteration
Bruno Korbar's avatar
Bruno Korbar committed
217

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        return {"data": frame, "pts": pts}

    def __iter__(self) -> Iterator[Dict[str, Any]]:
        return self

    def seek(self, time_s: float, keyframes_only: bool = False) -> "VideoReader":
        """Seek within current stream.

        Args:
            time_s (float): seek time in seconds
            keyframes_only (bool): allow to seek only to keyframes

        .. note::
            Current implementation is the so-called precise seek. This
            means following seek, call to :mod:`next()` will return the
            frame with the exact timestamp if it exists or
            the first frame with timestamp larger than ``time_s``.
        """
Bruno Korbar's avatar
Bruno Korbar committed
236
237
238
239
240
241
242
243
244
245
246
247
        if self.backend in ["cuda", "video_reader"]:
            self._c.seek(time_s, keyframes_only)
        else:
            # handle special case as pyav doesn't catch it
            if time_s < 0:
                time_s = 0
            temp_str = self.container.streams.get(**self.pyav_stream)[0]
            offset = int(round(time_s / temp_str.time_base))
            if not keyframes_only:
                warnings.warn("Accurate seek is not implemented for pyav backend")
            self.container.seek(offset, backward=True, any_frame=False, stream=temp_str)
            self._c = self.container.decode(**self.pyav_stream)
248
249
250
251
252
253
254
255
        return self

    def get_metadata(self) -> Dict[str, Any]:
        """Returns video metadata

        Returns:
            (dict): dictionary containing duration and frame rate for every stream
        """
Bruno Korbar's avatar
Bruno Korbar committed
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        if self.backend == "pyav":
            metadata = {}  # type:  Dict[str, Any]
            for stream in self.container.streams:
                if stream.type not in metadata:
                    if stream.type == "video":
                        rate_n = "fps"
                    else:
                        rate_n = "framerate"
                    metadata[stream.type] = {rate_n: [], "duration": []}

                rate = stream.average_rate if stream.average_rate is not None else stream.sample_rate

                metadata[stream.type]["duration"].append(float(stream.duration * stream.time_base))
                metadata[stream.type][rate_n].append(float(rate))
            return metadata
271
272
273
274
275
276
277
278
279
280
281
282
        return self._c.get_metadata()

    def set_current_stream(self, stream: str) -> bool:
        """Set current stream.
        Explicitly define the stream we are operating on.

        Args:
            stream (string): descriptor of the required stream. Defaults to ``"video:0"``
                Currently available stream types include ``['video', 'audio']``.
                Each descriptor consists of two parts: stream type (e.g. 'video') and
                a unique stream id (which are determined by video encoding).
                In this way, if the video contaner contains multiple
283
                streams of the same type, users can access the one they want.
284
285
286
287
                If only stream type is passed, the decoder auto-detects first stream
                of that type and returns it.

        Returns:
288
            (bool): True on success, False otherwise
289
        """
Bruno Korbar's avatar
Bruno Korbar committed
290
291
292
293
294
295
296
297
        if self.backend == "cuda":
            warnings.warn("GPU decoding only works with video stream.")
        if self.backend == "pyav":
            stream_type = stream.split(":")[0]
            stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
            self.pyav_stream = {stream_type: stream_id}
            self._c = self.container.decode(**self.pyav_stream)
            return True
298
        return self._c.set_current_stream(stream)