video.py 12.2 KB
Newer Older
1
import gc
2
import math
3
import re
4
import warnings
5
from typing import Tuple, List
6

7
8
import numpy as np
import torch
Francisco Massa's avatar
Francisco Massa committed
9

10
11
from . import _video_opt
from ._video_opt import VideoMetaData
Francisco Massa's avatar
Francisco Massa committed
12
13


14
15
try:
    import av
16

17
    av.logging.set_level(av.logging.ERROR)
18
19
20
    if not hasattr(av.video.frame.VideoFrame, "pict_type"):
        av = ImportError(
            """\
21
22
23
24
25
Your version of PyAV is too old for the necessary video operations in torchvision.
If you are on Python 3.5, you will have to build from source (the conda-forge
packages are not up-to-date).  See
https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
26
27
"""
        )
28
except ImportError:
29
30
    av = ImportError(
        """\
31
32
33
PyAV is not installed, and is necessary for the video operations in torchvision.
See https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
34
35
"""
    )
36
37


38
39
40
41
42
43
44
45
46
def _check_av_available():
    if isinstance(av, Exception):
        raise av


def _av_available():
    return not isinstance(av, Exception)


47
48
# PyAV has some reference cycles
_CALLED_TIMES = 0
49
_GC_COLLECTION_INTERVAL = 10
50
51


52
def write_video(filename, video_array, fps, video_codec="libx264", options=None):
53
54
55
    """
    Writes a 4d tensor in [T, H, W, C] format in a video file

56
57
58
59
60
61
62
63
    Parameters
    ----------
    filename : str
        path where the video will be saved
    video_array : Tensor[T, H, W, C]
        tensor containing the individual frames, as a uint8 tensor in [T, H, W, C] format
    fps : Number
        frames per second
64
65
66
67
    """
    _check_av_available()
    video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()

68
    container = av.open(filename, mode="w")
69

70
    stream = container.add_stream(video_codec, rate=fps)
71
72
    stream.width = video_array.shape[2]
    stream.height = video_array.shape[1]
73
    stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
74
    stream.options = options or {}
75
76

    for img in video_array:
77
78
        frame = av.VideoFrame.from_ndarray(img, format="rgb24")
        frame.pict_type = "NONE"
79
80
81
82
83
84
85
86
87
88
89
        for packet in stream.encode(frame):
            container.mux(packet)

    # Flush stream
    for packet in stream.encode():
        container.mux(packet)

    # Close the file
    container.close()


90
91
92
def _read_from_stream(
    container, start_offset, end_offset, pts_unit, stream, stream_name
):
93
94
95
96
97
    global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
    _CALLED_TIMES += 1
    if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
        gc.collect()

98
    if pts_unit == "sec":
99
100
101
102
        start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
        if end_offset != float("inf"):
            end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
    else:
103
104
105
106
        warnings.warn(
            "The pts_unit 'pts' gives wrong results and will be removed in a "
            + "follow-up version. Please use pts_unit 'sec'."
        )
107

108
109
110
111
    frames = {}
    should_buffer = False
    max_buffer_size = 5
    if stream.type == "video":
112
        # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
113
114
        # so need to buffer some extra frames to sort everything
        # properly
115
116
117
118
119
120
121
122
123
124
125
126
        extradata = stream.codec_context.extradata
        # overly complicated way of finding if `divx_packed` is set, following
        # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
        if extradata and b"DivX" in extradata:
            # can't use regex directly because of some weird characters sometimes...
            pos = extradata.find(b"DivX")
            d = extradata[pos:]
            o = re.search(br"DivX(\d+)Build(\d+)(\w)", d)
            if o is None:
                o = re.search(br"DivX(\d+)b(\d+)(\w)", d)
            if o is not None:
                should_buffer = o.group(3) == b"p"
127
    seek_offset = start_offset
128
129
    # some files don't seek to the right location, so better be safe here
    seek_offset = max(seek_offset - 1, 0)
130
131
132
133
    if should_buffer:
        # FIXME this is kind of a hack, but we will jump to the previous keyframe
        # so this will be safe
        seek_offset = max(seek_offset - max_buffer_size, 0)
134
135
136
137
    try:
        # TODO check if stream needs to always be the video stream here or not
        container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
    except av.AVError:
138
139
        # TODO add some warnings in this case
        # print("Corrupted file?", container.name)
140
        return []
141
    buffer_count = 0
142
    try:
143
        for _idx, frame in enumerate(container.decode(**stream_name)):
144
145
146
147
148
149
150
151
152
            frames[frame.pts] = frame
            if frame.pts >= end_offset:
                if should_buffer and buffer_count < max_buffer_size:
                    buffer_count += 1
                    continue
                break
    except av.AVError:
        # TODO add a warning
        pass
153
    # ensure that the results are sorted wrt the pts
154
155
156
    result = [
        frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset
    ]
157
    if len(frames) > 0 and start_offset > 0 and start_offset not in frames:
158
159
160
        # if there is no frame that exactly matches the pts of start_offset
        # add the last frame smaller than start_offset, to guarantee that
        # we will have all the necessary data. This is most useful for audio
161
162
163
164
        preceding_frames = [i for i in frames if i < start_offset]
        if len(preceding_frames) > 0:
            first_frame_pts = max(preceding_frames)
            result.insert(0, frames[first_frame_pts])
165
    return result
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180


def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
    start, end = audio_frames[0].pts, audio_frames[-1].pts
    total_aframes = aframes.shape[1]
    step_per_aframe = (end - start + 1) / total_aframes
    s_idx = 0
    e_idx = total_aframes
    if start < ref_start:
        s_idx = int((ref_start - start) / step_per_aframe)
    if end > ref_end:
        e_idx = int((ref_end - end) / step_per_aframe)
    return aframes[:, s_idx:e_idx]


181
def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
182
183
184
185
    """
    Reads a video from a file, returning both the video frames as well as
    the audio frames

186
187
188
189
    Parameters
    ----------
    filename : str
        path to the video file
190
191
    start_pts : int if pts_unit = 'pts', optional
        float / Fraction if pts_unit = 'sec', optional
192
        the start presentation time of the video
193
194
    end_pts : int if pts_unit = 'pts', optional
        float / Fraction if pts_unit = 'sec', optional
195
        the end presentation time
196
197
    pts_unit : str, optional
        unit in which start_pts and end_pts values will be interpreted, either 'pts' or 'sec'. Defaults to 'pts'.
198
199
200
201
202
203
204
205
206
207
208

    Returns
    -------
    vframes : Tensor[T, H, W, C]
        the `T` video frames
    aframes : Tensor[K, L]
        the audio frames, where `K` is the number of channels and `L` is the
        number of points
    info : Dict
        metadata for the video and audio. Can contain the fields video_fps (float)
        and audio_fps (int)
209
    """
Francisco Massa's avatar
Francisco Massa committed
210
211

    from torchvision import get_video_backend
212

Francisco Massa's avatar
Francisco Massa committed
213
214
215
    if get_video_backend() != "pyav":
        return _video_opt._read_video(filename, start_pts, end_pts, pts_unit)

216
217
218
219
220
221
    _check_av_available()

    if end_pts is None:
        end_pts = float("inf")

    if end_pts < start_pts:
222
223
224
225
        raise ValueError(
            "end_pts should be larger than start_pts, got "
            "start_pts={} and end_pts={}".format(start_pts, end_pts)
        )
226
227
228
229
230

    info = {}
    video_frames = []
    audio_frames = []

231
    try:
232
        container = av.open(filename, metadata_errors="ignore")
233
234
235
236
237
    except av.AVError:
        # TODO raise a warning?
        pass
    else:
        if container.streams.video:
238
239
240
241
242
243
244
245
            video_frames = _read_from_stream(
                container,
                start_pts,
                end_pts,
                pts_unit,
                container.streams.video[0],
                {"video": 0},
            )
246
247
248
249
250
251
            video_fps = container.streams.video[0].average_rate
            # guard against potentially corrupted files
            if video_fps is not None:
                info["video_fps"] = float(video_fps)

        if container.streams.audio:
252
253
254
255
256
257
258
259
            audio_frames = _read_from_stream(
                container,
                start_pts,
                end_pts,
                pts_unit,
                container.streams.audio[0],
                {"audio": 0},
            )
260
261
262
            info["audio_fps"] = container.streams.audio[0].rate

        container.close()
263
264
265

    vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
    aframes = [frame.to_ndarray() for frame in audio_frames]
266
267
268
269
270
271

    if vframes:
        vframes = torch.as_tensor(np.stack(vframes))
    else:
        vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)

272
273
274
275
276
277
278
279
280
281
    if aframes:
        aframes = np.concatenate(aframes, 1)
        aframes = torch.as_tensor(aframes)
        aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
    else:
        aframes = torch.empty((1, 0), dtype=torch.float32)

    return vframes, aframes, info


282
283
284
285
286
287
288
289
290
def _can_read_timestamps_from_packets(container):
    extradata = container.streams[0].codec_context.extradata
    if extradata is None:
        return False
    if b"Lavc" in extradata:
        return True
    return False


291
def read_video_timestamps(filename, pts_unit="pts"):
292
293
294
295
296
    """
    List the video frames timestamps.

    Note that the function decodes the whole video frame-by-frame.

297
298
299
300
    Parameters
    ----------
    filename : str
        path to the video file
301
302
    pts_unit : str, optional
        unit in which timestamp values will be returned either 'pts' or 'sec'. Defaults to 'pts'.
303
304
305

    Returns
    -------
306
307
    pts : List[int] if pts_unit = 'pts'
        List[Fraction] if pts_unit = 'sec'
308
309
310
        presentation timestamps for each one of the frames in the video.
    video_fps : int
        the frame rate for the video
311
312

    """
Francisco Massa's avatar
Francisco Massa committed
313
    from torchvision import get_video_backend
314

Francisco Massa's avatar
Francisco Massa committed
315
316
317
    if get_video_backend() != "pyav":
        return _video_opt._read_video_timestamps(filename, pts_unit)

318
    _check_av_available()
319

320
    video_frames = []
321
    video_fps = None
322
323

    try:
324
        container = av.open(filename, metadata_errors="ignore")
325
326
327
328
329
330
331
332
333
    except av.AVError:
        # TODO add a warning
        pass
    else:
        if container.streams.video:
            video_stream = container.streams.video[0]
            video_time_base = video_stream.time_base
            if _can_read_timestamps_from_packets(container):
                # fast path
334
335
336
                video_frames = [
                    x for x in container.demux(video=0) if x.pts is not None
                ]
337
            else:
338
339
340
                video_frames = _read_from_stream(
                    container, 0, float("inf"), pts_unit, video_stream, {"video": 0}
                )
341
342
343
344
345
            video_fps = float(video_stream.average_rate)
        container.close()

    pts = [x.pts for x in video_frames]

346
    if pts_unit == "sec":
347
348
349
        pts = [x * video_time_base for x in pts]

    return pts, video_fps
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391


def read_video_meta_data_from_memory(video_data):
    # type: (torch.Tensor) -> VideoMetaData
    return _video_opt._probe_video_from_memory(video_data)


def read_video_from_memory(
    video_data,  # type: torch.Tensor
    seek_frame_margin=0.25,  # type: float
    read_video_stream=1,  # type: int
    video_width=0,  # type: int
    video_height=0,  # type: int
    video_min_dimension=0,  # type: int
    video_pts_range=(0, -1),  # type: List[int]
    video_timebase_numerator=0,  # type: int
    video_timebase_denominator=1,  # type: int
    read_audio_stream=1,  # type: int
    audio_samples=0,  # type: int
    audio_channels=0,  # type: int
    audio_pts_range=(0, -1),  # type: List[int]
    audio_timebase_numerator=0,  # type: int
    audio_timebase_denominator=1,  # type: int
):
    # type: (...) -> Tuple[torch.Tensor, torch.Tensor]
    return _video_opt._read_video_from_memory(
        video_data,
        seek_frame_margin,
        read_audio_stream,
        video_width,
        video_height,
        video_min_dimension,
        video_pts_range,
        video_timebase_numerator,
        video_timebase_denominator,
        read_audio_stream,
        audio_samples,
        audio_channels,
        audio_pts_range,
        audio_timebase_numerator,
        audio_timebase_denominator,
    )