"examples/vscode:/vscode.git/clone" did not exist on "385eeea357771824f64c31265500a0f6d0c45cd7"
test_datasets_download.py 13.5 KB
Newer Older
1
2
import contextlib
import itertools
3
import time
4
import unittest.mock
5
from datetime import datetime
6
from distutils import dir_util
7
from os import path
8
from urllib.error import HTTPError, URLError
9
from urllib.parse import urlparse
10
from urllib.request import urlopen, Request
11
import tempfile
Philip Meier's avatar
Philip Meier committed
12
import warnings
13

14
15
import pytest

16
from torchvision import datasets
17
18
19
20
21
22
23
from torchvision.datasets.utils import (
    download_url,
    check_integrity,
    download_file_from_google_drive,
    _get_redirect_url,
    USER_AGENT,
)
24
25
26
27

from common_utils import get_tmp_dir


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def limit_requests_per_time(min_secs_between_requests=2.0):
    last_requests = {}

    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
            url = request.full_url if isinstance(request, Request) else request

            netloc = urlparse(url).netloc
            last_request = last_requests.get(netloc)
            if last_request is not None:
                elapsed_secs = (datetime.now() - last_request).total_seconds()
                delta = min_secs_between_requests - elapsed_secs
                if delta > 0:
                    time.sleep(delta)

            response = fn(request, *args, **kwargs)
            last_requests[netloc] = datetime.now()

            return response

        return inner_wrapper

    return outer_wrapper


urlopen = limit_requests_per_time()(urlopen)


56
def resolve_redirects(max_hops=3):
Philip Meier's avatar
Philip Meier committed
57
58
    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
59
60
            initial_url = request.full_url if isinstance(request, Request) else request
            url = _get_redirect_url(initial_url, max_hops=max_hops)
Philip Meier's avatar
Philip Meier committed
61

62
63
            if url == initial_url:
                return fn(request, *args, **kwargs)
Philip Meier's avatar
Philip Meier committed
64

65
            warnings.warn(f"The URL {initial_url} ultimately redirects to {url}.")
Philip Meier's avatar
Philip Meier committed
66

67
68
69
70
71
72
73
74
75
76
77
            if not isinstance(request, Request):
                return fn(url, *args, **kwargs)

            request_attrs = {
                attr: getattr(request, attr) for attr in ("data", "headers", "origin_req_host", "unverifiable")
            }
            # the 'method' attribute does only exist if the request was created with it
            if hasattr(request, "method"):
                request_attrs["method"] = request.method

            return fn(Request(url, **request_attrs), *args, **kwargs)
Philip Meier's avatar
Philip Meier committed
78
79
80
81
82
83
84
85
86

        return inner_wrapper

    return outer_wrapper


urlopen = resolve_redirects()(urlopen)


87
@contextlib.contextmanager
88
89
def log_download_attempts(
    urls_and_md5s=None,
90
    file="utils",
91
    patch=True,
92
    mock_auxiliaries=None,
93
):
94
95
96
97
98
99
100
101
102
    def add_mock(stack, name, file, **kwargs):
        try:
            return stack.enter_context(unittest.mock.patch(f"torchvision.datasets.{file}.{name}", **kwargs))
        except AttributeError as error:
            if file != "utils":
                return add_mock(stack, name, "utils", **kwargs)
            else:
                raise pytest.UsageError from error

103
104
    if urls_and_md5s is None:
        urls_and_md5s = set()
105
106
    if mock_auxiliaries is None:
        mock_auxiliaries = patch
107
108

    with contextlib.ExitStack() as stack:
109
110
111
        url_mock = add_mock(stack, "download_url", file, wraps=None if patch else download_url)
        google_drive_mock = add_mock(
            stack, "download_file_from_google_drive", file, wraps=None if patch else download_file_from_google_drive
112
        )
113
114
115
116

        if mock_auxiliaries:
            add_mock(stack, "extract_archive", file)

117
118
119
        try:
            yield urls_and_md5s
        finally:
120
            for args, kwargs in url_mock.call_args_list:
121
122
123
124
                url = args[0]
                md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
                urls_and_md5s.add((url, md5))

125
126
127
128
129
130
            for args, kwargs in google_drive_mock.call_args_list:
                id = args[0]
                url = f"https://drive.google.com/file/d/{id}"
                md5 = args[3] if len(args) == 4 else kwargs.get("md5")
                urls_and_md5s.add((url, md5))

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

def retry(fn, times=1, wait=5.0):
    msgs = []
    for _ in range(times + 1):
        try:
            return fn()
        except AssertionError as error:
            msgs.append(str(error))
            time.sleep(wait)
    else:
        raise AssertionError(
            "\n".join(
                (
                    f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n",
                    *(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)),
146
147
148
149
150
                )
            )
        )


151
152
153
154
@contextlib.contextmanager
def assert_server_response_ok():
    try:
        yield
155
156
    except URLError as error:
        raise AssertionError("The request timed out.") from error
157
158
    except HTTPError as error:
        raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
Philip Meier's avatar
Philip Meier committed
159
160
    except RecursionError as error:
        raise AssertionError(str(error)) from error
161
162


Philip Meier's avatar
Philip Meier committed
163
def assert_url_is_accessible(url, timeout=5.0):
164
    request = Request(url, headers={"User-Agent": USER_AGENT}, method="HEAD")
165
    with assert_server_response_ok():
Philip Meier's avatar
Philip Meier committed
166
        urlopen(request, timeout=timeout)
167
168


Philip Meier's avatar
Philip Meier committed
169
def assert_file_downloads_correctly(url, md5, timeout=5.0):
170
171
    with get_tmp_dir() as root:
        file = path.join(root, path.basename(url))
172
        with assert_server_response_ok():
Philip Meier's avatar
Philip Meier committed
173
            with open(file, "wb") as fh:
174
175
                request = Request(url, headers={"User-Agent": USER_AGENT})
                response = urlopen(request, timeout=timeout)
176
                fh.write(response.read())
177
178
179
180
181
182
183
184
185
186

        assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"


class DownloadConfig:
    def __init__(self, url, md5=None, id=None):
        self.url = url
        self.md5 = md5
        self.id = id or url

187
188
    def __repr__(self):
        return self.id
189
190


191
192
193
194
def make_download_configs(urls_and_md5s, name=None):
    return [
        DownloadConfig(url, md5=md5, id=f"{name}, {url}" if name is not None else None) for url, md5 in urls_and_md5s
    ]
195
196


197
198
199
200
201
202
203
204
205
206
207
def collect_download_configs(dataset_loader, name=None, **kwargs):
    urls_and_md5s = set()
    try:
        with log_download_attempts(urls_and_md5s=urls_and_md5s, **kwargs):
            dataset = dataset_loader()
    except Exception:
        dataset = None

    if name is None and dataset is not None:
        name = type(dataset).__name__

208
209
210
    return make_download_configs(urls_and_md5s, name)


211
212
213
214
215
216
217
218
219
220
221
# This is a workaround since fixtures, such as the built-in tmp_dir, can only be used within a test but not within a
# parametrization. Thus, we use a single root directory for all datasets and remove it when all download tests are run.
ROOT = tempfile.mkdtemp()


@pytest.fixture(scope="module", autouse=True)
def root():
    yield ROOT
    dir_util.remove_tree(ROOT)


222
def places365():
223
224
225
226
227
228
229
230
231
232
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.Places365(ROOT, split=split, small=small, download=True),
                name=f"Places365, {split}, {'small' if small else 'large'}",
                file="places365",
            )
            for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True))
        ]
    )
233
234
235


def caltech101():
236
    return collect_download_configs(lambda: datasets.Caltech101(ROOT, download=True), name="Caltech101")
237
238
239


def caltech256():
240
    return collect_download_configs(lambda: datasets.Caltech256(ROOT, download=True), name="Caltech256")
241
242
243


def cifar10():
244
    return collect_download_configs(lambda: datasets.CIFAR10(ROOT, download=True), name="CIFAR10")
245

246

247
def cifar100():
248
    return collect_download_configs(lambda: datasets.CIFAR100(ROOT, download=True), name="CIFAR100")
249
250


251
def voc():
252
253
254
    return itertools.chain(
        *[
            collect_download_configs(
255
                lambda: datasets.VOCSegmentation(ROOT, year=year, download=True),
256
                name=f"VOC, {year}",
257
                file="voc",
258
259
260
261
262
263
264
            )
            for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012")
        ]
    )


def mnist():
265
266
    with unittest.mock.patch.object(datasets.MNIST, "mirrors", datasets.MNIST.mirrors[-1:]):
        return collect_download_configs(lambda: datasets.MNIST(ROOT, download=True), name="MNIST")
267
268


Philip Meier's avatar
Philip Meier committed
269
def fashion_mnist():
270
    return collect_download_configs(lambda: datasets.FashionMNIST(ROOT, download=True), name="FashionMNIST")
Philip Meier's avatar
Philip Meier committed
271
272


273
def kmnist():
274
    return collect_download_configs(lambda: datasets.KMNIST(ROOT, download=True), name="KMNIST")
275
276
277
278


def emnist():
    # the 'split' argument can be any valid one, since everything is downloaded anyway
279
    return collect_download_configs(lambda: datasets.EMNIST(ROOT, split="byclass", download=True), name="EMNIST")
280
281
282
283
284
285


def qmnist():
    return itertools.chain(
        *[
            collect_download_configs(
286
                lambda: datasets.QMNIST(ROOT, what=what, download=True),
287
288
289
290
291
292
293
294
295
296
297
298
                name=f"QMNIST, {what}",
                file="mnist",
            )
            for what in ("train", "test", "nist")
        ]
    )


def omniglot():
    return itertools.chain(
        *[
            collect_download_configs(
299
                lambda: datasets.Omniglot(ROOT, background=background, download=True),
300
301
302
303
304
305
306
307
308
309
310
                name=f"Omniglot, {'background' if background else 'evaluation'}",
            )
            for background in (True, False)
        ]
    )


def phototour():
    return itertools.chain(
        *[
            collect_download_configs(
311
                lambda: datasets.PhotoTour(ROOT, name=name, download=True),
312
313
314
315
316
317
318
319
320
321
322
323
                name=f"PhotoTour, {name}",
                file="phototour",
            )
            # The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
            # requests timeout from within CI. They are disabled until this is resolved.
            for name in ("notredame", "yosemite", "liberty")  # "notredame_harris", "yosemite_harris", "liberty_harris"
        ]
    )


def sbdataset():
    return collect_download_configs(
324
        lambda: datasets.SBDataset(ROOT, download=True),
325
326
327
328
329
330
331
        name="SBDataset",
        file="voc",
    )


def sbu():
    return collect_download_configs(
332
        lambda: datasets.SBU(ROOT, download=True),
333
334
335
336
337
338
339
        name="SBU",
        file="sbu",
    )


def semeion():
    return collect_download_configs(
340
        lambda: datasets.SEMEION(ROOT, download=True),
341
342
343
344
345
346
347
        name="SEMEION",
        file="semeion",
    )


def stl10():
    return collect_download_configs(
348
        lambda: datasets.STL10(ROOT, download=True),
349
350
351
352
353
354
355
356
        name="STL10",
    )


def svhn():
    return itertools.chain(
        *[
            collect_download_configs(
357
                lambda: datasets.SVHN(ROOT, split=split, download=True),
358
359
360
361
362
363
364
365
366
367
368
369
                name=f"SVHN, {split}",
                file="svhn",
            )
            for split in ("train", "test", "extra")
        ]
    )


def usps():
    return itertools.chain(
        *[
            collect_download_configs(
370
                lambda: datasets.USPS(ROOT, train=train, download=True),
371
372
373
374
375
376
377
378
379
380
                name=f"USPS, {'train' if train else 'test'}",
                file="usps",
            )
            for train in (True, False)
        ]
    )


def celeba():
    return collect_download_configs(
381
        lambda: datasets.CelebA(ROOT, download=True),
382
383
384
385
386
387
388
        name="CelebA",
        file="celeba",
    )


def widerface():
    return collect_download_configs(
389
        lambda: datasets.WIDERFace(ROOT, download=True),
390
391
392
393
394
        name="WIDERFace",
        file="widerface",
    )


395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
def kinetics():
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.Kinetics(
                    path.join(ROOT, f"Kinetics{num_classes}"),
                    frames_per_clip=1,
                    num_classes=num_classes,
                    split=split,
                    download=True,
                ),
                name=f"Kinetics, {num_classes}, {split}",
                file="kinetics",
            )
            for num_classes, split in itertools.product(("400", "600", "700"), ("train", "val"))
        ]
    )


414
415
416
417
418
419
420
421
422
423
424
425
426
def kitti():
    return itertools.chain(
        *[
            collect_download_configs(
                lambda train=train: datasets.Kitti(ROOT, train=train, download=True),
                name=f"Kitti, {'train' if train else 'test'}",
                file="kitti",
            )
            for train in (True, False)
        ]
    )


427
428
429
430
431
432
433
434
def make_parametrize_kwargs(download_configs):
    argvalues = []
    ids = []
    for config in download_configs:
        argvalues.append((config.url, config.md5))
        ids.append(config.id)

    return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids)
435
436


437
@pytest.mark.parametrize(
438
439
440
441
442
443
444
    **make_parametrize_kwargs(
        itertools.chain(
            places365(),
            caltech101(),
            caltech256(),
            cifar10(),
            cifar100(),
445
446
            # The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
            # voc(),
447
            mnist(),
Philip Meier's avatar
Philip Meier committed
448
            fashion_mnist(),
449
450
451
452
453
454
455
456
457
458
459
460
461
            kmnist(),
            emnist(),
            qmnist(),
            omniglot(),
            phototour(),
            sbdataset(),
            sbu(),
            semeion(),
            stl10(),
            svhn(),
            usps(),
            celeba(),
            widerface(),
462
            kinetics(),
463
            kitti(),
464
465
        )
    )
466
)
467
468
def test_url_is_accessible(url, md5):
    retry(lambda: assert_url_is_accessible(url))
469
470


471
472
473
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain()))
def test_file_downloads_correctly(url, md5):
    retry(lambda: assert_file_downloads_correctly(url, md5))