video.py 11.2 KB
Newer Older
1
import gc
2
import math
3
import re
4
import warnings
5
from typing import List, Tuple, Union
6

7
8
import numpy as np
import torch
Francisco Massa's avatar
Francisco Massa committed
9

10
11
from . import _video_opt
from ._video_opt import VideoMetaData
Francisco Massa's avatar
Francisco Massa committed
12
13


14
15
try:
    import av
16

17
    av.logging.set_level(av.logging.ERROR)
18
19
20
    if not hasattr(av.video.frame.VideoFrame, "pict_type"):
        av = ImportError(
            """\
21
22
23
24
25
Your version of PyAV is too old for the necessary video operations in torchvision.
If you are on Python 3.5, you will have to build from source (the conda-forge
packages are not up-to-date).  See
https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
26
27
"""
        )
28
except ImportError:
29
30
    av = ImportError(
        """\
31
32
33
PyAV is not installed, and is necessary for the video operations in torchvision.
See https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
34
35
"""
    )
36
37


38
39
40
41
42
43
44
45
46
def _check_av_available():
    if isinstance(av, Exception):
        raise av


def _av_available():
    return not isinstance(av, Exception)


47
48
# PyAV has some reference cycles
_CALLED_TIMES = 0
49
_GC_COLLECTION_INTERVAL = 10
50
51


52
def write_video(filename, video_array, fps: Union[int, float], video_codec="libx264", options=None):
53
54
55
    """
    Writes a 4d tensor in [T, H, W, C] format in a video file

56
57
58
59
60
61
62
63
    Parameters
    ----------
    filename : str
        path where the video will be saved
    video_array : Tensor[T, H, W, C]
        tensor containing the individual frames, as a uint8 tensor in [T, H, W, C] format
    fps : Number
        frames per second
64
65
66
67
    """
    _check_av_available()
    video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()

68
69
70
71
72
    # PyAV does not support floating point numbers with decimal point
    # and will throw OverflowException in case this is not the case
    if isinstance(fps, float):
        fps = np.round(fps)

73
    container = av.open(filename, mode="w")
74

75
    stream = container.add_stream(video_codec, rate=fps)
76
77
    stream.width = video_array.shape[2]
    stream.height = video_array.shape[1]
78
    stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
79
    stream.options = options or {}
80
81

    for img in video_array:
82
83
        frame = av.VideoFrame.from_ndarray(img, format="rgb24")
        frame.pict_type = "NONE"
84
85
86
87
88
89
90
91
92
93
94
        for packet in stream.encode(frame):
            container.mux(packet)

    # Flush stream
    for packet in stream.encode():
        container.mux(packet)

    # Close the file
    container.close()


95
96
97
def _read_from_stream(
    container, start_offset, end_offset, pts_unit, stream, stream_name
):
98
99
100
101
102
    global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
    _CALLED_TIMES += 1
    if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
        gc.collect()

103
    if pts_unit == "sec":
104
105
106
107
        start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
        if end_offset != float("inf"):
            end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
    else:
108
109
110
111
        warnings.warn(
            "The pts_unit 'pts' gives wrong results and will be removed in a "
            + "follow-up version. Please use pts_unit 'sec'."
        )
112

113
    frames = {}
114
    should_buffer = True
115
116
    max_buffer_size = 5
    if stream.type == "video":
117
        # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
118
119
        # so need to buffer some extra frames to sort everything
        # properly
120
121
122
123
124
125
126
127
128
129
130
131
        extradata = stream.codec_context.extradata
        # overly complicated way of finding if `divx_packed` is set, following
        # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
        if extradata and b"DivX" in extradata:
            # can't use regex directly because of some weird characters sometimes...
            pos = extradata.find(b"DivX")
            d = extradata[pos:]
            o = re.search(br"DivX(\d+)Build(\d+)(\w)", d)
            if o is None:
                o = re.search(br"DivX(\d+)b(\d+)(\w)", d)
            if o is not None:
                should_buffer = o.group(3) == b"p"
132
    seek_offset = start_offset
133
134
    # some files don't seek to the right location, so better be safe here
    seek_offset = max(seek_offset - 1, 0)
135
136
137
138
    if should_buffer:
        # FIXME this is kind of a hack, but we will jump to the previous keyframe
        # so this will be safe
        seek_offset = max(seek_offset - max_buffer_size, 0)
139
140
141
142
    try:
        # TODO check if stream needs to always be the video stream here or not
        container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
    except av.AVError:
143
144
        # TODO add some warnings in this case
        # print("Corrupted file?", container.name)
145
        return []
146
    buffer_count = 0
147
    try:
148
        for _idx, frame in enumerate(container.decode(**stream_name)):
149
150
151
152
153
154
155
156
157
            frames[frame.pts] = frame
            if frame.pts >= end_offset:
                if should_buffer and buffer_count < max_buffer_size:
                    buffer_count += 1
                    continue
                break
    except av.AVError:
        # TODO add a warning
        pass
158
    # ensure that the results are sorted wrt the pts
159
160
161
    result = [
        frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset
    ]
162
    if len(frames) > 0 and start_offset > 0 and start_offset not in frames:
163
164
165
        # if there is no frame that exactly matches the pts of start_offset
        # add the last frame smaller than start_offset, to guarantee that
        # we will have all the necessary data. This is most useful for audio
166
167
168
169
        preceding_frames = [i for i in frames if i < start_offset]
        if len(preceding_frames) > 0:
            first_frame_pts = max(preceding_frames)
            result.insert(0, frames[first_frame_pts])
170
    return result
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185


def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
    start, end = audio_frames[0].pts, audio_frames[-1].pts
    total_aframes = aframes.shape[1]
    step_per_aframe = (end - start + 1) / total_aframes
    s_idx = 0
    e_idx = total_aframes
    if start < ref_start:
        s_idx = int((ref_start - start) / step_per_aframe)
    if end > ref_end:
        e_idx = int((ref_end - end) / step_per_aframe)
    return aframes[:, s_idx:e_idx]


186
def read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
187
188
189
190
    """
    Reads a video from a file, returning both the video frames as well as
    the audio frames

191
192
193
194
    Parameters
    ----------
    filename : str
        path to the video file
195
196
    start_pts : int if pts_unit = 'pts', optional
        float / Fraction if pts_unit = 'sec', optional
197
        the start presentation time of the video
198
199
    end_pts : int if pts_unit = 'pts', optional
        float / Fraction if pts_unit = 'sec', optional
200
        the end presentation time
201
202
    pts_unit : str, optional
        unit in which start_pts and end_pts values will be interpreted, either 'pts' or 'sec'. Defaults to 'pts'.
203
204
205
206
207
208
209
210
211
212
213

    Returns
    -------
    vframes : Tensor[T, H, W, C]
        the `T` video frames
    aframes : Tensor[K, L]
        the audio frames, where `K` is the number of channels and `L` is the
        number of points
    info : Dict
        metadata for the video and audio. Can contain the fields video_fps (float)
        and audio_fps (int)
214
    """
Francisco Massa's avatar
Francisco Massa committed
215
216

    from torchvision import get_video_backend
217

Francisco Massa's avatar
Francisco Massa committed
218
219
220
    if get_video_backend() != "pyav":
        return _video_opt._read_video(filename, start_pts, end_pts, pts_unit)

221
222
223
224
225
226
    _check_av_available()

    if end_pts is None:
        end_pts = float("inf")

    if end_pts < start_pts:
227
228
229
230
        raise ValueError(
            "end_pts should be larger than start_pts, got "
            "start_pts={} and end_pts={}".format(start_pts, end_pts)
        )
231
232
233
234
235

    info = {}
    video_frames = []
    audio_frames = []

236
    try:
237
        container = av.open(filename, metadata_errors="ignore")
238
239
240
241
242
    except av.AVError:
        # TODO raise a warning?
        pass
    else:
        if container.streams.video:
243
244
245
246
247
248
249
250
            video_frames = _read_from_stream(
                container,
                start_pts,
                end_pts,
                pts_unit,
                container.streams.video[0],
                {"video": 0},
            )
251
252
253
254
255
256
            video_fps = container.streams.video[0].average_rate
            # guard against potentially corrupted files
            if video_fps is not None:
                info["video_fps"] = float(video_fps)

        if container.streams.audio:
257
258
259
260
261
262
263
264
            audio_frames = _read_from_stream(
                container,
                start_pts,
                end_pts,
                pts_unit,
                container.streams.audio[0],
                {"audio": 0},
            )
265
266
267
            info["audio_fps"] = container.streams.audio[0].rate

        container.close()
268
269
270

    vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
    aframes = [frame.to_ndarray() for frame in audio_frames]
271
272
273
274
275
276

    if vframes:
        vframes = torch.as_tensor(np.stack(vframes))
    else:
        vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)

277
278
279
280
281
282
283
284
285
286
    if aframes:
        aframes = np.concatenate(aframes, 1)
        aframes = torch.as_tensor(aframes)
        aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
    else:
        aframes = torch.empty((1, 0), dtype=torch.float32)

    return vframes, aframes, info


287
288
289
290
291
292
293
294
295
def _can_read_timestamps_from_packets(container):
    extradata = container.streams[0].codec_context.extradata
    if extradata is None:
        return False
    if b"Lavc" in extradata:
        return True
    return False


296
def read_video_timestamps(filename, pts_unit="pts"):
297
298
299
300
301
    """
    List the video frames timestamps.

    Note that the function decodes the whole video frame-by-frame.

302
303
304
305
    Parameters
    ----------
    filename : str
        path to the video file
306
307
    pts_unit : str, optional
        unit in which timestamp values will be returned either 'pts' or 'sec'. Defaults to 'pts'.
308
309
310

    Returns
    -------
311
312
    pts : List[int] if pts_unit = 'pts'
        List[Fraction] if pts_unit = 'sec'
313
314
315
        presentation timestamps for each one of the frames in the video.
    video_fps : int
        the frame rate for the video
316
317

    """
Francisco Massa's avatar
Francisco Massa committed
318
    from torchvision import get_video_backend
319

Francisco Massa's avatar
Francisco Massa committed
320
321
322
    if get_video_backend() != "pyav":
        return _video_opt._read_video_timestamps(filename, pts_unit)

323
    _check_av_available()
324

325
    video_fps = None
326
    pts = []
327
328

    try:
329
        container = av.open(filename, metadata_errors="ignore")
330
331
332
333
334
335
336
    except av.AVError:
        # TODO add a warning
        pass
    else:
        if container.streams.video:
            video_stream = container.streams.video[0]
            video_time_base = video_stream.time_base
337
338
339
340
341
342
343
344
345
346
            try:
                if _can_read_timestamps_from_packets(container):
                    # fast path
                    pts = [x.pts for x in container.demux(video=0) if x.pts is not None]
                else:
                    pts = [
                        x.pts for x in container.decode(video=0) if x.pts is not None
                    ]
            except av.AVError:
                warnings.warn(f"Failed decoding frames for file {filename}")
347
348
349
            video_fps = float(video_stream.average_rate)
        container.close()

350
    pts.sort()
351

352
    if pts_unit == "sec":
353
354
355
        pts = [x * video_time_base for x in pts]

    return pts, video_fps