test_datasets_download.py 13.4 KB
Newer Older
1
2
import contextlib
import itertools
3
import time
4
import unittest.mock
5
from datetime import datetime
6
from distutils import dir_util
7
from os import path
8
from urllib.error import HTTPError, URLError
9
from urllib.parse import urlparse
10
from urllib.request import urlopen, Request
11
import tempfile
Philip Meier's avatar
Philip Meier committed
12
import warnings
13

14
15
import pytest

16
from torchvision import datasets
17
18
19
20
21
22
23
from torchvision.datasets.utils import (
    download_url,
    check_integrity,
    download_file_from_google_drive,
    _get_redirect_url,
    USER_AGENT,
)
24
25


26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def limit_requests_per_time(min_secs_between_requests=2.0):
    last_requests = {}

    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
            url = request.full_url if isinstance(request, Request) else request

            netloc = urlparse(url).netloc
            last_request = last_requests.get(netloc)
            if last_request is not None:
                elapsed_secs = (datetime.now() - last_request).total_seconds()
                delta = min_secs_between_requests - elapsed_secs
                if delta > 0:
                    time.sleep(delta)

            response = fn(request, *args, **kwargs)
            last_requests[netloc] = datetime.now()

            return response

        return inner_wrapper

    return outer_wrapper


urlopen = limit_requests_per_time()(urlopen)


54
def resolve_redirects(max_hops=3):
Philip Meier's avatar
Philip Meier committed
55
56
    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
57
58
            initial_url = request.full_url if isinstance(request, Request) else request
            url = _get_redirect_url(initial_url, max_hops=max_hops)
Philip Meier's avatar
Philip Meier committed
59

60
61
            if url == initial_url:
                return fn(request, *args, **kwargs)
Philip Meier's avatar
Philip Meier committed
62

63
            warnings.warn(f"The URL {initial_url} ultimately redirects to {url}.")
Philip Meier's avatar
Philip Meier committed
64

65
66
67
68
69
70
71
72
73
74
75
            if not isinstance(request, Request):
                return fn(url, *args, **kwargs)

            request_attrs = {
                attr: getattr(request, attr) for attr in ("data", "headers", "origin_req_host", "unverifiable")
            }
            # the 'method' attribute does only exist if the request was created with it
            if hasattr(request, "method"):
                request_attrs["method"] = request.method

            return fn(Request(url, **request_attrs), *args, **kwargs)
Philip Meier's avatar
Philip Meier committed
76
77
78
79
80
81
82
83
84

        return inner_wrapper

    return outer_wrapper


urlopen = resolve_redirects()(urlopen)


85
@contextlib.contextmanager
86
87
def log_download_attempts(
    urls_and_md5s=None,
88
    file="utils",
89
    patch=True,
90
    mock_auxiliaries=None,
91
):
92
93
94
95
96
97
98
99
100
    def add_mock(stack, name, file, **kwargs):
        try:
            return stack.enter_context(unittest.mock.patch(f"torchvision.datasets.{file}.{name}", **kwargs))
        except AttributeError as error:
            if file != "utils":
                return add_mock(stack, name, "utils", **kwargs)
            else:
                raise pytest.UsageError from error

101
102
    if urls_and_md5s is None:
        urls_and_md5s = set()
103
104
    if mock_auxiliaries is None:
        mock_auxiliaries = patch
105
106

    with contextlib.ExitStack() as stack:
107
108
109
        url_mock = add_mock(stack, "download_url", file, wraps=None if patch else download_url)
        google_drive_mock = add_mock(
            stack, "download_file_from_google_drive", file, wraps=None if patch else download_file_from_google_drive
110
        )
111
112
113
114

        if mock_auxiliaries:
            add_mock(stack, "extract_archive", file)

115
116
117
        try:
            yield urls_and_md5s
        finally:
118
            for args, kwargs in url_mock.call_args_list:
119
120
121
122
                url = args[0]
                md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
                urls_and_md5s.add((url, md5))

123
124
125
126
127
128
            for args, kwargs in google_drive_mock.call_args_list:
                id = args[0]
                url = f"https://drive.google.com/file/d/{id}"
                md5 = args[3] if len(args) == 4 else kwargs.get("md5")
                urls_and_md5s.add((url, md5))

129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

def retry(fn, times=1, wait=5.0):
    msgs = []
    for _ in range(times + 1):
        try:
            return fn()
        except AssertionError as error:
            msgs.append(str(error))
            time.sleep(wait)
    else:
        raise AssertionError(
            "\n".join(
                (
                    f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n",
                    *(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)),
144
145
146
147
148
                )
            )
        )


149
150
151
152
@contextlib.contextmanager
def assert_server_response_ok():
    try:
        yield
153
154
    except URLError as error:
        raise AssertionError("The request timed out.") from error
155
156
    except HTTPError as error:
        raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
Philip Meier's avatar
Philip Meier committed
157
158
    except RecursionError as error:
        raise AssertionError(str(error)) from error
159
160


Philip Meier's avatar
Philip Meier committed
161
def assert_url_is_accessible(url, timeout=5.0):
162
    request = Request(url, headers={"User-Agent": USER_AGENT}, method="HEAD")
163
    with assert_server_response_ok():
Philip Meier's avatar
Philip Meier committed
164
        urlopen(request, timeout=timeout)
165
166


167
168
169
170
171
172
173
def assert_file_downloads_correctly(url, md5, tmpdir, timeout=5.0):
    file = path.join(tmpdir, path.basename(url))
    with assert_server_response_ok():
        with open(file, "wb") as fh:
            request = Request(url, headers={"User-Agent": USER_AGENT})
            response = urlopen(request, timeout=timeout)
            fh.write(response.read())
174

175
    assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
176
177
178
179
180
181
182
183


class DownloadConfig:
    def __init__(self, url, md5=None, id=None):
        self.url = url
        self.md5 = md5
        self.id = id or url

184
185
    def __repr__(self):
        return self.id
186
187


188
189
190
191
def make_download_configs(urls_and_md5s, name=None):
    return [
        DownloadConfig(url, md5=md5, id=f"{name}, {url}" if name is not None else None) for url, md5 in urls_and_md5s
    ]
192
193


194
195
196
197
198
199
200
201
202
203
204
def collect_download_configs(dataset_loader, name=None, **kwargs):
    urls_and_md5s = set()
    try:
        with log_download_attempts(urls_and_md5s=urls_and_md5s, **kwargs):
            dataset = dataset_loader()
    except Exception:
        dataset = None

    if name is None and dataset is not None:
        name = type(dataset).__name__

205
206
207
    return make_download_configs(urls_and_md5s, name)


208
209
210
211
212
213
214
215
216
217
218
# This is a workaround since fixtures, such as the built-in tmp_dir, can only be used within a test but not within a
# parametrization. Thus, we use a single root directory for all datasets and remove it when all download tests are run.
ROOT = tempfile.mkdtemp()


@pytest.fixture(scope="module", autouse=True)
def root():
    yield ROOT
    dir_util.remove_tree(ROOT)


219
def places365():
220
221
222
223
224
225
226
227
228
229
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.Places365(ROOT, split=split, small=small, download=True),
                name=f"Places365, {split}, {'small' if small else 'large'}",
                file="places365",
            )
            for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True))
        ]
    )
230
231
232


def caltech101():
233
    return collect_download_configs(lambda: datasets.Caltech101(ROOT, download=True), name="Caltech101")
234
235
236


def caltech256():
237
    return collect_download_configs(lambda: datasets.Caltech256(ROOT, download=True), name="Caltech256")
238
239
240


def cifar10():
241
    return collect_download_configs(lambda: datasets.CIFAR10(ROOT, download=True), name="CIFAR10")
242

243

244
def cifar100():
245
    return collect_download_configs(lambda: datasets.CIFAR100(ROOT, download=True), name="CIFAR100")
246
247


248
def voc():
249
250
251
    return itertools.chain(
        *[
            collect_download_configs(
252
                lambda: datasets.VOCSegmentation(ROOT, year=year, download=True),
253
                name=f"VOC, {year}",
254
                file="voc",
255
256
257
258
259
260
261
            )
            for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012")
        ]
    )


def mnist():
262
263
    with unittest.mock.patch.object(datasets.MNIST, "mirrors", datasets.MNIST.mirrors[-1:]):
        return collect_download_configs(lambda: datasets.MNIST(ROOT, download=True), name="MNIST")
264
265


Philip Meier's avatar
Philip Meier committed
266
def fashion_mnist():
267
    return collect_download_configs(lambda: datasets.FashionMNIST(ROOT, download=True), name="FashionMNIST")
Philip Meier's avatar
Philip Meier committed
268
269


270
def kmnist():
271
    return collect_download_configs(lambda: datasets.KMNIST(ROOT, download=True), name="KMNIST")
272
273
274
275


def emnist():
    # the 'split' argument can be any valid one, since everything is downloaded anyway
276
    return collect_download_configs(lambda: datasets.EMNIST(ROOT, split="byclass", download=True), name="EMNIST")
277
278
279
280
281
282


def qmnist():
    return itertools.chain(
        *[
            collect_download_configs(
283
                lambda: datasets.QMNIST(ROOT, what=what, download=True),
284
285
286
287
288
289
290
291
292
293
294
295
                name=f"QMNIST, {what}",
                file="mnist",
            )
            for what in ("train", "test", "nist")
        ]
    )


def omniglot():
    return itertools.chain(
        *[
            collect_download_configs(
296
                lambda: datasets.Omniglot(ROOT, background=background, download=True),
297
298
299
300
301
302
303
304
305
306
307
                name=f"Omniglot, {'background' if background else 'evaluation'}",
            )
            for background in (True, False)
        ]
    )


def phototour():
    return itertools.chain(
        *[
            collect_download_configs(
308
                lambda: datasets.PhotoTour(ROOT, name=name, download=True),
309
310
311
312
313
314
315
316
317
318
319
320
                name=f"PhotoTour, {name}",
                file="phototour",
            )
            # The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
            # requests timeout from within CI. They are disabled until this is resolved.
            for name in ("notredame", "yosemite", "liberty")  # "notredame_harris", "yosemite_harris", "liberty_harris"
        ]
    )


def sbdataset():
    return collect_download_configs(
321
        lambda: datasets.SBDataset(ROOT, download=True),
322
323
324
325
326
327
328
        name="SBDataset",
        file="voc",
    )


def sbu():
    return collect_download_configs(
329
        lambda: datasets.SBU(ROOT, download=True),
330
331
332
333
334
335
336
        name="SBU",
        file="sbu",
    )


def semeion():
    return collect_download_configs(
337
        lambda: datasets.SEMEION(ROOT, download=True),
338
339
340
341
342
343
344
        name="SEMEION",
        file="semeion",
    )


def stl10():
    return collect_download_configs(
345
        lambda: datasets.STL10(ROOT, download=True),
346
347
348
349
350
351
352
353
        name="STL10",
    )


def svhn():
    return itertools.chain(
        *[
            collect_download_configs(
354
                lambda: datasets.SVHN(ROOT, split=split, download=True),
355
356
357
358
359
360
361
362
363
364
365
366
                name=f"SVHN, {split}",
                file="svhn",
            )
            for split in ("train", "test", "extra")
        ]
    )


def usps():
    return itertools.chain(
        *[
            collect_download_configs(
367
                lambda: datasets.USPS(ROOT, train=train, download=True),
368
369
370
371
372
373
374
375
376
377
                name=f"USPS, {'train' if train else 'test'}",
                file="usps",
            )
            for train in (True, False)
        ]
    )


def celeba():
    return collect_download_configs(
378
        lambda: datasets.CelebA(ROOT, download=True),
379
380
381
382
383
384
385
        name="CelebA",
        file="celeba",
    )


def widerface():
    return collect_download_configs(
386
        lambda: datasets.WIDERFace(ROOT, download=True),
387
388
389
390
391
        name="WIDERFace",
        file="widerface",
    )


392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
def kinetics():
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.Kinetics(
                    path.join(ROOT, f"Kinetics{num_classes}"),
                    frames_per_clip=1,
                    num_classes=num_classes,
                    split=split,
                    download=True,
                ),
                name=f"Kinetics, {num_classes}, {split}",
                file="kinetics",
            )
            for num_classes, split in itertools.product(("400", "600", "700"), ("train", "val"))
        ]
    )


411
412
413
414
415
416
417
418
419
420
421
422
423
def kitti():
    return itertools.chain(
        *[
            collect_download_configs(
                lambda train=train: datasets.Kitti(ROOT, train=train, download=True),
                name=f"Kitti, {'train' if train else 'test'}",
                file="kitti",
            )
            for train in (True, False)
        ]
    )


424
425
426
427
428
429
430
431
def make_parametrize_kwargs(download_configs):
    argvalues = []
    ids = []
    for config in download_configs:
        argvalues.append((config.url, config.md5))
        ids.append(config.id)

    return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids)
432
433


434
@pytest.mark.parametrize(
435
436
437
438
439
440
441
    **make_parametrize_kwargs(
        itertools.chain(
            places365(),
            caltech101(),
            caltech256(),
            cifar10(),
            cifar100(),
442
443
            # The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
            # voc(),
444
            mnist(),
Philip Meier's avatar
Philip Meier committed
445
            fashion_mnist(),
446
447
448
449
450
451
452
453
454
455
456
457
458
            kmnist(),
            emnist(),
            qmnist(),
            omniglot(),
            phototour(),
            sbdataset(),
            sbu(),
            semeion(),
            stl10(),
            svhn(),
            usps(),
            celeba(),
            widerface(),
459
            kinetics(),
460
            kitti(),
461
462
        )
    )
463
)
464
465
def test_url_is_accessible(url, md5):
    retry(lambda: assert_url_is_accessible(url))
466
467


468
469
470
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain()))
def test_file_downloads_correctly(url, md5):
    retry(lambda: assert_file_downloads_correctly(url, md5))