Unverified Commit 5d1372c0 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add VideoClips and Kinetics dataset (#1077)

* Add VideoClips and Kinetics dataset

* Lint + add back missing line

* Adds ClipSampler following Bruno comment

* Change name following Bruno's suggestion

* Enable specifying a target framerate

* Fix test_io for new interface

* Add comment mentioning drop_last behavior

* Make compute_clips more robust

* Flake8

* Fix for Python2
parent 2b81ad8c
import contextlib
import os
import torch
import unittest
from torchvision import io
from torchvision.datasets.video_utils import VideoClips, unfold, RandomClipSampler
from common_utils import get_tmp_dir
@contextlib.contextmanager
def get_list_of_videos(num_videos=5, sizes=None, fps=None):
with get_tmp_dir() as tmp_dir:
names = []
for i in range(num_videos):
if sizes is None:
size = 5 * (i + 1)
else:
size = sizes[i]
if fps is None:
f = 5
else:
f = fps[i]
data = torch.randint(0, 255, (size, 300, 400, 3), dtype=torch.uint8)
name = os.path.join(tmp_dir, "{}.mp4".format(i))
names.append(name)
io.write_video(name, data, fps=f)
yield names
class Tester(unittest.TestCase):
def test_unfold(self):
a = torch.arange(7)
r = unfold(a, 3, 3, 1)
expected = torch.tensor([
[0, 1, 2],
[3, 4, 5],
])
self.assertTrue(r.equal(expected))
r = unfold(a, 3, 2, 1)
expected = torch.tensor([
[0, 1, 2],
[2, 3, 4],
[4, 5, 6]
])
self.assertTrue(r.equal(expected))
r = unfold(a, 3, 2, 2)
expected = torch.tensor([
[0, 2, 4],
[2, 4, 6],
])
self.assertTrue(r.equal(expected))
def test_video_clips(self):
with get_list_of_videos(num_videos=3) as video_list:
video_clips = VideoClips(video_list, 5, 5)
self.assertEqual(video_clips.num_clips(), 1 + 2 + 3)
for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]):
video_idx, clip_idx = video_clips.get_clip_location(i)
self.assertEqual(video_idx, v_idx)
self.assertEqual(clip_idx, c_idx)
video_clips = VideoClips(video_list, 6, 6)
self.assertEqual(video_clips.num_clips(), 0 + 1 + 2)
for i, (v_idx, c_idx) in enumerate([(1, 0), (2, 0), (2, 1)]):
video_idx, clip_idx = video_clips.get_clip_location(i)
self.assertEqual(video_idx, v_idx)
self.assertEqual(clip_idx, c_idx)
video_clips = VideoClips(video_list, 6, 1)
self.assertEqual(video_clips.num_clips(), 0 + (10 - 6 + 1) + (15 - 6 + 1))
for i, v_idx, c_idx in [(0, 1, 0), (4, 1, 4), (5, 2, 0), (6, 2, 1)]:
video_idx, clip_idx = video_clips.get_clip_location(i)
self.assertEqual(video_idx, v_idx)
self.assertEqual(clip_idx, c_idx)
def test_video_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3)
self.assertEqual(len(sampler), 3 * 3)
indices = torch.tensor(list(iter(sampler)))
videos = indices // 5
v_idxs, count = torch.unique(videos, return_counts=True)
self.assertTrue(v_idxs.equal(torch.tensor([0, 1, 2])))
self.assertTrue(count.equal(torch.tensor([3, 3, 3])))
def test_video_sampler_unequal(self):
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3)
self.assertEqual(len(sampler), 2 + 3 + 3)
indices = list(iter(sampler))
self.assertIn(0, indices)
self.assertIn(1, indices)
# remove elements of the first video, to simplify testing
indices.remove(0)
indices.remove(1)
indices = torch.tensor(indices) - 2
videos = indices // 5
v_idxs, count = torch.unique(videos, return_counts=True)
self.assertTrue(v_idxs.equal(torch.tensor([0, 1])))
self.assertTrue(count.equal(torch.tensor([3, 3])))
def test_video_clips_custom_fps(self):
with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list:
num_frames = 4
for fps in [1, 3, 4, 10]:
video_clips = VideoClips(video_list, num_frames, num_frames, fps)
for i in range(video_clips.num_clips()):
video, audio, info, video_idx = video_clips.get_clip(i)
self.assertEqual(video.shape[0], num_frames)
self.assertEqual(info["video_fps"], fps)
# TODO add tests checking that the content is right
def test_compute_clips_for_video(self):
video_pts = torch.arange(30)
# case 1: single clip
num_frames = 13
orig_fps = 30
duration = float(len(video_pts)) / orig_fps
new_fps = 13
clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames,
orig_fps, new_fps)
resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps)
self.assertEqual(len(clips), 1)
self.assertTrue(clips.equal(idxs))
self.assertTrue(idxs[0].equal(resampled_idxs))
# case 2: all frames appear only once
num_frames = 4
orig_fps = 30
duration = float(len(video_pts)) / orig_fps
new_fps = 12
clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames,
orig_fps, new_fps)
resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps)
self.assertEqual(len(clips), 3)
self.assertTrue(clips.equal(idxs))
self.assertTrue(idxs.flatten().equal(resampled_idxs))
if __name__ == '__main__':
unittest.main()
...@@ -44,7 +44,7 @@ class Tester(unittest.TestCase): ...@@ -44,7 +44,7 @@ class Tester(unittest.TestCase):
data = self._create_video_frames(10, 300, 300) data = self._create_video_frames(10, 300, 300)
io.write_video(f.name, data, fps=5) io.write_video(f.name, data, fps=5)
pts = io.read_video_timestamps(f.name) pts, _ = io.read_video_timestamps(f.name)
# note: not all formats/codecs provide accurate information for computing the # note: not all formats/codecs provide accurate information for computing the
# timestamps. For the format that we use here, this information is available, # timestamps. For the format that we use here, this information is available,
...@@ -63,7 +63,7 @@ class Tester(unittest.TestCase): ...@@ -63,7 +63,7 @@ class Tester(unittest.TestCase):
data = self._create_video_frames(10, 300, 300) data = self._create_video_frames(10, 300, 300)
io.write_video(f.name, data, fps=5) io.write_video(f.name, data, fps=5)
pts = io.read_video_timestamps(f.name) pts, _ = io.read_video_timestamps(f.name)
for start in range(5): for start in range(5):
for l in range(1, 4): for l in range(1, 4):
......
...@@ -19,6 +19,7 @@ from .celeba import CelebA ...@@ -19,6 +19,7 @@ from .celeba import CelebA
from .sbd import SBDataset from .sbd import SBDataset
from .vision import VisionDataset from .vision import VisionDataset
from .usps import USPS from .usps import USPS
from .kinetics import KineticsVideo
__all__ = ('LSUN', 'LSUNClass', __all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData', 'ImageFolder', 'DatasetFolder', 'FakeData',
...@@ -28,4 +29,4 @@ __all__ = ('LSUN', 'LSUNClass', ...@@ -28,4 +29,4 @@ __all__ = ('LSUN', 'LSUNClass',
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k', 'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet', 'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
'Caltech101', 'Caltech256', 'CelebA', 'SBDataset', 'VisionDataset', 'Caltech101', 'Caltech256', 'CelebA', 'SBDataset', 'VisionDataset',
'USPS') 'USPS', 'KineticsVideo')
from .video_utils import VideoClips
from .utils import list_dir
from .folder import make_dataset
from .vision import VisionDataset
class KineticsVideo(VisionDataset):
def __init__(self, root, frames_per_clip, step_between_clips=1):
super(KineticsVideo, self).__init__(root)
extensions = ('avi',)
classes = list(sorted(list_dir(root)))
class_to_idx = {classes[i]: i for i in range(len(classes))}
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
self.classes = classes
video_list = [x[0] for x in self.samples]
self.video_clips = VideoClips(video_list, frames_per_clip, step_between_clips)
def __len__(self):
return self.video_clips.num_clips()
def __getitem__(self, idx):
video, audio, info, video_idx = self.video_clips.get_clip(idx)
label = self.samples[video_idx][1]
return video, audio, label
import bisect
import math
import torch
import torch.utils.data
from torchvision.io import read_video_timestamps, read_video
def unfold(tensor, size, step, dilation=1):
"""
similar to tensor.unfold, but with the dilation
and specialized for 1d tensors
Returns all consecutive windows of `size` elements, with
`step` between windows. The distance between each element
in a window is given by `dilation`.
"""
assert tensor.dim() == 1
o_stride = tensor.stride(0)
numel = tensor.numel()
new_stride = (step * o_stride, dilation * o_stride)
new_size = ((numel - (dilation * (size - 1) + 1)) // step + 1, size)
if new_size[0] < 1:
new_size = (0, size)
return torch.as_strided(tensor, new_size, new_stride)
class VideoClips(object):
"""
Given a list of video files, computes all consecutive subvideos of size
`clip_length_in_frames`, where the distance between each subvideo in the
same video is defined by `frames_between_clips`.
If `frame_rate` is specified, it will also resample all the videos to have
the same frame rate, and the clips will refer to this frame rate.
Creating this instance the first time is time-consuming, as it needs to
decode all the videos in `video_paths`. It is recommended that you
cache the results after instantiation of the class.
Recreating the clips for different clip lengths is fast, and can be done
with the `compute_clips` method.
Arguments:
video_paths (List[str]): paths to the video files
clip_length_in_frames (int): size of a clip in number of frames
frames_between_clips (int): step (in frames) between each clip
frame_rate (int, optional): if specified, it will resample the video
so that it has `frame_rate`, and then the clips will be defined
on the resampled video
"""
def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1,
frame_rate=None):
self.video_paths = video_paths
self._compute_frame_pts()
self.compute_clips(clip_length_in_frames, frames_between_clips, frame_rate)
def _compute_frame_pts(self):
self.video_pts = []
self.video_fps = []
# TODO maybe paralellize this
for video_file in self.video_paths:
clips, fps = read_video_timestamps(video_file)
self.video_pts.append(torch.as_tensor(clips))
self.video_fps.append(fps)
@staticmethod
def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate):
if frame_rate is None:
frame_rate = fps
total_frames = len(video_pts) * (float(frame_rate) / fps)
idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
video_pts = video_pts[idxs]
clips = unfold(video_pts, num_frames, step)
if isinstance(idxs, slice):
idxs = [idxs] * len(clips)
else:
idxs = unfold(idxs, num_frames, step)
return clips, idxs
def compute_clips(self, num_frames, step, frame_rate=None):
"""
Compute all consecutive sequences of clips from video_pts.
Always returns clips of size `num_frames`, meaning that the
last few frames in a video can potentially be dropped.
Arguments:
num_frames (int): number of frames for the clip
step (int): distance between two clips
dilation (int): distance between two consecutive frames
in a clip
"""
self.num_frames = num_frames
self.step = step
self.frame_rate = frame_rate
self.clips = []
self.resampling_idxs = []
for video_pts, fps in zip(self.video_pts, self.video_fps):
clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate)
self.clips.append(clips)
self.resampling_idxs.append(idxs)
clip_lengths = torch.as_tensor([len(v) for v in self.clips])
self.cumulative_sizes = clip_lengths.cumsum(0).tolist()
def __len__(self):
return self.num_clips()
def num_videos(self):
return len(self.video_paths)
def num_clips(self):
"""
Number of subclips that are available in the video list.
"""
return self.cumulative_sizes[-1]
def get_clip_location(self, idx):
"""
Converts a flattened representation of the indices into a video_idx, clip_idx
representation.
"""
video_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if video_idx == 0:
clip_idx = idx
else:
clip_idx = idx - self.cumulative_sizes[video_idx - 1]
return video_idx, clip_idx
@staticmethod
def _resample_video_idx(num_frames, original_fps, new_fps):
step = float(original_fps) / new_fps
if step.is_integer():
# optimization: if step is integer, don't need to perform
# advanced indexing
step = int(step)
return slice(None, None, step)
idxs = torch.arange(num_frames, dtype=torch.float32) * step
idxs = idxs.floor().to(torch.int64)
return idxs
def get_clip(self, idx):
"""
Gets a subclip from a list of videos.
Arguments:
idx (int): index of the subclip. Must be between 0 and num_clips().
Returns:
video (Tensor)
audio (Tensor)
info (Dict)
video_idx (int): index of the video in `video_paths`
"""
if idx >= self.num_clips():
raise IndexError("Index {} out of range "
"({} number of clips)".format(idx, self.num_clips()))
video_idx, clip_idx = self.get_clip_location(idx)
video_path = self.video_paths[video_idx]
clip_pts = self.clips[video_idx][clip_idx]
video, audio, info = read_video(video_path, clip_pts[0].item(), clip_pts[-1].item())
if self.frame_rate is not None:
resampling_idx = self.resampling_idxs[video_idx][clip_idx]
if isinstance(resampling_idx, torch.Tensor):
resampling_idx = resampling_idx - resampling_idx[0]
video = video[resampling_idx]
info["video_fps"] = self.frame_rate
assert len(video) == self.num_frames
return video, audio, info, video_idx
class RandomClipSampler(torch.utils.data.Sampler):
"""
Samples at most `max_video_clips_per_video` clips for each video randomly
Arguments:
video_clips (VideoClips): video clips to sample from
max_clips_per_video (int): maximum number of clips to be sampled per video
"""
def __init__(self, video_clips, max_clips_per_video):
if not isinstance(video_clips, VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, "
"got {}".format(type(video_clips)))
self.video_clips = video_clips
self.max_clips_per_video = max_clips_per_video
def __iter__(self):
idxs = []
s = 0
# select at most max_clips_per_video for each video, randomly
for c in self.video_clips.clips:
length = len(c)
size = min(length, self.max_clips_per_video)
sampled = torch.randperm(length)[:size] + s
s += length
idxs.append(sampled)
idxs = torch.cat(idxs)
# shuffle all clips randomly
perm = torch.randperm(len(idxs))
idxs = idxs[perm].tolist()
return iter(idxs)
def __len__(self):
return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)
...@@ -159,13 +159,16 @@ def read_video_timestamps(filename): ...@@ -159,13 +159,16 @@ def read_video_timestamps(filename):
Returns: Returns:
pts (List[int]): presentation timestamps for each one of the frames pts (List[int]): presentation timestamps for each one of the frames
in the video. in the video.
video_fps (int): the frame rate for the video
""" """
_check_av_available() _check_av_available()
container = av.open(filename) container = av.open(filename)
video_frames = [] video_frames = []
video_fps = None
if container.streams.video: if container.streams.video:
video_frames = _read_from_stream(container, 0, float("inf"), video_frames = _read_from_stream(container, 0, float("inf"),
container.streams.video[0], {'video': 0}) container.streams.video[0], {'video': 0})
video_fps = float(container.streams.video[0].average_rate)
container.close() container.close()
return [x.pts for x in video_frames] return [x.pts for x in video_frames], video_fps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment