Unverified Commit de96b977 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added missing typing annotations in datasets/hmdb51 (#4169)



* style: Added missing typing annotations

* style: Fixed last missing typing annotation

* chore: Fixed missing import

* style: Fixed typing

* refactor: Switched file selection from list to set
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent b29ed34f
import glob
import os
from typing import Optional, Callable, Tuple, Dict, Any, List
from torch import Tensor
from .folder import find_classes, make_dataset
from .video_utils import VideoClips
......@@ -52,10 +54,23 @@ class HMDB51(VisionDataset):
TRAIN_TAG = 1
TEST_TAG = 2
def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
frame_rate=None, fold=1, train=True, transform=None,
_precomputed_metadata=None, num_workers=1, _video_width=0,
_video_height=0, _video_min_dimension=0, _audio_samples=0):
def __init__(
self,
root: str,
annotation_path: str,
frames_per_clip: int,
step_between_clips: int = 1,
frame_rate: Optional[int] = None,
fold: int = 1,
train: bool = True,
transform: Optional[Callable] = None,
_precomputed_metadata: Optional[Dict[str, Any]] = None,
num_workers: int = 1,
_video_width: int = 0,
_video_height: int = 0,
_video_min_dimension: int = 0,
_audio_samples: int = 0,
) -> None:
super(HMDB51, self).__init__(root)
if fold not in (1, 2, 3):
raise ValueError("fold should be between 1 and 3, got {}".format(fold))
......@@ -92,15 +107,15 @@ class HMDB51(VisionDataset):
self.transform = transform
@property
def metadata(self):
def metadata(self) -> Dict[str, Any]:
return self.full_video_clips.metadata
def _select_fold(self, video_list, annotations_dir, fold, train):
def _select_fold(self, video_list: List[str], annotations_dir: str, fold: int, train: bool) -> List[int]:
target_tag = self.TRAIN_TAG if train else self.TEST_TAG
split_pattern_name = "*test_split{}.txt".format(fold)
split_pattern_path = os.path.join(annotations_dir, split_pattern_name)
annotation_paths = glob.glob(split_pattern_path)
selected_files = []
selected_files = set()
for filepath in annotation_paths:
with open(filepath) as fid:
lines = fid.readlines()
......@@ -108,8 +123,7 @@ class HMDB51(VisionDataset):
video_filename, tag_string = line.split()
tag = int(tag_string)
if tag == target_tag:
selected_files.append(video_filename)
selected_files = set(selected_files)
selected_files.add(video_filename)
indices = []
for video_index, video_path in enumerate(video_list):
......@@ -118,10 +132,10 @@ class HMDB51(VisionDataset):
return indices
def __len__(self):
def __len__(self) -> int:
return self.video_clips.num_clips()
def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]:
video, audio, _, video_idx = self.video_clips.get_clip(idx)
sample_index = self.indices[video_idx]
_, class_index = self.samples[sample_index]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment