test_datasets_download.py 14.3 KB
Newer Older
1
2
import contextlib
import itertools
3
import tempfile
4
import time
5
import unittest.mock
6
import warnings
7
from datetime import datetime
8
from distutils import dir_util
9
from os import path
10
from urllib.error import HTTPError, URLError
11
from urllib.parse import urlparse
12
from urllib.request import Request, urlopen
13

14
import pytest
15
from torchvision import datasets
16
from torchvision.datasets.utils import (
17
    _get_redirect_url,
18
19
    check_integrity,
    download_file_from_google_drive,
20
    download_url,
21
22
    USER_AGENT,
)
23
24


25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def limit_requests_per_time(min_secs_between_requests=2.0):
    last_requests = {}

    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
            url = request.full_url if isinstance(request, Request) else request

            netloc = urlparse(url).netloc
            last_request = last_requests.get(netloc)
            if last_request is not None:
                elapsed_secs = (datetime.now() - last_request).total_seconds()
                delta = min_secs_between_requests - elapsed_secs
                if delta > 0:
                    time.sleep(delta)

            response = fn(request, *args, **kwargs)
            last_requests[netloc] = datetime.now()

            return response

        return inner_wrapper

    return outer_wrapper


urlopen = limit_requests_per_time()(urlopen)


53
def resolve_redirects(max_hops=3):
Philip Meier's avatar
Philip Meier committed
54
55
    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
56
57
            initial_url = request.full_url if isinstance(request, Request) else request
            url = _get_redirect_url(initial_url, max_hops=max_hops)
Philip Meier's avatar
Philip Meier committed
58

59
60
            if url == initial_url:
                return fn(request, *args, **kwargs)
Philip Meier's avatar
Philip Meier committed
61

62
            warnings.warn(f"The URL {initial_url} ultimately redirects to {url}.")
Philip Meier's avatar
Philip Meier committed
63

64
65
66
67
68
69
70
71
72
73
74
            if not isinstance(request, Request):
                return fn(url, *args, **kwargs)

            request_attrs = {
                attr: getattr(request, attr) for attr in ("data", "headers", "origin_req_host", "unverifiable")
            }
            # the 'method' attribute does only exist if the request was created with it
            if hasattr(request, "method"):
                request_attrs["method"] = request.method

            return fn(Request(url, **request_attrs), *args, **kwargs)
Philip Meier's avatar
Philip Meier committed
75
76
77
78
79
80
81
82
83

        return inner_wrapper

    return outer_wrapper


urlopen = resolve_redirects()(urlopen)


84
@contextlib.contextmanager
85
86
def log_download_attempts(
    urls_and_md5s=None,
87
    file="utils",
88
    patch=True,
89
    mock_auxiliaries=None,
90
):
91
92
93
94
95
96
97
98
99
    def add_mock(stack, name, file, **kwargs):
        try:
            return stack.enter_context(unittest.mock.patch(f"torchvision.datasets.{file}.{name}", **kwargs))
        except AttributeError as error:
            if file != "utils":
                return add_mock(stack, name, "utils", **kwargs)
            else:
                raise pytest.UsageError from error

100
101
    if urls_and_md5s is None:
        urls_and_md5s = set()
102
103
    if mock_auxiliaries is None:
        mock_auxiliaries = patch
104
105

    with contextlib.ExitStack() as stack:
106
107
108
        url_mock = add_mock(stack, "download_url", file, wraps=None if patch else download_url)
        google_drive_mock = add_mock(
            stack, "download_file_from_google_drive", file, wraps=None if patch else download_file_from_google_drive
109
        )
110
111
112
113

        if mock_auxiliaries:
            add_mock(stack, "extract_archive", file)

114
115
116
        try:
            yield urls_and_md5s
        finally:
117
            for args, kwargs in url_mock.call_args_list:
118
119
120
121
                url = args[0]
                md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
                urls_and_md5s.add((url, md5))

122
123
124
125
126
127
            for args, kwargs in google_drive_mock.call_args_list:
                id = args[0]
                url = f"https://drive.google.com/file/d/{id}"
                md5 = args[3] if len(args) == 4 else kwargs.get("md5")
                urls_and_md5s.add((url, md5))

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

def retry(fn, times=1, wait=5.0):
    msgs = []
    for _ in range(times + 1):
        try:
            return fn()
        except AssertionError as error:
            msgs.append(str(error))
            time.sleep(wait)
    else:
        raise AssertionError(
            "\n".join(
                (
                    f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n",
                    *(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)),
143
144
145
146
147
                )
            )
        )


148
149
150
151
@contextlib.contextmanager
def assert_server_response_ok():
    try:
        yield
152
153
    except URLError as error:
        raise AssertionError("The request timed out.") from error
154
155
    except HTTPError as error:
        raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
Philip Meier's avatar
Philip Meier committed
156
157
    except RecursionError as error:
        raise AssertionError(str(error)) from error
158
159


Philip Meier's avatar
Philip Meier committed
160
def assert_url_is_accessible(url, timeout=5.0):
161
    request = Request(url, headers={"User-Agent": USER_AGENT}, method="HEAD")
162
    with assert_server_response_ok():
Philip Meier's avatar
Philip Meier committed
163
        urlopen(request, timeout=timeout)
164
165


166
167
168
169
170
171
172
def assert_file_downloads_correctly(url, md5, tmpdir, timeout=5.0):
    file = path.join(tmpdir, path.basename(url))
    with assert_server_response_ok():
        with open(file, "wb") as fh:
            request = Request(url, headers={"User-Agent": USER_AGENT})
            response = urlopen(request, timeout=timeout)
            fh.write(response.read())
173

174
    assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
175
176
177
178
179
180
181
182


class DownloadConfig:
    def __init__(self, url, md5=None, id=None):
        self.url = url
        self.md5 = md5
        self.id = id or url

Joao Gomes's avatar
Joao Gomes committed
183
    def __repr__(self) -> str:
184
        return self.id
185
186


187
188
189
190
def make_download_configs(urls_and_md5s, name=None):
    return [
        DownloadConfig(url, md5=md5, id=f"{name}, {url}" if name is not None else None) for url, md5 in urls_and_md5s
    ]
191
192


193
194
195
196
197
198
199
200
201
202
203
def collect_download_configs(dataset_loader, name=None, **kwargs):
    urls_and_md5s = set()
    try:
        with log_download_attempts(urls_and_md5s=urls_and_md5s, **kwargs):
            dataset = dataset_loader()
    except Exception:
        dataset = None

    if name is None and dataset is not None:
        name = type(dataset).__name__

204
205
206
    return make_download_configs(urls_and_md5s, name)


207
208
209
210
211
212
213
214
215
216
217
# This is a workaround since fixtures, such as the built-in tmp_dir, can only be used within a test but not within a
# parametrization. Thus, we use a single root directory for all datasets and remove it when all download tests are run.
ROOT = tempfile.mkdtemp()


@pytest.fixture(scope="module", autouse=True)
def root():
    yield ROOT
    dir_util.remove_tree(ROOT)


218
def places365():
219
220
221
222
223
224
225
226
227
228
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.Places365(ROOT, split=split, small=small, download=True),
                name=f"Places365, {split}, {'small' if small else 'large'}",
                file="places365",
            )
            for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True))
        ]
    )
229
230
231


def caltech101():
232
    return collect_download_configs(lambda: datasets.Caltech101(ROOT, download=True), name="Caltech101")
233
234
235


def caltech256():
236
    return collect_download_configs(lambda: datasets.Caltech256(ROOT, download=True), name="Caltech256")
237
238
239


def cifar10():
240
    return collect_download_configs(lambda: datasets.CIFAR10(ROOT, download=True), name="CIFAR10")
241

242

243
def cifar100():
244
    return collect_download_configs(lambda: datasets.CIFAR100(ROOT, download=True), name="CIFAR100")
245
246


247
def voc():
248
249
250
    return itertools.chain(
        *[
            collect_download_configs(
251
                lambda: datasets.VOCSegmentation(ROOT, year=year, download=True),
252
                name=f"VOC, {year}",
253
                file="voc",
254
255
256
257
258
259
260
            )
            for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012")
        ]
    )


def mnist():
261
262
    with unittest.mock.patch.object(datasets.MNIST, "mirrors", datasets.MNIST.mirrors[-1:]):
        return collect_download_configs(lambda: datasets.MNIST(ROOT, download=True), name="MNIST")
263
264


Philip Meier's avatar
Philip Meier committed
265
def fashion_mnist():
266
    return collect_download_configs(lambda: datasets.FashionMNIST(ROOT, download=True), name="FashionMNIST")
Philip Meier's avatar
Philip Meier committed
267
268


269
def kmnist():
270
    return collect_download_configs(lambda: datasets.KMNIST(ROOT, download=True), name="KMNIST")
271
272
273
274


def emnist():
    # the 'split' argument can be any valid one, since everything is downloaded anyway
275
    return collect_download_configs(lambda: datasets.EMNIST(ROOT, split="byclass", download=True), name="EMNIST")
276
277
278
279
280
281


def qmnist():
    return itertools.chain(
        *[
            collect_download_configs(
282
                lambda: datasets.QMNIST(ROOT, what=what, download=True),
283
284
285
286
287
288
289
290
291
292
293
294
                name=f"QMNIST, {what}",
                file="mnist",
            )
            for what in ("train", "test", "nist")
        ]
    )


def omniglot():
    return itertools.chain(
        *[
            collect_download_configs(
295
                lambda: datasets.Omniglot(ROOT, background=background, download=True),
296
297
298
299
300
301
302
303
304
305
306
                name=f"Omniglot, {'background' if background else 'evaluation'}",
            )
            for background in (True, False)
        ]
    )


def phototour():
    return itertools.chain(
        *[
            collect_download_configs(
307
                lambda: datasets.PhotoTour(ROOT, name=name, download=True),
308
309
310
311
312
313
314
315
316
317
318
319
                name=f"PhotoTour, {name}",
                file="phototour",
            )
            # The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
            # requests timeout from within CI. They are disabled until this is resolved.
            for name in ("notredame", "yosemite", "liberty")  # "notredame_harris", "yosemite_harris", "liberty_harris"
        ]
    )


def sbdataset():
    return collect_download_configs(
320
        lambda: datasets.SBDataset(ROOT, download=True),
321
322
323
324
325
326
327
        name="SBDataset",
        file="voc",
    )


def sbu():
    return collect_download_configs(
328
        lambda: datasets.SBU(ROOT, download=True),
329
330
331
332
333
334
335
        name="SBU",
        file="sbu",
    )


def semeion():
    return collect_download_configs(
336
        lambda: datasets.SEMEION(ROOT, download=True),
337
338
339
340
341
342
343
        name="SEMEION",
        file="semeion",
    )


def stl10():
    return collect_download_configs(
344
        lambda: datasets.STL10(ROOT, download=True),
345
346
347
348
349
350
351
352
        name="STL10",
    )


def svhn():
    return itertools.chain(
        *[
            collect_download_configs(
353
                lambda: datasets.SVHN(ROOT, split=split, download=True),
354
355
356
357
358
359
360
361
362
363
364
365
                name=f"SVHN, {split}",
                file="svhn",
            )
            for split in ("train", "test", "extra")
        ]
    )


def usps():
    return itertools.chain(
        *[
            collect_download_configs(
366
                lambda: datasets.USPS(ROOT, train=train, download=True),
367
368
369
370
371
372
373
374
375
376
                name=f"USPS, {'train' if train else 'test'}",
                file="usps",
            )
            for train in (True, False)
        ]
    )


def celeba():
    return collect_download_configs(
377
        lambda: datasets.CelebA(ROOT, download=True),
378
379
380
381
382
383
384
        name="CelebA",
        file="celeba",
    )


def widerface():
    return collect_download_configs(
385
        lambda: datasets.WIDERFace(ROOT, download=True),
386
387
388
389
390
        name="WIDERFace",
        file="widerface",
    )


391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
def kinetics():
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.Kinetics(
                    path.join(ROOT, f"Kinetics{num_classes}"),
                    frames_per_clip=1,
                    num_classes=num_classes,
                    split=split,
                    download=True,
                ),
                name=f"Kinetics, {num_classes}, {split}",
                file="kinetics",
            )
            for num_classes, split in itertools.product(("400", "600", "700"), ("train", "val"))
        ]
    )


410
411
412
413
414
415
416
417
418
419
420
421
422
def kitti():
    return itertools.chain(
        *[
            collect_download_configs(
                lambda train=train: datasets.Kitti(ROOT, train=train, download=True),
                name=f"Kitti, {'train' if train else 'test'}",
                file="kitti",
            )
            for train in (True, False)
        ]
    )


423
424
425
426
427
428
429
430
def make_parametrize_kwargs(download_configs):
    argvalues = []
    ids = []
    for config in download_configs:
        argvalues.append((config.url, config.md5))
        ids.append(config.id)

    return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids)
431
432


433
@pytest.mark.parametrize(
434
435
436
437
438
439
    **make_parametrize_kwargs(
        itertools.chain(
            caltech101(),
            caltech256(),
            cifar10(),
            cifar100(),
440
441
            # The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
            # voc(),
442
            mnist(),
Philip Meier's avatar
Philip Meier committed
443
            fashion_mnist(),
444
445
446
447
448
449
450
451
452
453
454
455
            kmnist(),
            emnist(),
            qmnist(),
            omniglot(),
            phototour(),
            sbdataset(),
            semeion(),
            stl10(),
            svhn(),
            usps(),
            celeba(),
            widerface(),
456
            kinetics(),
457
            kitti(),
458
459
        )
    )
460
)
461
def test_url_is_accessible(url, md5):
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
    """
    If you see this test failing, find the offending dataset in the parametrization and move it to
    ``test_url_is_not_accessible`` and link an issue detailing the problem.
    """
    retry(lambda: assert_url_is_accessible(url))


@pytest.mark.parametrize(
    **make_parametrize_kwargs(
        itertools.chain(
            places365(),  # https://github.com/pytorch/vision/issues/6268
            sbu(),  # https://github.com/pytorch/vision/issues/6390
        )
    )
)
@pytest.mark.xfail
def test_url_is_not_accessible(url, md5):
    """
    As the name implies, this test is the 'inverse' of ``test_url_is_accessible``. Since the download servers are
    beyond our control, some files might not be accessible for longer stretches of time. Still, we want to know if they
    come back up, or if we need to remove the download functionality of the dataset for good.

    If you see this test failing, find the offending dataset in the parametrization and move it to
    ``test_url_is_accessible``.
    """
487
    retry(lambda: assert_url_is_accessible(url))
488
489


490
491
492
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain()))
def test_file_downloads_correctly(url, md5):
    retry(lambda: assert_file_downloads_correctly(url, md5))