"...DensePose/densepose/evaluation/d2_evaluator_adapter.py" did not exist on "5b3792fc3ef9ab6a6f8f30634ab2e52fb0941af3"
video.py 15 KB
Newer Older
1
import gc
2
import math
3
import os
4
import re
5
import warnings
6
from fractions import Fraction
7
from typing import Any, Dict, List, Optional, Tuple, Union
8

9
10
import numpy as np
import torch
Francisco Massa's avatar
Francisco Massa committed
11

Kai Zhang's avatar
Kai Zhang committed
12
from ..utils import _log_api_usage_once
13
from . import _video_opt
Francisco Massa's avatar
Francisco Massa committed
14
15


16
17
try:
    import av
18

19
    av.logging.set_level(av.logging.ERROR)
20
21
22
    if not hasattr(av.video.frame.VideoFrame, "pict_type"):
        av = ImportError(
            """\
23
24
25
26
27
Your version of PyAV is too old for the necessary video operations in torchvision.
If you are on Python 3.5, you will have to build from source (the conda-forge
packages are not up-to-date).  See
https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
28
29
"""
        )
30
except ImportError:
31
32
    av = ImportError(
        """\
33
34
35
PyAV is not installed, and is necessary for the video operations in torchvision.
See https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
36
37
"""
    )
38
39


40
def _check_av_available() -> None:
41
42
43
44
    if isinstance(av, Exception):
        raise av


45
def _av_available() -> bool:
46
47
48
    return not isinstance(av, Exception)


49
50
# PyAV has some reference cycles
_CALLED_TIMES = 0
51
_GC_COLLECTION_INTERVAL = 10
52
53


54
55
56
57
58
59
def write_video(
    filename: str,
    video_array: torch.Tensor,
    fps: float,
    video_codec: str = "libx264",
    options: Optional[Dict[str, Any]] = None,
60
61
62
63
    audio_array: Optional[torch.Tensor] = None,
    audio_fps: Optional[float] = None,
    audio_codec: Optional[str] = None,
    audio_options: Optional[Dict[str, Any]] = None,
64
) -> None:
65
66
67
    """
    Writes a 4d tensor in [T, H, W, C] format in a video file

68
69
70
71
72
73
74
75
76
77
78
79
    Args:
        filename (str): path where the video will be saved
        video_array (Tensor[T, H, W, C]): tensor containing the individual frames,
            as a uint8 tensor in [T, H, W, C] format
        fps (Number): video frames per second
        video_codec (str): the name of the video codec, i.e. "libx264", "h264", etc.
        options (Dict): dictionary containing options to be passed into the PyAV video stream
        audio_array (Tensor[C, N]): tensor containing the audio, where C is the number of channels
            and N is the number of samples
        audio_fps (Number): audio sample rate, typically 44100 or 48000
        audio_codec (str): the name of the audio codec, i.e. "mp3", "aac", etc.
        audio_options (Dict): dictionary containing options to be passed into the PyAV audio stream
80
    """
Kai Zhang's avatar
Kai Zhang committed
81
82
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(write_video)
83
84
85
    _check_av_available()
    video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()

86
87
88
89
90
    # PyAV does not support floating point numbers with decimal point
    # and will throw OverflowException in case this is not the case
    if isinstance(fps, float):
        fps = np.round(fps)

91
92
93
94
95
96
97
    with av.open(filename, mode="w") as container:
        stream = container.add_stream(video_codec, rate=fps)
        stream.width = video_array.shape[2]
        stream.height = video_array.shape[1]
        stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
        stream.options = options or {}

98
99
        if audio_array is not None:
            audio_format_dtypes = {
100
101
102
103
104
105
106
107
108
109
                "dbl": "<f8",
                "dblp": "<f8",
                "flt": "<f4",
                "fltp": "<f4",
                "s16": "<i2",
                "s16p": "<i2",
                "s32": "<i4",
                "s32p": "<i4",
                "u8": "u1",
                "u8p": "u1",
110
111
112
113
114
115
116
117
118
119
120
            }
            a_stream = container.add_stream(audio_codec, rate=audio_fps)
            a_stream.options = audio_options or {}

            num_channels = audio_array.shape[0]
            audio_layout = "stereo" if num_channels > 1 else "mono"
            audio_sample_fmt = container.streams.audio[0].format.name

            format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt])
            audio_array = torch.as_tensor(audio_array).numpy().astype(format_dtype)

121
            frame = av.AudioFrame.from_ndarray(audio_array, format=audio_sample_fmt, layout=audio_layout)
122
123
124
125
126
127
128
129
130

            frame.sample_rate = audio_fps

            for packet in a_stream.encode(frame):
                container.mux(packet)

            for packet in a_stream.encode():
                container.mux(packet)

131
132
133
134
135
136
137
138
        for img in video_array:
            frame = av.VideoFrame.from_ndarray(img, format="rgb24")
            frame.pict_type = "NONE"
            for packet in stream.encode(frame):
                container.mux(packet)

        # Flush stream
        for packet in stream.encode():
139
140
141
            container.mux(packet)


142
def _read_from_stream(
143
144
145
146
147
148
149
    container: "av.container.Container",
    start_offset: float,
    end_offset: float,
    pts_unit: str,
    stream: "av.stream.Stream",
    stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]],
) -> List["av.frame.Frame"]:
150
151
152
153
154
    global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
    _CALLED_TIMES += 1
    if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
        gc.collect()

155
    if pts_unit == "sec":
156
157
        # TODO: we should change all of this from ground up to simply take
        # sec and convert to MS in C++
158
159
160
161
        start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
        if end_offset != float("inf"):
            end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
    else:
162
        warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
163

164
    frames = {}
165
    should_buffer = True
166
167
    max_buffer_size = 5
    if stream.type == "video":
168
        # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
169
170
        # so need to buffer some extra frames to sort everything
        # properly
171
172
173
174
175
176
177
        extradata = stream.codec_context.extradata
        # overly complicated way of finding if `divx_packed` is set, following
        # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
        if extradata and b"DivX" in extradata:
            # can't use regex directly because of some weird characters sometimes...
            pos = extradata.find(b"DivX")
            d = extradata[pos:]
178
            o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
179
            if o is None:
180
                o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
181
182
            if o is not None:
                should_buffer = o.group(3) == b"p"
183
    seek_offset = start_offset
184
185
    # some files don't seek to the right location, so better be safe here
    seek_offset = max(seek_offset - 1, 0)
186
187
188
189
    if should_buffer:
        # FIXME this is kind of a hack, but we will jump to the previous keyframe
        # so this will be safe
        seek_offset = max(seek_offset - max_buffer_size, 0)
190
191
192
193
    try:
        # TODO check if stream needs to always be the video stream here or not
        container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
    except av.AVError:
194
195
        # TODO add some warnings in this case
        # print("Corrupted file?", container.name)
196
        return []
197
    buffer_count = 0
198
    try:
199
        for _idx, frame in enumerate(container.decode(**stream_name)):
200
201
202
203
204
205
206
207
208
            frames[frame.pts] = frame
            if frame.pts >= end_offset:
                if should_buffer and buffer_count < max_buffer_size:
                    buffer_count += 1
                    continue
                break
    except av.AVError:
        # TODO add a warning
        pass
209
    # ensure that the results are sorted wrt the pts
210
    result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
211
    if len(frames) > 0 and start_offset > 0 and start_offset not in frames:
212
213
214
        # if there is no frame that exactly matches the pts of start_offset
        # add the last frame smaller than start_offset, to guarantee that
        # we will have all the necessary data. This is most useful for audio
215
216
217
218
        preceding_frames = [i for i in frames if i < start_offset]
        if len(preceding_frames) > 0:
            first_frame_pts = max(preceding_frames)
            result.insert(0, frames[first_frame_pts])
219
    return result
220
221


222
223
224
def _align_audio_frames(
    aframes: torch.Tensor, audio_frames: List["av.frame.Frame"], ref_start: int, ref_end: float
) -> torch.Tensor:
225
226
227
228
229
230
231
232
233
234
235
236
    start, end = audio_frames[0].pts, audio_frames[-1].pts
    total_aframes = aframes.shape[1]
    step_per_aframe = (end - start + 1) / total_aframes
    s_idx = 0
    e_idx = total_aframes
    if start < ref_start:
        s_idx = int((ref_start - start) / step_per_aframe)
    if end > ref_end:
        e_idx = int((ref_end - end) / step_per_aframe)
    return aframes[:, s_idx:e_idx]


237
def read_video(
238
239
240
241
    filename: str,
    start_pts: Union[float, Fraction] = 0,
    end_pts: Optional[Union[float, Fraction]] = None,
    pts_unit: str = "pts",
242
    output_format: str = "THWC",
243
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
244
245
246
247
    """
    Reads a video from a file, returning both the video frames as well as
    the audio frames

248
249
250
251
252
253
254
255
    Args:
        filename (str): path to the video file
        start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
            The start presentation time of the video
        end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
            The end presentation time
        pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
            either 'pts' or 'sec'. Defaults to 'pts'.
256
        output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
257
258

    Returns:
259
        vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
260
261
        aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
        info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
262
    """
Kai Zhang's avatar
Kai Zhang committed
263
264
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(read_video)
Francisco Massa's avatar
Francisco Massa committed
265

266
267
268
269
    output_format = output_format.upper()
    if output_format not in ("THWC", "TCHW"):
        raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")

Francisco Massa's avatar
Francisco Massa committed
270
    from torchvision import get_video_backend
271

272
    if not os.path.exists(filename):
273
        raise RuntimeError(f"File not found: {filename}")
274

Francisco Massa's avatar
Francisco Massa committed
275
276
277
    if get_video_backend() != "pyav":
        return _video_opt._read_video(filename, start_pts, end_pts, pts_unit)

278
279
280
281
282
283
    _check_av_available()

    if end_pts is None:
        end_pts = float("inf")

    if end_pts < start_pts:
284
        raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
285
286
287
288

    info = {}
    video_frames = []
    audio_frames = []
289
    audio_timebase = _video_opt.default_timebase
290

291
    try:
292
        with av.open(filename, metadata_errors="ignore") as container:
293
294
            if container.streams.audio:
                audio_timebase = container.streams.audio[0].time_base
295
296
297
            if container.streams.video:
                video_frames = _read_from_stream(
                    container,
298
299
                    start_pts,
                    end_pts,
300
301
302
303
304
305
306
307
308
309
310
311
                    pts_unit,
                    container.streams.video[0],
                    {"video": 0},
                )
                video_fps = container.streams.video[0].average_rate
                # guard against potentially corrupted files
                if video_fps is not None:
                    info["video_fps"] = float(video_fps)

            if container.streams.audio:
                audio_frames = _read_from_stream(
                    container,
312
313
                    start_pts,
                    end_pts,
314
315
316
317
318
319
                    pts_unit,
                    container.streams.audio[0],
                    {"audio": 0},
                )
                info["audio_fps"] = container.streams.audio[0].rate

320
321
322
    except av.AVError:
        # TODO raise a warning?
        pass
323

324
325
    vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
    aframes_list = [frame.to_ndarray() for frame in audio_frames]
326

327
328
    if vframes_list:
        vframes = torch.as_tensor(np.stack(vframes_list))
329
330
331
    else:
        vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)

332
333
    if aframes_list:
        aframes = np.concatenate(aframes_list, 1)
334
        aframes = torch.as_tensor(aframes)
335
        if pts_unit == "sec":
336
337
338
            start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
            if end_pts != float("inf"):
                end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
339
340
341
342
        aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
    else:
        aframes = torch.empty((1, 0), dtype=torch.float32)

343
344
345
346
    if output_format == "TCHW":
        # [T,H,W,C] --> [T,C,H,W]
        vframes = vframes.permute(0, 3, 1, 2)

347
348
349
    return vframes, aframes, info


350
def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool:
351
352
353
354
355
356
357
358
    extradata = container.streams[0].codec_context.extradata
    if extradata is None:
        return False
    if b"Lavc" in extradata:
        return True
    return False


359
def _decode_video_timestamps(container: "av.container.Container") -> List[int]:
360
361
362
363
364
365
366
    if _can_read_timestamps_from_packets(container):
        # fast path
        return [x.pts for x in container.demux(video=0) if x.pts is not None]
    else:
        return [x.pts for x in container.decode(video=0) if x.pts is not None]


367
def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]:
368
369
370
371
372
    """
    List the video frames timestamps.

    Note that the function decodes the whole video frame-by-frame.

373
374
375
376
377
378
379
380
381
    Args:
        filename (str): path to the video file
        pts_unit (str, optional): unit in which timestamp values will be returned
            either 'pts' or 'sec'. Defaults to 'pts'.

    Returns:
        pts (List[int] if pts_unit = 'pts', List[Fraction] if pts_unit = 'sec'):
            presentation timestamps for each one of the frames in the video.
        video_fps (float, optional): the frame rate for the video
382
383

    """
Kai Zhang's avatar
Kai Zhang committed
384
385
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(read_video_timestamps)
Francisco Massa's avatar
Francisco Massa committed
386
    from torchvision import get_video_backend
387

Francisco Massa's avatar
Francisco Massa committed
388
389
390
    if get_video_backend() != "pyav":
        return _video_opt._read_video_timestamps(filename, pts_unit)

391
    _check_av_available()
392

393
    video_fps = None
394
    pts = []
395
396

    try:
397
398
399
400
401
402
403
404
405
        with av.open(filename, metadata_errors="ignore") as container:
            if container.streams.video:
                video_stream = container.streams.video[0]
                video_time_base = video_stream.time_base
                try:
                    pts = _decode_video_timestamps(container)
                except av.AVError:
                    warnings.warn(f"Failed decoding frames for file {filename}")
                video_fps = float(video_stream.average_rate)
406
407
408
    except av.AVError as e:
        msg = f"Failed to open container for {filename}; Caught error: {e}"
        warnings.warn(msg, RuntimeWarning)
409

410
    pts.sort()
411

412
    if pts_unit == "sec":
413
414
415
        pts = [x * video_time_base for x in pts]

    return pts, video_fps