"docs/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "6c053a6cea38ad323047d92da99d38e740f19845"
Unverified Commit de96b977 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added missing typing annotations in datasets/hmdb51 (#4169)



* style: Added missing typing annotations

* style: Fixed last missing typing annotation

* chore: Fixed missing import

* style: Fixed typing

* refactor: Switched file selection from list to set
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent b29ed34f
import glob import glob
import os import os
from typing import Optional, Callable, Tuple, Dict, Any, List
from torch import Tensor
from .folder import find_classes, make_dataset from .folder import find_classes, make_dataset
from .video_utils import VideoClips from .video_utils import VideoClips
...@@ -52,10 +54,23 @@ class HMDB51(VisionDataset): ...@@ -52,10 +54,23 @@ class HMDB51(VisionDataset):
TRAIN_TAG = 1 TRAIN_TAG = 1
TEST_TAG = 2 TEST_TAG = 2
def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1, def __init__(
frame_rate=None, fold=1, train=True, transform=None, self,
_precomputed_metadata=None, num_workers=1, _video_width=0, root: str,
_video_height=0, _video_min_dimension=0, _audio_samples=0): annotation_path: str,
frames_per_clip: int,
step_between_clips: int = 1,
frame_rate: Optional[int] = None,
fold: int = 1,
train: bool = True,
transform: Optional[Callable] = None,
_precomputed_metadata: Optional[Dict[str, Any]] = None,
num_workers: int = 1,
_video_width: int = 0,
_video_height: int = 0,
_video_min_dimension: int = 0,
_audio_samples: int = 0,
) -> None:
super(HMDB51, self).__init__(root) super(HMDB51, self).__init__(root)
if fold not in (1, 2, 3): if fold not in (1, 2, 3):
raise ValueError("fold should be between 1 and 3, got {}".format(fold)) raise ValueError("fold should be between 1 and 3, got {}".format(fold))
...@@ -92,15 +107,15 @@ class HMDB51(VisionDataset): ...@@ -92,15 +107,15 @@ class HMDB51(VisionDataset):
self.transform = transform self.transform = transform
@property @property
def metadata(self): def metadata(self) -> Dict[str, Any]:
return self.full_video_clips.metadata return self.full_video_clips.metadata
def _select_fold(self, video_list, annotations_dir, fold, train): def _select_fold(self, video_list: List[str], annotations_dir: str, fold: int, train: bool) -> List[int]:
target_tag = self.TRAIN_TAG if train else self.TEST_TAG target_tag = self.TRAIN_TAG if train else self.TEST_TAG
split_pattern_name = "*test_split{}.txt".format(fold) split_pattern_name = "*test_split{}.txt".format(fold)
split_pattern_path = os.path.join(annotations_dir, split_pattern_name) split_pattern_path = os.path.join(annotations_dir, split_pattern_name)
annotation_paths = glob.glob(split_pattern_path) annotation_paths = glob.glob(split_pattern_path)
selected_files = [] selected_files = set()
for filepath in annotation_paths: for filepath in annotation_paths:
with open(filepath) as fid: with open(filepath) as fid:
lines = fid.readlines() lines = fid.readlines()
...@@ -108,8 +123,7 @@ class HMDB51(VisionDataset): ...@@ -108,8 +123,7 @@ class HMDB51(VisionDataset):
video_filename, tag_string = line.split() video_filename, tag_string = line.split()
tag = int(tag_string) tag = int(tag_string)
if tag == target_tag: if tag == target_tag:
selected_files.append(video_filename) selected_files.add(video_filename)
selected_files = set(selected_files)
indices = [] indices = []
for video_index, video_path in enumerate(video_list): for video_index, video_path in enumerate(video_list):
...@@ -118,10 +132,10 @@ class HMDB51(VisionDataset): ...@@ -118,10 +132,10 @@ class HMDB51(VisionDataset):
return indices return indices
def __len__(self): def __len__(self) -> int:
return self.video_clips.num_clips() return self.video_clips.num_clips()
def __getitem__(self, idx): def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]:
video, audio, _, video_idx = self.video_clips.get_clip(idx) video, audio, _, video_idx = self.video_clips.get_clip(idx)
sample_index = self.indices[video_idx] sample_index = self.indices[video_idx]
_, class_index = self.samples[sample_index] _, class_index = self.samples[sample_index]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment