Unverified Commit 642ad750 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added missing typing annotations in datasets/ucf101 (#4171)

parent f96a8a00
import os
from typing import Any, Dict, List, Tuple, Optional, Callable
from torch import Tensor
from .folder import find_classes, make_dataset
from .video_utils import VideoClips
......@@ -42,10 +44,23 @@ class UCF101(VisionDataset):
- label (int): class of the video clip
"""
def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
frame_rate=None, fold=1, train=True, transform=None,
_precomputed_metadata=None, num_workers=1, _video_width=0,
_video_height=0, _video_min_dimension=0, _audio_samples=0):
def __init__(
self,
root: str,
annotation_path: str,
frames_per_clip: int,
step_between_clips: int = 1,
frame_rate: Optional[int] = None,
fold: int = 1,
train: bool = True,
transform: Optional[Callable] = None,
_precomputed_metadata: Optional[Dict[str, Any]] = None,
num_workers: int = 1,
_video_width: int = 0,
_video_height: int = 0,
_video_min_dimension: int = 0,
_audio_samples: int = 0
) -> None:
super(UCF101, self).__init__(root)
if not 1 <= fold <= 3:
raise ValueError("fold should be between 1 and 3, got {}".format(fold))
......@@ -78,27 +93,26 @@ class UCF101(VisionDataset):
self.transform = transform
@property
def metadata(self):
def metadata(self) -> Dict[str, Any]:
return self.full_video_clips.metadata
def _select_fold(self, video_list, annotation_path, fold, train):
def _select_fold(self, video_list: List[str], annotation_path: str, fold: int, train: bool) -> List[int]:
name = "train" if train else "test"
name = "{}list{:02d}.txt".format(name, fold)
f = os.path.join(annotation_path, name)
selected_files = []
selected_files = set()
with open(f, "r") as fid:
data = fid.readlines()
data = [x.strip().split(" ") for x in data]
data = [os.path.join(self.root, x[0]) for x in data]
selected_files.extend(data)
selected_files = set(selected_files)
data = [x.strip().split(" ")[0] for x in data]
data = [os.path.join(self.root, x) for x in data]
selected_files.update(data)
indices = [i for i in range(len(video_list)) if video_list[i] in selected_files]
return indices
def __len__(self):
def __len__(self) -> int:
return self.video_clips.num_clips()
def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]:
video, audio, info, video_idx = self.video_clips.get_clip(idx)
label = self.samples[self.indices[video_idx]][1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment