Unverified Commit aa365993 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add tests for MNIST and variants (#3423)

* add tests for MNIST and variants

* remove old tests and fakedata generation

* fix default config detection for if dataset has variable keywords

* use split="mnist" as default for EMNIST

* fix QMNIST tests

* lint

* fix special kwargs detection

* Revert "use split="mnist" as default for EMNIST"

This reverts commit 62c9b23597f4a391cad409cbd93edac1b3474acf.

* fix tests

* fix QMNIST test case name

* remove dead code from test

* Revert "remove old tests and fakedata generation"

This reverts commit a285b97c4827566a5bc06c288f5bba8d614a4f7a.

* remove old tests

* readd removed import
parent 0818c682
......@@ -10,8 +10,8 @@ from torch._utils_internal import get_file_path_2
import torchvision
from torchvision.datasets import utils
from common_utils import get_tmp_dir
from fakedata_generation import mnist_root, \
cityscapes_root, svhn_root, places365_root, widerface_root, stl10_root
from fakedata_generation import cityscapes_root, svhn_root, places365_root, widerface_root, stl10_root
import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen
import itertools
......@@ -119,37 +119,6 @@ class Tester(DatasetTestcase):
root, loader=lambda x: x, is_valid_file=lambda x: False
)
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
@mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
def test_mnist(self, mock_download_extract, mock_check_integrity):
num_examples = 30
with mnist_root(num_examples, "MNIST") as root:
dataset = torchvision.datasets.MNIST(root, download=True)
self.generic_classification_dataset_test(dataset, num_images=num_examples)
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
@mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
def test_kmnist(self, mock_download_extract, mock_check_integrity):
num_examples = 30
with mnist_root(num_examples, "KMNIST") as root:
dataset = torchvision.datasets.KMNIST(root, download=True)
self.generic_classification_dataset_test(dataset, num_images=num_examples)
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
@mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True)
def test_fashionmnist(self, mock_download_extract, mock_check_integrity):
num_examples = 30
with mnist_root(num_examples, "FashionMNIST") as root:
dataset = torchvision.datasets.FashionMNIST(root, download=True)
self.generic_classification_dataset_test(dataset, num_images=num_examples)
img, target = dataset[0]
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
@unittest.skipIf(sys.platform in ('win32', 'cygwin'), 'temporarily disabled on Windows')
def test_cityscapes(self):
with cityscapes_root() as root:
......@@ -1499,5 +1468,131 @@ class Flickr30kTestCase(Flickr8kTestCase):
fh.write(f"{image.name}#{idx}\t{caption}\n")
class MNISTTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.MNIST
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
_MAGIC_DTYPES = {
torch.uint8: 8,
torch.int8: 9,
torch.int16: 11,
torch.int32: 12,
torch.float32: 13,
torch.float64: 14,
}
_IMAGES_SIZE = (28, 28)
_IMAGES_DTYPE = torch.uint8
_LABELS_SIZE = ()
_LABELS_DTYPE = torch.uint8
def inject_fake_data(self, tmpdir, config):
raw_dir = pathlib.Path(tmpdir) / self.DATASET_CLASS.__name__ / "raw"
os.makedirs(raw_dir, exist_ok=True)
num_images = self._num_images(config)
self._create_binary_file(
raw_dir, self._images_file(config), (num_images, *self._IMAGES_SIZE), self._IMAGES_DTYPE
)
self._create_binary_file(
raw_dir, self._labels_file(config), (num_images, *self._LABELS_SIZE), self._LABELS_DTYPE
)
return num_images
def _num_images(self, config):
return 2 if config["train"] else 1
def _images_file(self, config):
return f"{self._prefix(config)}-images-idx3-ubyte"
def _labels_file(self, config):
return f"{self._prefix(config)}-labels-idx1-ubyte"
def _prefix(self, config):
return "train" if config["train"] else "t10k"
def _create_binary_file(self, root, filename, size, dtype):
with open(pathlib.Path(root) / filename, "wb") as fh:
for meta in (self._magic(dtype, len(size)), *size):
fh.write(self._encode(meta))
# If ever an MNIST variant is added that uses floating point data, this should be adapted.
data = torch.randint(0, torch.iinfo(dtype).max + 1, size, dtype=dtype)
fh.write(data.numpy().tobytes())
def _magic(self, dtype, dims):
return self._MAGIC_DTYPES[dtype] * 256 + dims
def _encode(self, v):
return torch.tensor(v, dtype=torch.int32).numpy().tobytes()[::-1]
class FashionMNISTTestCase(MNISTTestCase):
DATASET_CLASS = datasets.FashionMNIST
class KMNISTTestCase(MNISTTestCase):
DATASET_CLASS = datasets.KMNIST
class EMNISTTestCase(MNISTTestCase):
DATASET_CLASS = datasets.EMNIST
DEFAULT_CONFIG = dict(split="byclass")
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("byclass", "bymerge", "balanced", "letters", "digits", "mnist"), train=(True, False)
)
def _prefix(self, config):
return f"emnist-{config['split']}-{'train' if config['train'] else 'test'}"
class QMNISTTestCase(MNISTTestCase):
DATASET_CLASS = datasets.QMNIST
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(what=("train", "test", "test10k", "nist"))
_LABELS_SIZE = (8,)
_LABELS_DTYPE = torch.int32
def _num_images(self, config):
if config["what"] == "nist":
return 3
elif config["what"] == "train":
return 2
elif config["what"] == "test50k":
# The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create
# more than 10000 images for the dataset to not be empty. Since this takes significantly longer than the
# creation of all other splits, this is excluded from the 'ADDITIONAL_CONFIGS' and is tested only once in
# 'test_num_examples_test50k'.
return 10001
else:
return 1
def _labels_file(self, config):
return f"{self._prefix(config)}-labels-idx2-int"
def _prefix(self, config):
if config["what"] == "nist":
return "xnist"
if config["what"] is None:
what = "train" if config["train"] else "test"
elif config["what"].startswith("test"):
what = "test"
else:
what = config["what"]
return f"qmnist-{what}"
def test_num_examples_test50k(self):
with self.create_dataset(what="test50k") as (dataset, info):
# Since the split 'test50k' selects all images beginning from the index 10000, we subtract the number of
# created examples by this.
self.assertEqual(len(dataset), info["num_examples"] - 10000)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment