test_datasets_utils.py 6.56 KB
Newer Older
1
import bz2
2
import os
Francisco Massa's avatar
Francisco Massa committed
3
import torchvision.datasets.utils as utils
4
import pytest
5
6
7
import zipfile
import tarfile
import gzip
8
9
import warnings
from torch._utils_internal import get_file_path_2
10
from urllib.error import URLError
11
12
import itertools
import lzma
Francisco Massa's avatar
Francisco Massa committed
13

14
15
from common_utils import get_tmp_dir
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
16
17
18


TEST_FILE = get_file_path_2(
19
    os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')
20

Francisco Massa's avatar
Francisco Massa committed
21

22
class TestDatasetsUtils:
Francisco Massa's avatar
Francisco Massa committed
23

24
25
26
27
    def test_check_md5(self):
        fpath = TEST_FILE
        correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
        false_md5 = ''
28
29
        assert utils.check_md5(fpath, correct_md5)
        assert not utils.check_md5(fpath, false_md5)
30
31
32
33
34
35

    def test_check_integrity(self):
        existing_fpath = TEST_FILE
        nonexisting_fpath = ''
        correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
        false_md5 = ''
36
37
38
39
        assert utils.check_integrity(existing_fpath, correct_md5)
        assert not utils.check_integrity(existing_fpath, false_md5)
        assert utils.check_integrity(existing_fpath)
        assert not utils.check_integrity(nonexisting_fpath)
40

41
42
43
44
45
46
47
48
49
50
51
52
    def test_get_google_drive_file_id(self):
        url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"
        expected = "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45"

        actual = utils._get_google_drive_file_id(url)
        assert actual == expected

    def test_get_google_drive_file_id_invalid_url(self):
        url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"

        assert utils._get_google_drive_file_id(url) is None

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    @pytest.mark.parametrize('file, expected', [
        ("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")),
        ("foo.tar.xz", (".tar.xz", ".tar", ".xz")),
        ("foo.tar", (".tar", ".tar", None)),
        ("foo.tar.gz", (".tar.gz", ".tar", ".gz")),
        ("foo.tbz", (".tbz", ".tar", ".bz2")),
        ("foo.tbz2", (".tbz2", ".tar", ".bz2")),
        ("foo.tgz", (".tgz", ".tar", ".gz")),
        ("foo.bz2", (".bz2", None, ".bz2")),
        ("foo.gz", (".gz", None, ".gz")),
        ("foo.zip", (".zip", ".zip", None)),
        ("foo.xz", (".xz", None, ".xz")),
        ("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")),
        ("foo.bar.gz", (".gz", None, ".gz")),
        ("foo.bar.zip", (".zip", ".zip", None))])
    def test_detect_file_type(self, file, expected):
        assert utils._detect_file_type(file) == expected

    @pytest.mark.parametrize('file', ["foo", "foo.tar.baz", "foo.bar"])
    def test_detect_file_type_incompatible(self, file):
        # tests detect file type for no extension, unknown compression and unknown partial extension
        with pytest.raises(RuntimeError):
            utils._detect_file_type(file)

    @pytest.mark.parametrize('extension', [".bz2", ".gz", ".xz"])
    def test_decompress(self, extension):
79
80
        def create_compressed(root, content="this is the content"):
            file = os.path.join(root, "file")
81
82
            compressed = f"{file}{extension}"
            compressed_file_opener = _COMPRESSED_FILE_OPENERS[extension]
83

84
            with compressed_file_opener(compressed, "wb") as fh:
85
86
87
88
89
90
91
92
93
                fh.write(content.encode())

            return compressed, file, content

        with get_tmp_dir() as temp_dir:
            compressed, file, content = create_compressed(temp_dir)

            utils._decompress(compressed)

94
            assert os.path.exists(file)
95
96

            with open(file, "r") as fh:
97
                assert fh.read() == content
98
99

    def test_decompress_no_compression(self):
100
        with pytest.raises(RuntimeError):
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
            utils._decompress("foo.tar")

    def test_decompress_remove_finished(self):
        def create_compressed(root, content="this is the content"):
            file = os.path.join(root, "file")
            compressed = f"{file}.gz"

            with gzip.open(compressed, "wb") as fh:
                fh.write(content.encode())

            return compressed, file, content

        with get_tmp_dir() as temp_dir:
            compressed, file, content = create_compressed(temp_dir)

            utils.extract_archive(compressed, temp_dir, remove_finished=True)

118
            assert not os.path.exists(compressed)
119

120
121
122
    @pytest.mark.parametrize('extension', [".gz", ".xz"])
    @pytest.mark.parametrize('remove_finished', [True, False])
    def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker):
123
        filename = "foo"
124
125
126
127
128
129
        file = f"{filename}{extension}"

        mocked = mocker.patch("torchvision.datasets.utils._decompress")
        utils.extract_archive(file, remove_finished=remove_finished)

        mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)
130

131
    def test_extract_zip(self):
132
133
134
135
136
137
138
139
140
        def create_archive(root, content="this is the content"):
            file = os.path.join(root, "dst.txt")
            archive = os.path.join(root, "archive.zip")

            with zipfile.ZipFile(archive, "w") as zf:
                zf.writestr(os.path.basename(file), content)

            return archive, file, content

141
        with get_tmp_dir() as temp_dir:
142
143
144
145
            archive, file, content = create_archive(temp_dir)

            utils.extract_archive(archive, temp_dir)

146
            assert os.path.exists(file)
147
148

            with open(file, "r") as fh:
149
                assert fh.read() == content
150

151
152
153
154
    @pytest.mark.parametrize('extension, mode', [
        ('.tar', 'w'), ('.tar.gz', 'w:gz'), ('.tgz', 'w:gz'), ('.tar.xz', 'w:xz')])
    def test_extract_tar(self, extension, mode):
        def create_archive(root, extension, mode, content="this is the content"):
155
156
            src = os.path.join(root, "src.txt")
            dst = os.path.join(root, "dst.txt")
157
            archive = os.path.join(root, f"archive{extension}")
158
159
160
161
162
163
164
165
166

            with open(src, "w") as fh:
                fh.write(content)

            with tarfile.open(archive, mode=mode) as fh:
                fh.add(src, arcname=os.path.basename(dst))

            return archive, dst, content

167
168
        with get_tmp_dir() as temp_dir:
            archive, file, content = create_archive(temp_dir, extension, mode)
169

170
            utils.extract_archive(archive, temp_dir)
171

172
            assert os.path.exists(file)
173

174
175
            with open(file, "r") as fh:
                assert fh.read() == content
176

177
    def test_verify_str_arg(self):
178
179
180
        assert "a" == utils.verify_str_arg("a", "arg", ("a",))
        pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
        pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
181

Francisco Massa's avatar
Francisco Massa committed
182
183

if __name__ == '__main__':
184
    pytest.main([__file__])