test_datasets_utils.py 7.3 KB
Newer Older
1
import os
2
import sys
Francisco Massa's avatar
Francisco Massa committed
3
4
5
import tempfile
import torchvision.datasets.utils as utils
import unittest
6
import unittest.mock
7
8
9
import zipfile
import tarfile
import gzip
10
11
import warnings
from torch._utils_internal import get_file_path_2
12
from urllib.error import URLError
Francisco Massa's avatar
Francisco Massa committed
13

14
15
16
17
from common_utils import get_tmp_dir


TEST_FILE = get_file_path_2(
18
    os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')
19

Francisco Massa's avatar
Francisco Massa committed
20
21
22

class Tester(unittest.TestCase):

23
24
25
26
    def test_check_md5(self):
        fpath = TEST_FILE
        correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
        false_md5 = ''
27
28
        self.assertTrue(utils.check_md5(fpath, correct_md5))
        self.assertFalse(utils.check_md5(fpath, false_md5))
29
30
31
32
33
34

    def test_check_integrity(self):
        existing_fpath = TEST_FILE
        nonexisting_fpath = ''
        correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
        false_md5 = ''
35
36
37
38
        self.assertTrue(utils.check_integrity(existing_fpath, correct_md5))
        self.assertFalse(utils.check_integrity(existing_fpath, false_md5))
        self.assertTrue(utils.check_integrity(existing_fpath))
        self.assertFalse(utils.check_integrity(nonexisting_fpath))
39

40
41
42
43
44
45
46
47
48
49
50
51
    def test_get_redirect_url(self):
        url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"
        expected = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"

        actual = utils._get_redirect_url(url)
        assert actual == expected

    def test_get_redirect_url_max_hops_exceeded(self):
        url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"
        with self.assertRaises(RecursionError):
            utils._get_redirect_url(url, max_hops=0)

52
53
54
55
56
57
58
59
60
61
62
63
    def test_get_google_drive_file_id(self):
        url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"
        expected = "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45"

        actual = utils._get_google_drive_file_id(url)
        assert actual == expected

    def test_get_google_drive_file_id_invalid_url(self):
        url = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"

        assert utils._get_google_drive_file_id(url) is None

Francisco Massa's avatar
Francisco Massa committed
64
    def test_download_url(self):
65
66
67
68
69
70
71
72
73
        with get_tmp_dir() as temp_dir:
            url = "http://github.com/pytorch/vision/archive/master.zip"
            try:
                utils.download_url(url, temp_dir)
                self.assertFalse(len(os.listdir(temp_dir)) == 0)
            except URLError:
                msg = "could not download test file '{}'".format(url)
                warnings.warn(msg, RuntimeWarning)
                raise unittest.SkipTest(msg)
Francisco Massa's avatar
Francisco Massa committed
74
75

    def test_download_url_retry_http(self):
76
77
78
79
80
81
82
83
84
        with get_tmp_dir() as temp_dir:
            url = "https://github.com/pytorch/vision/archive/master.zip"
            try:
                utils.download_url(url, temp_dir)
                self.assertFalse(len(os.listdir(temp_dir)) == 0)
            except URLError:
                msg = "could not download test file '{}'".format(url)
                warnings.warn(msg, RuntimeWarning)
                raise unittest.SkipTest(msg)
Francisco Massa's avatar
Francisco Massa committed
85

86
87
88
89
90
91
    def test_download_url_dont_exist(self):
        with get_tmp_dir() as temp_dir:
            url = "http://github.com/pytorch/vision/archive/this_doesnt_exist.zip"
            with self.assertRaises(URLError):
                utils.download_url(url, temp_dir)

92
93
94
95
96
97
98
99
100
101
102
103
104
    @unittest.mock.patch("torchvision.datasets.utils.download_file_from_google_drive")
    def test_download_url_dispatch_download_from_google_drive(self, mock):
        url = "https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view"

        id = "1hbzc_P1FuxMkcabkgn9ZKinBwW683j45"
        filename = "filename"
        md5 = "md5"

        with get_tmp_dir() as root:
            utils.download_url(url, root, filename, md5)

        mock.assert_called_once_with(id, root, filename, md5)

Francisco Massa's avatar
Francisco Massa committed
105
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
106
    def test_extract_zip(self):
107
108
109
110
111
112
113
114
115
        with get_tmp_dir() as temp_dir:
            with tempfile.NamedTemporaryFile(suffix='.zip') as f:
                with zipfile.ZipFile(f, 'w') as zf:
                    zf.writestr('file.tst', 'this is the content')
                utils.extract_archive(f.name, temp_dir)
                self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
                with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
                    data = nf.read()
                self.assertEqual(data, 'this is the content')
116

Francisco Massa's avatar
Francisco Massa committed
117
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
118
    def test_extract_tar(self):
119
        for ext, mode in zip(['.tar', '.tar.gz', '.tgz'], ['w', 'w:gz', 'w:gz']):
120
121
122
123
124
125
126
127
128
129
130
131
            with get_tmp_dir() as temp_dir:
                with tempfile.NamedTemporaryFile() as bf:
                    bf.write("this is the content".encode())
                    bf.seek(0)
                    with tempfile.NamedTemporaryFile(suffix=ext) as f:
                        with tarfile.open(f.name, mode=mode) as zf:
                            zf.add(bf.name, arcname='file.tst')
                        utils.extract_archive(f.name, temp_dir)
                        self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
                        with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
                            data = nf.read()
                        self.assertEqual(data, 'this is the content')
Ardalan's avatar
Ardalan committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147

    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
    def test_extract_tar_xz(self):
        for ext, mode in zip(['.tar.xz'], ['w:xz']):
            with get_tmp_dir() as temp_dir:
                with tempfile.NamedTemporaryFile() as bf:
                    bf.write("this is the content".encode())
                    bf.seek(0)
                    with tempfile.NamedTemporaryFile(suffix=ext) as f:
                        with tarfile.open(f.name, mode=mode) as zf:
                            zf.add(bf.name, arcname='file.tst')
                        utils.extract_archive(f.name, temp_dir)
                        self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
                        with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
                            data = nf.read()
                        self.assertEqual(data, 'this is the content')
148

Francisco Massa's avatar
Francisco Massa committed
149
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
150
    def test_extract_gzip(self):
151
152
153
154
155
156
157
158
159
160
        with get_tmp_dir() as temp_dir:
            with tempfile.NamedTemporaryFile(suffix='.gz') as f:
                with gzip.GzipFile(f.name, 'wb') as zf:
                    zf.write('this is the content'.encode())
                utils.extract_archive(f.name, temp_dir)
                f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0])
                self.assertTrue(os.path.exists(f_name))
                with open(os.path.join(f_name), 'r') as nf:
                    data = nf.read()
                self.assertEqual(data, 'this is the content')
161

162
163
164
165
166
    def test_verify_str_arg(self):
        self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",)))
        self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
        self.assertRaises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")

Francisco Massa's avatar
Francisco Massa committed
167
168
169

if __name__ == '__main__':
    unittest.main()