test_datasets_utils.py 10.2 KB
Newer Older
1
2
import contextlib
import gzip
3
import os
4
5
import pathlib
import re
6
import tarfile
7
import zipfile
Francisco Massa's avatar
Francisco Massa committed
8

9
import pytest
Philip Meier's avatar
Philip Meier committed
10
import torch
11
import torchvision.datasets.utils as utils
Philip Meier's avatar
Philip Meier committed
12
from common_utils import assert_equal
13
from torch._utils_internal import get_file_path_2
14
from torchvision.datasets.folder import make_dataset
15
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
16
17

TEST_FILE = get_file_path_2(
18
19
    os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
)
20

Francisco Massa's avatar
Francisco Massa committed
21

22
23
24
25
26
27
28
29
30
31
32
33
def patch_url_redirection(mocker, redirect_url):
    class Response:
        def __init__(self, url):
            self.url = url

    @contextlib.contextmanager
    def patched_opener(*args, **kwargs):
        yield Response(redirect_url)

    return mocker.patch("torchvision.datasets.utils.urllib.request.urlopen", side_effect=patched_opener)


34
class TestDatasetsUtils:
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    def test_get_redirect_url(self, mocker):
        url = "https://url.org"
        expected_redirect_url = "https://redirect.url.org"

        mock = patch_url_redirection(mocker, expected_redirect_url)

        actual = utils._get_redirect_url(url)
        assert actual == expected_redirect_url

        assert mock.call_count == 2
        call_args_1, call_args_2 = mock.call_args_list
        assert call_args_1[0][0].full_url == url
        assert call_args_2[0][0].full_url == expected_redirect_url

    def test_get_redirect_url_max_hops_exceeded(self, mocker):
        url = "https://url.org"
        redirect_url = "https://redirect.url.org"

        mock = patch_url_redirection(mocker, redirect_url)

        with pytest.raises(RecursionError):
            utils._get_redirect_url(url, max_hops=0)

        assert mock.call_count == 1
        assert mock.call_args[0][0].full_url == url
Francisco Massa's avatar
Francisco Massa committed
60

61
62
    @pytest.mark.parametrize("use_pathlib", (True, False))
    def test_check_md5(self, use_pathlib):
63
        fpath = TEST_FILE
64
65
        if use_pathlib:
            fpath = pathlib.Path(fpath)
66
67
        correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
        false_md5 = ""
68
69
        assert utils.check_md5(fpath, correct_md5)
        assert not utils.check_md5(fpath, false_md5)
70
71
72

    def test_check_integrity(self):
        existing_fpath = TEST_FILE
73
74
75
        nonexisting_fpath = ""
        correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
        false_md5 = ""
76
77
78
79
        assert utils.check_integrity(existing_fpath, correct_md5)
        assert not utils.check_integrity(existing_fpath, false_md5)
        assert utils.check_integrity(existing_fpath)
        assert not utils.check_integrity(nonexisting_fpath)
80

81
    def test_get_google_drive_file_id(self):
82
83
        url = "https://drive.google.com/file/d/1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV/view"
        expected = "1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV"
84
85
86
87
88
89
90
91
92

        actual = utils._get_google_drive_file_id(url)
        assert actual == expected

    def test_get_google_drive_file_id_invalid_url(self):
        url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"

        assert utils._get_google_drive_file_id(url) is None

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    @pytest.mark.parametrize(
        "file, expected",
        [
            ("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")),
            ("foo.tar.xz", (".tar.xz", ".tar", ".xz")),
            ("foo.tar", (".tar", ".tar", None)),
            ("foo.tar.gz", (".tar.gz", ".tar", ".gz")),
            ("foo.tbz", (".tbz", ".tar", ".bz2")),
            ("foo.tbz2", (".tbz2", ".tar", ".bz2")),
            ("foo.tgz", (".tgz", ".tar", ".gz")),
            ("foo.bz2", (".bz2", None, ".bz2")),
            ("foo.gz", (".gz", None, ".gz")),
            ("foo.zip", (".zip", ".zip", None)),
            ("foo.xz", (".xz", None, ".xz")),
            ("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")),
            ("foo.bar.gz", (".gz", None, ".gz")),
            ("foo.bar.zip", (".zip", ".zip", None)),
        ],
    )
112
113
114
    def test_detect_file_type(self, file, expected):
        assert utils._detect_file_type(file) == expected

115
    @pytest.mark.parametrize("file", ["foo", "foo.tar.baz", "foo.bar"])
116
117
118
119
120
    def test_detect_file_type_incompatible(self, file):
        # tests detect file type for no extension, unknown compression and unknown partial extension
        with pytest.raises(RuntimeError):
            utils._detect_file_type(file)

121
    @pytest.mark.parametrize("extension", [".bz2", ".gz", ".xz"])
122
123
    @pytest.mark.parametrize("use_pathlib", (True, False))
    def test_decompress(self, extension, tmpdir, use_pathlib):
124
125
        def create_compressed(root, content="this is the content"):
            file = os.path.join(root, "file")
126
127
            compressed = f"{file}{extension}"
            compressed_file_opener = _COMPRESSED_FILE_OPENERS[extension]
128

129
            with compressed_file_opener(compressed, "wb") as fh:
130
131
132
133
                fh.write(content.encode())

            return compressed, file, content

134
        compressed, file, content = create_compressed(tmpdir)
135
136
        if use_pathlib:
            compressed = pathlib.Path(compressed)
137

138
        utils._decompress(compressed)
139

140
        assert os.path.exists(file)
141

142
        with open(file) as fh:
143
            assert fh.read() == content
144
145

    def test_decompress_no_compression(self):
146
        with pytest.raises(RuntimeError):
147
148
            utils._decompress("foo.tar")

149
150
    @pytest.mark.parametrize("use_pathlib", (True, False))
    def test_decompress_remove_finished(self, tmpdir, use_pathlib):
151
152
153
154
155
156
157
158
159
        def create_compressed(root, content="this is the content"):
            file = os.path.join(root, "file")
            compressed = f"{file}.gz"

            with gzip.open(compressed, "wb") as fh:
                fh.write(content.encode())

            return compressed, file, content

160
        compressed, file, content = create_compressed(tmpdir)
161
162
163
164
        print(f"{type(compressed)=}")
        if use_pathlib:
            compressed = pathlib.Path(compressed)
            tmpdir = pathlib.Path(tmpdir)
165

166
        extracted_dir = utils.extract_archive(compressed, tmpdir, remove_finished=True)
167

168
        assert not os.path.exists(compressed)
169
170
171
172
173
174
        if use_pathlib:
            assert isinstance(extracted_dir, pathlib.Path)
            assert isinstance(compressed, pathlib.Path)
        else:
            assert isinstance(extracted_dir, str)
            assert isinstance(compressed, str)
175

176
177
    @pytest.mark.parametrize("extension", [".gz", ".xz"])
    @pytest.mark.parametrize("remove_finished", [True, False])
178
    def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker):
179
        filename = "foo"
180
181
182
183
184
185
        file = f"{filename}{extension}"

        mocked = mocker.patch("torchvision.datasets.utils._decompress")
        utils.extract_archive(file, remove_finished=remove_finished)

        mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)
186

187
188
    @pytest.mark.parametrize("use_pathlib", (True, False))
    def test_extract_zip(self, tmpdir, use_pathlib):
189
190
191
192
193
194
195
196
197
        def create_archive(root, content="this is the content"):
            file = os.path.join(root, "dst.txt")
            archive = os.path.join(root, "archive.zip")

            with zipfile.ZipFile(archive, "w") as zf:
                zf.writestr(os.path.basename(file), content)

            return archive, file, content

198
199
        if use_pathlib:
            tmpdir = pathlib.Path(tmpdir)
200
        archive, file, content = create_archive(tmpdir)
201

202
        utils.extract_archive(archive, tmpdir)
203

204
        assert os.path.exists(file)
205

206
        with open(file) as fh:
207
            assert fh.read() == content
208

209
210
211
    @pytest.mark.parametrize(
        "extension, mode", [(".tar", "w"), (".tar.gz", "w:gz"), (".tgz", "w:gz"), (".tar.xz", "w:xz")]
    )
212
213
    @pytest.mark.parametrize("use_pathlib", (True, False))
    def test_extract_tar(self, extension, mode, tmpdir, use_pathlib):
214
        def create_archive(root, extension, mode, content="this is the content"):
215
216
            src = os.path.join(root, "src.txt")
            dst = os.path.join(root, "dst.txt")
217
            archive = os.path.join(root, f"archive{extension}")
218
219
220
221
222
223
224
225
226

            with open(src, "w") as fh:
                fh.write(content)

            with tarfile.open(archive, mode=mode) as fh:
                fh.add(src, arcname=os.path.basename(dst))

            return archive, dst, content

227
228
        if use_pathlib:
            tmpdir = pathlib.Path(tmpdir)
229
        archive, file, content = create_archive(tmpdir, extension, mode)
230

231
        utils.extract_archive(archive, tmpdir)
232

233
        assert os.path.exists(file)
234

235
        with open(file) as fh:
236
            assert fh.read() == content
237

238
    def test_verify_str_arg(self):
239
240
241
        assert "a" == utils.verify_str_arg("a", "arg", ("a",))
        pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
        pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
242

Philip Meier's avatar
Philip Meier committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    @pytest.mark.parametrize(
        ("dtype", "actual_hex", "expected_hex"),
        [
            (torch.uint8, "01 23 45 67 89 AB CD EF", "01 23 45 67 89 AB CD EF"),
            (torch.float16, "01 23 45 67 89 AB CD EF", "23 01 67 45 AB 89 EF CD"),
            (torch.int32, "01 23 45 67 89 AB CD EF", "67 45 23 01 EF CD AB 89"),
            (torch.float64, "01 23 45 67 89 AB CD EF", "EF CD AB 89 67 45 23 01"),
        ],
    )
    def test_flip_byte_order(self, dtype, actual_hex, expected_hex):
        def to_tensor(hex):
            return torch.frombuffer(bytes.fromhex(hex), dtype=dtype)

        assert_equal(
            utils._flip_byte_order(to_tensor(actual_hex)),
            to_tensor(expected_hex),
        )

Francisco Massa's avatar
Francisco Massa committed
261

262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
@pytest.mark.parametrize(
    ("kwargs", "expected_error_msg"),
    [
        (dict(is_valid_file=lambda path: pathlib.Path(path).suffix in {".png", ".jpeg"}), "classes c"),
        (dict(extensions=".png"), re.escape("classes b, c. Supported extensions are: .png")),
        (dict(extensions=(".png", ".jpeg")), re.escape("classes c. Supported extensions are: .png, .jpeg")),
    ],
)
def test_make_dataset_no_valid_files(tmpdir, kwargs, expected_error_msg):
    tmpdir = pathlib.Path(tmpdir)

    (tmpdir / "a").mkdir()
    (tmpdir / "a" / "a.png").touch()

    (tmpdir / "b").mkdir()
    (tmpdir / "b" / "b.jpeg").touch()

    (tmpdir / "c").mkdir()
    (tmpdir / "c" / "c.unknown").touch()

    with pytest.raises(FileNotFoundError, match=expected_error_msg):
        make_dataset(str(tmpdir), **kwargs)


286
if __name__ == "__main__":
287
    pytest.main([__file__])