Unverified Commit 94c94170 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Move RandomClipSampler to references (#1186)

* Move RandomClipSampler to references

* Lint and bugfix
parent fe4d17fc
......@@ -87,3 +87,38 @@ class UniformClipSampler(torch.utils.data.Sampler):
def __len__(self):
return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)
class RandomClipSampler(torch.utils.data.Sampler):
"""
Samples at most `max_video_clips_per_video` clips for each video randomly
Arguments:
video_clips (VideoClips): video clips to sample from
max_clips_per_video (int): maximum number of clips to be sampled per video
"""
def __init__(self, video_clips, max_clips_per_video):
if not isinstance(video_clips, torchvision.datasets.video_utils.VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, "
"got {}".format(type(video_clips)))
self.video_clips = video_clips
self.max_clips_per_video = max_clips_per_video
def __iter__(self):
idxs = []
s = 0
# select at most max_clips_per_video for each video, randomly
for c in self.video_clips.clips:
length = len(c)
size = min(length, self.max_clips_per_video)
sampled = torch.randperm(length)[:size] + s
s += length
idxs.append(sampled)
idxs = torch.cat(idxs)
# shuffle all clips randomly
perm = torch.randperm(len(idxs))
idxs = idxs[perm].tolist()
return iter(idxs)
def __len__(self):
return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)
......@@ -13,7 +13,7 @@ import torchvision.datasets.video_utils
from torchvision import transforms
import utils
from sampler import DistributedSampler, UniformClipSampler
from sampler import DistributedSampler, UniformClipSampler, RandomClipSampler
from scheduler import WarmupMultiStepLR
import transforms as T
......@@ -184,7 +184,7 @@ def main(args):
dataset_test.video_clips.compute_clips(args.clip_len, 1, frame_rate=15)
print("Creating data loaders")
train_sampler = torchvision.datasets.video_utils.RandomClipSampler(dataset.video_clips, args.clips_per_video)
train_sampler = RandomClipSampler(dataset.video_clips, args.clips_per_video)
test_sampler = UniformClipSampler(dataset_test.video_clips, args.clips_per_video)
if args.distributed:
train_sampler = DistributedSampler(train_sampler)
......
......@@ -4,7 +4,7 @@ import torch
import unittest
from torchvision import io
from torchvision.datasets.video_utils import VideoClips, unfold, RandomClipSampler
from torchvision.datasets.video_utils import VideoClips, unfold
from common_utils import get_tmp_dir
......@@ -80,10 +80,11 @@ class Tester(unittest.TestCase):
self.assertEqual(video_idx, v_idx)
self.assertEqual(clip_idx, c_idx)
@unittest.skip("Moved to reference scripts for now")
def test_video_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3)
sampler = RandomClipSampler(video_clips, 3) # noqa: F821
self.assertEqual(len(sampler), 3 * 3)
indices = torch.tensor(list(iter(sampler)))
videos = indices // 5
......@@ -91,10 +92,11 @@ class Tester(unittest.TestCase):
self.assertTrue(v_idxs.equal(torch.tensor([0, 1, 2])))
self.assertTrue(count.equal(torch.tensor([3, 3, 3])))
@unittest.skip("Moved to reference scripts for now")
def test_video_sampler_unequal(self):
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3)
sampler = RandomClipSampler(video_clips, 3) # noqa: F821
self.assertEqual(len(sampler), 2 + 3 + 3)
indices = list(iter(sampler))
self.assertIn(0, indices)
......
import bisect
import math
import torch
import torch.utils.data
from torchvision.io import read_video_timestamps, read_video
from .utils import tqdm
......@@ -214,38 +213,3 @@ class VideoClips(object):
info["video_fps"] = self.frame_rate
assert len(video) == self.num_frames, "{} x {}".format(video.shape, self.num_frames)
return video, audio, info, video_idx
class RandomClipSampler(torch.utils.data.Sampler):
"""
Samples at most `max_video_clips_per_video` clips for each video randomly
Arguments:
video_clips (VideoClips): video clips to sample from
max_clips_per_video (int): maximum number of clips to be sampled per video
"""
def __init__(self, video_clips, max_clips_per_video):
if not isinstance(video_clips, VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, "
"got {}".format(type(video_clips)))
self.video_clips = video_clips
self.max_clips_per_video = max_clips_per_video
def __iter__(self):
idxs = []
s = 0
# select at most max_clips_per_video for each video, randomly
for c in self.video_clips.clips:
length = len(c)
size = min(length, self.max_clips_per_video)
sampled = torch.randperm(length)[:size] + s
s += length
idxs.append(sampled)
idxs = torch.cat(idxs)
# shuffle all clips randomly
perm = torch.randperm(len(idxs))
idxs = idxs[perm].tolist()
return iter(idxs)
def __len__(self):
return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment