test_datasets_download.py 12.6 KB
Newer Older
1
2
import contextlib
import itertools
3
import time
4
import unittest.mock
5
from datetime import datetime
6
from distutils import dir_util
7
from os import path
8
from urllib.error import HTTPError, URLError
9
from urllib.parse import urlparse
10
from urllib.request import urlopen, Request
11
import tempfile
Philip Meier's avatar
Philip Meier committed
12
import warnings
13

14
15
import pytest

16
from torchvision import datasets
17
18
19
20
21
22
23
from torchvision.datasets.utils import (
    download_url,
    check_integrity,
    download_file_from_google_drive,
    _get_redirect_url,
    USER_AGENT,
)
24
25
26
27
28

from common_utils import get_tmp_dir
from fakedata_generation import places365_root


29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def limit_requests_per_time(min_secs_between_requests=2.0):
    last_requests = {}

    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
            url = request.full_url if isinstance(request, Request) else request

            netloc = urlparse(url).netloc
            last_request = last_requests.get(netloc)
            if last_request is not None:
                elapsed_secs = (datetime.now() - last_request).total_seconds()
                delta = min_secs_between_requests - elapsed_secs
                if delta > 0:
                    time.sleep(delta)

            response = fn(request, *args, **kwargs)
            last_requests[netloc] = datetime.now()

            return response

        return inner_wrapper

    return outer_wrapper


urlopen = limit_requests_per_time()(urlopen)


57
def resolve_redirects(max_hops=3):
Philip Meier's avatar
Philip Meier committed
58
59
    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
60
61
            initial_url = request.full_url if isinstance(request, Request) else request
            url = _get_redirect_url(initial_url, max_hops=max_hops)
Philip Meier's avatar
Philip Meier committed
62

63
64
            if url == initial_url:
                return fn(request, *args, **kwargs)
Philip Meier's avatar
Philip Meier committed
65

66
            warnings.warn(f"The URL {initial_url} ultimately redirects to {url}.")
Philip Meier's avatar
Philip Meier committed
67

68
69
70
71
72
73
74
75
76
77
78
            if not isinstance(request, Request):
                return fn(url, *args, **kwargs)

            request_attrs = {
                attr: getattr(request, attr) for attr in ("data", "headers", "origin_req_host", "unverifiable")
            }
            # the 'method' attribute does only exist if the request was created with it
            if hasattr(request, "method"):
                request_attrs["method"] = request.method

            return fn(Request(url, **request_attrs), *args, **kwargs)
Philip Meier's avatar
Philip Meier committed
79
80
81
82
83
84
85
86
87

        return inner_wrapper

    return outer_wrapper


urlopen = resolve_redirects()(urlopen)


88
@contextlib.contextmanager
89
90
def log_download_attempts(
    urls_and_md5s=None,
91
    file="utils",
92
    patch=True,
93
    mock_auxiliaries=None,
94
):
95
96
97
98
99
100
101
102
103
    def add_mock(stack, name, file, **kwargs):
        try:
            return stack.enter_context(unittest.mock.patch(f"torchvision.datasets.{file}.{name}", **kwargs))
        except AttributeError as error:
            if file != "utils":
                return add_mock(stack, name, "utils", **kwargs)
            else:
                raise pytest.UsageError from error

104
105
    if urls_and_md5s is None:
        urls_and_md5s = set()
106
107
    if mock_auxiliaries is None:
        mock_auxiliaries = patch
108
109

    with contextlib.ExitStack() as stack:
110
111
112
        url_mock = add_mock(stack, "download_url", file, wraps=None if patch else download_url)
        google_drive_mock = add_mock(
            stack, "download_file_from_google_drive", file, wraps=None if patch else download_file_from_google_drive
113
        )
114
115
116
117

        if mock_auxiliaries:
            add_mock(stack, "extract_archive", file)

118
119
120
        try:
            yield urls_and_md5s
        finally:
121
            for args, kwargs in url_mock.call_args_list:
122
123
124
125
                url = args[0]
                md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
                urls_and_md5s.add((url, md5))

126
127
128
129
130
131
            for args, kwargs in google_drive_mock.call_args_list:
                id = args[0]
                url = f"https://drive.google.com/file/d/{id}"
                md5 = args[3] if len(args) == 4 else kwargs.get("md5")
                urls_and_md5s.add((url, md5))

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

def retry(fn, times=1, wait=5.0):
    msgs = []
    for _ in range(times + 1):
        try:
            return fn()
        except AssertionError as error:
            msgs.append(str(error))
            time.sleep(wait)
    else:
        raise AssertionError(
            "\n".join(
                (
                    f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n",
                    *(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)),
147
148
149
150
151
                )
            )
        )


152
153
154
155
@contextlib.contextmanager
def assert_server_response_ok():
    try:
        yield
156
157
    except URLError as error:
        raise AssertionError("The request timed out.") from error
158
159
    except HTTPError as error:
        raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
Philip Meier's avatar
Philip Meier committed
160
161
    except RecursionError as error:
        raise AssertionError(str(error)) from error
162
163


Philip Meier's avatar
Philip Meier committed
164
def assert_url_is_accessible(url, timeout=5.0):
165
    request = Request(url, headers={"User-Agent": USER_AGENT}, method="HEAD")
166
    with assert_server_response_ok():
Philip Meier's avatar
Philip Meier committed
167
        urlopen(request, timeout=timeout)
168
169


Philip Meier's avatar
Philip Meier committed
170
def assert_file_downloads_correctly(url, md5, timeout=5.0):
171
172
    with get_tmp_dir() as root:
        file = path.join(root, path.basename(url))
173
        with assert_server_response_ok():
Philip Meier's avatar
Philip Meier committed
174
            with open(file, "wb") as fh:
175
176
                request = Request(url, headers={"User-Agent": USER_AGENT})
                response = urlopen(request, timeout=timeout)
177
                fh.write(response.read())
178
179
180
181
182
183
184
185
186
187

        assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"


class DownloadConfig:
    def __init__(self, url, md5=None, id=None):
        self.url = url
        self.md5 = md5
        self.id = id or url

188
189
    def __repr__(self):
        return self.id
190
191


192
193
194
195
def make_download_configs(urls_and_md5s, name=None):
    return [
        DownloadConfig(url, md5=md5, id=f"{name}, {url}" if name is not None else None) for url, md5 in urls_and_md5s
    ]
196
197


198
199
200
201
202
203
204
205
206
207
208
def collect_download_configs(dataset_loader, name=None, **kwargs):
    urls_and_md5s = set()
    try:
        with log_download_attempts(urls_and_md5s=urls_and_md5s, **kwargs):
            dataset = dataset_loader()
    except Exception:
        dataset = None

    if name is None and dataset is not None:
        name = type(dataset).__name__

209
210
211
    return make_download_configs(urls_and_md5s, name)


212
213
214
215
216
217
218
219
220
221
222
# This is a workaround since fixtures, such as the built-in tmp_dir, can only be used within a test but not within a
# parametrization. Thus, we use a single root directory for all datasets and remove it when all download tests are run.
ROOT = tempfile.mkdtemp()


@pytest.fixture(scope="module", autouse=True)
def root():
    yield ROOT
    dir_util.remove_tree(ROOT)


223
224
225
226
227
228
229
230
def places365():
    with log_download_attempts(patch=False) as urls_and_md5s:
        for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
            with places365_root(split=split, small=small) as places365:
                root, data = places365

                datasets.Places365(root, split=split, small=small, download=True)

231
    return make_download_configs(urls_and_md5s, name="Places365")
232
233
234


def caltech101():
235
    return collect_download_configs(lambda: datasets.Caltech101(ROOT, download=True), name="Caltech101")
236
237
238


def caltech256():
239
    return collect_download_configs(lambda: datasets.Caltech256(ROOT, download=True), name="Caltech256")
240
241
242


def cifar10():
243
    return collect_download_configs(lambda: datasets.CIFAR10(ROOT, download=True), name="CIFAR10")
244

245

246
def cifar100():
247
    return collect_download_configs(lambda: datasets.CIFAR100(ROOT, download=True), name="CIFAR100")
248
249


250
def voc():
251
252
253
    return itertools.chain(
        *[
            collect_download_configs(
254
                lambda: datasets.VOCSegmentation(ROOT, year=year, download=True),
255
                name=f"VOC, {year}",
256
                file="voc",
257
258
259
260
261
262
263
            )
            for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012")
        ]
    )


def mnist():
264
265
    with unittest.mock.patch.object(datasets.MNIST, "mirrors", datasets.MNIST.mirrors[-1:]):
        return collect_download_configs(lambda: datasets.MNIST(ROOT, download=True), name="MNIST")
266
267


Philip Meier's avatar
Philip Meier committed
268
def fashion_mnist():
269
    return collect_download_configs(lambda: datasets.FashionMNIST(ROOT, download=True), name="FashionMNIST")
Philip Meier's avatar
Philip Meier committed
270
271


272
def kmnist():
273
    return collect_download_configs(lambda: datasets.KMNIST(ROOT, download=True), name="KMNIST")
274
275
276
277


def emnist():
    # the 'split' argument can be any valid one, since everything is downloaded anyway
278
    return collect_download_configs(lambda: datasets.EMNIST(ROOT, split="byclass", download=True), name="EMNIST")
279
280
281
282
283
284


def qmnist():
    return itertools.chain(
        *[
            collect_download_configs(
285
                lambda: datasets.QMNIST(ROOT, what=what, download=True),
286
287
288
289
290
291
292
293
294
295
296
297
                name=f"QMNIST, {what}",
                file="mnist",
            )
            for what in ("train", "test", "nist")
        ]
    )


def omniglot():
    return itertools.chain(
        *[
            collect_download_configs(
298
                lambda: datasets.Omniglot(ROOT, background=background, download=True),
299
300
301
302
303
304
305
306
307
308
309
                name=f"Omniglot, {'background' if background else 'evaluation'}",
            )
            for background in (True, False)
        ]
    )


def phototour():
    return itertools.chain(
        *[
            collect_download_configs(
310
                lambda: datasets.PhotoTour(ROOT, name=name, download=True),
311
312
313
314
315
316
317
318
319
320
321
322
                name=f"PhotoTour, {name}",
                file="phototour",
            )
            # The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
            # requests timeout from within CI. They are disabled until this is resolved.
            for name in ("notredame", "yosemite", "liberty")  # "notredame_harris", "yosemite_harris", "liberty_harris"
        ]
    )


def sbdataset():
    return collect_download_configs(
323
        lambda: datasets.SBDataset(ROOT, download=True),
324
325
326
327
328
329
330
        name="SBDataset",
        file="voc",
    )


def sbu():
    return collect_download_configs(
331
        lambda: datasets.SBU(ROOT, download=True),
332
333
334
335
336
337
338
        name="SBU",
        file="sbu",
    )


def semeion():
    return collect_download_configs(
339
        lambda: datasets.SEMEION(ROOT, download=True),
340
341
342
343
344
345
346
        name="SEMEION",
        file="semeion",
    )


def stl10():
    return collect_download_configs(
347
        lambda: datasets.STL10(ROOT, download=True),
348
349
350
351
352
353
354
355
        name="STL10",
    )


def svhn():
    return itertools.chain(
        *[
            collect_download_configs(
356
                lambda: datasets.SVHN(ROOT, split=split, download=True),
357
358
359
360
361
362
363
364
365
366
367
368
                name=f"SVHN, {split}",
                file="svhn",
            )
            for split in ("train", "test", "extra")
        ]
    )


def usps():
    return itertools.chain(
        *[
            collect_download_configs(
369
                lambda: datasets.USPS(ROOT, train=train, download=True),
370
371
372
373
374
375
376
377
378
379
                name=f"USPS, {'train' if train else 'test'}",
                file="usps",
            )
            for train in (True, False)
        ]
    )


def celeba():
    return collect_download_configs(
380
        lambda: datasets.CelebA(ROOT, download=True),
381
382
383
384
385
386
387
        name="CelebA",
        file="celeba",
    )


def widerface():
    return collect_download_configs(
388
        lambda: datasets.WIDERFace(ROOT, download=True),
389
390
391
392
393
        name="WIDERFace",
        file="widerface",
    )


394
395
396
397
398
399
400
401
def make_parametrize_kwargs(download_configs):
    argvalues = []
    ids = []
    for config in download_configs:
        argvalues.append((config.url, config.md5))
        ids.append(config.id)

    return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids)
402
403


404
@pytest.mark.parametrize(
405
406
407
408
409
410
411
    **make_parametrize_kwargs(
        itertools.chain(
            places365(),
            caltech101(),
            caltech256(),
            cifar10(),
            cifar100(),
412
413
            # The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
            # voc(),
414
            mnist(),
Philip Meier's avatar
Philip Meier committed
415
            fashion_mnist(),
416
417
418
419
420
421
422
423
424
425
426
427
428
            kmnist(),
            emnist(),
            qmnist(),
            omniglot(),
            phototour(),
            sbdataset(),
            sbu(),
            semeion(),
            stl10(),
            svhn(),
            usps(),
            celeba(),
            widerface(),
429
430
        )
    )
431
)
432
433
def test_url_is_accessible(url, md5):
    retry(lambda: assert_url_is_accessible(url))
434
435


436
437
438
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain()))
def test_file_downloads_correctly(url, md5):
    retry(lambda: assert_file_downloads_correctly(url, md5))