Unverified Commit 1a6148d4 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for torchvision.datasets.svhn (#2539)

parent 7c1ed419
......@@ -3,6 +3,7 @@ from PIL import Image
import os
import os.path
import numpy as np
from typing import Any, Callable, Optional, Tuple
from .utils import download_url, check_integrity, verify_str_arg
......@@ -39,8 +40,14 @@ class SVHN(VisionDataset):
'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
"extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}
def __init__(self, root, split='train', transform=None, target_transform=None,
download=False):
def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(SVHN, self).__init__(root, transform=transform,
target_transform=target_transform)
self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
......@@ -75,7 +82,7 @@ class SVHN(VisionDataset):
np.place(self.labels, self.labels == 10, 0)
self.data = np.transpose(self.data, (3, 2, 0, 1))
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
......@@ -97,18 +104,18 @@ class SVHN(VisionDataset):
return img, target
def __len__(self):
def __len__(self) -> int:
return len(self.data)
def _check_integrity(self):
def _check_integrity(self) -> bool:
root = self.root
md5 = self.split_list[self.split][2]
fpath = os.path.join(root, self.filename)
return check_integrity(fpath, md5)
def download(self):
def download(self) -> None:
md5 = self.split_list[self.split][2]
download_url(self.url, self.root, self.filename, md5)
def extra_repr(self):
def extra_repr(self) -> str:
return "Split: {split}".format(**self.__dict__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment