Unverified Commit bb2805a6 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Remove caching from MNIST and variants (#3420)

* remove caching from (Fashion|K)?MNIST

* remove unnecessary lazy import

* remove false check of binaries against the md5 of archives

* remove caching from EMNIST

* remove caching from QMNIST

* lint

* fix EMNIST

* streamline QMNIST download
parent 9e474c3c
...@@ -120,7 +120,8 @@ class Tester(DatasetTestcase): ...@@ -120,7 +120,8 @@ class Tester(DatasetTestcase):
) )
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive') @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
def test_mnist(self, mock_download_extract): @mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
def test_mnist(self, mock_download_extract, mock_check_integrity):
num_examples = 30 num_examples = 30
with mnist_root(num_examples, "MNIST") as root: with mnist_root(num_examples, "MNIST") as root:
dataset = torchvision.datasets.MNIST(root, download=True) dataset = torchvision.datasets.MNIST(root, download=True)
...@@ -129,7 +130,8 @@ class Tester(DatasetTestcase): ...@@ -129,7 +130,8 @@ class Tester(DatasetTestcase):
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive') @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
def test_kmnist(self, mock_download_extract): @mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
def test_kmnist(self, mock_download_extract, mock_check_integrity):
num_examples = 30 num_examples = 30
with mnist_root(num_examples, "KMNIST") as root: with mnist_root(num_examples, "KMNIST") as root:
dataset = torchvision.datasets.KMNIST(root, download=True) dataset = torchvision.datasets.KMNIST(root, download=True)
...@@ -138,7 +140,8 @@ class Tester(DatasetTestcase): ...@@ -138,7 +140,8 @@ class Tester(DatasetTestcase):
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive') @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
def test_fashionmnist(self, mock_download_extract): @mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
def test_fashionmnist(self, mock_download_extract, mock_check_integrity):
num_examples = 30 num_examples = 30
with mnist_root(num_examples, "FashionMNIST") as root: with mnist_root(num_examples, "FashionMNIST") as root:
dataset = torchvision.datasets.FashionMNIST(root, download=True) dataset = torchvision.datasets.FashionMNIST(root, download=True)
......
...@@ -7,12 +7,10 @@ import numpy as np ...@@ -7,12 +7,10 @@ import numpy as np
import torch import torch
import codecs import codecs
import string import string
import gzip from typing import Any, Callable, Dict, List, Optional, Tuple
import lzma
from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union
from urllib.error import URLError from urllib.error import URLError
from .utils import download_url, download_and_extract_archive, extract_archive, \ from .utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity
verify_str_arg import shutil
class MNIST(VisionDataset): class MNIST(VisionDataset):
...@@ -81,6 +79,10 @@ class MNIST(VisionDataset): ...@@ -81,6 +79,10 @@ class MNIST(VisionDataset):
target_transform=target_transform) target_transform=target_transform)
self.train = train # training set or test set self.train = train # training set or test set
if self._check_legacy_exist():
self.data, self.targets = self._load_legacy_data()
return
if download: if download:
self.download() self.download()
...@@ -88,11 +90,31 @@ class MNIST(VisionDataset): ...@@ -88,11 +90,31 @@ class MNIST(VisionDataset):
raise RuntimeError('Dataset not found.' + raise RuntimeError('Dataset not found.' +
' You can use download=True to download it') ' You can use download=True to download it')
if self.train: self.data, self.targets = self._load_data()
data_file = self.training_file
else: def _check_legacy_exist(self):
data_file = self.test_file processed_folder_exists = os.path.exists(self.processed_folder)
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) if not processed_folder_exists:
return False
return all(
check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
)
def _load_legacy_data(self):
# This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
# directly.
data_file = self.training_file if self.train else self.test_file
return torch.load(os.path.join(self.processed_folder, data_file))
def _load_data(self):
image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
data = read_image_file(os.path.join(self.raw_folder, image_file))
label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
targets = read_label_file(os.path.join(self.raw_folder, label_file))
return data, targets
def __getitem__(self, index: int) -> Tuple[Any, Any]: def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
...@@ -132,19 +154,18 @@ class MNIST(VisionDataset): ...@@ -132,19 +154,18 @@ class MNIST(VisionDataset):
return {_class: i for i, _class in enumerate(self.classes)} return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self) -> bool: def _check_exists(self) -> bool:
return (os.path.exists(os.path.join(self.processed_folder, return all(
self.training_file)) and check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
os.path.exists(os.path.join(self.processed_folder, for url, _ in self.resources
self.test_file))) )
def download(self) -> None: def download(self) -> None:
"""Download the MNIST data if it doesn't exist in processed_folder already.""" """Download the MNIST data if it doesn't exist already."""
if self._check_exists(): if self._check_exists():
return return
os.makedirs(self.raw_folder, exist_ok=True) os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True)
# download files # download files
for filename, md5 in self.resources: for filename, md5 in self.resources:
...@@ -168,24 +189,6 @@ class MNIST(VisionDataset): ...@@ -168,24 +189,6 @@ class MNIST(VisionDataset):
else: else:
raise RuntimeError("Error downloading {}".format(filename)) raise RuntimeError("Error downloading {}".format(filename))
# process and save as torch files
print('Processing...')
training_set = (
read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
)
test_set = (
read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
)
with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
torch.save(training_set, f)
with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
torch.save(test_set, f)
print('Done!')
def extra_repr(self) -> str: def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test") return "Split: {}".format("Train" if self.train is True else "Test")
...@@ -298,44 +301,39 @@ class EMNIST(MNIST): ...@@ -298,44 +301,39 @@ class EMNIST(MNIST):
def _test_file(split) -> str: def _test_file(split) -> str:
return 'test_{}.pt'.format(split) return 'test_{}.pt'.format(split)
@property
def _file_prefix(self) -> str:
return f"emnist-{self.split}-{'train' if self.train else 'test'}"
@property
def images_file(self) -> str:
return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte")
@property
def labels_file(self) -> str:
return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte")
def _load_data(self):
return read_image_file(self.images_file), read_label_file(self.labels_file)
def _check_exists(self) -> bool:
return all(check_integrity(file) for file in (self.images_file, self.labels_file))
def download(self) -> None: def download(self) -> None:
"""Download the EMNIST data if it doesn't exist in processed_folder already.""" """Download the EMNIST data if it doesn't exist already."""
import shutil
if self._check_exists(): if self._check_exists():
return return
os.makedirs(self.raw_folder, exist_ok=True) os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True)
# download files download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5)
print('Downloading and extracting zip archive')
download_and_extract_archive(self.url, download_root=self.raw_folder, filename="emnist.zip",
remove_finished=True, md5=self.md5)
gzip_folder = os.path.join(self.raw_folder, 'gzip') gzip_folder = os.path.join(self.raw_folder, 'gzip')
for gzip_file in os.listdir(gzip_folder): for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith('.gz'): if gzip_file.endswith('.gz'):
extract_archive(os.path.join(gzip_folder, gzip_file), gzip_folder) extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder)
# process and save as torch files
for split in self.splits:
print('Processing ' + split)
training_set = (
read_image_file(os.path.join(gzip_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
read_label_file(os.path.join(gzip_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
)
test_set = (
read_image_file(os.path.join(gzip_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
read_label_file(os.path.join(gzip_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
)
with open(os.path.join(self.processed_folder, self._training_file(split)), 'wb') as f:
torch.save(training_set, f)
with open(os.path.join(self.processed_folder, self._test_file(split)), 'wb') as f:
torch.save(test_set, f)
shutil.rmtree(gzip_folder) shutil.rmtree(gzip_folder)
print('Done!')
class QMNIST(MNIST): class QMNIST(MNIST):
"""`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset. """`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.
...@@ -404,40 +402,51 @@ class QMNIST(MNIST): ...@@ -404,40 +402,51 @@ class QMNIST(MNIST):
self.test_file = self.data_file self.test_file = self.data_file
super(QMNIST, self).__init__(root, train, **kwargs) super(QMNIST, self).__init__(root, train, **kwargs)
@property
def images_file(self) -> str:
(url, _), _ = self.resources[self.subsets[self.what]]
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
@property
def labels_file(self) -> str:
_, (url, _) = self.resources[self.subsets[self.what]]
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
def _check_exists(self) -> bool:
return all(check_integrity(file) for file in (self.images_file, self.labels_file))
def _load_data(self):
data = read_sn3_pascalvincent_tensor(self.images_file)
assert (data.dtype == torch.uint8)
assert (data.ndimension() == 3)
targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
assert (targets.ndimension() == 2)
if self.what == 'test10k':
data = data[0:10000, :, :].clone()
targets = targets[0:10000, :].clone()
elif self.what == 'test50k':
data = data[10000:, :, :].clone()
targets = targets[10000:, :].clone()
return data, targets
def download(self) -> None: def download(self) -> None:
"""Download the QMNIST data if it doesn't exist in processed_folder already. """Download the QMNIST data if it doesn't exist already.
Note that we only download what has been asked for (argument 'what'). Note that we only download what has been asked for (argument 'what').
""" """
if self._check_exists(): if self._check_exists():
return return
os.makedirs(self.raw_folder, exist_ok=True) os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True)
split = self.resources[self.subsets[self.what]] split = self.resources[self.subsets[self.what]]
files = []
# download data files if not already there
for url, md5 in split: for url, md5 in split:
filename = url.rpartition('/')[2] filename = url.rpartition('/')[2]
file_path = os.path.join(self.raw_folder, filename) file_path = os.path.join(self.raw_folder, filename)
if not os.path.isfile(file_path): if not os.path.isfile(file_path):
download_url(url, root=self.raw_folder, filename=filename, md5=md5) download_and_extract_archive(url, self.raw_folder, filename=filename, md5=md5)
files.append(file_path)
# process and save as torch files
print('Processing...')
data = read_sn3_pascalvincent_tensor(files[0])
assert(data.dtype == torch.uint8)
assert(data.ndimension() == 3)
targets = read_sn3_pascalvincent_tensor(files[1]).long()
assert(targets.ndimension() == 2)
if self.what == 'test10k':
data = data[0:10000, :, :].clone()
targets = targets[0:10000, :].clone()
if self.what == 'test50k':
data = data[10000:, :, :].clone()
targets = targets[10000:, :].clone()
with open(os.path.join(self.processed_folder, self.data_file), 'wb') as f:
torch.save((data, targets), f)
def __getitem__(self, index: int) -> Tuple[Any, Any]: def __getitem__(self, index: int) -> Tuple[Any, Any]:
# redefined to handle the compat flag # redefined to handle the compat flag
...@@ -459,19 +468,6 @@ def get_int(b: bytes) -> int: ...@@ -459,19 +468,6 @@ def get_int(b: bytes) -> int:
return int(codecs.encode(b, 'hex'), 16) return int(codecs.encode(b, 'hex'), 16)
def open_maybe_compressed_file(path: Union[str, IO]) -> Union[IO, gzip.GzipFile]:
"""Return a file object that possibly decompresses 'path' on the fly.
Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'.
"""
if not isinstance(path, torch._six.string_classes):
return path
if path.endswith('.gz'):
return gzip.open(path, 'rb')
if path.endswith('.xz'):
return lzma.open(path, 'rb')
return open(path, 'rb')
SN3_PASCALVINCENT_TYPEMAP = { SN3_PASCALVINCENT_TYPEMAP = {
8: (torch.uint8, np.uint8, np.uint8), 8: (torch.uint8, np.uint8, np.uint8),
9: (torch.int8, np.int8, np.int8), 9: (torch.int8, np.int8, np.int8),
...@@ -482,12 +478,12 @@ SN3_PASCALVINCENT_TYPEMAP = { ...@@ -482,12 +478,12 @@ SN3_PASCALVINCENT_TYPEMAP = {
} }
def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) -> torch.Tensor: def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
Argument may be a filename, compressed filename, or file object. Argument may be a filename, compressed filename, or file object.
""" """
# read # read
with open_maybe_compressed_file(path) as f: with open(path, "rb") as f:
data = f.read() data = f.read()
# parse # parse
magic = get_int(data[0:4]) magic = get_int(data[0:4])
...@@ -503,16 +499,14 @@ def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) -> ...@@ -503,16 +499,14 @@ def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) ->
def read_label_file(path: str) -> torch.Tensor: def read_label_file(path: str) -> torch.Tensor:
with open(path, 'rb') as f: x = read_sn3_pascalvincent_tensor(path, strict=False)
x = read_sn3_pascalvincent_tensor(f, strict=False)
assert(x.dtype == torch.uint8) assert(x.dtype == torch.uint8)
assert(x.ndimension() == 1) assert(x.ndimension() == 1)
return x.long() return x.long()
def read_image_file(path: str) -> torch.Tensor: def read_image_file(path: str) -> torch.Tensor:
with open(path, 'rb') as f: x = read_sn3_pascalvincent_tensor(path, strict=False)
x = read_sn3_pascalvincent_tensor(f, strict=False)
assert(x.dtype == torch.uint8) assert(x.dtype == torch.uint8)
assert(x.ndimension() == 3) assert(x.ndimension() == 3)
return x return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment