test_datasets.py 490 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from torchvision.datasets import MNIST, FashionMNIST
import unittest
import tempfile
import shutil
import os


class Tester(unittest.TestCase):
    def setUp(self):
        self.test_dir = tempfile.mkdtemp()

    def tearDown(self):
        shutil.rmtree(self.test_dir)

    def test_fashion_mnist_doesnt_load_mnist(self):
        MNIST(root=self.test_dir, download=True)
        FashionMNIST(root=self.test_dir, download=True)


if __name__ == '__main__':
    unittest.main()