_video_opt.py 20.1 KB
Newer Older
Francisco Massa's avatar
Francisco Massa committed
1
import math
2
3
import warnings
from fractions import Fraction
4
from typing import List, Tuple, Dict, Optional, Union
5

6
import torch
7

8
from .._internally_replaced_utils import _get_extension_path
9
10
11


try:
12
    lib_path = _get_extension_path("video_reader")
13
14
    torch.ops.load_library(lib_path)
    _HAS_VIDEO_OPT = True
15
except (ImportError, OSError):
16
    _HAS_VIDEO_OPT = False
17
18
19
20

default_timebase = Fraction(0, 1)


21
22
# simple class for torch scripting
# the complex Fraction class from fractions module is not scriptable
23
class Timebase:
24
25
26
27
28
    __annotations__ = {"numerator": int, "denominator": int}
    __slots__ = ["numerator", "denominator"]

    def __init__(
        self,
29
30
31
        numerator: int,
        denominator: int,
    ) -> None:
32
33
34
35
        self.numerator = numerator
        self.denominator = denominator


36
class VideoMetaData:
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    __annotations__ = {
        "has_video": bool,
        "video_timebase": Timebase,
        "video_duration": float,
        "video_fps": float,
        "has_audio": bool,
        "audio_timebase": Timebase,
        "audio_duration": float,
        "audio_sample_rate": float,
    }
    __slots__ = [
        "has_video",
        "video_timebase",
        "video_duration",
        "video_fps",
        "has_audio",
        "audio_timebase",
        "audio_duration",
        "audio_sample_rate",
    ]

58
    def __init__(self) -> None:
59
60
61
62
63
64
65
66
67
68
        self.has_video = False
        self.video_timebase = Timebase(0, 1)
        self.video_duration = 0.0
        self.video_fps = 0.0
        self.has_audio = False
        self.audio_timebase = Timebase(0, 1)
        self.audio_duration = 0.0
        self.audio_sample_rate = 0.0


69
def _validate_pts(pts_range: Tuple[int, int]) -> None:
70

71
    if pts_range[1] > 0:
72
73
74
        assert (
            pts_range[0] <= pts_range[1]
        ), """Start pts should not be smaller than end pts, got
75
            start pts: {0:d} and end pts: {1:d}""".format(
76
77
78
            pts_range[0],
            pts_range[1],
        )
79
80


81
82
83
84
85
86
87
88
def _fill_info(
    vtimebase: torch.Tensor,
    vfps: torch.Tensor,
    vduration: torch.Tensor,
    atimebase: torch.Tensor,
    asample_rate: torch.Tensor,
    aduration: torch.Tensor,
) -> VideoMetaData:
89
90
91
92
    """
    Build update VideoMetaData struct with info about the video
    """
    meta = VideoMetaData()
93
    if vtimebase.numel() > 0:
94
        meta.video_timebase = Timebase(int(vtimebase[0].item()), int(vtimebase[1].item()))
95
        timebase = vtimebase[0].item() / float(vtimebase[1].item())
96
        if vduration.numel() > 0:
97
98
            meta.has_video = True
            meta.video_duration = float(vduration.item()) * timebase
99
    if vfps.numel() > 0:
100
        meta.video_fps = float(vfps.item())
101
    if atimebase.numel() > 0:
102
        meta.audio_timebase = Timebase(int(atimebase[0].item()), int(atimebase[1].item()))
103
        timebase = atimebase[0].item() / float(atimebase[1].item())
104
        if aduration.numel() > 0:
105
106
            meta.has_audio = True
            meta.audio_duration = float(aduration.item()) * timebase
107
    if asample_rate.numel() > 0:
108
        meta.audio_sample_rate = float(asample_rate.item())
109

110
    return meta
111
112


113
114
115
def _align_audio_frames(
    aframes: torch.Tensor, aframe_pts: torch.Tensor, audio_pts_range: Tuple[int, int]
) -> torch.Tensor:
116
117
118
119
120
121
122
    start, end = aframe_pts[0], aframe_pts[-1]
    num_samples = aframes.size(0)
    step_per_aframe = float(end - start + 1) / float(num_samples)
    s_idx = 0
    e_idx = num_samples
    if start < audio_pts_range[0]:
        s_idx = int((audio_pts_range[0] - start) / step_per_aframe)
123
    if audio_pts_range[1] != -1 and end > audio_pts_range[1]:
124
125
126
127
128
        e_idx = int((audio_pts_range[1] - end) / step_per_aframe)
    return aframes[s_idx:e_idx, :]


def _read_video_from_file(
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    filename: str,
    seek_frame_margin: float = 0.25,
    read_video_stream: bool = True,
    video_width: int = 0,
    video_height: int = 0,
    video_min_dimension: int = 0,
    video_max_dimension: int = 0,
    video_pts_range: Tuple[int, int] = (0, -1),
    video_timebase: Fraction = default_timebase,
    read_audio_stream: bool = True,
    audio_samples: int = 0,
    audio_channels: int = 0,
    audio_pts_range: Tuple[int, int] = (0, -1),
    audio_timebase: Fraction = default_timebase,
) -> Tuple[torch.Tensor, torch.Tensor, VideoMetaData]:
144
145
146
147
    """
    Reads a video from a file, returning both the video frames as well as
    the audio frames

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    Args:
    filename (str): path to the video file
    seek_frame_margin (double, optional): seeking frame in the stream is imprecise. Thus,
        when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
    read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
    video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
        the size of decoded frames:

            - When video_width = 0, video_height = 0, video_min_dimension = 0,
                and video_max_dimension = 0, keep the original frame resolution
            - When video_width = 0, video_height = 0, video_min_dimension != 0,
                and video_max_dimension = 0, keep the aspect ratio and resize the
                frame so that shorter edge size is video_min_dimension
            - When video_width = 0, video_height = 0, video_min_dimension = 0,
                and video_max_dimension != 0, keep the aspect ratio and resize
                the frame so that longer edge size is video_max_dimension
            - When video_width = 0, video_height = 0, video_min_dimension != 0,
                and video_max_dimension != 0, resize the frame so that shorter
                edge size is video_min_dimension, and longer edge size is
                video_max_dimension. The aspect ratio may not be preserved
            - When video_width = 0, video_height != 0, video_min_dimension = 0,
                and video_max_dimension = 0, keep the aspect ratio and resize
                the frame so that frame video_height is $video_height
            - When video_width != 0, video_height == 0, video_min_dimension = 0,
                and video_max_dimension = 0, keep the aspect ratio and resize
                the frame so that frame video_width is $video_width
            - When video_width != 0, video_height != 0, video_min_dimension = 0,
                and video_max_dimension = 0, resize the frame so that frame
                video_width and  video_height are set to $video_width and
                $video_height, respectively
    video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
    video_timebase (Fraction, optional): a Fraction rational number which denotes timebase in video stream
    read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
    audio_samples (int, optional): audio sampling rate
    audio_channels (int optional): audio channels
    audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
    audio_timebase (Fraction, optional): a Fraction rational number which denotes time base in audio stream
185
186

    Returns
187
188
        vframes (Tensor[T, H, W, C]): the `T` video frames
        aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
189
            `K` is the number of audio_channels
190
191
        info (Dict): metadata for the video and audio. Can contain the fields video_fps (float)
            and audio_fps (int)
192
193
194
195
196
197
198
199
200
201
202
203
    """
    _validate_pts(video_pts_range)
    _validate_pts(audio_pts_range)

    result = torch.ops.video_reader.read_video_from_file(
        filename,
        seek_frame_margin,
        0,  # getPtsOnly
        read_video_stream,
        video_width,
        video_height,
        video_min_dimension,
204
        video_max_dimension,
205
206
207
208
209
210
211
212
213
214
215
216
        video_pts_range[0],
        video_pts_range[1],
        video_timebase.numerator,
        video_timebase.denominator,
        read_audio_stream,
        audio_samples,
        audio_channels,
        audio_pts_range[0],
        audio_pts_range[1],
        audio_timebase.numerator,
        audio_timebase.denominator,
    )
217
    vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
218
    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
219
220
221
222
223
224
    if aframes.numel() > 0:
        # when audio stream is found
        aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
    return vframes, aframes, info


225
def _read_video_timestamps_from_file(filename: str) -> Tuple[List[int], List[int], VideoMetaData]:
226
227
228
229
230
231
232
233
234
235
236
237
238
    """
    Decode all video- and audio frames in the video. Only pts
    (presentation timestamp) is returned. The actual frame pixel data is not
    copied. Thus, it is much faster than read_video(...)
    """
    result = torch.ops.video_reader.read_video_from_file(
        filename,
        0,  # seek_frame_margin
        1,  # getPtsOnly
        1,  # read_video_stream
        0,  # video_width
        0,  # video_height
        0,  # video_min_dimension
239
        0,  # video_max_dimension
240
241
242
243
244
245
246
247
248
249
250
251
        0,  # video_start_pts
        -1,  # video_end_pts
        0,  # video_timebase_num
        1,  # video_timebase_den
        1,  # read_audio_stream
        0,  # audio_samples
        0,  # audio_channels
        0,  # audio_start_pts
        -1,  # audio_end_pts
        0,  # audio_timebase_num
        1,  # audio_timebase_den
    )
252
    _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
253
    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
254
255
256
257
258
259

    vframe_pts = vframe_pts.numpy().tolist()
    aframe_pts = aframe_pts.numpy().tolist()
    return vframe_pts, aframe_pts, info


260
def _probe_video_from_file(filename: str) -> VideoMetaData:
261
    """
262
    Probe a video file and return VideoMetaData with info about the video
263
264
265
266
267
268
269
    """
    result = torch.ops.video_reader.probe_video_from_file(filename)
    vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
    return info


270
def _read_video_from_memory(
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    video_data: torch.Tensor,
    seek_frame_margin: float = 0.25,
    read_video_stream: int = 1,
    video_width: int = 0,
    video_height: int = 0,
    video_min_dimension: int = 0,
    video_max_dimension: int = 0,
    video_pts_range: Tuple[int, int] = (0, -1),
    video_timebase_numerator: int = 0,
    video_timebase_denominator: int = 1,
    read_audio_stream: int = 1,
    audio_samples: int = 0,
    audio_channels: int = 0,
    audio_pts_range: Tuple[int, int] = (0, -1),
    audio_timebase_numerator: int = 0,
    audio_timebase_denominator: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
288
289
290
    """
    Reads a video from memory, returning both the video frames as well as
    the audio frames
291
    This function is torchscriptable.
292

293
294
    Args:
    video_data (data type could be 1) torch.Tensor, dtype=torch.int8 or 2) python bytes):
295
        compressed video content stored in either 1) torch.Tensor 2) python bytes
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
    seek_frame_margin (double, optional): seeking frame in the stream is imprecise.
        Thus, when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds
    read_video_stream (int, optional): whether read video stream. If yes, set to 1. Otherwise, 0
    video_width/video_height/video_min_dimension/video_max_dimension (int): together decide
        the size of decoded frames:

            - When video_width = 0, video_height = 0, video_min_dimension = 0,
                and video_max_dimension = 0, keep the original frame resolution
            - When video_width = 0, video_height = 0, video_min_dimension != 0,
                and video_max_dimension = 0, keep the aspect ratio and resize the
                frame so that shorter edge size is video_min_dimension
            - When video_width = 0, video_height = 0, video_min_dimension = 0,
                and video_max_dimension != 0, keep the aspect ratio and resize
                the frame so that longer edge size is video_max_dimension
            - When video_width = 0, video_height = 0, video_min_dimension != 0,
                and video_max_dimension != 0, resize the frame so that shorter
                edge size is video_min_dimension, and longer edge size is
                video_max_dimension. The aspect ratio may not be preserved
            - When video_width = 0, video_height != 0, video_min_dimension = 0,
                and video_max_dimension = 0, keep the aspect ratio and resize
                the frame so that frame video_height is $video_height
            - When video_width != 0, video_height == 0, video_min_dimension = 0,
                and video_max_dimension = 0, keep the aspect ratio and resize
                the frame so that frame video_width is $video_width
            - When video_width != 0, video_height != 0, video_min_dimension = 0,
                and video_max_dimension = 0, resize the frame so that frame
                video_width and  video_height are set to $video_width and
                $video_height, respectively
    video_pts_range (list(int), optional): the start and end presentation timestamp of video stream
    video_timebase_numerator / video_timebase_denominator (float, optional): a rational
        number which denotes timebase in video stream
    read_audio_stream (int, optional): whether read audio stream. If yes, set to 1. Otherwise, 0
    audio_samples (int, optional): audio sampling rate
    audio_channels (int optional): audio audio_channels
    audio_pts_range (list(int), optional): the start and end presentation timestamp of audio stream
    audio_timebase_numerator / audio_timebase_denominator (float, optional):
332
        a rational number which denotes time base in audio stream
333

334
335
336
    Returns:
        vframes (Tensor[T, H, W, C]): the `T` video frames
        aframes (Tensor[L, K]): the audio frames, where `L` is the number of points and
337
338
339
340
341
342
            `K` is the number of channels
    """

    _validate_pts(video_pts_range)
    _validate_pts(audio_pts_range)

343
    if not isinstance(video_data, torch.Tensor):
344
        video_data = torch.frombuffer(video_data, dtype=torch.uint8)
345

346
    result = torch.ops.video_reader.read_video_from_memory(
347
        video_data,
348
349
350
351
352
353
        seek_frame_margin,
        0,  # getPtsOnly
        read_video_stream,
        video_width,
        video_height,
        video_min_dimension,
354
        video_max_dimension,
355
356
        video_pts_range[0],
        video_pts_range[1],
357
358
        video_timebase_numerator,
        video_timebase_denominator,
359
360
361
362
363
        read_audio_stream,
        audio_samples,
        audio_channels,
        audio_pts_range[0],
        audio_pts_range[1],
364
365
        audio_timebase_numerator,
        audio_timebase_denominator,
366
367
    )

368
    vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
369

370
371
372
    if aframes.numel() > 0:
        # when audio stream is found
        aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range)
373
374

    return vframes, aframes
375
376


377
378
379
def _read_video_timestamps_from_memory(
    video_data: torch.Tensor,
) -> Tuple[List[int], List[int], VideoMetaData]:
380
381
382
383
384
    """
    Decode all frames in the video. Only pts (presentation timestamp) is returned.
    The actual frame pixel data is not copied. Thus, read_video_timestamps(...)
    is much faster than read_video(...)
    """
385
    if not isinstance(video_data, torch.Tensor):
386
        video_data = torch.frombuffer(video_data, dtype=torch.uint8)
387
    result = torch.ops.video_reader.read_video_from_memory(
388
        video_data,
389
390
391
392
393
394
        0,  # seek_frame_margin
        1,  # getPtsOnly
        1,  # read_video_stream
        0,  # video_width
        0,  # video_height
        0,  # video_min_dimension
395
        0,  # video_max_dimension
396
397
398
399
400
401
402
403
404
405
406
407
        0,  # video_start_pts
        -1,  # video_end_pts
        0,  # video_timebase_num
        1,  # video_timebase_den
        1,  # read_audio_stream
        0,  # audio_samples
        0,  # audio_channels
        0,  # audio_start_pts
        -1,  # audio_end_pts
        0,  # audio_timebase_num
        1,  # audio_timebase_den
    )
408
    _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
409
    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
410
411
412
413

    vframe_pts = vframe_pts.numpy().tolist()
    aframe_pts = aframe_pts.numpy().tolist()
    return vframe_pts, aframe_pts, info
414
415


416
417
418
def _probe_video_from_memory(
    video_data: torch.Tensor,
) -> VideoMetaData:
419
    """
420
421
    Probe a video in memory and return VideoMetaData with info about the video
    This function is torchscriptable
422
423
    """
    if not isinstance(video_data, torch.Tensor):
424
        video_data = torch.frombuffer(video_data, dtype=torch.uint8)
425
426
427
428
    result = torch.ops.video_reader.probe_video_from_memory(video_data)
    vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
    info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
    return info
Francisco Massa's avatar
Francisco Massa committed
429
430


431
432
433
def _convert_to_sec(
    start_pts: Union[float, Fraction], end_pts: Union[float, Fraction], pts_unit: str, time_base: Fraction
) -> Tuple[Union[float, Fraction], Union[float, Fraction], str]:
434
    if pts_unit == "pts":
435
436
        start_pts = float(start_pts * time_base)
        end_pts = float(end_pts * time_base)
437
        pts_unit = "sec"
438
439
440
    return start_pts, end_pts, pts_unit


441
442
443
444
445
446
def _read_video(
    filename: str,
    start_pts: Union[float, Fraction] = 0,
    end_pts: Optional[Union[float, Fraction]] = None,
    pts_unit: str = "pts",
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]:
Francisco Massa's avatar
Francisco Massa committed
447
448
449
    if end_pts is None:
        end_pts = float("inf")

450
451
452
453
454
    if pts_unit == "pts":
        warnings.warn(
            "The pts_unit 'pts' gives wrong results and will be removed in a "
            + "follow-up version. Please use pts_unit 'sec'."
        )
Francisco Massa's avatar
Francisco Massa committed
455
456
457

    info = _probe_video_from_file(filename)

458
459
    has_video = info.has_video
    has_audio = info.has_audio
460
461
462
463
464
465
466
    video_pts_range = (0, -1)
    video_timebase = default_timebase
    audio_pts_range = (0, -1)
    audio_timebase = default_timebase
    time_base = default_timebase

    if has_video:
467
        video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
468
469
470
        time_base = video_timebase

    if has_audio:
471
        audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
472
473
474
        time_base = time_base if time_base else audio_timebase

    # video_timebase is the default time_base
475
    start_pts_sec, end_pts_sec, pts_unit = _convert_to_sec(start_pts, end_pts, pts_unit, time_base)
Francisco Massa's avatar
Francisco Massa committed
476
477

    def get_pts(time_base):
478
479
        start_offset = start_pts_sec
        end_offset = end_pts_sec
480
        if pts_unit == "sec":
481
            start_offset = int(math.floor(start_pts_sec * (1 / time_base)))
Francisco Massa's avatar
Francisco Massa committed
482
            if end_offset != float("inf"):
483
                end_offset = int(math.ceil(end_pts_sec * (1 / time_base)))
Francisco Massa's avatar
Francisco Massa committed
484
485
486
487
488
489
490
491
492
493
        if end_offset == float("inf"):
            end_offset = -1
        return start_offset, end_offset

    if has_video:
        video_pts_range = get_pts(video_timebase)

    if has_audio:
        audio_pts_range = get_pts(audio_timebase)

494
    vframes, aframes, info = _read_video_from_file(
Francisco Massa's avatar
Francisco Massa committed
495
496
497
498
499
500
501
502
        filename,
        read_video_stream=True,
        video_pts_range=video_pts_range,
        video_timebase=video_timebase,
        read_audio_stream=True,
        audio_pts_range=audio_pts_range,
        audio_timebase=audio_timebase,
    )
503
504
    _info = {}
    if has_video:
505
        _info["video_fps"] = info.video_fps
506
    if has_audio:
507
        _info["audio_fps"] = info.audio_sample_rate
508
509

    return vframes, aframes, _info
Francisco Massa's avatar
Francisco Massa committed
510
511


512
513
514
def _read_video_timestamps(
    filename: str, pts_unit: str = "pts"
) -> Tuple[Union[List[int], List[Fraction]], Optional[float]]:
515
516
517
518
519
    if pts_unit == "pts":
        warnings.warn(
            "The pts_unit 'pts' gives wrong results and will be removed in a "
            + "follow-up version. Please use pts_unit 'sec'."
        )
Francisco Massa's avatar
Francisco Massa committed
520

521
    pts: Union[List[int], List[Fraction]]
Francisco Massa's avatar
Francisco Massa committed
522
523
    pts, _, info = _read_video_timestamps_from_file(filename)

524
    if pts_unit == "sec":
525
        video_time_base = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
Francisco Massa's avatar
Francisco Massa committed
526
527
        pts = [x * video_time_base for x in pts]

528
    video_fps = info.video_fps if info.has_video else None
Francisco Massa's avatar
Francisco Massa committed
529
530

    return pts, video_fps