Commit 20844b5b authored by JMistele's avatar JMistele Committed by Francisco Massa
Browse files

Fixed video labelling after `subset` call for HMDB51 dataset (hmdb51.py)...


Fixed video labelling after `subset` call for HMDB51 dataset (hmdb51.py) (EDIT: UCF101 as well) (#1240)

* Fixed video labelling after subset for HMDB51 dataset

* Fixed video labelling after subset for HMDB51 dataset
Co-authored-by: default avatarEric Tang <etang21@stanford.edu>
Co-authored-by: default avatarRyan Cao <ryancao@stanford.edu>

* UCF 101 Labeling fixes

- Analogous fix to HMDB51 to maintain correct labels after the train-test split
- Additional change to the `select_fold` method in `ucf101.py` to correctly reflect the annotation format
Co-authored-by: default avatarRyan Cao <ryancao@stanford.edu>
Co-authored-by: default avatarEric Tang <etang21@stanford.edu>
parent 0bd7080c
...@@ -65,8 +65,8 @@ class HMDB51(VisionDataset): ...@@ -65,8 +65,8 @@ class HMDB51(VisionDataset):
self.classes = classes self.classes = classes
video_list = [x[0] for x in self.samples] video_list = [x[0] for x in self.samples]
video_clips = VideoClips(video_list, frames_per_clip, step_between_clips) video_clips = VideoClips(video_list, frames_per_clip, step_between_clips)
indices = self._select_fold(video_list, annotation_path, fold, train) self.indices = self._select_fold(video_list, annotation_path, fold, train)
self.video_clips = video_clips.subset(indices) self.video_clips = video_clips.subset(self.indices)
self.transform = transform self.transform = transform
def _select_fold(self, video_list, annotation_path, fold, train): def _select_fold(self, video_list, annotation_path, fold, train):
...@@ -89,7 +89,7 @@ class HMDB51(VisionDataset): ...@@ -89,7 +89,7 @@ class HMDB51(VisionDataset):
def __getitem__(self, idx): def __getitem__(self, idx):
video, audio, info, video_idx = self.video_clips.get_clip(idx) video, audio, info, video_idx = self.video_clips.get_clip(idx)
label = self.samples[video_idx][1] label = self.samples[self.indices[video_idx]][1]
if self.transform is not None: if self.transform is not None:
video = self.transform(video) video = self.transform(video)
......
...@@ -58,8 +58,8 @@ class UCF101(VisionDataset): ...@@ -58,8 +58,8 @@ class UCF101(VisionDataset):
self.classes = classes self.classes = classes
video_list = [x[0] for x in self.samples] video_list = [x[0] for x in self.samples]
video_clips = VideoClips(video_list, frames_per_clip, step_between_clips) video_clips = VideoClips(video_list, frames_per_clip, step_between_clips)
indices = self._select_fold(video_list, annotation_path, fold, train) self.indices = self._select_fold(video_list, annotation_path, fold, train)
self.video_clips = video_clips.subset(indices) self.video_clips = video_clips.subset(self.indices)
self.transform = transform self.transform = transform
def _select_fold(self, video_list, annotation_path, fold, train): def _select_fold(self, video_list, annotation_path, fold, train):
...@@ -81,7 +81,7 @@ class UCF101(VisionDataset): ...@@ -81,7 +81,7 @@ class UCF101(VisionDataset):
def __getitem__(self, idx): def __getitem__(self, idx):
video, audio, info, video_idx = self.video_clips.get_clip(idx) video, audio, info, video_idx = self.video_clips.get_clip(idx)
label = self.samples[video_idx][1] label = self.samples[self.indices[video_idx]][1]
if self.transform is not None: if self.transform is not None:
video = self.transform(video) video = self.transform(video)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment