Unverified Commit bb2805a6 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Remove caching from MNIST and variants (#3420)

* remove caching from (Fashion|K)?MNIST

* remove unnecessary lazy import

* remove false check of binaries against the md5 of archives

* remove caching from EMNIST

* remove caching from QMNIST

* lint

* fix EMNIST

* streamline QMNIST download
parent 9e474c3c
......@@ -120,7 +120,8 @@ class Tester(DatasetTestcase):
)
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
def test_mnist(self, mock_download_extract):
@mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
def test_mnist(self, mock_download_extract, mock_check_integrity):
num_examples = 30
with mnist_root(num_examples, "MNIST") as root:
dataset = torchvision.datasets.MNIST(root, download=True)
......@@ -129,7 +130,8 @@ class Tester(DatasetTestcase):
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
def test_kmnist(self, mock_download_extract):
@mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
def test_kmnist(self, mock_download_extract, mock_check_integrity):
num_examples = 30
with mnist_root(num_examples, "KMNIST") as root:
dataset = torchvision.datasets.KMNIST(root, download=True)
......@@ -138,7 +140,8 @@ class Tester(DatasetTestcase):
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
def test_fashionmnist(self, mock_download_extract):
@mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
def test_fashionmnist(self, mock_download_extract, mock_check_integrity):
num_examples = 30
with mnist_root(num_examples, "FashionMNIST") as root:
dataset = torchvision.datasets.FashionMNIST(root, download=True)
......
......@@ -7,12 +7,10 @@ import numpy as np
import torch
import codecs
import string
import gzip
import lzma
from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.error import URLError
from .utils import download_url, download_and_extract_archive, extract_archive, \
verify_str_arg
from .utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity
import shutil
class MNIST(VisionDataset):
......@@ -81,6 +79,10 @@ class MNIST(VisionDataset):
target_transform=target_transform)
self.train = train # training set or test set
if self._check_legacy_exist():
self.data, self.targets = self._load_legacy_data()
return
if download:
self.download()
......@@ -88,11 +90,31 @@ class MNIST(VisionDataset):
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
if self.train:
data_file = self.training_file
else:
data_file = self.test_file
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
self.data, self.targets = self._load_data()
def _check_legacy_exist(self):
processed_folder_exists = os.path.exists(self.processed_folder)
if not processed_folder_exists:
return False
return all(
check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
)
def _load_legacy_data(self):
# This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
# directly.
data_file = self.training_file if self.train else self.test_file
return torch.load(os.path.join(self.processed_folder, data_file))
def _load_data(self):
image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
data = read_image_file(os.path.join(self.raw_folder, image_file))
label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
targets = read_label_file(os.path.join(self.raw_folder, label_file))
return data, targets
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
......@@ -132,19 +154,18 @@ class MNIST(VisionDataset):
return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self) -> bool:
return (os.path.exists(os.path.join(self.processed_folder,
self.training_file)) and
os.path.exists(os.path.join(self.processed_folder,
self.test_file)))
return all(
check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
for url, _ in self.resources
)
def download(self) -> None:
"""Download the MNIST data if it doesn't exist in processed_folder already."""
"""Download the MNIST data if it doesn't exist already."""
if self._check_exists():
return
os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True)
# download files
for filename, md5 in self.resources:
......@@ -168,24 +189,6 @@ class MNIST(VisionDataset):
else:
raise RuntimeError("Error downloading {}".format(filename))
# process and save as torch files
print('Processing...')
training_set = (
read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
)
test_set = (
read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
)
with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
torch.save(training_set, f)
with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
torch.save(test_set, f)
print('Done!')
def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test")
......@@ -298,44 +301,39 @@ class EMNIST(MNIST):
def _test_file(split) -> str:
return 'test_{}.pt'.format(split)
@property
def _file_prefix(self) -> str:
return f"emnist-{self.split}-{'train' if self.train else 'test'}"
@property
def images_file(self) -> str:
return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte")
@property
def labels_file(self) -> str:
return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte")
def _load_data(self):
return read_image_file(self.images_file), read_label_file(self.labels_file)
def _check_exists(self) -> bool:
return all(check_integrity(file) for file in (self.images_file, self.labels_file))
def download(self) -> None:
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
import shutil
"""Download the EMNIST data if it doesn't exist already."""
if self._check_exists():
return
os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True)
# download files
print('Downloading and extracting zip archive')
download_and_extract_archive(self.url, download_root=self.raw_folder, filename="emnist.zip",
remove_finished=True, md5=self.md5)
download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5)
gzip_folder = os.path.join(self.raw_folder, 'gzip')
for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith('.gz'):
extract_archive(os.path.join(gzip_folder, gzip_file), gzip_folder)
# process and save as torch files
for split in self.splits:
print('Processing ' + split)
training_set = (
read_image_file(os.path.join(gzip_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
read_label_file(os.path.join(gzip_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
)
test_set = (
read_image_file(os.path.join(gzip_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
read_label_file(os.path.join(gzip_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
)
with open(os.path.join(self.processed_folder, self._training_file(split)), 'wb') as f:
torch.save(training_set, f)
with open(os.path.join(self.processed_folder, self._test_file(split)), 'wb') as f:
torch.save(test_set, f)
extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder)
shutil.rmtree(gzip_folder)
print('Done!')
class QMNIST(MNIST):
"""`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.
......@@ -404,40 +402,51 @@ class QMNIST(MNIST):
self.test_file = self.data_file
super(QMNIST, self).__init__(root, train, **kwargs)
@property
def images_file(self) -> str:
(url, _), _ = self.resources[self.subsets[self.what]]
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
@property
def labels_file(self) -> str:
_, (url, _) = self.resources[self.subsets[self.what]]
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
def _check_exists(self) -> bool:
return all(check_integrity(file) for file in (self.images_file, self.labels_file))
def _load_data(self):
data = read_sn3_pascalvincent_tensor(self.images_file)
assert (data.dtype == torch.uint8)
assert (data.ndimension() == 3)
targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
assert (targets.ndimension() == 2)
if self.what == 'test10k':
data = data[0:10000, :, :].clone()
targets = targets[0:10000, :].clone()
elif self.what == 'test50k':
data = data[10000:, :, :].clone()
targets = targets[10000:, :].clone()
return data, targets
def download(self) -> None:
"""Download the QMNIST data if it doesn't exist in processed_folder already.
"""Download the QMNIST data if it doesn't exist already.
Note that we only download what has been asked for (argument 'what').
"""
if self._check_exists():
return
os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True)
split = self.resources[self.subsets[self.what]]
files = []
# download data files if not already there
for url, md5 in split:
filename = url.rpartition('/')[2]
file_path = os.path.join(self.raw_folder, filename)
if not os.path.isfile(file_path):
download_url(url, root=self.raw_folder, filename=filename, md5=md5)
files.append(file_path)
# process and save as torch files
print('Processing...')
data = read_sn3_pascalvincent_tensor(files[0])
assert(data.dtype == torch.uint8)
assert(data.ndimension() == 3)
targets = read_sn3_pascalvincent_tensor(files[1]).long()
assert(targets.ndimension() == 2)
if self.what == 'test10k':
data = data[0:10000, :, :].clone()
targets = targets[0:10000, :].clone()
if self.what == 'test50k':
data = data[10000:, :, :].clone()
targets = targets[10000:, :].clone()
with open(os.path.join(self.processed_folder, self.data_file), 'wb') as f:
torch.save((data, targets), f)
download_and_extract_archive(url, self.raw_folder, filename=filename, md5=md5)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
# redefined to handle the compat flag
......@@ -459,19 +468,6 @@ def get_int(b: bytes) -> int:
return int(codecs.encode(b, 'hex'), 16)
def open_maybe_compressed_file(path: Union[str, IO]) -> Union[IO, gzip.GzipFile]:
"""Return a file object that possibly decompresses 'path' on the fly.
Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'.
"""
if not isinstance(path, torch._six.string_classes):
return path
if path.endswith('.gz'):
return gzip.open(path, 'rb')
if path.endswith('.xz'):
return lzma.open(path, 'rb')
return open(path, 'rb')
SN3_PASCALVINCENT_TYPEMAP = {
8: (torch.uint8, np.uint8, np.uint8),
9: (torch.int8, np.int8, np.int8),
......@@ -482,12 +478,12 @@ SN3_PASCALVINCENT_TYPEMAP = {
}
def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) -> torch.Tensor:
def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
Argument may be a filename, compressed filename, or file object.
"""
# read
with open_maybe_compressed_file(path) as f:
with open(path, "rb") as f:
data = f.read()
# parse
magic = get_int(data[0:4])
......@@ -503,16 +499,14 @@ def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) ->
def read_label_file(path: str) -> torch.Tensor:
with open(path, 'rb') as f:
x = read_sn3_pascalvincent_tensor(f, strict=False)
x = read_sn3_pascalvincent_tensor(path, strict=False)
assert(x.dtype == torch.uint8)
assert(x.ndimension() == 1)
return x.long()
def read_image_file(path: str) -> torch.Tensor:
with open(path, 'rb') as f:
x = read_sn3_pascalvincent_tensor(f, strict=False)
x = read_sn3_pascalvincent_tensor(path, strict=False)
assert(x.dtype == torch.uint8)
assert(x.ndimension() == 3)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment