test_datasets_utils.py 3.58 KB
Newer Older
1
import os
Francisco Massa's avatar
Francisco Massa committed
2
3
4
5
import shutil
import tempfile
import torchvision.datasets.utils as utils
import unittest
6
7
8
import zipfile
import tarfile
import gzip
Francisco Massa's avatar
Francisco Massa committed
9

10
11
12
TEST_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                         'assets', 'grace_hopper_517x606.jpg')

Francisco Massa's avatar
Francisco Massa committed
13
14
15

class Tester(unittest.TestCase):

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    def test_check_md5(self):
        fpath = TEST_FILE
        correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
        false_md5 = ''
        assert utils.check_md5(fpath, correct_md5)
        assert not utils.check_md5(fpath, false_md5)

    def test_check_integrity(self):
        existing_fpath = TEST_FILE
        nonexisting_fpath = ''
        correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
        false_md5 = ''
        assert utils.check_integrity(existing_fpath, correct_md5)
        assert not utils.check_integrity(existing_fpath, false_md5)
        assert utils.check_integrity(existing_fpath)
        assert not utils.check_integrity(nonexisting_fpath)

Francisco Massa's avatar
Francisco Massa committed
33
34
35
36
    def test_download_url(self):
        temp_dir = tempfile.mkdtemp()
        url = "http://github.com/pytorch/vision/archive/master.zip"
        utils.download_url(url, temp_dir)
37
        assert not len(os.listdir(temp_dir)) == 0, 'The downloaded root directory is empty after download.'
Francisco Massa's avatar
Francisco Massa committed
38
39
40
41
42
43
        shutil.rmtree(temp_dir)

    def test_download_url_retry_http(self):
        temp_dir = tempfile.mkdtemp()
        url = "https://github.com/pytorch/vision/archive/master.zip"
        utils.download_url(url, temp_dir)
44
        assert not len(os.listdir(temp_dir)) == 0, 'The downloaded root directory is empty after download.'
Francisco Massa's avatar
Francisco Massa committed
45
46
        shutil.rmtree(temp_dir)

47
48
49
50
51
    def test_extract_zip(self):
        temp_dir = tempfile.mkdtemp()
        with tempfile.NamedTemporaryFile(suffix='.zip') as f:
            with zipfile.ZipFile(f, 'w') as zf:
                zf.writestr('file.tst', 'this is the content')
52
            utils.extract_archive(f.name, temp_dir)
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
            assert os.path.exists(os.path.join(temp_dir, 'file.tst'))
            with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
                data = nf.read()
            assert data == 'this is the content'
        shutil.rmtree(temp_dir)

    def test_extract_tar(self):
        for ext, mode in zip(['.tar', '.tar.gz'], ['w', 'w:gz']):
            temp_dir = tempfile.mkdtemp()
            with tempfile.NamedTemporaryFile() as bf:
                bf.write("this is the content".encode())
                bf.seek(0)
                with tempfile.NamedTemporaryFile(suffix=ext) as f:
                    with tarfile.open(f.name, mode=mode) as zf:
                        zf.add(bf.name, arcname='file.tst')
68
                    utils.extract_archive(f.name, temp_dir)
69
70
71
72
73
74
75
76
77
78
79
                    assert os.path.exists(os.path.join(temp_dir, 'file.tst'))
                    with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
                        data = nf.read()
                    assert data == 'this is the content', data
            shutil.rmtree(temp_dir)

    def test_extract_gzip(self):
        temp_dir = tempfile.mkdtemp()
        with tempfile.NamedTemporaryFile(suffix='.gz') as f:
            with gzip.GzipFile(f.name, 'wb') as zf:
                zf.write('this is the content'.encode())
80
            utils.extract_archive(f.name, temp_dir)
81
82
83
84
85
86
87
            f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0])
            assert os.path.exists(f_name)
            with open(os.path.join(f_name), 'r') as nf:
                data = nf.read()
            assert data == 'this is the content', data
        shutil.rmtree(temp_dir)

Francisco Massa's avatar
Francisco Massa committed
88
89
90

if __name__ == '__main__':
    unittest.main()