Unverified Commit 3a159dfe authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for torchvision.datasets.stl10 (#2540)

* add typehints for torchvision.datasets.stl10

* move annotation from class to instance scope
parent 7fc47eaa
......@@ -2,6 +2,7 @@ from PIL import Image
import os
import os.path
import numpy as np
from typing import Any, Callable, Optional, Tuple
from .vision import VisionDataset
from .utils import check_integrity, download_and_extract_archive, verify_str_arg
......@@ -45,8 +46,15 @@ class STL10(VisionDataset):
]
splits = ('train', 'train+unlabeled', 'unlabeled', 'test')
def __init__(self, root, split='train', folds=None, transform=None,
target_transform=None, download=False):
def __init__(
self,
root: str,
split: str = "train",
folds: Optional[int] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(STL10, self).__init__(root, transform=transform,
target_transform=target_transform)
self.split = verify_str_arg(split, "split", self.splits)
......@@ -60,6 +68,7 @@ class STL10(VisionDataset):
'You can use download=True to download it')
# now load the picked numpy arrays
self.labels: np.ndarray
if self.split == 'train':
self.data, self.labels = self.__loadfile(
self.train_list[0][0], self.train_list[1][0])
......@@ -87,7 +96,7 @@ class STL10(VisionDataset):
with open(class_file) as f:
self.classes = f.read().splitlines()
def _verify_folds(self, folds):
def _verify_folds(self, folds: Optional[int]) -> Optional[int]:
if folds is None:
return folds
elif isinstance(folds, int):
......@@ -100,7 +109,7 @@ class STL10(VisionDataset):
msg = "Expected type None or int for argument folds, but got type {}."
raise ValueError(msg.format(type(folds)))
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
......@@ -108,6 +117,7 @@ class STL10(VisionDataset):
Returns:
tuple: (image, target) where target is index of the target class.
"""
target: Optional[int]
if self.labels is not None:
img, target = self.data[index], int(self.labels[index])
else:
......@@ -125,10 +135,10 @@ class STL10(VisionDataset):
return img, target
def __len__(self):
def __len__(self) -> int:
return self.data.shape[0]
def __loadfile(self, data_file, labels_file=None):
def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
labels = None
if labels_file:
path_to_labels = os.path.join(
......@@ -145,7 +155,7 @@ class STL10(VisionDataset):
return images, labels
def _check_integrity(self):
def _check_integrity(self) -> bool:
root = self.root
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
......@@ -154,17 +164,17 @@ class STL10(VisionDataset):
return False
return True
def download(self):
def download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
return
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
self._check_integrity()
def extra_repr(self):
def extra_repr(self) -> str:
return "Split: {split}".format(**self.__dict__)
def __load_folds(self, folds):
def __load_folds(self, folds: Optional[int]) -> None:
# loads one of the folds if specified
if folds is None:
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment