"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "bfe94a3993e069bf386c84e16a84ebbecdd7c5db"
Unverified Commit 85ffd93c authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Expose frame-rate and cache to video datasets (#1356)

parent 31fad34f
...@@ -50,7 +50,8 @@ class HMDB51(VisionDataset): ...@@ -50,7 +50,8 @@ class HMDB51(VisionDataset):
} }
def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1, def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
fold=1, train=True, transform=None): frame_rate=None, fold=1, train=True, transform=None,
_precomputed_metadata=None):
super(HMDB51, self).__init__(root) super(HMDB51, self).__init__(root)
if not 1 <= fold <= 3: if not 1 <= fold <= 3:
raise ValueError("fold should be between 1 and 3, got {}".format(fold)) raise ValueError("fold should be between 1 and 3, got {}".format(fold))
...@@ -64,7 +65,13 @@ class HMDB51(VisionDataset): ...@@ -64,7 +65,13 @@ class HMDB51(VisionDataset):
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None) self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
self.classes = classes self.classes = classes
video_list = [x[0] for x in self.samples] video_list = [x[0] for x in self.samples]
video_clips = VideoClips(video_list, frames_per_clip, step_between_clips) video_clips = VideoClips(
video_list,
frames_per_clip,
step_between_clips,
frame_rate,
_precomputed_metadata,
)
self.indices = self._select_fold(video_list, annotation_path, fold, train) self.indices = self._select_fold(video_list, annotation_path, fold, train)
self.video_clips = video_clips.subset(self.indices) self.video_clips = video_clips.subset(self.indices)
self.transform = transform self.transform = transform
......
...@@ -36,7 +36,8 @@ class Kinetics400(VisionDataset): ...@@ -36,7 +36,8 @@ class Kinetics400(VisionDataset):
label (int): class of the video clip label (int): class of the video clip
""" """
def __init__(self, root, frames_per_clip, step_between_clips=1, transform=None): def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None,
extensions=('avi',), transform=None, _precomputed_metadata=None):
super(Kinetics400, self).__init__(root) super(Kinetics400, self).__init__(root)
extensions = ('avi',) extensions = ('avi',)
...@@ -45,7 +46,13 @@ class Kinetics400(VisionDataset): ...@@ -45,7 +46,13 @@ class Kinetics400(VisionDataset):
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None) self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
self.classes = classes self.classes = classes
video_list = [x[0] for x in self.samples] video_list = [x[0] for x in self.samples]
self.video_clips = VideoClips(video_list, frames_per_clip, step_between_clips) self.video_clips = VideoClips(
video_list,
frames_per_clip,
step_between_clips,
frame_rate,
_precomputed_metadata,
)
self.transform = transform self.transform = transform
def __len__(self): def __len__(self):
......
...@@ -43,7 +43,8 @@ class UCF101(VisionDataset): ...@@ -43,7 +43,8 @@ class UCF101(VisionDataset):
""" """
def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1, def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
fold=1, train=True, transform=None): frame_rate=None, fold=1, train=True, transform=None,
_precomputed_metadata=None):
super(UCF101, self).__init__(root) super(UCF101, self).__init__(root)
if not 1 <= fold <= 3: if not 1 <= fold <= 3:
raise ValueError("fold should be between 1 and 3, got {}".format(fold)) raise ValueError("fold should be between 1 and 3, got {}".format(fold))
...@@ -57,7 +58,13 @@ class UCF101(VisionDataset): ...@@ -57,7 +58,13 @@ class UCF101(VisionDataset):
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None) self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
self.classes = classes self.classes = classes
video_list = [x[0] for x in self.samples] video_list = [x[0] for x in self.samples]
video_clips = VideoClips(video_list, frames_per_clip, step_between_clips) video_clips = VideoClips(
video_list,
frames_per_clip,
step_between_clips,
frame_rate,
_precomputed_metadata,
)
self.indices = self._select_fold(video_list, annotation_path, fold, train) self.indices = self._select_fold(video_list, annotation_path, fold, train)
self.video_clips = video_clips.subset(self.indices) self.video_clips = video_clips.subset(self.indices)
self.transform = transform self.transform = transform
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment