Unverified Commit 52b80c48 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

style: Added typing to datasets/lfw (#6844)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent e0068d8e
import os
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from PIL import Image
......@@ -38,7 +38,7 @@ class _LFW(VisionDataset):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
):
) -> None:
super().__init__(os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform)
self.image_set = verify_str_arg(image_set.lower(), "image_set", self.file_dict.keys())
......@@ -62,7 +62,7 @@ class _LFW(VisionDataset):
img = Image.open(f)
return img.convert("RGB")
def _check_integrity(self):
def _check_integrity(self) -> bool:
st1 = check_integrity(os.path.join(self.root, self.filename), self.md5)
st2 = check_integrity(os.path.join(self.root, self.labels_file), self.checksums[self.labels_file])
if not st1 or not st2:
......@@ -71,7 +71,7 @@ class _LFW(VisionDataset):
return check_integrity(os.path.join(self.root, self.names), self.checksums[self.names])
return True
def download(self):
def download(self) -> None:
if self._check_integrity():
print("Files already downloaded and verified")
return
......@@ -81,13 +81,13 @@ class _LFW(VisionDataset):
if self.view == "people":
download_url(f"{self.download_url_prefix}{self.names}", self.root)
def _get_path(self, identity, no):
def _get_path(self, identity: str, no: Union[int, str]) -> str:
return os.path.join(self.images_dir, identity, f"{identity}_{int(no):04d}.jpg")
def extra_repr(self) -> str:
return f"Alignment: {self.image_set}\nSplit: {self.split}"
def __len__(self):
def __len__(self) -> int:
return len(self.data)
......@@ -119,13 +119,13 @@ class LFWPeople(_LFW):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
):
) -> None:
super().__init__(root, split, image_set, "people", transform, target_transform, download)
self.class_to_idx = self._get_classes()
self.data, self.targets = self._get_people()
def _get_people(self):
def _get_people(self) -> Tuple[List[str], List[int]]:
data, targets = [], []
with open(os.path.join(self.root, self.labels_file)) as f:
lines = f.readlines()
......@@ -143,7 +143,7 @@ class LFWPeople(_LFW):
return data, targets
def _get_classes(self):
def _get_classes(self) -> Dict[str, int]:
with open(os.path.join(self.root, self.names)) as f:
lines = f.readlines()
names = [line.strip().split()[0] for line in lines]
......@@ -201,12 +201,12 @@ class LFWPairs(_LFW):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
):
) -> None:
super().__init__(root, split, image_set, "pairs", transform, target_transform, download)
self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir)
def _get_pairs(self, images_dir):
def _get_pairs(self, images_dir: str) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]], List[int]]:
pair_names, data, targets = [], [], []
with open(os.path.join(self.root, self.labels_file)) as f:
lines = f.readlines()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment