test_datasets_utils.py 6.01 KB
Newer Older
1
import os
2
import sys
Francisco Massa's avatar
Francisco Massa committed
3
4
5
import tempfile
import torchvision.datasets.utils as utils
import unittest
6
7
8
import zipfile
import tarfile
import gzip
9
10
import warnings
from torch._utils_internal import get_file_path_2
Francisco Massa's avatar
Francisco Massa committed
11

12
13
14
15
16
17
18
19
20
21
from common_utils import get_tmp_dir

if sys.version_info < (3,):
    from urllib2 import URLError
else:
    from urllib.error import URLError


TEST_FILE = get_file_path_2(
    os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
22

Francisco Massa's avatar
Francisco Massa committed
23
24
25

class Tester(unittest.TestCase):

26
27
28
29
    def test_check_md5(self):
        fpath = TEST_FILE
        correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
        false_md5 = ''
30
31
        self.assertTrue(utils.check_md5(fpath, correct_md5))
        self.assertFalse(utils.check_md5(fpath, false_md5))
32
33
34
35
36
37

    def test_check_integrity(self):
        existing_fpath = TEST_FILE
        nonexisting_fpath = ''
        correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
        false_md5 = ''
38
39
40
41
        self.assertTrue(utils.check_integrity(existing_fpath, correct_md5))
        self.assertFalse(utils.check_integrity(existing_fpath, false_md5))
        self.assertTrue(utils.check_integrity(existing_fpath))
        self.assertFalse(utils.check_integrity(nonexisting_fpath))
42

Francisco Massa's avatar
Francisco Massa committed
43
    def test_download_url(self):
44
45
46
47
48
49
50
51
52
        with get_tmp_dir() as temp_dir:
            url = "http://github.com/pytorch/vision/archive/master.zip"
            try:
                utils.download_url(url, temp_dir)
                self.assertFalse(len(os.listdir(temp_dir)) == 0)
            except URLError:
                msg = "could not download test file '{}'".format(url)
                warnings.warn(msg, RuntimeWarning)
                raise unittest.SkipTest(msg)
Francisco Massa's avatar
Francisco Massa committed
53
54

    def test_download_url_retry_http(self):
55
56
57
58
59
60
61
62
63
        with get_tmp_dir() as temp_dir:
            url = "https://github.com/pytorch/vision/archive/master.zip"
            try:
                utils.download_url(url, temp_dir)
                self.assertFalse(len(os.listdir(temp_dir)) == 0)
            except URLError:
                msg = "could not download test file '{}'".format(url)
                warnings.warn(msg, RuntimeWarning)
                raise unittest.SkipTest(msg)
Francisco Massa's avatar
Francisco Massa committed
64

65
66
67
68
69
70
71
    @unittest.skipIf(sys.version_info < (3,), "Python2 doesn't raise error")
    def test_download_url_dont_exist(self):
        with get_tmp_dir() as temp_dir:
            url = "http://github.com/pytorch/vision/archive/this_doesnt_exist.zip"
            with self.assertRaises(URLError):
                utils.download_url(url, temp_dir)

Francisco Massa's avatar
Francisco Massa committed
72
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
73
    def test_extract_zip(self):
74
75
76
77
78
79
80
81
82
        with get_tmp_dir() as temp_dir:
            with tempfile.NamedTemporaryFile(suffix='.zip') as f:
                with zipfile.ZipFile(f, 'w') as zf:
                    zf.writestr('file.tst', 'this is the content')
                utils.extract_archive(f.name, temp_dir)
                self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
                with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
                    data = nf.read()
                self.assertEqual(data, 'this is the content')
83

Francisco Massa's avatar
Francisco Massa committed
84
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
85
    def test_extract_tar(self):
86
        for ext, mode in zip(['.tar', '.tar.gz', '.tgz'], ['w', 'w:gz', 'w:gz']):
87
88
89
90
91
92
93
94
95
96
97
98
            with get_tmp_dir() as temp_dir:
                with tempfile.NamedTemporaryFile() as bf:
                    bf.write("this is the content".encode())
                    bf.seek(0)
                    with tempfile.NamedTemporaryFile(suffix=ext) as f:
                        with tarfile.open(f.name, mode=mode) as zf:
                            zf.add(bf.name, arcname='file.tst')
                        utils.extract_archive(f.name, temp_dir)
                        self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
                        with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
                            data = nf.read()
                        self.assertEqual(data, 'this is the content')
Ardalan's avatar
Ardalan committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
    @unittest.skipIf(sys.version_info < (3,), "Extracting .tar.xz files is not supported under Python 2.x")
    def test_extract_tar_xz(self):
        for ext, mode in zip(['.tar.xz'], ['w:xz']):
            with get_tmp_dir() as temp_dir:
                with tempfile.NamedTemporaryFile() as bf:
                    bf.write("this is the content".encode())
                    bf.seek(0)
                    with tempfile.NamedTemporaryFile(suffix=ext) as f:
                        with tarfile.open(f.name, mode=mode) as zf:
                            zf.add(bf.name, arcname='file.tst')
                        utils.extract_archive(f.name, temp_dir)
                        self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
                        with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
                            data = nf.read()
                        self.assertEqual(data, 'this is the content')
116

Francisco Massa's avatar
Francisco Massa committed
117
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
118
    def test_extract_gzip(self):
119
120
121
122
123
124
125
126
127
128
        with get_tmp_dir() as temp_dir:
            with tempfile.NamedTemporaryFile(suffix='.gz') as f:
                with gzip.GzipFile(f.name, 'wb') as zf:
                    zf.write('this is the content'.encode())
                utils.extract_archive(f.name, temp_dir)
                f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0])
                self.assertTrue(os.path.exists(f_name))
                with open(os.path.join(f_name), 'r') as nf:
                    data = nf.read()
                self.assertEqual(data, 'this is the content')
129

130
131
132
133
134
    def test_verify_str_arg(self):
        self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",)))
        self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
        self.assertRaises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")

Francisco Massa's avatar
Francisco Massa committed
135
136
137

if __name__ == '__main__':
    unittest.main()