Unverified Commit 1a6148d4 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for torchvision.datasets.svhn (#2539)

parent 7c1ed419
...@@ -3,6 +3,7 @@ from PIL import Image ...@@ -3,6 +3,7 @@ from PIL import Image
import os import os
import os.path import os.path
import numpy as np import numpy as np
from typing import Any, Callable, Optional, Tuple
from .utils import download_url, check_integrity, verify_str_arg from .utils import download_url, check_integrity, verify_str_arg
...@@ -39,8 +40,14 @@ class SVHN(VisionDataset): ...@@ -39,8 +40,14 @@ class SVHN(VisionDataset):
'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
"extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]} "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}
def __init__(self, root, split='train', transform=None, target_transform=None, def __init__(
download=False): self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(SVHN, self).__init__(root, transform=transform, super(SVHN, self).__init__(root, transform=transform,
target_transform=target_transform) target_transform=target_transform)
self.split = verify_str_arg(split, "split", tuple(self.split_list.keys())) self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
...@@ -75,7 +82,7 @@ class SVHN(VisionDataset): ...@@ -75,7 +82,7 @@ class SVHN(VisionDataset):
np.place(self.labels, self.labels == 10, 0) np.place(self.labels, self.labels == 10, 0)
self.data = np.transpose(self.data, (3, 2, 0, 1)) self.data = np.transpose(self.data, (3, 2, 0, 1))
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
Args: Args:
index (int): Index index (int): Index
...@@ -97,18 +104,18 @@ class SVHN(VisionDataset): ...@@ -97,18 +104,18 @@ class SVHN(VisionDataset):
return img, target return img, target
def __len__(self): def __len__(self) -> int:
return len(self.data) return len(self.data)
def _check_integrity(self): def _check_integrity(self) -> bool:
root = self.root root = self.root
md5 = self.split_list[self.split][2] md5 = self.split_list[self.split][2]
fpath = os.path.join(root, self.filename) fpath = os.path.join(root, self.filename)
return check_integrity(fpath, md5) return check_integrity(fpath, md5)
def download(self): def download(self) -> None:
md5 = self.split_list[self.split][2] md5 = self.split_list[self.split][2]
download_url(self.url, self.root, self.filename, md5) download_url(self.url, self.root, self.filename, md5)
def extra_repr(self): def extra_repr(self) -> str:
return "Split: {split}".format(**self.__dict__) return "Split: {split}".format(**self.__dict__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment