"vscode:/vscode.git/clone" did not exist on "61ea57c5a712b862a88c387892b5a25dfc504b4a"
test_datasets_utils.py 8.41 KB
Newer Older
1
2
import contextlib
import gzip
3
import os
4
5
import pathlib
import re
6
import tarfile
7
import zipfile
Francisco Massa's avatar
Francisco Massa committed
8

9
10
11
import pytest
import torchvision.datasets.utils as utils
from torch._utils_internal import get_file_path_2
12
from torchvision.datasets.folder import make_dataset
13
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
14
15

TEST_FILE = get_file_path_2(
16
17
    os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg"
)
18

Francisco Massa's avatar
Francisco Massa committed
19

20
21
22
23
24
25
26
27
28
29
30
31
def patch_url_redirection(mocker, redirect_url):
    class Response:
        def __init__(self, url):
            self.url = url

    @contextlib.contextmanager
    def patched_opener(*args, **kwargs):
        yield Response(redirect_url)

    return mocker.patch("torchvision.datasets.utils.urllib.request.urlopen", side_effect=patched_opener)


32
class TestDatasetsUtils:
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    def test_get_redirect_url(self, mocker):
        url = "https://url.org"
        expected_redirect_url = "https://redirect.url.org"

        mock = patch_url_redirection(mocker, expected_redirect_url)

        actual = utils._get_redirect_url(url)
        assert actual == expected_redirect_url

        assert mock.call_count == 2
        call_args_1, call_args_2 = mock.call_args_list
        assert call_args_1[0][0].full_url == url
        assert call_args_2[0][0].full_url == expected_redirect_url

    def test_get_redirect_url_max_hops_exceeded(self, mocker):
        url = "https://url.org"
        redirect_url = "https://redirect.url.org"

        mock = patch_url_redirection(mocker, redirect_url)

        with pytest.raises(RecursionError):
            utils._get_redirect_url(url, max_hops=0)

        assert mock.call_count == 1
        assert mock.call_args[0][0].full_url == url
Francisco Massa's avatar
Francisco Massa committed
58

59
60
    def test_check_md5(self):
        fpath = TEST_FILE
61
62
        correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
        false_md5 = ""
63
64
        assert utils.check_md5(fpath, correct_md5)
        assert not utils.check_md5(fpath, false_md5)
65
66
67

    def test_check_integrity(self):
        existing_fpath = TEST_FILE
68
69
70
        nonexisting_fpath = ""
        correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc"
        false_md5 = ""
71
72
73
74
        assert utils.check_integrity(existing_fpath, correct_md5)
        assert not utils.check_integrity(existing_fpath, false_md5)
        assert utils.check_integrity(existing_fpath)
        assert not utils.check_integrity(nonexisting_fpath)
75

76
    def test_get_google_drive_file_id(self):
77
78
        url = "https://drive.google.com/file/d/1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV/view"
        expected = "1GO-BHUYRuvzr1Gtp2_fqXRsr9TIeYbhV"
79
80
81
82
83
84
85
86
87

        actual = utils._get_google_drive_file_id(url)
        assert actual == expected

    def test_get_google_drive_file_id_invalid_url(self):
        url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"

        assert utils._get_google_drive_file_id(url) is None

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    @pytest.mark.parametrize(
        "file, expected",
        [
            ("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")),
            ("foo.tar.xz", (".tar.xz", ".tar", ".xz")),
            ("foo.tar", (".tar", ".tar", None)),
            ("foo.tar.gz", (".tar.gz", ".tar", ".gz")),
            ("foo.tbz", (".tbz", ".tar", ".bz2")),
            ("foo.tbz2", (".tbz2", ".tar", ".bz2")),
            ("foo.tgz", (".tgz", ".tar", ".gz")),
            ("foo.bz2", (".bz2", None, ".bz2")),
            ("foo.gz", (".gz", None, ".gz")),
            ("foo.zip", (".zip", ".zip", None)),
            ("foo.xz", (".xz", None, ".xz")),
            ("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")),
            ("foo.bar.gz", (".gz", None, ".gz")),
            ("foo.bar.zip", (".zip", ".zip", None)),
        ],
    )
107
108
109
    def test_detect_file_type(self, file, expected):
        assert utils._detect_file_type(file) == expected

110
    @pytest.mark.parametrize("file", ["foo", "foo.tar.baz", "foo.bar"])
111
112
113
114
115
    def test_detect_file_type_incompatible(self, file):
        # tests detect file type for no extension, unknown compression and unknown partial extension
        with pytest.raises(RuntimeError):
            utils._detect_file_type(file)

116
    @pytest.mark.parametrize("extension", [".bz2", ".gz", ".xz"])
117
    def test_decompress(self, extension, tmpdir):
118
119
        def create_compressed(root, content="this is the content"):
            file = os.path.join(root, "file")
120
121
            compressed = f"{file}{extension}"
            compressed_file_opener = _COMPRESSED_FILE_OPENERS[extension]
122

123
            with compressed_file_opener(compressed, "wb") as fh:
124
125
126
127
                fh.write(content.encode())

            return compressed, file, content

128
        compressed, file, content = create_compressed(tmpdir)
129

130
        utils._decompress(compressed)
131

132
        assert os.path.exists(file)
133

134
        with open(file) as fh:
135
            assert fh.read() == content
136
137

    def test_decompress_no_compression(self):
138
        with pytest.raises(RuntimeError):
139
140
            utils._decompress("foo.tar")

141
    def test_decompress_remove_finished(self, tmpdir):
142
143
144
145
146
147
148
149
150
        def create_compressed(root, content="this is the content"):
            file = os.path.join(root, "file")
            compressed = f"{file}.gz"

            with gzip.open(compressed, "wb") as fh:
                fh.write(content.encode())

            return compressed, file, content

151
        compressed, file, content = create_compressed(tmpdir)
152

153
        utils.extract_archive(compressed, tmpdir, remove_finished=True)
154

155
        assert not os.path.exists(compressed)
156

157
158
    @pytest.mark.parametrize("extension", [".gz", ".xz"])
    @pytest.mark.parametrize("remove_finished", [True, False])
159
    def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker):
160
        filename = "foo"
161
162
163
164
165
166
        file = f"{filename}{extension}"

        mocked = mocker.patch("torchvision.datasets.utils._decompress")
        utils.extract_archive(file, remove_finished=remove_finished)

        mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)
167

168
    def test_extract_zip(self, tmpdir):
169
170
171
172
173
174
175
176
177
        def create_archive(root, content="this is the content"):
            file = os.path.join(root, "dst.txt")
            archive = os.path.join(root, "archive.zip")

            with zipfile.ZipFile(archive, "w") as zf:
                zf.writestr(os.path.basename(file), content)

            return archive, file, content

178
        archive, file, content = create_archive(tmpdir)
179

180
        utils.extract_archive(archive, tmpdir)
181

182
        assert os.path.exists(file)
183

184
        with open(file) as fh:
185
            assert fh.read() == content
186

187
188
189
    @pytest.mark.parametrize(
        "extension, mode", [(".tar", "w"), (".tar.gz", "w:gz"), (".tgz", "w:gz"), (".tar.xz", "w:xz")]
    )
190
    def test_extract_tar(self, extension, mode, tmpdir):
191
        def create_archive(root, extension, mode, content="this is the content"):
192
193
            src = os.path.join(root, "src.txt")
            dst = os.path.join(root, "dst.txt")
194
            archive = os.path.join(root, f"archive{extension}")
195
196
197
198
199
200
201
202
203

            with open(src, "w") as fh:
                fh.write(content)

            with tarfile.open(archive, mode=mode) as fh:
                fh.add(src, arcname=os.path.basename(dst))

            return archive, dst, content

204
        archive, file, content = create_archive(tmpdir, extension, mode)
205

206
        utils.extract_archive(archive, tmpdir)
207

208
        assert os.path.exists(file)
209

210
        with open(file) as fh:
211
            assert fh.read() == content
212

213
    def test_verify_str_arg(self):
214
215
216
        assert "a" == utils.verify_str_arg("a", "arg", ("a",))
        pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
        pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
217

Francisco Massa's avatar
Francisco Massa committed
218

219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
@pytest.mark.parametrize(
    ("kwargs", "expected_error_msg"),
    [
        (dict(is_valid_file=lambda path: pathlib.Path(path).suffix in {".png", ".jpeg"}), "classes c"),
        (dict(extensions=".png"), re.escape("classes b, c. Supported extensions are: .png")),
        (dict(extensions=(".png", ".jpeg")), re.escape("classes c. Supported extensions are: .png, .jpeg")),
    ],
)
def test_make_dataset_no_valid_files(tmpdir, kwargs, expected_error_msg):
    tmpdir = pathlib.Path(tmpdir)

    (tmpdir / "a").mkdir()
    (tmpdir / "a" / "a.png").touch()

    (tmpdir / "b").mkdir()
    (tmpdir / "b" / "b.jpeg").touch()

    (tmpdir / "c").mkdir()
    (tmpdir / "c" / "c.unknown").touch()

    with pytest.raises(FileNotFoundError, match=expected_error_msg):
        make_dataset(str(tmpdir), **kwargs)


243
if __name__ == "__main__":
244
    pytest.main([__file__])