"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "a0939977a3b3c34c925c565c3fd3dcbe5d09e23c"
Unverified Commit 94c94170 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Move RandomClipSampler to references (#1186)

* Move RandomClipSampler to references

* Lint and bugfix
parent fe4d17fc
...@@ -87,3 +87,38 @@ class UniformClipSampler(torch.utils.data.Sampler): ...@@ -87,3 +87,38 @@ class UniformClipSampler(torch.utils.data.Sampler):
def __len__(self): def __len__(self):
return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips) return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)
class RandomClipSampler(torch.utils.data.Sampler):
"""
Samples at most `max_video_clips_per_video` clips for each video randomly
Arguments:
video_clips (VideoClips): video clips to sample from
max_clips_per_video (int): maximum number of clips to be sampled per video
"""
def __init__(self, video_clips, max_clips_per_video):
if not isinstance(video_clips, torchvision.datasets.video_utils.VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, "
"got {}".format(type(video_clips)))
self.video_clips = video_clips
self.max_clips_per_video = max_clips_per_video
def __iter__(self):
idxs = []
s = 0
# select at most max_clips_per_video for each video, randomly
for c in self.video_clips.clips:
length = len(c)
size = min(length, self.max_clips_per_video)
sampled = torch.randperm(length)[:size] + s
s += length
idxs.append(sampled)
idxs = torch.cat(idxs)
# shuffle all clips randomly
perm = torch.randperm(len(idxs))
idxs = idxs[perm].tolist()
return iter(idxs)
def __len__(self):
return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)
...@@ -13,7 +13,7 @@ import torchvision.datasets.video_utils ...@@ -13,7 +13,7 @@ import torchvision.datasets.video_utils
from torchvision import transforms from torchvision import transforms
import utils import utils
from sampler import DistributedSampler, UniformClipSampler from sampler import DistributedSampler, UniformClipSampler, RandomClipSampler
from scheduler import WarmupMultiStepLR from scheduler import WarmupMultiStepLR
import transforms as T import transforms as T
...@@ -184,7 +184,7 @@ def main(args): ...@@ -184,7 +184,7 @@ def main(args):
dataset_test.video_clips.compute_clips(args.clip_len, 1, frame_rate=15) dataset_test.video_clips.compute_clips(args.clip_len, 1, frame_rate=15)
print("Creating data loaders") print("Creating data loaders")
train_sampler = torchvision.datasets.video_utils.RandomClipSampler(dataset.video_clips, args.clips_per_video) train_sampler = RandomClipSampler(dataset.video_clips, args.clips_per_video)
test_sampler = UniformClipSampler(dataset_test.video_clips, args.clips_per_video) test_sampler = UniformClipSampler(dataset_test.video_clips, args.clips_per_video)
if args.distributed: if args.distributed:
train_sampler = DistributedSampler(train_sampler) train_sampler = DistributedSampler(train_sampler)
......
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
import unittest import unittest
from torchvision import io from torchvision import io
from torchvision.datasets.video_utils import VideoClips, unfold, RandomClipSampler from torchvision.datasets.video_utils import VideoClips, unfold
from common_utils import get_tmp_dir from common_utils import get_tmp_dir
...@@ -80,10 +80,11 @@ class Tester(unittest.TestCase): ...@@ -80,10 +80,11 @@ class Tester(unittest.TestCase):
self.assertEqual(video_idx, v_idx) self.assertEqual(video_idx, v_idx)
self.assertEqual(clip_idx, c_idx) self.assertEqual(clip_idx, c_idx)
@unittest.skip("Moved to reference scripts for now")
def test_video_sampler(self): def test_video_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5) video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3) sampler = RandomClipSampler(video_clips, 3) # noqa: F821
self.assertEqual(len(sampler), 3 * 3) self.assertEqual(len(sampler), 3 * 3)
indices = torch.tensor(list(iter(sampler))) indices = torch.tensor(list(iter(sampler)))
videos = indices // 5 videos = indices // 5
...@@ -91,10 +92,11 @@ class Tester(unittest.TestCase): ...@@ -91,10 +92,11 @@ class Tester(unittest.TestCase):
self.assertTrue(v_idxs.equal(torch.tensor([0, 1, 2]))) self.assertTrue(v_idxs.equal(torch.tensor([0, 1, 2])))
self.assertTrue(count.equal(torch.tensor([3, 3, 3]))) self.assertTrue(count.equal(torch.tensor([3, 3, 3])))
@unittest.skip("Moved to reference scripts for now")
def test_video_sampler_unequal(self): def test_video_sampler_unequal(self):
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list: with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5) video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3) sampler = RandomClipSampler(video_clips, 3) # noqa: F821
self.assertEqual(len(sampler), 2 + 3 + 3) self.assertEqual(len(sampler), 2 + 3 + 3)
indices = list(iter(sampler)) indices = list(iter(sampler))
self.assertIn(0, indices) self.assertIn(0, indices)
......
import bisect import bisect
import math import math
import torch import torch
import torch.utils.data
from torchvision.io import read_video_timestamps, read_video from torchvision.io import read_video_timestamps, read_video
from .utils import tqdm from .utils import tqdm
...@@ -214,38 +213,3 @@ class VideoClips(object): ...@@ -214,38 +213,3 @@ class VideoClips(object):
info["video_fps"] = self.frame_rate info["video_fps"] = self.frame_rate
assert len(video) == self.num_frames, "{} x {}".format(video.shape, self.num_frames) assert len(video) == self.num_frames, "{} x {}".format(video.shape, self.num_frames)
return video, audio, info, video_idx return video, audio, info, video_idx
class RandomClipSampler(torch.utils.data.Sampler):
"""
Samples at most `max_video_clips_per_video` clips for each video randomly
Arguments:
video_clips (VideoClips): video clips to sample from
max_clips_per_video (int): maximum number of clips to be sampled per video
"""
def __init__(self, video_clips, max_clips_per_video):
if not isinstance(video_clips, VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, "
"got {}".format(type(video_clips)))
self.video_clips = video_clips
self.max_clips_per_video = max_clips_per_video
def __iter__(self):
idxs = []
s = 0
# select at most max_clips_per_video for each video, randomly
for c in self.video_clips.clips:
length = len(c)
size = min(length, self.max_clips_per_video)
sampled = torch.randperm(length)[:size] + s
s += length
idxs.append(sampled)
idxs = torch.cat(idxs)
# shuffle all clips randomly
perm = torch.randperm(len(idxs))
idxs = idxs[perm].tolist()
return iter(idxs)
def __len__(self):
return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment