test_datasets_utils.py 7.18 KB
Newer Older
1
import os
Francisco Massa's avatar
Francisco Massa committed
2
3
import torchvision.datasets.utils as utils
import unittest
4
import unittest.mock
5
6
7
import zipfile
import tarfile
import gzip
8
9
import warnings
from torch._utils_internal import get_file_path_2
10
from urllib.error import URLError
Francisco Massa's avatar
Francisco Massa committed
11

12
13
14
15
from common_utils import get_tmp_dir


TEST_FILE = get_file_path_2(
16
    os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')
17

Francisco Massa's avatar
Francisco Massa committed
18
19
20

class Tester(unittest.TestCase):

21
22
23
24
    def test_check_md5(self):
        fpath = TEST_FILE
        correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
        false_md5 = ''
25
26
        self.assertTrue(utils.check_md5(fpath, correct_md5))
        self.assertFalse(utils.check_md5(fpath, false_md5))
27
28
29
30
31
32

    def test_check_integrity(self):
        existing_fpath = TEST_FILE
        nonexisting_fpath = ''
        correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
        false_md5 = ''
33
34
35
36
        self.assertTrue(utils.check_integrity(existing_fpath, correct_md5))
        self.assertFalse(utils.check_integrity(existing_fpath, false_md5))
        self.assertTrue(utils.check_integrity(existing_fpath))
        self.assertFalse(utils.check_integrity(nonexisting_fpath))
37

38
39
40
41
42
43
44
45
46
47
48
49
    def test_get_redirect_url(self):
        url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"
        expected = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"

        actual = utils._get_redirect_url(url)
        assert actual == expected

    def test_get_redirect_url_max_hops_exceeded(self):
        url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"
        with self.assertRaises(RecursionError):
            utils._get_redirect_url(url, max_hops=0)

50
51
52
53
54
55
56
57
58
59
60
61
    def test_get_google_drive_file_id(self):
        url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"
        expected = "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45"

        actual = utils._get_google_drive_file_id(url)
        assert actual == expected

    def test_get_google_drive_file_id_invalid_url(self):
        url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"

        assert utils._get_google_drive_file_id(url) is None

Francisco Massa's avatar
Francisco Massa committed
62
    def test_download_url(self):
63
64
65
66
67
68
69
70
71
        with get_tmp_dir() as temp_dir:
            url = "http://github.com/pytorch/vision/archive/master.zip"
            try:
                utils.download_url(url, temp_dir)
                self.assertFalse(len(os.listdir(temp_dir)) == 0)
            except URLError:
                msg = "could not download test file '{}'".format(url)
                warnings.warn(msg, RuntimeWarning)
                raise unittest.SkipTest(msg)
Francisco Massa's avatar
Francisco Massa committed
72
73

    def test_download_url_retry_http(self):
74
75
76
77
78
79
80
81
82
        with get_tmp_dir() as temp_dir:
            url = "https://github.com/pytorch/vision/archive/master.zip"
            try:
                utils.download_url(url, temp_dir)
                self.assertFalse(len(os.listdir(temp_dir)) == 0)
            except URLError:
                msg = "could not download test file '{}'".format(url)
                warnings.warn(msg, RuntimeWarning)
                raise unittest.SkipTest(msg)
Francisco Massa's avatar
Francisco Massa committed
83

84
85
86
87
88
89
    def test_download_url_dont_exist(self):
        with get_tmp_dir() as temp_dir:
            url = "http://github.com/pytorch/vision/archive/this_doesnt_exist.zip"
            with self.assertRaises(URLError):
                utils.download_url(url, temp_dir)

90
91
92
93
94
95
96
97
98
99
100
101
102
    @unittest.mock.patch("torchvision.datasets.utils.download_file_from_google_drive")
    def test_download_url_dispatch_download_from_google_drive(self, mock):
        url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"

        id = "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45"
        filename = "filename"
        md5 = "md5"

        with get_tmp_dir() as root:
            utils.download_url(url, root, filename, md5)

        mock.assert_called_once_with(id, root, filename, md5)

103
    def test_extract_zip(self):
104
105
106
107
108
109
110
111
112
        def create_archive(root, content="this is the content"):
            file = os.path.join(root, "dst.txt")
            archive = os.path.join(root, "archive.zip")

            with zipfile.ZipFile(archive, "w") as zf:
                zf.writestr(os.path.basename(file), content)

            return archive, file, content

113
        with get_tmp_dir() as temp_dir:
114
115
116
117
118
119
120
121
122
            archive, file, content = create_archive(temp_dir)

            utils.extract_archive(archive, temp_dir)

            self.assertTrue(os.path.exists(file))

            with open(file, "r") as fh:
                self.assertEqual(fh.read(), content)

123
    def test_extract_tar(self):
124
125
126
127
128
129
130
131
132
133
134
135
136
        def create_archive(root, ext, mode, content="this is the content"):
            src = os.path.join(root, "src.txt")
            dst = os.path.join(root, "dst.txt")
            archive = os.path.join(root, f"archive{ext}")

            with open(src, "w") as fh:
                fh.write(content)

            with tarfile.open(archive, mode=mode) as fh:
                fh.add(src, arcname=os.path.basename(dst))

            return archive, dst, content

137
        for ext, mode in zip(['.tar', '.tar.gz', '.tgz'], ['w', 'w:gz', 'w:gz']):
138
            with get_tmp_dir() as temp_dir:
139
140
141
142
143
144
145
146
147
                archive, file, content = create_archive(temp_dir, ext, mode)

                utils.extract_archive(archive, temp_dir)

                self.assertTrue(os.path.exists(file))

                with open(file, "r") as fh:
                    self.assertEqual(fh.read(), content)

Ardalan's avatar
Ardalan committed
148
    def test_extract_tar_xz(self):
149
150
151
152
153
154
155
156
157
158
159
160
161
        def create_archive(root, ext, mode, content="this is the content"):
            src = os.path.join(root, "src.txt")
            dst = os.path.join(root, "dst.txt")
            archive = os.path.join(root, f"archive{ext}")

            with open(src, "w") as fh:
                fh.write(content)

            with tarfile.open(archive, mode=mode) as fh:
                fh.add(src, arcname=os.path.basename(dst))

            return archive, dst, content

Ardalan's avatar
Ardalan committed
162
163
        for ext, mode in zip(['.tar.xz'], ['w:xz']):
            with get_tmp_dir() as temp_dir:
164
165
166
167
168
169
170
171
172
                archive, file, content = create_archive(temp_dir, ext, mode)

                utils.extract_archive(archive, temp_dir)

                self.assertTrue(os.path.exists(file))

                with open(file, "r") as fh:
                    self.assertEqual(fh.read(), content)

173
    def test_extract_gzip(self):
174
175
176
177
178
179
180
181
182
        def create_compressed(root, content="this is the content"):
            file = os.path.join(root, "file")
            compressed = f"{file}.gz"

            with gzip.GzipFile(compressed, "wb") as fh:
                fh.write(content.encode())

            return compressed, file, content

183
        with get_tmp_dir() as temp_dir:
184
185
186
187
188
189
190
191
            compressed, file, content = create_compressed(temp_dir)

            utils.extract_archive(compressed, temp_dir)

            self.assertTrue(os.path.exists(file))

            with open(file, "r") as fh:
                self.assertEqual(fh.read(), content)
192

193
194
195
196
197
    def test_verify_str_arg(self):
        self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",)))
        self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
        self.assertRaises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")

Francisco Massa's avatar
Francisco Massa committed
198
199
200

if __name__ == '__main__':
    unittest.main()