Unverified Commit 49ec4a16 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for torchvision.datasets.semeion (#2534)

parent 262d6177
...@@ -2,6 +2,7 @@ from PIL import Image ...@@ -2,6 +2,7 @@ from PIL import Image
import os import os
import os.path import os.path
import numpy as np import numpy as np
from typing import Any, Callable, Optional, Tuple
from .vision import VisionDataset from .vision import VisionDataset
from .utils import download_url, check_integrity from .utils import download_url, check_integrity
...@@ -23,7 +24,13 @@ class SEMEION(VisionDataset): ...@@ -23,7 +24,13 @@ class SEMEION(VisionDataset):
filename = "semeion.data" filename = "semeion.data"
md5_checksum = 'cb545d371d2ce14ec121470795a77432' md5_checksum = 'cb545d371d2ce14ec121470795a77432'
def __init__(self, root, transform=None, target_transform=None, download=True): def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
) -> None:
super(SEMEION, self).__init__(root, transform=transform, super(SEMEION, self).__init__(root, transform=transform,
target_transform=target_transform) target_transform=target_transform)
...@@ -44,7 +51,7 @@ class SEMEION(VisionDataset): ...@@ -44,7 +51,7 @@ class SEMEION(VisionDataset):
self.data = np.reshape(self.data, (-1, 16, 16)) self.data = np.reshape(self.data, (-1, 16, 16))
self.labels = np.nonzero(data[:, 256:])[1] self.labels = np.nonzero(data[:, 256:])[1]
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
Args: Args:
index (int): Index index (int): Index
...@@ -65,17 +72,17 @@ class SEMEION(VisionDataset): ...@@ -65,17 +72,17 @@ class SEMEION(VisionDataset):
return img, target return img, target
def __len__(self): def __len__(self) -> int:
return len(self.data) return len(self.data)
def _check_integrity(self): def _check_integrity(self) -> bool:
root = self.root root = self.root
fpath = os.path.join(root, self.filename) fpath = os.path.join(root, self.filename)
if not check_integrity(fpath, self.md5_checksum): if not check_integrity(fpath, self.md5_checksum):
return False return False
return True return True
def download(self): def download(self) -> None:
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print('Files already downloaded and verified')
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment