plot_video_api.py 11.5 KB
Newer Older
1
"""
Nicolas Hug's avatar
Nicolas Hug committed
2
=========
3
Video API
Nicolas Hug's avatar
Nicolas Hug committed
4
5
6
7
8
=========

.. note::
    Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_video_api.ipynb>`_
    or :ref:`go to the end <sphx_glr_download_auto_examples_others_plot_video_api.py>` to download the full example code.
9
10
11
12
13

This example illustrates some of the APIs that torchvision offers for
videos, together with the examples on how to build datasets and more.
"""

14
# %%
15
16
17
18
19
20
# 1. Introduction: building a new video object and examining the properties
# -------------------------------------------------------------------------
# First we select a video to test the object out. For the sake of argument
# we're using one from kinetics400 dataset.
# To create it, we need to define the path and the stream we want to use.

21
# %%
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Chosen video statistics:
#
# - WUzgd7C1pWA.mp4
#     - source:
#         - kinetics-400
#     - video:
#         - H-264
#         - MPEG-4 AVC (part 10) (avc1)
#         - fps: 29.97
#     - audio:
#         - MPEG AAC audio (mp4a)
#         - sample rate: 48K Hz
#

import torch
import torchvision
from torchvision.datasets.utils import download_url
Bruno Korbar's avatar
Bruno Korbar committed
39
torchvision.set_video_backend("video_reader")
40
41
42

# Download the sample video
download_url(
43
    "https://github.com/pytorch/vision/blob/main/test/assets/videos/WUzgd7C1pWA.mp4?raw=true",
44
45
46
47
48
    ".",
    "WUzgd7C1pWA.mp4"
)
video_path = "./WUzgd7C1pWA.mp4"

49
# %%
50
51
52
53
54
55
56
57
58
# Streams are defined in a similar fashion as torch devices. We encode them as strings in a form
# of ``stream_type:stream_id`` where ``stream_type`` is a string and ``stream_id`` a long int.
# The constructor accepts passing a ``stream_type`` only, in which case the stream is auto-discovered.
# Firstly, let's get the metadata for our particular video:

stream = "video"
video = torchvision.io.VideoReader(video_path, stream)
video.get_metadata()

59
# %%
60
61
62
63
64
65
66
67
# Here we can see that video has two streams - a video and an audio stream.
# Currently available stream types include ['video', 'audio'].
# Each descriptor consists of two parts: stream type (e.g. 'video') and a unique stream id
# (which are determined by video encoding).
# In this way, if the video container contains multiple streams of the same type,
# users can access the one they want.
# If only stream type is passed, the decoder auto-detects first stream of that type and returns it.

68
# %%
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Let's read all the frames from the video stream. By default, the return value of
# ``next(video_reader)`` is a dict containing the following fields.
#
# The return fields are:
#
# - ``data``: containing a torch.tensor
# - ``pts``: containing a float timestamp of this particular frame

metadata = video.get_metadata()
video.set_current_stream("audio")

frames = []  # we are going to save the frames here.
ptss = []  # pts is a presentation timestamp in seconds (float) of each frame
for frame in video:
    frames.append(frame['data'])
    ptss.append(frame['pts'])

print("PTS for first five frames ", ptss[:5])
print("Total number of frames: ", len(frames))
approx_nf = metadata['audio']['duration'][0] * metadata['audio']['framerate'][0]
print("Approx total number of datapoints we can expect: ", approx_nf)
print("Read data size: ", frames[0].size(0) * len(frames))

92
# %%
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# But what if we only want to read certain time segment of the video?
# That can be done easily using the combination of our ``seek`` function, and the fact that each call
# to next returns the presentation timestamp of the returned frame in seconds.
#
# Given that our implementation relies on python iterators,
# we can leverage itertools to simplify the process and make it more pythonic.
#
# For example, if we wanted to read ten frames from second second:


import itertools
video.set_current_stream("video")

frames = []  # we are going to save the frames here.

# We seek into a second second of the video and use islice to get 10 frames since
for frame, pts in itertools.islice(video.seek(2), 10):
    frames.append(frame)

print("Total number of frames: ", len(frames))

114
# %%
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Or if we wanted to read from 2nd to 5th second,
# We seek into a second second of the video,
# then we utilize the itertools takewhile to get the
# correct number of frames:

video.set_current_stream("video")
frames = []  # we are going to save the frames here.
video = video.seek(2)

for frame in itertools.takewhile(lambda x: x['pts'] <= 5, video):
    frames.append(frame['data'])

print("Total number of frames: ", len(frames))
approx_nf = (5 - 2) * video.get_metadata()['video']['fps'][0]
print("We can expect approx: ", approx_nf)
print("Tensor size: ", frames[0].size())

132
# %%
133
134
135
136
137
138
139
140
141
142
143
144
# 2. Building a sample read_video function
# ----------------------------------------------------------------------------------------
# We can utilize the methods above to build the read video function that follows
# the same API to the existing ``read_video`` function.


def example_read_video(video_object, start=0, end=None, read_video=True, read_audio=True):
    if end is None:
        end = float("inf")
    if end < start:
        raise ValueError(
            "end time should be larger than start time, got "
145
            f"start time={start} and end time={end}"
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        )

    video_frames = torch.empty(0)
    video_pts = []
    if read_video:
        video_object.set_current_stream("video")
        frames = []
        for frame in itertools.takewhile(lambda x: x['pts'] <= end, video_object.seek(start)):
            frames.append(frame['data'])
            video_pts.append(frame['pts'])
        if len(frames) > 0:
            video_frames = torch.stack(frames, 0)

    audio_frames = torch.empty(0)
    audio_pts = []
    if read_audio:
        video_object.set_current_stream("audio")
        frames = []
        for frame in itertools.takewhile(lambda x: x['pts'] <= end, video_object.seek(start)):
            frames.append(frame['data'])
satojkovic's avatar
satojkovic committed
166
            audio_pts.append(frame['pts'])
167
168
169
170
171
172
173
174
175
176
        if len(frames) > 0:
            audio_frames = torch.cat(frames, 0)

    return video_frames, audio_frames, (video_pts, audio_pts), video_object.get_metadata()


# Total number of frames should be 327 for video and 523264 datapoints for audio
vf, af, info, meta = example_read_video(video)
print(vf.size(), af.size())

177
# %%
178
# 3. Building an example randomly sampled dataset (can be applied to training dataset of kinetics400)
179
180
181
182
183
# -------------------------------------------------------------------------------------------------------
# Cool, so now we can use the same principle to make the sample dataset.
# We suggest trying out iterable dataset for this purpose.
# Here, we are going to build an example dataset that reads randomly selected 10 frames of video.

184
# %%
185
186
187
188
189
190
# Make sample dataset
import os
os.makedirs("./dataset", exist_ok=True)
os.makedirs("./dataset/1", exist_ok=True)
os.makedirs("./dataset/2", exist_ok=True)

191
# %%
192
193
194
# Download the videos
from torchvision.datasets.utils import download_url
download_url(
195
    "https://github.com/pytorch/vision/blob/main/test/assets/videos/WUzgd7C1pWA.mp4?raw=true",
196
197
198
    "./dataset/1", "WUzgd7C1pWA.mp4"
)
download_url(
199
    "https://github.com/pytorch/vision/blob/main/test/assets/videos/RATRACE_wave_f_nm_np1_fr_goo_37.avi?raw=true",
200
201
202
203
    "./dataset/1",
    "RATRACE_wave_f_nm_np1_fr_goo_37.avi"
)
download_url(
204
    "https://github.com/pytorch/vision/blob/main/test/assets/videos/SOX5yA1l24A.mp4?raw=true",
205
206
207
208
    "./dataset/2",
    "SOX5yA1l24A.mp4"
)
download_url(
209
    "https://github.com/pytorch/vision/blob/main/test/assets/videos/v_SoccerJuggling_g23_c01.avi?raw=true",
210
211
212
213
    "./dataset/2",
    "v_SoccerJuggling_g23_c01.avi"
)
download_url(
214
    "https://github.com/pytorch/vision/blob/main/test/assets/videos/v_SoccerJuggling_g24_c01.avi?raw=true",
215
216
217
218
    "./dataset/2",
    "v_SoccerJuggling_g24_c01.avi"
)

219
# %%
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# Housekeeping and utilities
import os
import random

from torchvision.datasets.folder import make_dataset
from torchvision import transforms as t


def _find_classes(dir):
    classes = [d.name for d in os.scandir(dir) if d.is_dir()]
    classes.sort()
    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx


def get_samples(root, extensions=(".mp4", ".avi")):
    _, class_to_idx = _find_classes(root)
    return make_dataset(root, class_to_idx, extensions=extensions)

239
# %%
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
# We are going to define the dataset and some basic arguments.
# We assume the structure of the FolderDataset, and add the following parameters:
#
# - ``clip_len``: length of a clip in frames
# - ``frame_transform``: transform for every frame individually
# - ``video_transform``: transform on a video sequence
#
# .. note::
#   We actually add epoch size as using :func:`~torch.utils.data.IterableDataset`
#   class allows us to naturally oversample clips or images from each video if needed.


class RandomDataset(torch.utils.data.IterableDataset):
    def __init__(self, root, epoch_size=None, frame_transform=None, video_transform=None, clip_len=16):
        super(RandomDataset).__init__()

        self.samples = get_samples(root)

        # Allow for temporal jittering
        if epoch_size is None:
            epoch_size = len(self.samples)
        self.epoch_size = epoch_size

        self.clip_len = clip_len
        self.frame_transform = frame_transform
        self.video_transform = video_transform

    def __iter__(self):
        for i in range(self.epoch_size):
            # Get random sample
            path, target = random.choice(self.samples)
            # Get video object
            vid = torchvision.io.VideoReader(path, "video")
            metadata = vid.get_metadata()
            video_frames = []  # video frame buffer

            # Seek and return frames
            max_seek = metadata["video"]['duration'][0] - (self.clip_len / metadata["video"]['fps'][0])
            start = random.uniform(0., max_seek)
            for frame in itertools.islice(vid.seek(start), self.clip_len):
                video_frames.append(self.frame_transform(frame['data']))
                current_pts = frame['pts']
            # Stack it into a tensor
            video = torch.stack(video_frames, 0)
            if self.video_transform:
                video = self.video_transform(video)
            output = {
                'path': path,
                'video': video,
                'target': target,
                'start': start,
                'end': current_pts}
            yield output

294
# %%
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
# Given a path of videos in a folder structure, i.e:
#
# - dataset
#     - class 1
#         - file 0
#         - file 1
#         - ...
#     - class 2
#         - file 0
#         - file 1
#         - ...
#     - ...
#
# We can generate a dataloader and test the dataset.


transforms = [t.Resize((112, 112))]
frame_transform = t.Compose(transforms)

dataset = RandomDataset("./dataset", epoch_size=None, frame_transform=frame_transform)

316
# %%
317
318
319
320
321
322
323
324
325
326
327
from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=12)
data = {"video": [], 'start': [], 'end': [], 'tensorsize': []}
for batch in loader:
    for i in range(len(batch['path'])):
        data['video'].append(batch['path'][i])
        data['start'].append(batch['start'][i].item())
        data['end'].append(batch['end'][i].item())
        data['tensorsize'].append(batch['video'][i].size())
print(data)

328
# %%
329
330
331
332
# 4. Data Visualization
# ----------------------------------
# Example of visualized video

333
import matplotlib.pyplot as plt
334
335
336
337
338
339
340

plt.figure(figsize=(12, 12))
for i in range(16):
    plt.subplot(4, 4, i + 1)
    plt.imshow(batch["video"][0, i, ...].permute(1, 2, 0))
    plt.axis("off")

341
# %%
342
343
344
345
346
# Cleanup the video and dataset:
import os
import shutil
os.remove("./WUzgd7C1pWA.mp4")
shutil.rmtree("./dataset")