Unverified Commit 8a64dbcd authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Mock MNIST download for less flaky tests (#1004)

parent de387e8c
......@@ -5,6 +5,7 @@ import tempfile
import unittest
import mock
import PIL
import torch
import torchvision
FAKEDATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)),
......@@ -23,6 +24,41 @@ def tmp_dir(src=None, **kwargs):
shutil.rmtree(tmp_dir)
@contextlib.contextmanager
def get_mnist_data(num_images, cls_name, **kwargs):
def _encode(v):
return torch.tensor(v, dtype=torch.int32).numpy().tobytes()[::-1]
def _make_image_file(filename, num_images):
img = torch.randint(0, 255, size=(28 * 28 * num_images,), dtype=torch.uint8)
with open(filename, "wb") as f:
f.write(_encode(2051)) # magic header
f.write(_encode(num_images))
f.write(_encode(28))
f.write(_encode(28))
f.write(img.numpy().tobytes())
def _make_label_file(filename, num_images):
labels = torch.randint(0, 10, size=(num_images,), dtype=torch.uint8)
with open(filename, "wb") as f:
f.write(_encode(2049)) # magic header
f.write(_encode(num_images))
f.write(labels.numpy().tobytes())
tmp_dir = tempfile.mkdtemp(**kwargs)
raw_dir = os.path.join(tmp_dir, cls_name, "raw")
os.makedirs(raw_dir)
_make_image_file(os.path.join(raw_dir, "train-images-idx3-ubyte"), num_images)
_make_label_file(os.path.join(raw_dir, "train-labels-idx1-ubyte"), num_images)
_make_image_file(os.path.join(raw_dir, "t10k-images-idx3-ubyte"), num_images)
_make_label_file(os.path.join(raw_dir, "t10k-labels-idx1-ubyte"), num_images)
try:
yield tmp_dir
finally:
shutil.rmtree(tmp_dir)
class Tester(unittest.TestCase):
def test_imagefolder(self):
......@@ -70,25 +106,33 @@ class Tester(unittest.TestCase):
outputs = sorted([dataset[i] for i in range(len(dataset))])
self.assertEqual(imgs, outputs)
def test_mnist(self):
with tmp_dir() as root:
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
def test_mnist(self, mock_download_extract):
num_examples = 30
with get_mnist_data(num_examples, "MNIST") as root:
dataset = torchvision.datasets.MNIST(root, download=True)
self.assertEqual(len(dataset), 60000)
self.assertEqual(len(dataset), num_examples)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
def test_kmnist(self):
with tmp_dir() as root:
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
def test_kmnist(self, mock_download_extract):
num_examples = 30
with get_mnist_data(num_examples, "KMNIST") as root:
dataset = torchvision.datasets.KMNIST(root, download=True)
img, target = dataset[0]
self.assertEqual(len(dataset), num_examples)
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
def test_fashionmnist(self):
with tmp_dir() as root:
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
def test_fashionmnist(self, mock_download_extract):
num_examples = 30
with get_mnist_data(num_examples, "FashionMNIST") as root:
dataset = torchvision.datasets.FashionMNIST(root, download=True)
img, target = dataset[0]
self.assertEqual(len(dataset), num_examples)
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment