Commit f0d3daa7 authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Francisco Massa
Browse files

move sampler into TV core. Update UniformClipSampler (#1408)

* move sampler into TV core. Update UniformClipSampler

* Fix reference training script

* Skip test if pyav not available

* change interpolation from round() to floor() as round(0.5) behaves differently between py2 and py3
parent edfd5a77
......@@ -11,9 +11,10 @@ from torch import nn
import torchvision
import torchvision.datasets.video_utils
from torchvision import transforms
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
import utils
from sampler import DistributedSampler, UniformClipSampler, RandomClipSampler
from scheduler import WarmupMultiStepLR
import transforms as T
......
import contextlib
import sys
import os
import torch
import unittest
from torchvision import io
from torchvision.datasets.samplers import RandomClipSampler, UniformClipSampler
from torchvision.datasets.video_utils import VideoClips, unfold
from torchvision import get_video_backend
from common_utils import get_tmp_dir
@contextlib.contextmanager
def get_list_of_videos(num_videos=5, sizes=None, fps=None):
with get_tmp_dir() as tmp_dir:
names = []
for i in range(num_videos):
if sizes is None:
size = 5 * (i + 1)
else:
size = sizes[i]
if fps is None:
f = 5
else:
f = fps[i]
data = torch.randint(0, 255, (size, 300, 400, 3), dtype=torch.uint8)
name = os.path.join(tmp_dir, "{}.mp4".format(i))
names.append(name)
io.write_video(name, data, fps=f)
yield names
@unittest.skipIf(not io.video._av_available(), "this test requires av")
class Tester(unittest.TestCase):
def test_random_clip_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3)
self.assertEqual(len(sampler), 3 * 3)
indices = torch.tensor(list(iter(sampler)))
videos = indices // 5
v_idxs, count = torch.unique(videos, return_counts=True)
self.assertTrue(v_idxs.equal(torch.tensor([0, 1, 2])))
self.assertTrue(count.equal(torch.tensor([3, 3, 3])))
def test_random_clip_sampler_unequal(self):
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3)
self.assertEqual(len(sampler), 2 + 3 + 3)
indices = list(iter(sampler))
self.assertIn(0, indices)
self.assertIn(1, indices)
# remove elements of the first video, to simplify testing
indices.remove(0)
indices.remove(1)
indices = torch.tensor(indices) - 2
videos = indices // 5
v_idxs, count = torch.unique(videos, return_counts=True)
self.assertTrue(v_idxs.equal(torch.tensor([0, 1])))
self.assertTrue(count.equal(torch.tensor([3, 3])))
def test_uniform_clip_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
sampler = UniformClipSampler(video_clips, 3)
self.assertEqual(len(sampler), 3 * 3)
indices = torch.tensor(list(iter(sampler)))
videos = indices // 5
v_idxs, count = torch.unique(videos, return_counts=True)
self.assertTrue(v_idxs.equal(torch.tensor([0, 1, 2])))
self.assertTrue(count.equal(torch.tensor([3, 3, 3])))
self.assertTrue(indices.equal(torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14])))
def test_uniform_clip_sampler_insufficient_clips(self):
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
sampler = UniformClipSampler(video_clips, 3)
self.assertEqual(len(sampler), 3 * 3)
indices = torch.tensor(list(iter(sampler)))
self.assertTrue(indices.equal(torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11])))
if __name__ == '__main__':
unittest.main()
......@@ -83,36 +83,6 @@ class Tester(unittest.TestCase):
self.assertEqual(video_idx, v_idx)
self.assertEqual(clip_idx, c_idx)
@unittest.skip("Moved to reference scripts for now")
def test_video_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3) # noqa: F821
self.assertEqual(len(sampler), 3 * 3)
indices = torch.tensor(list(iter(sampler)))
videos = indices // 5
v_idxs, count = torch.unique(videos, return_counts=True)
self.assertTrue(v_idxs.equal(torch.tensor([0, 1, 2])))
self.assertTrue(count.equal(torch.tensor([3, 3, 3])))
@unittest.skip("Moved to reference scripts for now")
def test_video_sampler_unequal(self):
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3) # noqa: F821
self.assertEqual(len(sampler), 2 + 3 + 3)
indices = list(iter(sampler))
self.assertIn(0, indices)
self.assertIn(1, indices)
# remove elements of the first video, to simplify testing
indices.remove(0)
indices.remove(1)
indices = torch.tensor(indices) - 2
videos = indices // 5
v_idxs, count = torch.unique(videos, return_counts=True)
self.assertTrue(v_idxs.equal(torch.tensor([0, 1])))
self.assertTrue(count.equal(torch.tensor([3, 3])))
@unittest.skipIf(not io.video._av_available(), "this test requires av")
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_video_clips_custom_fps(self):
......
from .clip_sampler import DistributedSampler, UniformClipSampler, RandomClipSampler
__all__ = ('DistributedSampler', 'UniformClipSampler', 'RandomClipSampler')
......@@ -60,33 +60,45 @@ class DistributedSampler(Sampler):
class UniformClipSampler(torch.utils.data.Sampler):
"""
Samples at most `max_video_clips_per_video` clips for each video, equally spaced
Sample `num_video_clips_per_video` clips for each video, equally spaced.
When number of unique clips in the video is fewer than num_video_clips_per_video,
repeat the clips until `num_video_clips_per_video` clips are collected
Arguments:
video_clips (VideoClips): video clips to sample from
max_clips_per_video (int): maximum number of clips to be sampled per video
num_clips_per_video (int): number of clips to be sampled per video
"""
def __init__(self, video_clips, max_clips_per_video):
def __init__(self, video_clips, num_clips_per_video):
if not isinstance(video_clips, torchvision.datasets.video_utils.VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, "
"got {}".format(type(video_clips)))
self.video_clips = video_clips
self.max_clips_per_video = max_clips_per_video
self.num_clips_per_video = num_clips_per_video
def __iter__(self):
idxs = []
s = 0
# select at most max_clips_per_video for each video, uniformly spaced
# select num_clips_per_video for each video, uniformly spaced
for c in self.video_clips.clips:
length = len(c)
step = max(length // self.max_clips_per_video, 1)
sampled = torch.arange(length)[::step] + s
if length == 0:
# corner case where video decoding fails
continue
sampled = (
torch.linspace(s, s + length - 1, steps=self.num_clips_per_video)
.floor()
.to(torch.int64)
)
s += length
idxs.append(sampled)
idxs = torch.cat(idxs).tolist()
return iter(idxs)
def __len__(self):
return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)
return sum(
self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0
)
class RandomClipSampler(torch.utils.data.Sampler):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment