Unverified Commit 3f70e3c4 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for torchvision.datasets.mnist (#2532)

parent 203a7841
......@@ -7,6 +7,7 @@ import numpy as np
import torch
import codecs
import string
from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union
from .utils import download_url, download_and_extract_archive, extract_archive, \
verify_str_arg
......@@ -60,8 +61,14 @@ class MNIST(VisionDataset):
warnings.warn("test_data has been renamed data")
return self.data
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False):
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(MNIST, self).__init__(root, transform=transform,
target_transform=target_transform)
self.train = train # training set or test set
......@@ -79,7 +86,7 @@ class MNIST(VisionDataset):
data_file = self.test_file
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
......@@ -101,28 +108,28 @@ class MNIST(VisionDataset):
return img, target
def __len__(self):
def __len__(self) -> int:
return len(self.data)
@property
def raw_folder(self):
def raw_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, 'raw')
@property
def processed_folder(self):
def processed_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, 'processed')
@property
def class_to_idx(self):
def class_to_idx(self) -> Dict[str, int]:
return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self):
def _check_exists(self) -> bool:
return (os.path.exists(os.path.join(self.processed_folder,
self.training_file)) and
os.path.exists(os.path.join(self.processed_folder,
self.test_file)))
def download(self):
def download(self) -> None:
"""Download the MNIST data if it doesn't exist in processed_folder already."""
if self._check_exists():
......@@ -154,7 +161,7 @@ class MNIST(VisionDataset):
print('Done!')
def extra_repr(self):
def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test")
......@@ -251,7 +258,7 @@ class EMNIST(MNIST):
'mnist': list(string.digits),
}
def __init__(self, root, split, **kwargs):
def __init__(self, root: str, split: str, **kwargs: Any) -> None:
self.split = verify_str_arg(split, "split", self.splits)
self.training_file = self._training_file(split)
self.test_file = self._test_file(split)
......@@ -259,14 +266,14 @@ class EMNIST(MNIST):
self.classes = self.classes_split_dict[self.split]
@staticmethod
def _training_file(split):
def _training_file(split) -> str:
return 'training_{}.pt'.format(split)
@staticmethod
def _test_file(split):
def _test_file(split) -> str:
return 'test_{}.pt'.format(split)
def download(self):
def download(self) -> None:
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
import shutil
......@@ -343,7 +350,7 @@ class QMNIST(MNIST):
'test50k': 'test',
'nist': 'nist'
}
resources = { # type: ignore[assignment]
resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment]
'train': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz',
'ed72d4157d28c017586c42bc6afe6370'),
('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz',
......@@ -360,7 +367,10 @@ class QMNIST(MNIST):
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
def __init__(self, root, what=None, compat=True, train=True, **kwargs):
def __init__(
self, root: str, what: Optional[str] = None, compat: bool = True,
train: bool = True, **kwargs: Any
) -> None:
if what is None:
what = 'train' if train else 'test'
self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
......@@ -370,7 +380,7 @@ class QMNIST(MNIST):
self.test_file = self.data_file
super(QMNIST, self).__init__(root, train, **kwargs)
def download(self):
def download(self) -> None:
"""Download the QMNIST data if it doesn't exist in processed_folder already.
Note that we only download what has been asked for (argument 'what').
"""
......@@ -405,7 +415,7 @@ class QMNIST(MNIST):
with open(os.path.join(self.processed_folder, self.data_file), 'wb') as f:
torch.save((data, targets), f)
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
# redefined to handle the compat flag
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img.numpy(), mode='L')
......@@ -417,15 +427,15 @@ class QMNIST(MNIST):
target = self.target_transform(target)
return img, target
def extra_repr(self):
def extra_repr(self) -> str:
return "Split: {}".format(self.what)
def get_int(b):
def get_int(b: bytes) -> int:
return int(codecs.encode(b, 'hex'), 16)
def open_maybe_compressed_file(path):
def open_maybe_compressed_file(path: Union[str, IO]) -> IO:
"""Return a file object that possibly decompresses 'path' on the fly.
Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'.
"""
......@@ -440,19 +450,20 @@ def open_maybe_compressed_file(path):
return open(path, 'rb')
def read_sn3_pascalvincent_tensor(path, strict=True):
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
Argument may be a filename, compressed filename, or file object.
"""
# typemap
if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'):
read_sn3_pascalvincent_tensor.typemap = {
SN3_PASCALVINCENT_TYPEMAP = {
8: (torch.uint8, np.uint8, np.uint8),
9: (torch.int8, np.int8, np.int8),
11: (torch.int16, np.dtype('>i2'), 'i2'),
12: (torch.int32, np.dtype('>i4'), 'i4'),
13: (torch.float32, np.dtype('>f4'), 'f4'),
14: (torch.float64, np.dtype('>f8'), 'f8')}
14: (torch.float64, np.dtype('>f8'), 'f8')
}
def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) -> torch.Tensor:
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
Argument may be a filename, compressed filename, or file object.
"""
# read
with open_maybe_compressed_file(path) as f:
data = f.read()
......@@ -462,14 +473,14 @@ def read_sn3_pascalvincent_tensor(path, strict=True):
ty = magic // 256
assert nd >= 1 and nd <= 3
assert ty >= 8 and ty <= 14
m = read_sn3_pascalvincent_tensor.typemap[ty]
m = SN3_PASCALVINCENT_TYPEMAP[ty]
s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)]
parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
assert parsed.shape[0] == np.prod(s) or not strict
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
def read_label_file(path):
def read_label_file(path: str) -> torch.Tensor:
with open(path, 'rb') as f:
x = read_sn3_pascalvincent_tensor(f, strict=False)
assert(x.dtype == torch.uint8)
......@@ -477,7 +488,7 @@ def read_label_file(path):
return x.long()
def read_image_file(path):
def read_image_file(path: str) -> torch.Tensor:
with open(path, 'rb') as f:
x = read_sn3_pascalvincent_tensor(f, strict=False)
assert(x.dtype == torch.uint8)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment