video.py 8.49 KB
Newer Older
1
import re
2
3
4
5
6
7
import gc
import torch
import numpy as np

try:
    import av
8
    av.logging.set_level(av.logging.ERROR)
9
10
11
12
13
14
15
16
    if not hasattr(av.video.frame.VideoFrame, 'pict_type'):
        av = ImportError("""\
Your version of PyAV is too old for the necessary video operations in torchvision.
If you are on Python 3.5, you will have to build from source (the conda-forge
packages are not up-to-date).  See
https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
""")
17
except ImportError:
18
    av = ImportError("""\
19
20
21
22
23
24
PyAV is not installed, and is necessary for the video operations in torchvision.
See https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
""")


25
26
27
28
29
30
31
32
33
def _check_av_available():
    if isinstance(av, Exception):
        raise av


def _av_available():
    return not isinstance(av, Exception)


34
35
# PyAV has some reference cycles
_CALLED_TIMES = 0
36
_GC_COLLECTION_INTERVAL = 10
37
38


39
def write_video(filename, video_array, fps, video_codec='libx264', options=None):
40
41
42
    """
    Writes a 4d tensor in [T, H, W, C] format in a video file

43
44
45
46
47
48
49
50
    Parameters
    ----------
    filename : str
        path where the video will be saved
    video_array : Tensor[T, H, W, C]
        tensor containing the individual frames, as a uint8 tensor in [T, H, W, C] format
    fps : Number
        frames per second
51
52
53
54
55
56
    """
    _check_av_available()
    video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()

    container = av.open(filename, mode='w')

57
    stream = container.add_stream(video_codec, rate=fps)
58
59
    stream.width = video_array.shape[2]
    stream.height = video_array.shape[1]
60
61
    stream.pix_fmt = 'yuv420p' if video_codec != 'libx264rgb' else 'rgb24'
    stream.options = options or {}
62
63
64

    for img in video_array:
        frame = av.VideoFrame.from_ndarray(img, format='rgb24')
65
        frame.pict_type = 'NONE'
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        for packet in stream.encode(frame):
            container.mux(packet)

    # Flush stream
    for packet in stream.encode():
        container.mux(packet)

    # Close the file
    container.close()


def _read_from_stream(container, start_offset, end_offset, stream, stream_name):
    global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
    _CALLED_TIMES += 1
    if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
        gc.collect()

83
84
85
86
    frames = {}
    should_buffer = False
    max_buffer_size = 5
    if stream.type == "video":
87
        # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
88
89
        # so need to buffer some extra frames to sort everything
        # properly
90
91
92
93
94
95
96
97
98
99
100
101
        extradata = stream.codec_context.extradata
        # overly complicated way of finding if `divx_packed` is set, following
        # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
        if extradata and b"DivX" in extradata:
            # can't use regex directly because of some weird characters sometimes...
            pos = extradata.find(b"DivX")
            d = extradata[pos:]
            o = re.search(br"DivX(\d+)Build(\d+)(\w)", d)
            if o is None:
                o = re.search(br"DivX(\d+)b(\d+)(\w)", d)
            if o is not None:
                should_buffer = o.group(3) == b"p"
102
    seek_offset = start_offset
103
104
    # some files don't seek to the right location, so better be safe here
    seek_offset = max(seek_offset - 1, 0)
105
106
107
108
    if should_buffer:
        # FIXME this is kind of a hack, but we will jump to the previous keyframe
        # so this will be safe
        seek_offset = max(seek_offset - max_buffer_size, 0)
109
110
111
112
    try:
        # TODO check if stream needs to always be the video stream here or not
        container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
    except av.AVError:
113
114
        # TODO add some warnings in this case
        # print("Corrupted file?", container.name)
115
        return []
116
    buffer_count = 0
117
    for idx, frame in enumerate(container.decode(**stream_name)):
118
        frames[frame.pts] = frame
119
        if frame.pts >= end_offset:
120
121
122
            if should_buffer and buffer_count < max_buffer_size:
                buffer_count += 1
                continue
123
            break
124
125
126
127
128
129
130
131
132
    # ensure that the results are sorted wrt the pts
    result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
    if start_offset > 0 and start_offset not in frames:
        # if there is no frame that exactly matches the pts of start_offset
        # add the last frame smaller than start_offset, to guarantee that
        # we will have all the necessary data. This is most useful for audio
        first_frame_pts = max(i for i in frames if i < start_offset)
        result.insert(0, frames[first_frame_pts])
    return result
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152


def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
    start, end = audio_frames[0].pts, audio_frames[-1].pts
    total_aframes = aframes.shape[1]
    step_per_aframe = (end - start + 1) / total_aframes
    s_idx = 0
    e_idx = total_aframes
    if start < ref_start:
        s_idx = int((ref_start - start) / step_per_aframe)
    if end > ref_end:
        e_idx = int((ref_end - end) / step_per_aframe)
    return aframes[:, s_idx:e_idx]


def read_video(filename, start_pts=0, end_pts=None):
    """
    Reads a video from a file, returning both the video frames as well as
    the audio frames

153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    Parameters
    ----------
    filename : str
        path to the video file
    start_pts : int, optional
        the start presentation time of the video
    end_pts : int, optional
        the end presentation time

    Returns
    -------
    vframes : Tensor[T, H, W, C]
        the `T` video frames
    aframes : Tensor[K, L]
        the audio frames, where `K` is the number of channels and `L` is the
        number of points
    info : Dict
        metadata for the video and audio. Can contain the fields video_fps (float)
        and audio_fps (int)
172
173
174
175
176
177
178
179
180
181
    """
    _check_av_available()

    if end_pts is None:
        end_pts = float("inf")

    if end_pts < start_pts:
        raise ValueError("end_pts should be larger than start_pts, got "
                         "start_pts={} and end_pts={}".format(start_pts, end_pts))

182
    container = av.open(filename, metadata_errors='ignore')
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    info = {}

    video_frames = []
    if container.streams.video:
        video_frames = _read_from_stream(container, start_pts, end_pts,
                                         container.streams.video[0], {'video': 0})
        info["video_fps"] = float(container.streams.video[0].average_rate)
    audio_frames = []
    if container.streams.audio:
        audio_frames = _read_from_stream(container, start_pts, end_pts,
                                         container.streams.audio[0], {'audio': 0})
        info["audio_fps"] = container.streams.audio[0].rate

    container.close()

    vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
    aframes = [frame.to_ndarray() for frame in audio_frames]
    vframes = torch.as_tensor(np.stack(vframes))
    if aframes:
        aframes = np.concatenate(aframes, 1)
        aframes = torch.as_tensor(aframes)
        aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
    else:
        aframes = torch.empty((1, 0), dtype=torch.float32)

    return vframes, aframes, info


211
212
213
214
215
216
217
218
219
def _can_read_timestamps_from_packets(container):
    extradata = container.streams[0].codec_context.extradata
    if extradata is None:
        return False
    if b"Lavc" in extradata:
        return True
    return False


220
221
222
223
224
225
def read_video_timestamps(filename):
    """
    List the video frames timestamps.

    Note that the function decodes the whole video frame-by-frame.

226
227
228
229
230
231
232
233
234
235
236
    Parameters
    ----------
    filename : str
        path to the video file

    Returns
    -------
    pts : List[int]
        presentation timestamps for each one of the frames in the video.
    video_fps : int
        the frame rate for the video
237
238
239

    """
    _check_av_available()
240
    container = av.open(filename, metadata_errors='ignore')
241
242

    video_frames = []
243
    video_fps = None
244
    if container.streams.video:
245
246
247
248
249
250
        if _can_read_timestamps_from_packets(container):
            # fast path
            video_frames = [x for x in container.demux(video=0) if x.pts is not None]
        else:
            video_frames = _read_from_stream(container, 0, float("inf"),
                                             container.streams.video[0], {'video': 0})
251
        video_fps = float(container.streams.video[0].average_rate)
252
    container.close()
253
    return [x.pts for x in video_frames], video_fps