test_datasets_utils.py 7.54 KB
Newer Older
1
2
import contextlib
import gzip
3
import os
4
import tarfile
5
import zipfile
Francisco Massa's avatar
Francisco Massa committed
6

7
8
9
import pytest
import torchvision.datasets.utils as utils
from torch._utils_internal import get_file_path_2
10
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
11
12
13


TEST_FILE = get_file_path_2(
14
15
    os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
)
16

Francisco Massa's avatar
Francisco Massa committed
17

18
19
20
21
22
23
24
25
26
27
28
29
def patch_url_redirection(mocker, redirect_url):
    class Response:
        def __init__(self, url):
            self.url = url

    @contextlib.contextmanager
    def patched_opener(*args, **kwargs):
        yield Response(redirect_url)

    return mocker.patch("torchvision.datasets.utils.urllib.request.urlopen", side_effect=patched_opener)


30
class TestDatasetsUtils:
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    def test_get_redirect_url(self, mocker):
        url = "https://url.org"
        expected_redirect_url = "https://redirect.url.org"

        mock = patch_url_redirection(mocker, expected_redirect_url)

        actual = utils._get_redirect_url(url)
        assert actual == expected_redirect_url

        assert mock.call_count == 2
        call_args_1, call_args_2 = mock.call_args_list
        assert call_args_1[0][0].full_url == url
        assert call_args_2[0][0].full_url == expected_redirect_url

    def test_get_redirect_url_max_hops_exceeded(self, mocker):
        url = "https://url.org"
        redirect_url = "https://redirect.url.org"

        mock = patch_url_redirection(mocker, redirect_url)

        with pytest.raises(RecursionError):
            utils._get_redirect_url(url, max_hops=0)

        assert mock.call_count == 1
        assert mock.call_args[0][0].full_url == url
Francisco Massa's avatar
Francisco Massa committed
56

57
58
    def test_check_md5(self):
        fpath = TEST_FILE
59
60
        correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
        false_md5 = ""
61
62
        assert utils.check_md5(fpath, correct_md5)
        assert not utils.check_md5(fpath, false_md5)
63
64
65

    def test_check_integrity(self):
        existing_fpath = TEST_FILE
66
67
68
        nonexisting_fpath = ""
        correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
        false_md5 = ""
69
70
71
72
        assert utils.check_integrity(existing_fpath, correct_md5)
        assert not utils.check_integrity(existing_fpath, false_md5)
        assert utils.check_integrity(existing_fpath)
        assert not utils.check_integrity(nonexisting_fpath)
73

74
    def test_get_google_drive_file_id(self):
75
76
        url = "https://drive.google.com/file/d/1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV/view"
        expected = "1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV"
77
78
79
80
81
82
83
84
85

        actual = utils._get_google_drive_file_id(url)
        assert actual == expected

    def test_get_google_drive_file_id_invalid_url(self):
        url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"

        assert utils._get_google_drive_file_id(url) is None

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    @pytest.mark.parametrize(
        "file, expected",
        [
            ("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")),
            ("foo.tar.xz", (".tar.xz", ".tar", ".xz")),
            ("foo.tar", (".tar", ".tar", None)),
            ("foo.tar.gz", (".tar.gz", ".tar", ".gz")),
            ("foo.tbz", (".tbz", ".tar", ".bz2")),
            ("foo.tbz2", (".tbz2", ".tar", ".bz2")),
            ("foo.tgz", (".tgz", ".tar", ".gz")),
            ("foo.bz2", (".bz2", None, ".bz2")),
            ("foo.gz", (".gz", None, ".gz")),
            ("foo.zip", (".zip", ".zip", None)),
            ("foo.xz", (".xz", None, ".xz")),
            ("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")),
            ("foo.bar.gz", (".gz", None, ".gz")),
            ("foo.bar.zip", (".zip", ".zip", None)),
        ],
    )
105
106
107
    def test_detect_file_type(self, file, expected):
        assert utils._detect_file_type(file) == expected

108
    @pytest.mark.parametrize("file", ["foo", "foo.tar.baz", "foo.bar"])
109
110
111
112
113
    def test_detect_file_type_incompatible(self, file):
        # tests detect file type for no extension, unknown compression and unknown partial extension
        with pytest.raises(RuntimeError):
            utils._detect_file_type(file)

114
    @pytest.mark.parametrize("extension", [".bz2", ".gz", ".xz"])
115
    def test_decompress(self, extension, tmpdir):
116
117
        def create_compressed(root, content="this is the content"):
            file = os.path.join(root, "file")
118
119
            compressed = f"{file}{extension}"
            compressed_file_opener = _COMPRESSED_FILE_OPENERS[extension]
120

121
            with compressed_file_opener(compressed, "wb") as fh:
122
123
124
125
                fh.write(content.encode())

            return compressed, file, content

126
        compressed, file, content = create_compressed(tmpdir)
127

128
        utils._decompress(compressed)
129

130
        assert os.path.exists(file)
131

132
        with open(file) as fh:
133
            assert fh.read() == content
134
135

    def test_decompress_no_compression(self):
136
        with pytest.raises(RuntimeError):
137
138
            utils._decompress("foo.tar")

139
    def test_decompress_remove_finished(self, tmpdir):
140
141
142
143
144
145
146
147
148
        def create_compressed(root, content="this is the content"):
            file = os.path.join(root, "file")
            compressed = f"{file}.gz"

            with gzip.open(compressed, "wb") as fh:
                fh.write(content.encode())

            return compressed, file, content

149
        compressed, file, content = create_compressed(tmpdir)
150

151
        utils.extract_archive(compressed, tmpdir, remove_finished=True)
152

153
        assert not os.path.exists(compressed)
154

155
156
    @pytest.mark.parametrize("extension", [".gz", ".xz"])
    @pytest.mark.parametrize("remove_finished", [True, False])
157
    def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker):
158
        filename = "foo"
159
160
161
162
163
164
        file = f"{filename}{extension}"

        mocked = mocker.patch("torchvision.datasets.utils._decompress")
        utils.extract_archive(file, remove_finished=remove_finished)

        mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)
165

166
    def test_extract_zip(self, tmpdir):
167
168
169
170
171
172
173
174
175
        def create_archive(root, content="this is the content"):
            file = os.path.join(root, "dst.txt")
            archive = os.path.join(root, "archive.zip")

            with zipfile.ZipFile(archive, "w") as zf:
                zf.writestr(os.path.basename(file), content)

            return archive, file, content

176
        archive, file, content = create_archive(tmpdir)
177

178
        utils.extract_archive(archive, tmpdir)
179

180
        assert os.path.exists(file)
181

182
        with open(file) as fh:
183
            assert fh.read() == content
184

185
186
187
    @pytest.mark.parametrize(
        "extension, mode", [(".tar", "w"), (".tar.gz", "w:gz"), (".tgz", "w:gz"), (".tar.xz", "w:xz")]
    )
188
    def test_extract_tar(self, extension, mode, tmpdir):
189
        def create_archive(root, extension, mode, content="this is the content"):
190
191
            src = os.path.join(root, "src.txt")
            dst = os.path.join(root, "dst.txt")
192
            archive = os.path.join(root, f"archive{extension}")
193
194
195
196
197
198
199
200
201

            with open(src, "w") as fh:
                fh.write(content)

            with tarfile.open(archive, mode=mode) as fh:
                fh.add(src, arcname=os.path.basename(dst))

            return archive, dst, content

202
        archive, file, content = create_archive(tmpdir, extension, mode)
203

204
        utils.extract_archive(archive, tmpdir)
205

206
        assert os.path.exists(file)
207

208
        with open(file) as fh:
209
            assert fh.read() == content
210

211
    def test_verify_str_arg(self):
212
213
214
        assert "a" == utils.verify_str_arg("a", "arg", ("a",))
        pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
        pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
215

Francisco Massa's avatar
Francisco Massa committed
216

217
if __name__ == "__main__":
218
    pytest.main([__file__])