test_datasets_download.py 14.7 KB
Newer Older
1
2
import contextlib
import itertools
3
import tempfile
4
import time
5
import traceback
6
import unittest.mock
7
import warnings
8
from datetime import datetime
9
from distutils import dir_util
10
from os import path
11
from urllib.error import HTTPError, URLError
12
from urllib.parse import urlparse
13
from urllib.request import Request, urlopen
14

15
import pytest
16
from torchvision import datasets
17
from torchvision.datasets.utils import (
18
    _get_redirect_url,
19
20
    check_integrity,
    download_file_from_google_drive,
21
    download_url,
22
23
    USER_AGENT,
)
24
25


26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def limit_requests_per_time(min_secs_between_requests=2.0):
    last_requests = {}

    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
            url = request.full_url if isinstance(request, Request) else request

            netloc = urlparse(url).netloc
            last_request = last_requests.get(netloc)
            if last_request is not None:
                elapsed_secs = (datetime.now() - last_request).total_seconds()
                delta = min_secs_between_requests - elapsed_secs
                if delta > 0:
                    time.sleep(delta)

            response = fn(request, *args, **kwargs)
            last_requests[netloc] = datetime.now()

            return response

        return inner_wrapper

    return outer_wrapper


urlopen = limit_requests_per_time()(urlopen)


54
def resolve_redirects(max_hops=3):
Philip Meier's avatar
Philip Meier committed
55
56
    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
57
58
            initial_url = request.full_url if isinstance(request, Request) else request
            url = _get_redirect_url(initial_url, max_hops=max_hops)
Philip Meier's avatar
Philip Meier committed
59

60
61
            if url == initial_url:
                return fn(request, *args, **kwargs)
Philip Meier's avatar
Philip Meier committed
62

63
            warnings.warn(f"The URL {initial_url} ultimately redirects to {url}.")
Philip Meier's avatar
Philip Meier committed
64

65
66
67
68
69
70
71
72
73
74
75
            if not isinstance(request, Request):
                return fn(url, *args, **kwargs)

            request_attrs = {
                attr: getattr(request, attr) for attr in ("data", "headers", "origin_req_host", "unverifiable")
            }
            # the 'method' attribute does only exist if the request was created with it
            if hasattr(request, "method"):
                request_attrs["method"] = request.method

            return fn(Request(url, **request_attrs), *args, **kwargs)
Philip Meier's avatar
Philip Meier committed
76
77
78
79
80
81
82
83
84

        return inner_wrapper

    return outer_wrapper


urlopen = resolve_redirects()(urlopen)


85
@contextlib.contextmanager
86
87
def log_download_attempts(
    urls_and_md5s=None,
88
    file="utils",
89
    patch=True,
90
    mock_auxiliaries=None,
91
):
92
93
94
95
96
97
98
99
100
    def add_mock(stack, name, file, **kwargs):
        try:
            return stack.enter_context(unittest.mock.patch(f"torchvision.datasets.{file}.{name}", **kwargs))
        except AttributeError as error:
            if file != "utils":
                return add_mock(stack, name, "utils", **kwargs)
            else:
                raise pytest.UsageError from error

101
102
    if urls_and_md5s is None:
        urls_and_md5s = set()
103
104
    if mock_auxiliaries is None:
        mock_auxiliaries = patch
105
106

    with contextlib.ExitStack() as stack:
107
108
109
        url_mock = add_mock(stack, "download_url", file, wraps=None if patch else download_url)
        google_drive_mock = add_mock(
            stack, "download_file_from_google_drive", file, wraps=None if patch else download_file_from_google_drive
110
        )
111
112
113
114

        if mock_auxiliaries:
            add_mock(stack, "extract_archive", file)

115
116
117
        try:
            yield urls_and_md5s
        finally:
118
            for args, kwargs in url_mock.call_args_list:
119
120
121
122
                url = args[0]
                md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
                urls_and_md5s.add((url, md5))

123
124
125
126
127
128
            for args, kwargs in google_drive_mock.call_args_list:
                id = args[0]
                url = f"https://drive.google.com/file/d/{id}"
                md5 = args[3] if len(args) == 4 else kwargs.get("md5")
                urls_and_md5s.add((url, md5))

129
130

def retry(fn, times=1, wait=5.0):
131
    tbs = []
132
133
134
135
    for _ in range(times + 1):
        try:
            return fn()
        except AssertionError as error:
136
            tbs.append("".join(traceback.format_exception(type(error), error, error.__traceback__)))
137
138
139
140
141
            time.sleep(wait)
    else:
        raise AssertionError(
            "\n".join(
                (
142
143
144
145
146
147
                    "\n",
                    *[f"{'_' * 40}  {idx:2d}  {'_' * 40}\n\n{tb}" for idx, tb in enumerate(tbs, 1)],
                    (
                        f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time. "
                        f"You can find the the full tracebacks above."
                    ),
148
149
150
151
152
                )
            )
        )


153
154
155
156
157
158
@contextlib.contextmanager
def assert_server_response_ok():
    try:
        yield
    except HTTPError as error:
        raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
159
160
161
162
    except URLError as error:
        raise AssertionError(
            "Connection not possible due to SSL." if "SSL" in str(error) else "The request timed out."
        ) from error
Philip Meier's avatar
Philip Meier committed
163
164
    except RecursionError as error:
        raise AssertionError(str(error)) from error
165
166


Philip Meier's avatar
Philip Meier committed
167
def assert_url_is_accessible(url, timeout=5.0):
168
    request = Request(url, headers={"User-Agent": USER_AGENT}, method="HEAD")
169
    with assert_server_response_ok():
Philip Meier's avatar
Philip Meier committed
170
        urlopen(request, timeout=timeout)
171
172


173
174
175
176
177
178
179
def assert_file_downloads_correctly(url, md5, tmpdir, timeout=5.0):
    file = path.join(tmpdir, path.basename(url))
    with assert_server_response_ok():
        with open(file, "wb") as fh:
            request = Request(url, headers={"User-Agent": USER_AGENT})
            response = urlopen(request, timeout=timeout)
            fh.write(response.read())
180

181
    assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
182
183
184
185
186
187
188
189


class DownloadConfig:
    def __init__(self, url, md5=None, id=None):
        self.url = url
        self.md5 = md5
        self.id = id or url

Joao Gomes's avatar
Joao Gomes committed
190
    def __repr__(self) -> str:
191
        return self.id
192
193


194
195
196
197
def make_download_configs(urls_and_md5s, name=None):
    return [
        DownloadConfig(url, md5=md5, id=f"{name}, {url}" if name is not None else None) for url, md5 in urls_and_md5s
    ]
198
199


200
201
202
203
204
205
206
207
208
209
210
def collect_download_configs(dataset_loader, name=None, **kwargs):
    urls_and_md5s = set()
    try:
        with log_download_attempts(urls_and_md5s=urls_and_md5s, **kwargs):
            dataset = dataset_loader()
    except Exception:
        dataset = None

    if name is None and dataset is not None:
        name = type(dataset).__name__

211
212
213
    return make_download_configs(urls_and_md5s, name)


214
215
216
217
218
219
220
221
222
223
224
# This is a workaround since fixtures, such as the built-in tmp_dir, can only be used within a test but not within a
# parametrization. Thus, we use a single root directory for all datasets and remove it when all download tests are run.
ROOT = tempfile.mkdtemp()


@pytest.fixture(scope="module", autouse=True)
def root():
    yield ROOT
    dir_util.remove_tree(ROOT)


225
def places365():
226
227
228
229
230
231
232
233
234
235
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.Places365(ROOT, split=split, small=small, download=True),
                name=f"Places365, {split}, {'small' if small else 'large'}",
                file="places365",
            )
            for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True))
        ]
    )
236
237
238


def caltech101():
239
    return collect_download_configs(lambda: datasets.Caltech101(ROOT, download=True), name="Caltech101")
240
241
242


def caltech256():
243
    return collect_download_configs(lambda: datasets.Caltech256(ROOT, download=True), name="Caltech256")
244
245
246


def cifar10():
247
    return collect_download_configs(lambda: datasets.CIFAR10(ROOT, download=True), name="CIFAR10")
248

249

250
def cifar100():
251
    return collect_download_configs(lambda: datasets.CIFAR100(ROOT, download=True), name="CIFAR100")
252
253


254
def voc():
255
    # TODO: Also test the "2007-test" key
256
257
258
    return itertools.chain(
        *[
            collect_download_configs(
259
                lambda: datasets.VOCSegmentation(ROOT, year=year, download=True),
260
                name=f"VOC, {year}",
261
                file="voc",
262
            )
263
            for year in ("2007", "2008", "2009", "2010", "2011", "2012")
264
265
266
267
268
        ]
    )


def mnist():
269
270
    with unittest.mock.patch.object(datasets.MNIST, "mirrors", datasets.MNIST.mirrors[-1:]):
        return collect_download_configs(lambda: datasets.MNIST(ROOT, download=True), name="MNIST")
271
272


Philip Meier's avatar
Philip Meier committed
273
def fashion_mnist():
274
    return collect_download_configs(lambda: datasets.FashionMNIST(ROOT, download=True), name="FashionMNIST")
Philip Meier's avatar
Philip Meier committed
275
276


277
def kmnist():
278
    return collect_download_configs(lambda: datasets.KMNIST(ROOT, download=True), name="KMNIST")
279
280
281
282


def emnist():
    # the 'split' argument can be any valid one, since everything is downloaded anyway
283
    return collect_download_configs(lambda: datasets.EMNIST(ROOT, split="byclass", download=True), name="EMNIST")
284
285
286
287
288
289


def qmnist():
    return itertools.chain(
        *[
            collect_download_configs(
290
                lambda: datasets.QMNIST(ROOT, what=what, download=True),
291
292
293
294
295
296
297
298
299
300
301
302
                name=f"QMNIST, {what}",
                file="mnist",
            )
            for what in ("train", "test", "nist")
        ]
    )


def omniglot():
    return itertools.chain(
        *[
            collect_download_configs(
303
                lambda: datasets.Omniglot(ROOT, background=background, download=True),
304
305
306
307
308
309
310
311
312
313
314
                name=f"Omniglot, {'background' if background else 'evaluation'}",
            )
            for background in (True, False)
        ]
    )


def phototour():
    return itertools.chain(
        *[
            collect_download_configs(
315
                lambda: datasets.PhotoTour(ROOT, name=name, download=True),
316
317
318
319
320
321
322
323
324
325
326
327
                name=f"PhotoTour, {name}",
                file="phototour",
            )
            # The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
            # requests timeout from within CI. They are disabled until this is resolved.
            for name in ("notredame", "yosemite", "liberty")  # "notredame_harris", "yosemite_harris", "liberty_harris"
        ]
    )


def sbdataset():
    return collect_download_configs(
328
        lambda: datasets.SBDataset(ROOT, download=True),
329
330
331
332
333
334
335
        name="SBDataset",
        file="voc",
    )


def sbu():
    return collect_download_configs(
336
        lambda: datasets.SBU(ROOT, download=True),
337
338
339
340
341
342
343
        name="SBU",
        file="sbu",
    )


def semeion():
    return collect_download_configs(
344
        lambda: datasets.SEMEION(ROOT, download=True),
345
346
347
348
349
350
351
        name="SEMEION",
        file="semeion",
    )


def stl10():
    return collect_download_configs(
352
        lambda: datasets.STL10(ROOT, download=True),
353
354
355
356
357
358
359
360
        name="STL10",
    )


def svhn():
    return itertools.chain(
        *[
            collect_download_configs(
361
                lambda: datasets.SVHN(ROOT, split=split, download=True),
362
363
364
365
366
367
368
369
370
371
372
373
                name=f"SVHN, {split}",
                file="svhn",
            )
            for split in ("train", "test", "extra")
        ]
    )


def usps():
    return itertools.chain(
        *[
            collect_download_configs(
374
                lambda: datasets.USPS(ROOT, train=train, download=True),
375
376
377
378
379
380
381
382
383
384
                name=f"USPS, {'train' if train else 'test'}",
                file="usps",
            )
            for train in (True, False)
        ]
    )


def celeba():
    return collect_download_configs(
385
        lambda: datasets.CelebA(ROOT, download=True),
386
387
388
389
390
391
392
        name="CelebA",
        file="celeba",
    )


def widerface():
    return collect_download_configs(
393
        lambda: datasets.WIDERFace(ROOT, download=True),
394
395
396
397
398
        name="WIDERFace",
        file="widerface",
    )


399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
def kinetics():
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.Kinetics(
                    path.join(ROOT, f"Kinetics{num_classes}"),
                    frames_per_clip=1,
                    num_classes=num_classes,
                    split=split,
                    download=True,
                ),
                name=f"Kinetics, {num_classes}, {split}",
                file="kinetics",
            )
            for num_classes, split in itertools.product(("400", "600", "700"), ("train", "val"))
        ]
    )


418
419
420
421
422
423
424
425
426
427
428
429
430
def kitti():
    return itertools.chain(
        *[
            collect_download_configs(
                lambda train=train: datasets.Kitti(ROOT, train=train, download=True),
                name=f"Kitti, {'train' if train else 'test'}",
                file="kitti",
            )
            for train in (True, False)
        ]
    )


431
432
433
434
435
436
437
438
def make_parametrize_kwargs(download_configs):
    argvalues = []
    ids = []
    for config in download_configs:
        argvalues.append((config.url, config.md5))
        ids.append(config.id)

    return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids)
439
440


441
@pytest.mark.parametrize(
442
443
444
445
446
447
    **make_parametrize_kwargs(
        itertools.chain(
            caltech101(),
            caltech256(),
            cifar10(),
            cifar100(),
448
449
            # The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
            # voc(),
450
            mnist(),
Philip Meier's avatar
Philip Meier committed
451
            fashion_mnist(),
452
453
454
455
456
457
458
459
460
461
462
463
            kmnist(),
            emnist(),
            qmnist(),
            omniglot(),
            phototour(),
            sbdataset(),
            semeion(),
            stl10(),
            svhn(),
            usps(),
            celeba(),
            widerface(),
464
            kinetics(),
465
            kitti(),
466
467
        )
    )
468
)
469
def test_url_is_accessible(url, md5):
470
471
472
473
474
475
476
477
478
479
480
    """
    If you see this test failing, find the offending dataset in the parametrization and move it to
    ``test_url_is_not_accessible`` and link an issue detailing the problem.
    """
    retry(lambda: assert_url_is_accessible(url))


@pytest.mark.parametrize(
    **make_parametrize_kwargs(
        itertools.chain(
            places365(),  # https://github.com/pytorch/vision/issues/6268
481
            sbu(),  # https://github.com/pytorch/vision/issues/7005
482
483
484
485
486
487
488
489
490
491
492
493
494
        )
    )
)
@pytest.mark.xfail
def test_url_is_not_accessible(url, md5):
    """
    As the name implies, this test is the 'inverse' of ``test_url_is_accessible``. Since the download servers are
    beyond our control, some files might not be accessible for longer stretches of time. Still, we want to know if they
    come back up, or if we need to remove the download functionality of the dataset for good.

    If you see this test failing, find the offending dataset in the parametrization and move it to
    ``test_url_is_accessible``.
    """
495
    retry(lambda: assert_url_is_accessible(url))
496
497


498
499
500
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain()))
def test_file_downloads_correctly(url, md5):
    retry(lambda: assert_file_downloads_correctly(url, md5))