test_datasets_utils.py 10.2 KB
Newer Older
limm's avatar
limm committed
1
2
import contextlib
import gzip
3
import os
limm's avatar
limm committed
4
5
import pathlib
import re
6
import tarfile
limm's avatar
limm committed
7
8
9
10
11
12
import zipfile

import pytest
import torch
import torchvision.datasets.utils as utils
from common_utils import assert_equal
13
from torch._utils_internal import get_file_path_2
limm's avatar
limm committed
14
15
from torchvision.datasets.folder import make_dataset
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
Francisco Massa's avatar
Francisco Massa committed
16

limm's avatar
limm committed
17
18
19
TEST_FILE = get_file_path_2(
    os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
)
20
21


limm's avatar
limm committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def patch_url_redirection(mocker, redirect_url):
    class Response:
        def __init__(self, url):
            self.url = url

    @contextlib.contextmanager
    def patched_opener(*args, **kwargs):
        yield Response(redirect_url)

    return mocker.patch("torchvision.datasets.utils.urllib.request.urlopen", side_effect=patched_opener)


class TestDatasetsUtils:
    def test_get_redirect_url(self, mocker):
        url = "https://url.org"
        expected_redirect_url = "https://redirect.url.org"

        mock = patch_url_redirection(mocker, expected_redirect_url)
40

limm's avatar
limm committed
41
42
        actual = utils._get_redirect_url(url)
        assert actual == expected_redirect_url
Francisco Massa's avatar
Francisco Massa committed
43

limm's avatar
limm committed
44
45
46
47
        assert mock.call_count == 2
        call_args_1, call_args_2 = mock.call_args_list
        assert call_args_1[0][0].full_url == url
        assert call_args_2[0][0].full_url == expected_redirect_url
Francisco Massa's avatar
Francisco Massa committed
48

limm's avatar
limm committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    def test_get_redirect_url_max_hops_exceeded(self, mocker):
        url = "https://url.org"
        redirect_url = "https://redirect.url.org"

        mock = patch_url_redirection(mocker, redirect_url)

        with pytest.raises(RecursionError):
            utils._get_redirect_url(url, max_hops=0)

        assert mock.call_count == 1
        assert mock.call_args[0][0].full_url == url

    @pytest.mark.parametrize("use_pathlib", (True, False))
    def test_check_md5(self, use_pathlib):
63
        fpath = TEST_FILE
limm's avatar
limm committed
64
65
66
67
68
69
        if use_pathlib:
            fpath = pathlib.Path(fpath)
        correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
        false_md5 = ""
        assert utils.check_md5(fpath, correct_md5)
        assert not utils.check_md5(fpath, false_md5)
70
71
72

    def test_check_integrity(self):
        existing_fpath = TEST_FILE
limm's avatar
limm committed
73
74
75
76
77
78
79
        nonexisting_fpath = ""
        correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
        false_md5 = ""
        assert utils.check_integrity(existing_fpath, correct_md5)
        assert not utils.check_integrity(existing_fpath, false_md5)
        assert utils.check_integrity(existing_fpath)
        assert not utils.check_integrity(nonexisting_fpath)
80

81
    def test_get_google_drive_file_id(self):
limm's avatar
limm committed
82
83
        url = "https://drive.google.com/file/d/1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV/view"
        expected = "1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV"
84
85
86
87
88
89
90
91
92

        actual = utils._get_google_drive_file_id(url)
        assert actual == expected

    def test_get_google_drive_file_id_invalid_url(self):
        url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"

        assert utils._get_google_drive_file_id(url) is None

limm's avatar
limm committed
93
94
95
96
    @pytest.mark.parametrize(
        "file, expected",
        [
            ("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")),
97
98
99
            ("foo.tar.xz", (".tar.xz", ".tar", ".xz")),
            ("foo.tar", (".tar", ".tar", None)),
            ("foo.tar.gz", (".tar.gz", ".tar", ".gz")),
limm's avatar
limm committed
100
101
            ("foo.tbz", (".tbz", ".tar", ".bz2")),
            ("foo.tbz2", (".tbz2", ".tar", ".bz2")),
102
            ("foo.tgz", (".tgz", ".tar", ".gz")),
limm's avatar
limm committed
103
            ("foo.bz2", (".bz2", None, ".bz2")),
104
105
106
            ("foo.gz", (".gz", None, ".gz")),
            ("foo.zip", (".zip", ".zip", None)),
            ("foo.xz", (".xz", None, ".xz")),
limm's avatar
limm committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
            ("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")),
            ("foo.bar.gz", (".gz", None, ".gz")),
            ("foo.bar.zip", (".zip", ".zip", None)),
        ],
    )
    def test_detect_file_type(self, file, expected):
        assert utils._detect_file_type(file) == expected

    @pytest.mark.parametrize("file", ["foo", "foo.tar.baz", "foo.bar"])
    def test_detect_file_type_incompatible(self, file):
        # tests detect file type for no extension, unknown compression and unknown partial extension
        with pytest.raises(RuntimeError):
            utils._detect_file_type(file)

    @pytest.mark.parametrize("extension", [".bz2", ".gz", ".xz"])
    @pytest.mark.parametrize("use_pathlib", (True, False))
    def test_decompress(self, extension, tmpdir, use_pathlib):
124
125
        def create_compressed(root, content="this is the content"):
            file = os.path.join(root, "file")
limm's avatar
limm committed
126
127
            compressed = f"{file}{extension}"
            compressed_file_opener = _COMPRESSED_FILE_OPENERS[extension]
128

limm's avatar
limm committed
129
            with compressed_file_opener(compressed, "wb") as fh:
130
131
132
133
                fh.write(content.encode())

            return compressed, file, content

limm's avatar
limm committed
134
135
136
        compressed, file, content = create_compressed(tmpdir)
        if use_pathlib:
            compressed = pathlib.Path(compressed)
137

limm's avatar
limm committed
138
        utils._decompress(compressed)
139

limm's avatar
limm committed
140
        assert os.path.exists(file)
141

limm's avatar
limm committed
142
143
        with open(file) as fh:
            assert fh.read() == content
144
145

    def test_decompress_no_compression(self):
limm's avatar
limm committed
146
        with pytest.raises(RuntimeError):
147
148
            utils._decompress("foo.tar")

limm's avatar
limm committed
149
150
    @pytest.mark.parametrize("use_pathlib", (True, False))
    def test_decompress_remove_finished(self, tmpdir, use_pathlib):
151
152
153
154
155
156
157
158
159
        def create_compressed(root, content="this is the content"):
            file = os.path.join(root, "file")
            compressed = f"{file}.gz"

            with gzip.open(compressed, "wb") as fh:
                fh.write(content.encode())

            return compressed, file, content

limm's avatar
limm committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        compressed, file, content = create_compressed(tmpdir)
        print(f"{type(compressed)=}")
        if use_pathlib:
            compressed = pathlib.Path(compressed)
            tmpdir = pathlib.Path(tmpdir)

        extracted_dir = utils.extract_archive(compressed, tmpdir, remove_finished=True)

        assert not os.path.exists(compressed)
        if use_pathlib:
            assert isinstance(extracted_dir, pathlib.Path)
            assert isinstance(compressed, pathlib.Path)
        else:
            assert isinstance(extracted_dir, str)
            assert isinstance(compressed, str)

    @pytest.mark.parametrize("extension", [".gz", ".xz"])
    @pytest.mark.parametrize("remove_finished", [True, False])
    def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker):
        filename = "foo"
        file = f"{filename}{extension}"
181

limm's avatar
limm committed
182
183
        mocked = mocker.patch("torchvision.datasets.utils._decompress")
        utils.extract_archive(file, remove_finished=remove_finished)
184

limm's avatar
limm committed
185
        mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)
186

limm's avatar
limm committed
187
188
    @pytest.mark.parametrize("use_pathlib", (True, False))
    def test_extract_zip(self, tmpdir, use_pathlib):
189
190
191
192
193
194
195
196
197
        def create_archive(root, content="this is the content"):
            file = os.path.join(root, "dst.txt")
            archive = os.path.join(root, "archive.zip")

            with zipfile.ZipFile(archive, "w") as zf:
                zf.writestr(os.path.basename(file), content)

            return archive, file, content

limm's avatar
limm committed
198
199
200
        if use_pathlib:
            tmpdir = pathlib.Path(tmpdir)
        archive, file, content = create_archive(tmpdir)
201

limm's avatar
limm committed
202
        utils.extract_archive(archive, tmpdir)
203

limm's avatar
limm committed
204
        assert os.path.exists(file)
205

limm's avatar
limm committed
206
207
        with open(file) as fh:
            assert fh.read() == content
208

limm's avatar
limm committed
209
210
211
212
213
214
    @pytest.mark.parametrize(
        "extension, mode", [(".tar", "w"), (".tar.gz", "w:gz"), (".tgz", "w:gz"), (".tar.xz", "w:xz")]
    )
    @pytest.mark.parametrize("use_pathlib", (True, False))
    def test_extract_tar(self, extension, mode, tmpdir, use_pathlib):
        def create_archive(root, extension, mode, content="this is the content"):
215
216
            src = os.path.join(root, "src.txt")
            dst = os.path.join(root, "dst.txt")
limm's avatar
limm committed
217
            archive = os.path.join(root, f"archive{extension}")
218
219
220
221
222
223
224
225
226

            with open(src, "w") as fh:
                fh.write(content)

            with tarfile.open(archive, mode=mode) as fh:
                fh.add(src, arcname=os.path.basename(dst))

            return archive, dst, content

limm's avatar
limm committed
227
228
229
        if use_pathlib:
            tmpdir = pathlib.Path(tmpdir)
        archive, file, content = create_archive(tmpdir, extension, mode)
230

limm's avatar
limm committed
231
        utils.extract_archive(archive, tmpdir)
232

limm's avatar
limm committed
233
        assert os.path.exists(file)
234

limm's avatar
limm committed
235
236
        with open(file) as fh:
            assert fh.read() == content
237

238
    def test_verify_str_arg(self):
limm's avatar
limm committed
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
        assert "a" == utils.verify_str_arg("a", "arg", ("a",))
        pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
        pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")

    @pytest.mark.parametrize(
        ("dtype", "actual_hex", "expected_hex"),
        [
            (torch.uint8, "01 23 45 67 89 AB CD EF", "01 23 45 67 89 AB CD EF"),
            (torch.float16, "01 23 45 67 89 AB CD EF", "23 01 67 45 AB 89 EF CD"),
            (torch.int32, "01 23 45 67 89 AB CD EF", "67 45 23 01 EF CD AB 89"),
            (torch.float64, "01 23 45 67 89 AB CD EF", "EF CD AB 89 67 45 23 01"),
        ],
    )
    def test_flip_byte_order(self, dtype, actual_hex, expected_hex):
        def to_tensor(hex):
            return torch.frombuffer(bytes.fromhex(hex), dtype=dtype)

        assert_equal(
            utils._flip_byte_order(to_tensor(actual_hex)),
            to_tensor(expected_hex),
        )


@pytest.mark.parametrize(
    ("kwargs", "expected_error_msg"),
    [
        (dict(is_valid_file=lambda path: pathlib.Path(path).suffix in {".png", ".jpeg"}), "classes c"),
        (dict(extensions=".png"), re.escape("classes b, c. Supported extensions are: .png")),
        (dict(extensions=(".png", ".jpeg")), re.escape("classes c. Supported extensions are: .png, .jpeg")),
    ],
)
def test_make_dataset_no_valid_files(tmpdir, kwargs, expected_error_msg):
    tmpdir = pathlib.Path(tmpdir)

    (tmpdir / "a").mkdir()
    (tmpdir / "a" / "a.png").touch()

    (tmpdir / "b").mkdir()
    (tmpdir / "b" / "b.jpeg").touch()

    (tmpdir / "c").mkdir()
    (tmpdir / "c" / "c.unknown").touch()

    with pytest.raises(FileNotFoundError, match=expected_error_msg):
        make_dataset(str(tmpdir), **kwargs)


if __name__ == "__main__":
    pytest.main([__file__])