Unverified Commit 49ec4a16 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for torchvision.datasets.semeion (#2534)

parent 262d6177
......@@ -2,6 +2,7 @@ from PIL import Image
import os
import os.path
import numpy as np
from typing import Any, Callable, Optional, Tuple
from .vision import VisionDataset
from .utils import download_url, check_integrity
......@@ -23,7 +24,13 @@ class SEMEION(VisionDataset):
filename = "semeion.data"
md5_checksum = 'cb545d371d2ce14ec121470795a77432'
def __init__(self, root, transform=None, target_transform=None, download=True):
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
) -> None:
super(SEMEION, self).__init__(root, transform=transform,
target_transform=target_transform)
......@@ -44,7 +51,7 @@ class SEMEION(VisionDataset):
self.data = np.reshape(self.data, (-1, 16, 16))
self.labels = np.nonzero(data[:, 256:])[1]
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
......@@ -65,17 +72,17 @@ class SEMEION(VisionDataset):
return img, target
def __len__(self):
def __len__(self) -> int:
return len(self.data)
def _check_integrity(self):
def _check_integrity(self) -> bool:
root = self.root
fpath = os.path.join(root, self.filename)
if not check_integrity(fpath, self.md5_checksum):
return False
return True
def download(self):
def download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment