test_datasets_utils.py 9.17 KB
Newer Older
1
2
import contextlib
import gzip
3
import os
4
5
import pathlib
import re
6
import tarfile
7
import zipfile
Francisco Massa's avatar
Francisco Massa committed
8

9
import pytest
Philip Meier's avatar
Philip Meier committed
10
import torch
11
import torchvision.datasets.utils as utils
Philip Meier's avatar
Philip Meier committed
12
from common_utils import assert_equal
13
from torch._utils_internal import get_file_path_2
14
from torchvision.datasets.folder import make_dataset
15
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
16
17

TEST_FILE = get_file_path_2(
18
19
    os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
)
20

Francisco Massa's avatar
Francisco Massa committed
21

22
23
24
25
26
27
28
29
30
31
32
33
def patch_url_redirection(mocker, redirect_url):
    class Response:
        def __init__(self, url):
            self.url = url

    @contextlib.contextmanager
    def patched_opener(*args, **kwargs):
        yield Response(redirect_url)

    return mocker.patch("torchvision.datasets.utils.urllib.request.urlopen", side_effect=patched_opener)


34
class TestDatasetsUtils:
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    def test_get_redirect_url(self, mocker):
        url = "https://url.org"
        expected_redirect_url = "https://redirect.url.org"

        mock = patch_url_redirection(mocker, expected_redirect_url)

        actual = utils._get_redirect_url(url)
        assert actual == expected_redirect_url

        assert mock.call_count == 2
        call_args_1, call_args_2 = mock.call_args_list
        assert call_args_1[0][0].full_url == url
        assert call_args_2[0][0].full_url == expected_redirect_url

    def test_get_redirect_url_max_hops_exceeded(self, mocker):
        url = "https://url.org"
        redirect_url = "https://redirect.url.org"

        mock = patch_url_redirection(mocker, redirect_url)

        with pytest.raises(RecursionError):
            utils._get_redirect_url(url, max_hops=0)

        assert mock.call_count == 1
        assert mock.call_args[0][0].full_url == url
Francisco Massa's avatar
Francisco Massa committed
60

61
62
    def test_check_md5(self):
        fpath = TEST_FILE
63
64
        correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
        false_md5 = ""
65
66
        assert utils.check_md5(fpath, correct_md5)
        assert not utils.check_md5(fpath, false_md5)
67
68
69

    def test_check_integrity(self):
        existing_fpath = TEST_FILE
70
71
72
        nonexisting_fpath = ""
        correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
        false_md5 = ""
73
74
75
76
        assert utils.check_integrity(existing_fpath, correct_md5)
        assert not utils.check_integrity(existing_fpath, false_md5)
        assert utils.check_integrity(existing_fpath)
        assert not utils.check_integrity(nonexisting_fpath)
77

78
    def test_get_google_drive_file_id(self):
79
80
        url = "https://drive.google.com/file/d/1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV/view"
        expected = "1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV"
81
82
83
84
85
86
87
88
89

        actual = utils._get_google_drive_file_id(url)
        assert actual == expected

    def test_get_google_drive_file_id_invalid_url(self):
        url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"

        assert utils._get_google_drive_file_id(url) is None

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    @pytest.mark.parametrize(
        "file, expected",
        [
            ("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")),
            ("foo.tar.xz", (".tar.xz", ".tar", ".xz")),
            ("foo.tar", (".tar", ".tar", None)),
            ("foo.tar.gz", (".tar.gz", ".tar", ".gz")),
            ("foo.tbz", (".tbz", ".tar", ".bz2")),
            ("foo.tbz2", (".tbz2", ".tar", ".bz2")),
            ("foo.tgz", (".tgz", ".tar", ".gz")),
            ("foo.bz2", (".bz2", None, ".bz2")),
            ("foo.gz", (".gz", None, ".gz")),
            ("foo.zip", (".zip", ".zip", None)),
            ("foo.xz", (".xz", None, ".xz")),
            ("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")),
            ("foo.bar.gz", (".gz", None, ".gz")),
            ("foo.bar.zip", (".zip", ".zip", None)),
        ],
    )
109
110
111
    def test_detect_file_type(self, file, expected):
        assert utils._detect_file_type(file) == expected

112
    @pytest.mark.parametrize("file", ["foo", "foo.tar.baz", "foo.bar"])
113
114
115
116
117
    def test_detect_file_type_incompatible(self, file):
        # tests detect file type for no extension, unknown compression and unknown partial extension
        with pytest.raises(RuntimeError):
            utils._detect_file_type(file)

118
    @pytest.mark.parametrize("extension", [".bz2", ".gz", ".xz"])
119
    def test_decompress(self, extension, tmpdir):
120
121
        def create_compressed(root, content="this is the content"):
            file = os.path.join(root, "file")
122
123
            compressed = f"{file}{extension}"
            compressed_file_opener = _COMPRESSED_FILE_OPENERS[extension]
124

125
            with compressed_file_opener(compressed, "wb") as fh:
126
127
128
129
                fh.write(content.encode())

            return compressed, file, content

130
        compressed, file, content = create_compressed(tmpdir)
131

132
        utils._decompress(compressed)
133

134
        assert os.path.exists(file)
135

136
        with open(file) as fh:
137
            assert fh.read() == content
138
139

    def test_decompress_no_compression(self):
140
        with pytest.raises(RuntimeError):
141
142
            utils._decompress("foo.tar")

143
    def test_decompress_remove_finished(self, tmpdir):
144
145
146
147
148
149
150
151
152
        def create_compressed(root, content="this is the content"):
            file = os.path.join(root, "file")
            compressed = f"{file}.gz"

            with gzip.open(compressed, "wb") as fh:
                fh.write(content.encode())

            return compressed, file, content

153
        compressed, file, content = create_compressed(tmpdir)
154

155
        utils.extract_archive(compressed, tmpdir, remove_finished=True)
156

157
        assert not os.path.exists(compressed)
158

159
160
    @pytest.mark.parametrize("extension", [".gz", ".xz"])
    @pytest.mark.parametrize("remove_finished", [True, False])
161
    def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker):
162
        filename = "foo"
163
164
165
166
167
168
        file = f"{filename}{extension}"

        mocked = mocker.patch("torchvision.datasets.utils._decompress")
        utils.extract_archive(file, remove_finished=remove_finished)

        mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)
169

170
    def test_extract_zip(self, tmpdir):
171
172
173
174
175
176
177
178
179
        def create_archive(root, content="this is the content"):
            file = os.path.join(root, "dst.txt")
            archive = os.path.join(root, "archive.zip")

            with zipfile.ZipFile(archive, "w") as zf:
                zf.writestr(os.path.basename(file), content)

            return archive, file, content

180
        archive, file, content = create_archive(tmpdir)
181

182
        utils.extract_archive(archive, tmpdir)
183

184
        assert os.path.exists(file)
185

186
        with open(file) as fh:
187
            assert fh.read() == content
188

189
190
191
    @pytest.mark.parametrize(
        "extension, mode", [(".tar", "w"), (".tar.gz", "w:gz"), (".tgz", "w:gz"), (".tar.xz", "w:xz")]
    )
192
    def test_extract_tar(self, extension, mode, tmpdir):
193
        def create_archive(root, extension, mode, content="this is the content"):
194
195
            src = os.path.join(root, "src.txt")
            dst = os.path.join(root, "dst.txt")
196
            archive = os.path.join(root, f"archive{extension}")
197
198
199
200
201
202
203
204
205

            with open(src, "w") as fh:
                fh.write(content)

            with tarfile.open(archive, mode=mode) as fh:
                fh.add(src, arcname=os.path.basename(dst))

            return archive, dst, content

206
        archive, file, content = create_archive(tmpdir, extension, mode)
207

208
        utils.extract_archive(archive, tmpdir)
209

210
        assert os.path.exists(file)
211

212
        with open(file) as fh:
213
            assert fh.read() == content
214

215
    def test_verify_str_arg(self):
216
217
218
        assert "a" == utils.verify_str_arg("a", "arg", ("a",))
        pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
        pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
219

Philip Meier's avatar
Philip Meier committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    @pytest.mark.parametrize(
        ("dtype", "actual_hex", "expected_hex"),
        [
            (torch.uint8, "01 23 45 67 89 AB CD EF", "01 23 45 67 89 AB CD EF"),
            (torch.float16, "01 23 45 67 89 AB CD EF", "23 01 67 45 AB 89 EF CD"),
            (torch.int32, "01 23 45 67 89 AB CD EF", "67 45 23 01 EF CD AB 89"),
            (torch.float64, "01 23 45 67 89 AB CD EF", "EF CD AB 89 67 45 23 01"),
        ],
    )
    def test_flip_byte_order(self, dtype, actual_hex, expected_hex):
        def to_tensor(hex):
            return torch.frombuffer(bytes.fromhex(hex), dtype=dtype)

        assert_equal(
            utils._flip_byte_order(to_tensor(actual_hex)),
            to_tensor(expected_hex),
        )

Francisco Massa's avatar
Francisco Massa committed
238

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
@pytest.mark.parametrize(
    ("kwargs", "expected_error_msg"),
    [
        (dict(is_valid_file=lambda path: pathlib.Path(path).suffix in {".png", ".jpeg"}), "classes c"),
        (dict(extensions=".png"), re.escape("classes b, c. Supported extensions are: .png")),
        (dict(extensions=(".png", ".jpeg")), re.escape("classes c. Supported extensions are: .png, .jpeg")),
    ],
)
def test_make_dataset_no_valid_files(tmpdir, kwargs, expected_error_msg):
    tmpdir = pathlib.Path(tmpdir)

    (tmpdir / "a").mkdir()
    (tmpdir / "a" / "a.png").touch()

    (tmpdir / "b").mkdir()
    (tmpdir / "b" / "b.jpeg").touch()

    (tmpdir / "c").mkdir()
    (tmpdir / "c" / "c.unknown").touch()

    with pytest.raises(FileNotFoundError, match=expected_error_msg):
        make_dataset(str(tmpdir), **kwargs)


263
if __name__ == "__main__":
264
    pytest.main([__file__])