video_utils.py 16.6 KB
Newer Older
1
2
import bisect
import math
3
import warnings
4
from fractions import Fraction
5
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, TypeVar, Union
6

7
import torch
8
from torchvision.io import _probe_video_from_file, _read_video_from_file, read_video, read_video_timestamps
9

10
11
from .utils import tqdm

12
T = TypeVar("T")
13

14
15

def pts_convert(pts: int, timebase_from: Fraction, timebase_to: Fraction, round_func: Callable = math.floor) -> int:
16
17
18
19
20
21
22
23
24
25
26
    """convert pts between different time bases
    Args:
        pts: presentation timestamp, float
        timebase_from: original timebase. Fraction
        timebase_to: new timebase. Fraction
        round_func: rounding function.
    """
    new_pts = Fraction(pts, 1) * timebase_from / timebase_to
    return round_func(new_pts)


27
def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> torch.Tensor:
28
29
30
31
32
33
34
35
    """
    similar to tensor.unfold, but with the dilation
    and specialized for 1d tensors

    Returns all consecutive windows of `size` elements, with
    `step` between windows. The distance between each element
    in a window is given by `dilation`.
    """
36
37
    if tensor.dim() != 1:
        raise ValueError(f"tensor should have 1 dimension instead of {tensor.dim()}")
38
39
40
41
42
43
44
45
46
    o_stride = tensor.stride(0)
    numel = tensor.numel()
    new_stride = (step * o_stride, dilation * o_stride)
    new_size = ((numel - (dilation * (size - 1) + 1)) // step + 1, size)
    if new_size[0] < 1:
        new_size = (0, size)
    return torch.as_strided(tensor, new_size, new_stride)


47
class _VideoTimestampsDataset:
48
    """
49
50
51
52
53
    Dataset used to parallelize the reading of the timestamps
    of a list of videos, given their paths in the filesystem.

    Used in VideoClips and defined at top level so it can be
    pickled when forking.
54
    """
55

56
    def __init__(self, video_paths: List[str]) -> None:
57
        self.video_paths = video_paths
58

59
    def __len__(self) -> int:
60
        return len(self.video_paths)
61

62
    def __getitem__(self, idx: int) -> Tuple[List[int], Optional[float]]:
63
        return read_video_timestamps(self.video_paths[idx])
64
65


66
def _collate_fn(x: T) -> T:
67
68
69
70
71
72
    """
    Dummy collate function to be used with _VideoTimestampsDataset
    """
    return x


73
class VideoClips:
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    """
    Given a list of video files, computes all consecutive subvideos of size
    `clip_length_in_frames`, where the distance between each subvideo in the
    same video is defined by `frames_between_clips`.
    If `frame_rate` is specified, it will also resample all the videos to have
    the same frame rate, and the clips will refer to this frame rate.

    Creating this instance the first time is time-consuming, as it needs to
    decode all the videos in `video_paths`. It is recommended that you
    cache the results after instantiation of the class.

    Recreating the clips for different clip lengths is fast, and can be done
    with the `compute_clips` method.

88
    Args:
89
90
91
92
93
94
        video_paths (List[str]): paths to the video files
        clip_length_in_frames (int): size of a clip in number of frames
        frames_between_clips (int): step (in frames) between each clip
        frame_rate (int, optional): if specified, it will resample the video
            so that it has `frame_rate`, and then the clips will be defined
            on the resampled video
ekosman's avatar
ekosman committed
95
96
        num_workers (int): how many subprocesses to use for data loading.
            0 means that the data will be loaded in the main process. (default: 0)
97
        output_format (str): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
98
    """
99
100
101

    def __init__(
        self,
102
103
104
105
106
107
108
109
110
111
112
113
        video_paths: List[str],
        clip_length_in_frames: int = 16,
        frames_between_clips: int = 1,
        frame_rate: Optional[int] = None,
        _precomputed_metadata: Optional[Dict[str, Any]] = None,
        num_workers: int = 0,
        _video_width: int = 0,
        _video_height: int = 0,
        _video_min_dimension: int = 0,
        _video_max_dimension: int = 0,
        _audio_samples: int = 0,
        _audio_channels: int = 0,
114
        output_format: str = "THWC",
115
    ) -> None:
116

117
        self.video_paths = video_paths
118
        self.num_workers = num_workers
119
120

        # these options are not valid for pyav backend
121
122
123
        self._video_width = _video_width
        self._video_height = _video_height
        self._video_min_dimension = _video_min_dimension
124
        self._video_max_dimension = _video_max_dimension
125
        self._audio_samples = _audio_samples
126
        self._audio_channels = _audio_channels
127
128
129
        self.output_format = output_format.upper()
        if self.output_format not in ("THWC", "TCHW"):
            raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
ekosman's avatar
ekosman committed
130

131
132
133
134
        if _precomputed_metadata is None:
            self._compute_frame_pts()
        else:
            self._init_from_metadata(_precomputed_metadata)
135
136
        self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate)

137
    def _compute_frame_pts(self) -> None:
138
        self.video_pts = []
139
        self.video_fps = []
140
141
142
143

        # strategy: use a DataLoader to parallelize read_video_timestamps
        # so need to create a dummy dataset first
        import torch.utils.data
144

145
146
        dl: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
            _VideoTimestampsDataset(self.video_paths),  # type: ignore[arg-type]
147
            batch_size=16,
148
            num_workers=self.num_workers,
149
            collate_fn=_collate_fn,
150
        )
151
152
153
154

        with tqdm(total=len(dl)) as pbar:
            for batch in dl:
                pbar.update(1)
155
                clips, fps = list(zip(*batch))
156
157
158
159
                # we need to specify dtype=torch.long because for empty list,
                # torch.as_tensor will use torch.float as default dtype. This
                # happens when decoding fails and no pts is returned in the list.
                clips = [torch.as_tensor(c, dtype=torch.long) for c in clips]
160
161
                self.video_pts.extend(clips)
                self.video_fps.extend(fps)
162

163
    def _init_from_metadata(self, metadata: Dict[str, Any]) -> None:
164
        self.video_paths = metadata["video_paths"]
165
166
        assert len(self.video_paths) == len(metadata["video_pts"])
        self.video_pts = metadata["video_pts"]
167
168
        assert len(self.video_paths) == len(metadata["video_fps"])
        self.video_fps = metadata["video_fps"]
169
170

    @property
171
    def metadata(self) -> Dict[str, Any]:
172
173
174
        _metadata = {
            "video_paths": self.video_paths,
            "video_pts": self.video_pts,
175
            "video_fps": self.video_fps,
176
        }
177
        return _metadata
178

179
    def subset(self, indices: List[int]) -> "VideoClips":
180
181
        video_paths = [self.video_paths[i] for i in indices]
        video_pts = [self.video_pts[i] for i in indices]
182
        video_fps = [self.video_fps[i] for i in indices]
183
        metadata = {
184
            "video_paths": video_paths,
185
            "video_pts": video_pts,
186
            "video_fps": video_fps,
187
        }
188
189
190
191
192
193
194
195
196
197
        return type(self)(
            video_paths,
            self.num_frames,
            self.step,
            self.frame_rate,
            _precomputed_metadata=metadata,
            num_workers=self.num_workers,
            _video_width=self._video_width,
            _video_height=self._video_height,
            _video_min_dimension=self._video_min_dimension,
198
            _video_max_dimension=self._video_max_dimension,
199
200
            _audio_samples=self._audio_samples,
            _audio_channels=self._audio_channels,
201
            output_format=self.output_format,
202
        )
203

204
    @staticmethod
205
206
207
    def compute_clips_for_video(
        video_pts: torch.Tensor, num_frames: int, step: int, fps: int, frame_rate: Optional[int] = None
    ) -> Tuple[torch.Tensor, Union[List[slice], torch.Tensor]]:
208
209
210
211
        if fps is None:
            # if for some reason the video doesn't have fps (because doesn't have a video stream)
            # set the fps to 1. The value doesn't matter, because video_pts is empty anyway
            fps = 1
212
213
214
        if frame_rate is None:
            frame_rate = fps
        total_frames = len(video_pts) * (float(frame_rate) / fps)
215
216
        _idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
        video_pts = video_pts[_idxs]
217
        clips = unfold(video_pts, num_frames, step)
218
        if not clips.numel():
219
220
221
222
            warnings.warn(
                "There aren't enough frames in the current video to get a clip for the given clip length and "
                "frames between clips. The video (and potentially others) will be skipped."
            )
223
224
225
        idxs: Union[List[slice], torch.Tensor]
        if isinstance(_idxs, slice):
            idxs = [_idxs] * len(clips)
226
        else:
227
            idxs = unfold(_idxs, num_frames, step)
228
229
        return clips, idxs

230
    def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[int] = None) -> None:
231
232
233
234
235
        """
        Compute all consecutive sequences of clips from video_pts.
        Always returns clips of size `num_frames`, meaning that the
        last few frames in a video can potentially be dropped.

236
        Args:
237
238
            num_frames (int): number of frames for the clip
            step (int): distance between two clips
239
            frame_rate (int, optional): The frame rate
240
241
242
243
244
245
        """
        self.num_frames = num_frames
        self.step = step
        self.frame_rate = frame_rate
        self.clips = []
        self.resampling_idxs = []
246
        for video_pts, fps in zip(self.video_pts, self.video_fps):
247
            clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate)
248
249
            self.clips.append(clips)
            self.resampling_idxs.append(idxs)
250
251
252
        clip_lengths = torch.as_tensor([len(v) for v in self.clips])
        self.cumulative_sizes = clip_lengths.cumsum(0).tolist()

253
    def __len__(self) -> int:
254
255
        return self.num_clips()

256
    def num_videos(self) -> int:
257
258
        return len(self.video_paths)

259
    def num_clips(self) -> int:
260
261
262
263
264
        """
        Number of subclips that are available in the video list.
        """
        return self.cumulative_sizes[-1]

265
    def get_clip_location(self, idx: int) -> Tuple[int, int]:
266
267
268
269
270
271
272
273
274
275
276
277
        """
        Converts a flattened representation of the indices into a video_idx, clip_idx
        representation.
        """
        video_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if video_idx == 0:
            clip_idx = idx
        else:
            clip_idx = idx - self.cumulative_sizes[video_idx - 1]
        return video_idx, clip_idx

    @staticmethod
278
    def _resample_video_idx(num_frames: int, original_fps: int, new_fps: int) -> Union[slice, torch.Tensor]:
279
280
281
282
283
284
285
286
287
288
        step = float(original_fps) / new_fps
        if step.is_integer():
            # optimization: if step is integer, don't need to perform
            # advanced indexing
            step = int(step)
            return slice(None, None, step)
        idxs = torch.arange(num_frames, dtype=torch.float32) * step
        idxs = idxs.floor().to(torch.int64)
        return idxs

289
    def get_clip(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any], int]:
290
291
292
        """
        Gets a subclip from a list of videos.

293
        Args:
294
295
296
297
298
299
300
301
302
            idx (int): index of the subclip. Must be between 0 and num_clips().

        Returns:
            video (Tensor)
            audio (Tensor)
            info (Dict)
            video_idx (int): index of the video in `video_paths`
        """
        if idx >= self.num_clips():
303
            raise IndexError(f"Index {idx} out of range ({self.num_clips()} number of clips)")
304
305
306
        video_idx, clip_idx = self.get_clip_location(idx)
        video_path = self.video_paths[video_idx]
        clip_pts = self.clips[video_idx][clip_idx]
307

308
        from torchvision import get_video_backend
309

310
311
312
313
314
315
316
317
318
        backend = get_video_backend()

        if backend == "pyav":
            # check for invalid options
            if self._video_width != 0:
                raise ValueError("pyav backend doesn't support _video_width != 0")
            if self._video_height != 0:
                raise ValueError("pyav backend doesn't support _video_height != 0")
            if self._video_min_dimension != 0:
319
                raise ValueError("pyav backend doesn't support _video_min_dimension != 0")
320
            if self._video_max_dimension != 0:
321
                raise ValueError("pyav backend doesn't support _video_max_dimension != 0")
322
323
324
325
            if self._audio_samples != 0:
                raise ValueError("pyav backend doesn't support _audio_samples != 0")

        if backend == "pyav":
326
327
328
329
            start_pts = clip_pts[0].item()
            end_pts = clip_pts[-1].item()
            video, audio, info = read_video(video_path, start_pts, end_pts)
        else:
330
331
            _info = _probe_video_from_file(video_path)
            video_fps = _info.video_fps
332
            audio_fps = None
333

334
335
            video_start_pts = cast(int, clip_pts[0].item())
            video_end_pts = cast(int, clip_pts[-1].item())
336
337
338

            audio_start_pts, audio_end_pts = 0, -1
            audio_timebase = Fraction(0, 1)
339
340
341
            video_timebase = Fraction(_info.video_timebase.numerator, _info.video_timebase.denominator)
            if _info.has_audio:
                audio_timebase = Fraction(_info.audio_timebase.numerator, _info.audio_timebase.denominator)
342
343
                audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor)
                audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil)
344
345
                audio_fps = _info.audio_sample_rate
            video, audio, _ = _read_video_from_file(
346
                video_path,
347
348
349
                video_width=self._video_width,
                video_height=self._video_height,
                video_min_dimension=self._video_min_dimension,
350
                video_max_dimension=self._video_max_dimension,
351
                video_pts_range=(video_start_pts, video_end_pts),
352
                video_timebase=video_timebase,
353
                audio_samples=self._audio_samples,
354
                audio_channels=self._audio_channels,
355
356
357
                audio_pts_range=(audio_start_pts, audio_end_pts),
                audio_timebase=audio_timebase,
            )
358
359
360
361
362

            info = {"video_fps": video_fps}
            if audio_fps is not None:
                info["audio_fps"] = audio_fps

363
364
365
366
367
368
        if self.frame_rate is not None:
            resampling_idx = self.resampling_idxs[video_idx][clip_idx]
            if isinstance(resampling_idx, torch.Tensor):
                resampling_idx = resampling_idx - resampling_idx[0]
            video = video[resampling_idx]
            info["video_fps"] = self.frame_rate
369
        assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}"
370
371
372
373
374

        if self.output_format == "TCHW":
            # [T,H,W,C] --> [T,C,H,W]
            video = video.permute(0, 3, 1, 2)

375
        return video, audio, info, video_idx
376

377
    def __getstate__(self) -> Dict[str, Any]:
378
379
380
381
382
383
384
385
        video_pts_sizes = [len(v) for v in self.video_pts]
        # To be back-compatible, we convert data to dtype torch.long as needed
        # because for empty list, in legacy implementation, torch.as_tensor will
        # use torch.float as default dtype. This happens when decoding fails and
        # no pts is returned in the list.
        video_pts = [x.to(torch.int64) for x in self.video_pts]
        # video_pts can be an empty list if no frames have been decoded
        if video_pts:
386
            video_pts = torch.cat(video_pts)  # type: ignore[assignment]
387
388
            # avoid bug in https://github.com/pytorch/pytorch/issues/32351
            # TODO: Revert it once the bug is fixed.
389
            video_pts = video_pts.numpy()  # type: ignore[attr-defined]
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404

        # make a copy of the fields of self
        d = self.__dict__.copy()
        d["video_pts_sizes"] = video_pts_sizes
        d["video_pts"] = video_pts
        # delete the following attributes to reduce the size of dictionary. They
        # will be re-computed in "__setstate__()"
        del d["clips"]
        del d["resampling_idxs"]
        del d["cumulative_sizes"]

        # for backwards-compatibility
        d["_version"] = 2
        return d

405
    def __setstate__(self, d: Dict[str, Any]) -> None:
406
407
408
409
410
411
412
413
414
415
416
417
418
419
        # for backwards-compatibility
        if "_version" not in d:
            self.__dict__ = d
            return

        video_pts = torch.as_tensor(d["video_pts"], dtype=torch.int64)
        video_pts = torch.split(video_pts, d["video_pts_sizes"], dim=0)
        # don't need this info anymore
        del d["video_pts_sizes"]

        d["video_pts"] = video_pts
        self.__dict__ = d
        # recompute attributes "clips", "resampling_idxs" and other derivative ones
        self.compute_clips(self.num_frames, self.step, self.frame_rate)