test_datasets_download.py 10.7 KB
Newer Older
1
2
import contextlib
import itertools
3
import time
4
import unittest.mock
5
from datetime import datetime
6
from os import path
7
from urllib.error import HTTPError, URLError
8
from urllib.parse import urlparse
9
10
from urllib.request import urlopen, Request

11
12
import pytest

13
from torchvision import datasets
14
from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive
15
16
17
18
19

from common_utils import get_tmp_dir
from fakedata_generation import places365_root


20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def limit_requests_per_time(min_secs_between_requests=2.0):
    last_requests = {}

    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
            url = request.full_url if isinstance(request, Request) else request

            netloc = urlparse(url).netloc
            last_request = last_requests.get(netloc)
            if last_request is not None:
                elapsed_secs = (datetime.now() - last_request).total_seconds()
                delta = min_secs_between_requests - elapsed_secs
                if delta > 0:
                    time.sleep(delta)

            response = fn(request, *args, **kwargs)
            last_requests[netloc] = datetime.now()

            return response

        return inner_wrapper

    return outer_wrapper


urlopen = limit_requests_per_time()(urlopen)


48
@contextlib.contextmanager
49
50
def log_download_attempts(
    urls_and_md5s=None,
51
    file="utils",
52
    patch=True,
53
    mock_auxiliaries=None,
54
):
55
56
57
58
59
60
61
62
63
    def add_mock(stack, name, file, **kwargs):
        try:
            return stack.enter_context(unittest.mock.patch(f"torchvision.datasets.{file}.{name}", **kwargs))
        except AttributeError as error:
            if file != "utils":
                return add_mock(stack, name, "utils", **kwargs)
            else:
                raise pytest.UsageError from error

64
65
    if urls_and_md5s is None:
        urls_and_md5s = set()
66
67
    if mock_auxiliaries is None:
        mock_auxiliaries = patch
68
69

    with contextlib.ExitStack() as stack:
70
71
72
        url_mock = add_mock(stack, "download_url", file, wraps=None if patch else download_url)
        google_drive_mock = add_mock(
            stack, "download_file_from_google_drive", file, wraps=None if patch else download_file_from_google_drive
73
        )
74
75
76
77

        if mock_auxiliaries:
            add_mock(stack, "extract_archive", file)

78
79
80
        try:
            yield urls_and_md5s
        finally:
81
            for args, kwargs in url_mock.call_args_list:
82
83
84
85
                url = args[0]
                md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
                urls_and_md5s.add((url, md5))

86
87
88
89
90
91
            for args, kwargs in google_drive_mock.call_args_list:
                id = args[0]
                url = f"https://drive.google.com/file/d/{id}"
                md5 = args[3] if len(args) == 4 else kwargs.get("md5")
                urls_and_md5s.add((url, md5))

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

def retry(fn, times=1, wait=5.0):
    msgs = []
    for _ in range(times + 1):
        try:
            return fn()
        except AssertionError as error:
            msgs.append(str(error))
            time.sleep(wait)
    else:
        raise AssertionError(
            "\n".join(
                (
                    f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n",
                    *(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)),
107
108
109
110
111
                )
            )
        )


112
113
114
115
@contextlib.contextmanager
def assert_server_response_ok():
    try:
        yield
116
117
    except URLError as error:
        raise AssertionError("The request timed out.") from error
118
119
    except HTTPError as error:
        raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
120
121
122
123


def assert_url_is_accessible(url):
    request = Request(url, headers=dict(method="HEAD"))
124
    with assert_server_response_ok():
125
        urlopen(request, timeout=5.0)
126
127
128
129
130


def assert_file_downloads_correctly(url, md5):
    with get_tmp_dir() as root:
        file = path.join(root, path.basename(url))
131
        with assert_server_response_ok():
132
            with urlopen(url, timeout=5.0) as response, open(file, "wb") as fh:
133
                fh.write(response.read())
134
135
136
137
138
139
140
141
142
143

        assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"


class DownloadConfig:
    def __init__(self, url, md5=None, id=None):
        self.url = url
        self.md5 = md5
        self.id = id or url

144
145
    def __repr__(self):
        return self.id
146
147


148
149
150
151
def make_download_configs(urls_and_md5s, name=None):
    return [
        DownloadConfig(url, md5=md5, id=f"{name}, {url}" if name is not None else None) for url, md5 in urls_and_md5s
    ]
152
153


154
155
156
157
158
159
160
161
162
163
164
def collect_download_configs(dataset_loader, name=None, **kwargs):
    urls_and_md5s = set()
    try:
        with log_download_attempts(urls_and_md5s=urls_and_md5s, **kwargs):
            dataset = dataset_loader()
    except Exception:
        dataset = None

    if name is None and dataset is not None:
        name = type(dataset).__name__

165
166
167
    return make_download_configs(urls_and_md5s, name)


168
169
170
171
172
173
174
175
def places365():
    with log_download_attempts(patch=False) as urls_and_md5s:
        for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
            with places365_root(split=split, small=small) as places365:
                root, data = places365

                datasets.Places365(root, split=split, small=small, download=True)

176
    return make_download_configs(urls_and_md5s, name="Places365")
177
178
179


def caltech101():
180
    return collect_download_configs(lambda: datasets.Caltech101(".", download=True), name="Caltech101")
181
182
183


def caltech256():
184
    return collect_download_configs(lambda: datasets.Caltech256(".", download=True), name="Caltech256")
185
186
187


def cifar10():
188
    return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR10")
189

190

191
def cifar100():
192
    return collect_download_configs(lambda: datasets.CIFAR100(".", download=True), name="CIFAR100")
193
194


195
def voc():
196
197
198
199
200
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.VOCSegmentation(".", year=year, download=True),
                name=f"VOC, {year}",
201
                file="voc",
202
203
204
205
206
207
208
209
            )
            for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012")
        ]
    )


def mnist():
    return collect_download_configs(lambda: datasets.MNIST(".", download=True), name="MNIST")
210
211


Philip Meier's avatar
Philip Meier committed
212
213
214
215
def fashion_mnist():
    return collect_download_configs(lambda: datasets.FashionMNIST(".", download=True), name="FashionMNIST")


216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
def kmnist():
    return collect_download_configs(lambda: datasets.KMNIST(".", download=True), name="KMNIST")


def emnist():
    # the 'split' argument can be any valid one, since everything is downloaded anyway
    return collect_download_configs(lambda: datasets.EMNIST(".", split="byclass", download=True), name="EMNIST")


def qmnist():
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.QMNIST(".", what=what, download=True),
                name=f"QMNIST, {what}",
                file="mnist",
            )
            for what in ("train", "test", "nist")
        ]
    )


def omniglot():
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.Omniglot(".", background=background, download=True),
                name=f"Omniglot, {'background' if background else 'evaluation'}",
            )
            for background in (True, False)
        ]
    )


def phototour():
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.PhotoTour(".", name=name, download=True),
                name=f"PhotoTour, {name}",
                file="phototour",
            )
            # The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
            # requests timeout from within CI. They are disabled until this is resolved.
            for name in ("notredame", "yosemite", "liberty")  # "notredame_harris", "yosemite_harris", "liberty_harris"
        ]
    )


def sbdataset():
    return collect_download_configs(
        lambda: datasets.SBDataset(".", download=True),
        name="SBDataset",
        file="voc",
    )


def sbu():
    return collect_download_configs(
        lambda: datasets.SBU(".", download=True),
        name="SBU",
        file="sbu",
    )


def semeion():
    return collect_download_configs(
        lambda: datasets.SEMEION(".", download=True),
        name="SEMEION",
        file="semeion",
    )


def stl10():
    return collect_download_configs(
        lambda: datasets.STL10(".", download=True),
        name="STL10",
    )


def svhn():
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.SVHN(".", split=split, download=True),
                name=f"SVHN, {split}",
                file="svhn",
            )
            for split in ("train", "test", "extra")
        ]
    )


def usps():
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.USPS(".", train=train, download=True),
                name=f"USPS, {'train' if train else 'test'}",
                file="usps",
            )
            for train in (True, False)
        ]
    )


def celeba():
    return collect_download_configs(
        lambda: datasets.CelebA(".", download=True),
        name="CelebA",
        file="celeba",
    )


def widerface():
    return collect_download_configs(
        lambda: datasets.WIDERFace(".", download=True),
        name="WIDERFace",
        file="widerface",
    )


338
339
340
341
342
343
344
345
def make_parametrize_kwargs(download_configs):
    argvalues = []
    ids = []
    for config in download_configs:
        argvalues.append((config.url, config.md5))
        ids.append(config.id)

    return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids)
346
347


348
@pytest.mark.parametrize(
349
350
351
352
353
354
355
    **make_parametrize_kwargs(
        itertools.chain(
            places365(),
            caltech101(),
            caltech256(),
            cifar10(),
            cifar100(),
356
357
            # The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
            # voc(),
358
            mnist(),
Philip Meier's avatar
Philip Meier committed
359
            fashion_mnist(),
360
361
362
363
364
365
366
367
368
369
370
371
372
            kmnist(),
            emnist(),
            qmnist(),
            omniglot(),
            phototour(),
            sbdataset(),
            sbu(),
            semeion(),
            stl10(),
            svhn(),
            usps(),
            celeba(),
            widerface(),
373
374
        )
    )
375
)
376
377
def test_url_is_accessible(url, md5):
    retry(lambda: assert_url_is_accessible(url))
378
379


380
381
382
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain()))
def test_file_downloads_correctly(url, md5):
    retry(lambda: assert_file_downloads_correctly(url, md5))