Unverified Commit 3a159dfe authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for torchvision.datasets.stl10 (#2540)

* add typehints for torchvision.datasets.stl10

* move annotation from class to instance scope
parent 7fc47eaa
...@@ -2,6 +2,7 @@ from PIL import Image ...@@ -2,6 +2,7 @@ from PIL import Image
import os import os
import os.path import os.path
import numpy as np import numpy as np
from typing import Any, Callable, Optional, Tuple
from .vision import VisionDataset from .vision import VisionDataset
from .utils import check_integrity, download_and_extract_archive, verify_str_arg from .utils import check_integrity, download_and_extract_archive, verify_str_arg
...@@ -45,8 +46,15 @@ class STL10(VisionDataset): ...@@ -45,8 +46,15 @@ class STL10(VisionDataset):
] ]
splits = ('train', 'train+unlabeled', 'unlabeled', 'test') splits = ('train', 'train+unlabeled', 'unlabeled', 'test')
def __init__(self, root, split='train', folds=None, transform=None, def __init__(
target_transform=None, download=False): self,
root: str,
split: str = "train",
folds: Optional[int] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(STL10, self).__init__(root, transform=transform, super(STL10, self).__init__(root, transform=transform,
target_transform=target_transform) target_transform=target_transform)
self.split = verify_str_arg(split, "split", self.splits) self.split = verify_str_arg(split, "split", self.splits)
...@@ -60,6 +68,7 @@ class STL10(VisionDataset): ...@@ -60,6 +68,7 @@ class STL10(VisionDataset):
'You can use download=True to download it') 'You can use download=True to download it')
# now load the picked numpy arrays # now load the picked numpy arrays
self.labels: np.ndarray
if self.split == 'train': if self.split == 'train':
self.data, self.labels = self.__loadfile( self.data, self.labels = self.__loadfile(
self.train_list[0][0], self.train_list[1][0]) self.train_list[0][0], self.train_list[1][0])
...@@ -87,7 +96,7 @@ class STL10(VisionDataset): ...@@ -87,7 +96,7 @@ class STL10(VisionDataset):
with open(class_file) as f: with open(class_file) as f:
self.classes = f.read().splitlines() self.classes = f.read().splitlines()
def _verify_folds(self, folds): def _verify_folds(self, folds: Optional[int]) -> Optional[int]:
if folds is None: if folds is None:
return folds return folds
elif isinstance(folds, int): elif isinstance(folds, int):
...@@ -100,7 +109,7 @@ class STL10(VisionDataset): ...@@ -100,7 +109,7 @@ class STL10(VisionDataset):
msg = "Expected type None or int for argument folds, but got type {}." msg = "Expected type None or int for argument folds, but got type {}."
raise ValueError(msg.format(type(folds))) raise ValueError(msg.format(type(folds)))
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
Args: Args:
index (int): Index index (int): Index
...@@ -108,6 +117,7 @@ class STL10(VisionDataset): ...@@ -108,6 +117,7 @@ class STL10(VisionDataset):
Returns: Returns:
tuple: (image, target) where target is index of the target class. tuple: (image, target) where target is index of the target class.
""" """
target: Optional[int]
if self.labels is not None: if self.labels is not None:
img, target = self.data[index], int(self.labels[index]) img, target = self.data[index], int(self.labels[index])
else: else:
...@@ -125,10 +135,10 @@ class STL10(VisionDataset): ...@@ -125,10 +135,10 @@ class STL10(VisionDataset):
return img, target return img, target
def __len__(self): def __len__(self) -> int:
return self.data.shape[0] return self.data.shape[0]
def __loadfile(self, data_file, labels_file=None): def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
labels = None labels = None
if labels_file: if labels_file:
path_to_labels = os.path.join( path_to_labels = os.path.join(
...@@ -145,7 +155,7 @@ class STL10(VisionDataset): ...@@ -145,7 +155,7 @@ class STL10(VisionDataset):
return images, labels return images, labels
def _check_integrity(self): def _check_integrity(self) -> bool:
root = self.root root = self.root
for fentry in (self.train_list + self.test_list): for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1] filename, md5 = fentry[0], fentry[1]
...@@ -154,17 +164,17 @@ class STL10(VisionDataset): ...@@ -154,17 +164,17 @@ class STL10(VisionDataset):
return False return False
return True return True
def download(self): def download(self) -> None:
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print('Files already downloaded and verified')
return return
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
self._check_integrity() self._check_integrity()
def extra_repr(self): def extra_repr(self) -> str:
return "Split: {split}".format(**self.__dict__) return "Split: {split}".format(**self.__dict__)
def __load_folds(self, folds): def __load_folds(self, folds: Optional[int]) -> None:
# loads one of the folds if specified # loads one of the folds if specified
if folds is None: if folds is None:
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment