Unverified Commit 52b80c48 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

style: Added typing to datasets/lfw (#6844)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent e0068d8e
import os import os
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from PIL import Image from PIL import Image
...@@ -38,7 +38,7 @@ class _LFW(VisionDataset): ...@@ -38,7 +38,7 @@ class _LFW(VisionDataset):
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
): ) -> None:
super().__init__(os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform) super().__init__(os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform)
self.image_set = verify_str_arg(image_set.lower(), "image_set", self.file_dict.keys()) self.image_set = verify_str_arg(image_set.lower(), "image_set", self.file_dict.keys())
...@@ -62,7 +62,7 @@ class _LFW(VisionDataset): ...@@ -62,7 +62,7 @@ class _LFW(VisionDataset):
img = Image.open(f) img = Image.open(f)
return img.convert("RGB") return img.convert("RGB")
def _check_integrity(self): def _check_integrity(self) -> bool:
st1 = check_integrity(os.path.join(self.root, self.filename), self.md5) st1 = check_integrity(os.path.join(self.root, self.filename), self.md5)
st2 = check_integrity(os.path.join(self.root, self.labels_file), self.checksums[self.labels_file]) st2 = check_integrity(os.path.join(self.root, self.labels_file), self.checksums[self.labels_file])
if not st1 or not st2: if not st1 or not st2:
...@@ -71,7 +71,7 @@ class _LFW(VisionDataset): ...@@ -71,7 +71,7 @@ class _LFW(VisionDataset):
return check_integrity(os.path.join(self.root, self.names), self.checksums[self.names]) return check_integrity(os.path.join(self.root, self.names), self.checksums[self.names])
return True return True
def download(self): def download(self) -> None:
if self._check_integrity(): if self._check_integrity():
print("Files already downloaded and verified") print("Files already downloaded and verified")
return return
...@@ -81,13 +81,13 @@ class _LFW(VisionDataset): ...@@ -81,13 +81,13 @@ class _LFW(VisionDataset):
if self.view == "people": if self.view == "people":
download_url(f"{self.download_url_prefix}{self.names}", self.root) download_url(f"{self.download_url_prefix}{self.names}", self.root)
def _get_path(self, identity, no): def _get_path(self, identity: str, no: Union[int, str]) -> str:
return os.path.join(self.images_dir, identity, f"{identity}_{int(no):04d}.jpg") return os.path.join(self.images_dir, identity, f"{identity}_{int(no):04d}.jpg")
def extra_repr(self) -> str: def extra_repr(self) -> str:
return f"Alignment: {self.image_set}\nSplit: {self.split}" return f"Alignment: {self.image_set}\nSplit: {self.split}"
def __len__(self): def __len__(self) -> int:
return len(self.data) return len(self.data)
...@@ -119,13 +119,13 @@ class LFWPeople(_LFW): ...@@ -119,13 +119,13 @@ class LFWPeople(_LFW):
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
): ) -> None:
super().__init__(root, split, image_set, "people", transform, target_transform, download) super().__init__(root, split, image_set, "people", transform, target_transform, download)
self.class_to_idx = self._get_classes() self.class_to_idx = self._get_classes()
self.data, self.targets = self._get_people() self.data, self.targets = self._get_people()
def _get_people(self): def _get_people(self) -> Tuple[List[str], List[int]]:
data, targets = [], [] data, targets = [], []
with open(os.path.join(self.root, self.labels_file)) as f: with open(os.path.join(self.root, self.labels_file)) as f:
lines = f.readlines() lines = f.readlines()
...@@ -143,7 +143,7 @@ class LFWPeople(_LFW): ...@@ -143,7 +143,7 @@ class LFWPeople(_LFW):
return data, targets return data, targets
def _get_classes(self): def _get_classes(self) -> Dict[str, int]:
with open(os.path.join(self.root, self.names)) as f: with open(os.path.join(self.root, self.names)) as f:
lines = f.readlines() lines = f.readlines()
names = [line.strip().split()[0] for line in lines] names = [line.strip().split()[0] for line in lines]
...@@ -201,12 +201,12 @@ class LFWPairs(_LFW): ...@@ -201,12 +201,12 @@ class LFWPairs(_LFW):
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
): ) -> None:
super().__init__(root, split, image_set, "pairs", transform, target_transform, download) super().__init__(root, split, image_set, "pairs", transform, target_transform, download)
self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir) self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir)
def _get_pairs(self, images_dir): def _get_pairs(self, images_dir: str) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]], List[int]]:
pair_names, data, targets = [], [], [] pair_names, data, targets = [], [], []
with open(os.path.join(self.root, self.labels_file)) as f: with open(os.path.join(self.root, self.labels_file)) as f:
lines = f.readlines() lines = f.readlines()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment