"src/vscode:/vscode.git/clone" did not exist on "2c59af7222990a5d1cbf745acd01ceeb7eb80196"
Unverified Commit 85ffd93c authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Expose frame-rate and cache to video datasets (#1356)

parent 31fad34f
......@@ -50,7 +50,8 @@ class HMDB51(VisionDataset):
}
def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
fold=1, train=True, transform=None):
frame_rate=None, fold=1, train=True, transform=None,
_precomputed_metadata=None):
super(HMDB51, self).__init__(root)
if not 1 <= fold <= 3:
raise ValueError("fold should be between 1 and 3, got {}".format(fold))
......@@ -64,7 +65,13 @@ class HMDB51(VisionDataset):
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
self.classes = classes
video_list = [x[0] for x in self.samples]
video_clips = VideoClips(video_list, frames_per_clip, step_between_clips)
video_clips = VideoClips(
video_list,
frames_per_clip,
step_between_clips,
frame_rate,
_precomputed_metadata,
)
self.indices = self._select_fold(video_list, annotation_path, fold, train)
self.video_clips = video_clips.subset(self.indices)
self.transform = transform
......
......@@ -36,7 +36,8 @@ class Kinetics400(VisionDataset):
label (int): class of the video clip
"""
def __init__(self, root, frames_per_clip, step_between_clips=1, transform=None):
def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None,
extensions=('avi',), transform=None, _precomputed_metadata=None):
super(Kinetics400, self).__init__(root)
extensions = ('avi',)
......@@ -45,7 +46,13 @@ class Kinetics400(VisionDataset):
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
self.classes = classes
video_list = [x[0] for x in self.samples]
self.video_clips = VideoClips(video_list, frames_per_clip, step_between_clips)
self.video_clips = VideoClips(
video_list,
frames_per_clip,
step_between_clips,
frame_rate,
_precomputed_metadata,
)
self.transform = transform
def __len__(self):
......
......@@ -43,7 +43,8 @@ class UCF101(VisionDataset):
"""
def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
fold=1, train=True, transform=None):
frame_rate=None, fold=1, train=True, transform=None,
_precomputed_metadata=None):
super(UCF101, self).__init__(root)
if not 1 <= fold <= 3:
raise ValueError("fold should be between 1 and 3, got {}".format(fold))
......@@ -57,7 +58,13 @@ class UCF101(VisionDataset):
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
self.classes = classes
video_list = [x[0] for x in self.samples]
video_clips = VideoClips(video_list, frames_per_clip, step_between_clips)
video_clips = VideoClips(
video_list,
frames_per_clip,
step_between_clips,
frame_rate,
_precomputed_metadata,
)
self.indices = self._select_fold(video_list, annotation_path, fold, train)
self.video_clips = video_clips.subset(self.indices)
self.transform = transform
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment