video.py 6.75 KB
Newer Older
1
2
3
4
5
6
import gc
import torch
import numpy as np

try:
    import av
7
    av.logging.set_level(av.logging.ERROR)
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
except ImportError:
    av = None


def _check_av_available():
    if av is None:
        raise ImportError("""\
PyAV is not installed, and is necessary for the video operations in torchvision.
See https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
""")


# PyAV has some reference cycles
_CALLED_TIMES = 0
_GC_COLLECTION_INTERVAL = 20


26
def write_video(filename, video_array, fps, video_codec='libx264', options=None):
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    """
    Writes a 4d tensor in [T, H, W, C] format in a video file

    Arguments:
        filename (str): path where the video will be saved
        video_array (Tensor[T, H, W, C]): tensor containing the individual frames,
            as a uint8 tensor in [T, H, W, C] format
        fps (Number): frames per second
    """
    _check_av_available()
    video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()

    container = av.open(filename, mode='w')

41
    stream = container.add_stream(video_codec, rate=fps)
42
43
    stream.width = video_array.shape[2]
    stream.height = video_array.shape[1]
44
45
    stream.pix_fmt = 'yuv420p' if video_codec != 'libx264rgb' else 'rgb24'
    stream.options = options or {}
46
47
48

    for img in video_array:
        frame = av.VideoFrame.from_ndarray(img, format='rgb24')
49
        frame.pict_type = 'NONE'
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        for packet in stream.encode(frame):
            container.mux(packet)

    # Flush stream
    for packet in stream.encode():
        container.mux(packet)

    # Close the file
    container.close()


def _read_from_stream(container, start_offset, end_offset, stream, stream_name):
    global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
    _CALLED_TIMES += 1
    if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
        gc.collect()

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    frames = {}
    should_buffer = False
    max_buffer_size = 5
    if stream.type == "video":
        # TODO consider also using stream.codec_context.codec.reorder
        # videos with b frames can have out-of-order pts
        # so need to buffer some extra frames to sort everything
        # properly
        should_buffer = stream.codec_context.has_b_frames
    seek_offset = start_offset
    if should_buffer:
        # FIXME this is kind of a hack, but we will jump to the previous keyframe
        # so this will be safe
        seek_offset = max(seek_offset - max_buffer_size, 0)
    # TODO check if stream needs to always be the video stream here or not
    container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
    buffer_count = 0
84
    for idx, frame in enumerate(container.decode(**stream_name)):
85
        frames[frame.pts] = frame
86
        if frame.pts >= end_offset:
87
88
89
            if should_buffer and buffer_count < max_buffer_size:
                buffer_count += 1
                continue
90
            break
91
92
93
94
95
96
97
98
99
    # ensure that the results are sorted wrt the pts
    result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
    if start_offset > 0 and start_offset not in frames:
        # if there is no frame that exactly matches the pts of start_offset
        # add the last frame smaller than start_offset, to guarantee that
        # we will have all the necessary data. This is most useful for audio
        first_frame_pts = max(i for i in frames if i < start_offset)
        result.insert(0, frames[first_frame_pts])
    return result
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141


def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
    start, end = audio_frames[0].pts, audio_frames[-1].pts
    total_aframes = aframes.shape[1]
    step_per_aframe = (end - start + 1) / total_aframes
    s_idx = 0
    e_idx = total_aframes
    if start < ref_start:
        s_idx = int((ref_start - start) / step_per_aframe)
    if end > ref_end:
        e_idx = int((ref_end - end) / step_per_aframe)
    return aframes[:, s_idx:e_idx]


def read_video(filename, start_pts=0, end_pts=None):
    """
    Reads a video from a file, returning both the video frames as well as
    the audio frames

    Arguments:
        filename (str): path to the video file
        start_pts (int, optional): the start presentation time of the video
        end_pts (int, optional): the end presentation time

    Returns:
        vframes (Tensor[T, H, W, C]): the `T` video frames
        aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels
            and `L` is the number of points
        info (Dict): metadata for the video and audio. Can contain the fields
            - video_fps (float)
            - audio_fps (int)
    """
    _check_av_available()

    if end_pts is None:
        end_pts = float("inf")

    if end_pts < start_pts:
        raise ValueError("end_pts should be larger than start_pts, got "
                         "start_pts={} and end_pts={}".format(start_pts, end_pts))

142
    container = av.open(filename, metadata_errors='ignore')
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    info = {}

    video_frames = []
    if container.streams.video:
        video_frames = _read_from_stream(container, start_pts, end_pts,
                                         container.streams.video[0], {'video': 0})
        info["video_fps"] = float(container.streams.video[0].average_rate)
    audio_frames = []
    if container.streams.audio:
        audio_frames = _read_from_stream(container, start_pts, end_pts,
                                         container.streams.audio[0], {'audio': 0})
        info["audio_fps"] = container.streams.audio[0].rate

    container.close()

    vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
    aframes = [frame.to_ndarray() for frame in audio_frames]
    vframes = torch.as_tensor(np.stack(vframes))
    if aframes:
        aframes = np.concatenate(aframes, 1)
        aframes = torch.as_tensor(aframes)
        aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
    else:
        aframes = torch.empty((1, 0), dtype=torch.float32)

    return vframes, aframes, info


def read_video_timestamps(filename):
    """
    List the video frames timestamps.

    Note that the function decodes the whole video frame-by-frame.

    Arguments:
        filename (str): path to the video file

    Returns:
        pts (List[int]): presentation timestamps for each one of the frames
            in the video.
183
        video_fps (int): the frame rate for the video
184
185
    """
    _check_av_available()
186
    container = av.open(filename, metadata_errors='ignore')
187
188

    video_frames = []
189
    video_fps = None
190
191
192
    if container.streams.video:
        video_frames = _read_from_stream(container, 0, float("inf"),
                                         container.streams.video[0], {'video': 0})
193
        video_fps = float(container.streams.video[0].average_rate)
194
    container.close()
195
    return [x.pts for x in video_frames], video_fps