Unverified Commit 3f70e3c4 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for torchvision.datasets.mnist (#2532)

parent 203a7841
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
import torch import torch
import codecs import codecs
import string import string
from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union
from .utils import download_url, download_and_extract_archive, extract_archive, \ from .utils import download_url, download_and_extract_archive, extract_archive, \
verify_str_arg verify_str_arg
...@@ -60,8 +61,14 @@ class MNIST(VisionDataset): ...@@ -60,8 +61,14 @@ class MNIST(VisionDataset):
warnings.warn("test_data has been renamed data") warnings.warn("test_data has been renamed data")
return self.data return self.data
def __init__(self, root, train=True, transform=None, target_transform=None, def __init__(
download=False): self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(MNIST, self).__init__(root, transform=transform, super(MNIST, self).__init__(root, transform=transform,
target_transform=target_transform) target_transform=target_transform)
self.train = train # training set or test set self.train = train # training set or test set
...@@ -79,7 +86,7 @@ class MNIST(VisionDataset): ...@@ -79,7 +86,7 @@ class MNIST(VisionDataset):
data_file = self.test_file data_file = self.test_file
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
Args: Args:
index (int): Index index (int): Index
...@@ -101,28 +108,28 @@ class MNIST(VisionDataset): ...@@ -101,28 +108,28 @@ class MNIST(VisionDataset):
return img, target return img, target
def __len__(self): def __len__(self) -> int:
return len(self.data) return len(self.data)
@property @property
def raw_folder(self): def raw_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, 'raw') return os.path.join(self.root, self.__class__.__name__, 'raw')
@property @property
def processed_folder(self): def processed_folder(self) -> str:
return os.path.join(self.root, self.__class__.__name__, 'processed') return os.path.join(self.root, self.__class__.__name__, 'processed')
@property @property
def class_to_idx(self): def class_to_idx(self) -> Dict[str, int]:
return {_class: i for i, _class in enumerate(self.classes)} return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self): def _check_exists(self) -> bool:
return (os.path.exists(os.path.join(self.processed_folder, return (os.path.exists(os.path.join(self.processed_folder,
self.training_file)) and self.training_file)) and
os.path.exists(os.path.join(self.processed_folder, os.path.exists(os.path.join(self.processed_folder,
self.test_file))) self.test_file)))
def download(self): def download(self) -> None:
"""Download the MNIST data if it doesn't exist in processed_folder already.""" """Download the MNIST data if it doesn't exist in processed_folder already."""
if self._check_exists(): if self._check_exists():
...@@ -154,7 +161,7 @@ class MNIST(VisionDataset): ...@@ -154,7 +161,7 @@ class MNIST(VisionDataset):
print('Done!') print('Done!')
def extra_repr(self): def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test") return "Split: {}".format("Train" if self.train is True else "Test")
...@@ -251,7 +258,7 @@ class EMNIST(MNIST): ...@@ -251,7 +258,7 @@ class EMNIST(MNIST):
'mnist': list(string.digits), 'mnist': list(string.digits),
} }
def __init__(self, root, split, **kwargs): def __init__(self, root: str, split: str, **kwargs: Any) -> None:
self.split = verify_str_arg(split, "split", self.splits) self.split = verify_str_arg(split, "split", self.splits)
self.training_file = self._training_file(split) self.training_file = self._training_file(split)
self.test_file = self._test_file(split) self.test_file = self._test_file(split)
...@@ -259,14 +266,14 @@ class EMNIST(MNIST): ...@@ -259,14 +266,14 @@ class EMNIST(MNIST):
self.classes = self.classes_split_dict[self.split] self.classes = self.classes_split_dict[self.split]
@staticmethod @staticmethod
def _training_file(split): def _training_file(split) -> str:
return 'training_{}.pt'.format(split) return 'training_{}.pt'.format(split)
@staticmethod @staticmethod
def _test_file(split): def _test_file(split) -> str:
return 'test_{}.pt'.format(split) return 'test_{}.pt'.format(split)
def download(self): def download(self) -> None:
"""Download the EMNIST data if it doesn't exist in processed_folder already.""" """Download the EMNIST data if it doesn't exist in processed_folder already."""
import shutil import shutil
...@@ -343,7 +350,7 @@ class QMNIST(MNIST): ...@@ -343,7 +350,7 @@ class QMNIST(MNIST):
'test50k': 'test', 'test50k': 'test',
'nist': 'nist' 'nist': 'nist'
} }
resources = { # type: ignore[assignment] resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment]
'train': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz', 'train': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz',
'ed72d4157d28c017586c42bc6afe6370'), 'ed72d4157d28c017586c42bc6afe6370'),
('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz', ('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz',
...@@ -360,7 +367,10 @@ class QMNIST(MNIST): ...@@ -360,7 +367,10 @@ class QMNIST(MNIST):
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
def __init__(self, root, what=None, compat=True, train=True, **kwargs): def __init__(
self, root: str, what: Optional[str] = None, compat: bool = True,
train: bool = True, **kwargs: Any
) -> None:
if what is None: if what is None:
what = 'train' if train else 'test' what = 'train' if train else 'test'
self.what = verify_str_arg(what, "what", tuple(self.subsets.keys())) self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
...@@ -370,7 +380,7 @@ class QMNIST(MNIST): ...@@ -370,7 +380,7 @@ class QMNIST(MNIST):
self.test_file = self.data_file self.test_file = self.data_file
super(QMNIST, self).__init__(root, train, **kwargs) super(QMNIST, self).__init__(root, train, **kwargs)
def download(self): def download(self) -> None:
"""Download the QMNIST data if it doesn't exist in processed_folder already. """Download the QMNIST data if it doesn't exist in processed_folder already.
Note that we only download what has been asked for (argument 'what'). Note that we only download what has been asked for (argument 'what').
""" """
...@@ -405,7 +415,7 @@ class QMNIST(MNIST): ...@@ -405,7 +415,7 @@ class QMNIST(MNIST):
with open(os.path.join(self.processed_folder, self.data_file), 'wb') as f: with open(os.path.join(self.processed_folder, self.data_file), 'wb') as f:
torch.save((data, targets), f) torch.save((data, targets), f)
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
# redefined to handle the compat flag # redefined to handle the compat flag
img, target = self.data[index], self.targets[index] img, target = self.data[index], self.targets[index]
img = Image.fromarray(img.numpy(), mode='L') img = Image.fromarray(img.numpy(), mode='L')
...@@ -417,15 +427,15 @@ class QMNIST(MNIST): ...@@ -417,15 +427,15 @@ class QMNIST(MNIST):
target = self.target_transform(target) target = self.target_transform(target)
return img, target return img, target
def extra_repr(self): def extra_repr(self) -> str:
return "Split: {}".format(self.what) return "Split: {}".format(self.what)
def get_int(b): def get_int(b: bytes) -> int:
return int(codecs.encode(b, 'hex'), 16) return int(codecs.encode(b, 'hex'), 16)
def open_maybe_compressed_file(path): def open_maybe_compressed_file(path: Union[str, IO]) -> IO:
"""Return a file object that possibly decompresses 'path' on the fly. """Return a file object that possibly decompresses 'path' on the fly.
Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'. Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'.
""" """
...@@ -440,19 +450,20 @@ def open_maybe_compressed_file(path): ...@@ -440,19 +450,20 @@ def open_maybe_compressed_file(path):
return open(path, 'rb') return open(path, 'rb')
def read_sn3_pascalvincent_tensor(path, strict=True): SN3_PASCALVINCENT_TYPEMAP = {
8: (torch.uint8, np.uint8, np.uint8),
9: (torch.int8, np.int8, np.int8),
11: (torch.int16, np.dtype('>i2'), 'i2'),
12: (torch.int32, np.dtype('>i4'), 'i4'),
13: (torch.float32, np.dtype('>f4'), 'f4'),
14: (torch.float64, np.dtype('>f8'), 'f8')
}
def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) -> torch.Tensor:
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
Argument may be a filename, compressed filename, or file object. Argument may be a filename, compressed filename, or file object.
""" """
# typemap
if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'):
read_sn3_pascalvincent_tensor.typemap = {
8: (torch.uint8, np.uint8, np.uint8),
9: (torch.int8, np.int8, np.int8),
11: (torch.int16, np.dtype('>i2'), 'i2'),
12: (torch.int32, np.dtype('>i4'), 'i4'),
13: (torch.float32, np.dtype('>f4'), 'f4'),
14: (torch.float64, np.dtype('>f8'), 'f8')}
# read # read
with open_maybe_compressed_file(path) as f: with open_maybe_compressed_file(path) as f:
data = f.read() data = f.read()
...@@ -462,14 +473,14 @@ def read_sn3_pascalvincent_tensor(path, strict=True): ...@@ -462,14 +473,14 @@ def read_sn3_pascalvincent_tensor(path, strict=True):
ty = magic // 256 ty = magic // 256
assert nd >= 1 and nd <= 3 assert nd >= 1 and nd <= 3
assert ty >= 8 and ty <= 14 assert ty >= 8 and ty <= 14
m = read_sn3_pascalvincent_tensor.typemap[ty] m = SN3_PASCALVINCENT_TYPEMAP[ty]
s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)] s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)]
parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1))) parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
assert parsed.shape[0] == np.prod(s) or not strict assert parsed.shape[0] == np.prod(s) or not strict
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s) return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
def read_label_file(path): def read_label_file(path: str) -> torch.Tensor:
with open(path, 'rb') as f: with open(path, 'rb') as f:
x = read_sn3_pascalvincent_tensor(f, strict=False) x = read_sn3_pascalvincent_tensor(f, strict=False)
assert(x.dtype == torch.uint8) assert(x.dtype == torch.uint8)
...@@ -477,7 +488,7 @@ def read_label_file(path): ...@@ -477,7 +488,7 @@ def read_label_file(path):
return x.long() return x.long()
def read_image_file(path): def read_image_file(path: str) -> torch.Tensor:
with open(path, 'rb') as f: with open(path, 'rb') as f:
x = read_sn3_pascalvincent_tensor(f, strict=False) x = read_sn3_pascalvincent_tensor(f, strict=False)
assert(x.dtype == torch.uint8) assert(x.dtype == torch.uint8)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment