test_datasets_download.py 12.3 KB
Newer Older
1
2
import contextlib
import itertools
3
import time
4
import unittest.mock
5
from datetime import datetime
6
from distutils import dir_util
7
from os import path
8
from urllib.error import HTTPError, URLError
9
from urllib.parse import urlparse
10
from urllib.request import urlopen, Request
11
import tempfile
Philip Meier's avatar
Philip Meier committed
12
import warnings
13

14
15
import pytest

16
from torchvision import datasets
17
from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive, USER_AGENT
18
19
20
21
22

from common_utils import get_tmp_dir
from fakedata_generation import places365_root


23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def limit_requests_per_time(min_secs_between_requests=2.0):
    last_requests = {}

    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
            url = request.full_url if isinstance(request, Request) else request

            netloc = urlparse(url).netloc
            last_request = last_requests.get(netloc)
            if last_request is not None:
                elapsed_secs = (datetime.now() - last_request).total_seconds()
                delta = min_secs_between_requests - elapsed_secs
                if delta > 0:
                    time.sleep(delta)

            response = fn(request, *args, **kwargs)
            last_requests[netloc] = datetime.now()

            return response

        return inner_wrapper

    return outer_wrapper


urlopen = limit_requests_per_time()(urlopen)


Philip Meier's avatar
Philip Meier committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def resolve_redirects(max_redirects=3):
    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
            url = initial_url = request.full_url if isinstance(request, Request) else request

            for _ in range(max_redirects + 1):
                response = fn(request, *args, **kwargs)

                if response.url == url or response.url is None:
                    if url != initial_url:
                        warnings.warn(f"The URL {initial_url} ultimately redirects to {url}.")
                    return response

                url = response.url
            else:
                raise RecursionError(f"Request to {initial_url} exceeded {max_redirects} redirects.")

        return inner_wrapper

    return outer_wrapper


urlopen = resolve_redirects()(urlopen)


76
@contextlib.contextmanager
77
78
def log_download_attempts(
    urls_and_md5s=None,
79
    file="utils",
80
    patch=True,
81
    mock_auxiliaries=None,
82
):
83
84
85
86
87
88
89
90
91
    def add_mock(stack, name, file, **kwargs):
        try:
            return stack.enter_context(unittest.mock.patch(f"torchvision.datasets.{file}.{name}", **kwargs))
        except AttributeError as error:
            if file != "utils":
                return add_mock(stack, name, "utils", **kwargs)
            else:
                raise pytest.UsageError from error

92
93
    if urls_and_md5s is None:
        urls_and_md5s = set()
94
95
    if mock_auxiliaries is None:
        mock_auxiliaries = patch
96
97

    with contextlib.ExitStack() as stack:
98
99
100
        url_mock = add_mock(stack, "download_url", file, wraps=None if patch else download_url)
        google_drive_mock = add_mock(
            stack, "download_file_from_google_drive", file, wraps=None if patch else download_file_from_google_drive
101
        )
102
103
104
105

        if mock_auxiliaries:
            add_mock(stack, "extract_archive", file)

106
107
108
        try:
            yield urls_and_md5s
        finally:
109
            for args, kwargs in url_mock.call_args_list:
110
111
112
113
                url = args[0]
                md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
                urls_and_md5s.add((url, md5))

114
115
116
117
118
119
            for args, kwargs in google_drive_mock.call_args_list:
                id = args[0]
                url = f"https://drive.google.com/file/d/{id}"
                md5 = args[3] if len(args) == 4 else kwargs.get("md5")
                urls_and_md5s.add((url, md5))

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

def retry(fn, times=1, wait=5.0):
    msgs = []
    for _ in range(times + 1):
        try:
            return fn()
        except AssertionError as error:
            msgs.append(str(error))
            time.sleep(wait)
    else:
        raise AssertionError(
            "\n".join(
                (
                    f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n",
                    *(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)),
135
136
137
138
139
                )
            )
        )


140
141
142
143
@contextlib.contextmanager
def assert_server_response_ok():
    try:
        yield
144
145
    except URLError as error:
        raise AssertionError("The request timed out.") from error
146
147
    except HTTPError as error:
        raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
Philip Meier's avatar
Philip Meier committed
148
149
    except RecursionError as error:
        raise AssertionError(str(error)) from error
150
151


Philip Meier's avatar
Philip Meier committed
152
def assert_url_is_accessible(url, timeout=5.0):
153
    request = Request(url, headers={"method": "HEAD", "User-Agent": USER_AGENT})
154
    with assert_server_response_ok():
Philip Meier's avatar
Philip Meier committed
155
        urlopen(request, timeout=timeout)
156
157


Philip Meier's avatar
Philip Meier committed
158
def assert_file_downloads_correctly(url, md5, timeout=5.0):
159
160
    with get_tmp_dir() as root:
        file = path.join(root, path.basename(url))
161
        with assert_server_response_ok():
Philip Meier's avatar
Philip Meier committed
162
            with open(file, "wb") as fh:
163
164
                request = Request(url, headers={"User-Agent": USER_AGENT})
                response = urlopen(request, timeout=timeout)
165
                fh.write(response.read())
166
167
168
169
170
171
172
173
174
175

        assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"


class DownloadConfig:
    def __init__(self, url, md5=None, id=None):
        self.url = url
        self.md5 = md5
        self.id = id or url

176
177
    def __repr__(self):
        return self.id
178
179


180
181
182
183
def make_download_configs(urls_and_md5s, name=None):
    return [
        DownloadConfig(url, md5=md5, id=f"{name}, {url}" if name is not None else None) for url, md5 in urls_and_md5s
    ]
184
185


186
187
188
189
190
191
192
193
194
195
196
def collect_download_configs(dataset_loader, name=None, **kwargs):
    urls_and_md5s = set()
    try:
        with log_download_attempts(urls_and_md5s=urls_and_md5s, **kwargs):
            dataset = dataset_loader()
    except Exception:
        dataset = None

    if name is None and dataset is not None:
        name = type(dataset).__name__

197
198
199
    return make_download_configs(urls_and_md5s, name)


200
201
202
203
204
205
206
207
208
209
210
# This is a workaround since fixtures, such as the built-in tmp_dir, can only be used within a test but not within a
# parametrization. Thus, we use a single root directory for all datasets and remove it when all download tests are run.
ROOT = tempfile.mkdtemp()


@pytest.fixture(scope="module", autouse=True)
def root():
    yield ROOT
    dir_util.remove_tree(ROOT)


211
212
213
214
215
216
217
218
def places365():
    with log_download_attempts(patch=False) as urls_and_md5s:
        for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
            with places365_root(split=split, small=small) as places365:
                root, data = places365

                datasets.Places365(root, split=split, small=small, download=True)

219
    return make_download_configs(urls_and_md5s, name="Places365")
220
221
222


def caltech101():
223
    return collect_download_configs(lambda: datasets.Caltech101(ROOT, download=True), name="Caltech101")
224
225
226


def caltech256():
227
    return collect_download_configs(lambda: datasets.Caltech256(ROOT, download=True), name="Caltech256")
228
229
230


def cifar10():
231
    return collect_download_configs(lambda: datasets.CIFAR10(ROOT, download=True), name="CIFAR10")
232

233

234
def cifar100():
235
    return collect_download_configs(lambda: datasets.CIFAR100(ROOT, download=True), name="CIFAR100")
236
237


238
def voc():
239
240
241
    return itertools.chain(
        *[
            collect_download_configs(
242
                lambda: datasets.VOCSegmentation(ROOT, year=year, download=True),
243
                name=f"VOC, {year}",
244
                file="voc",
245
246
247
248
249
250
251
            )
            for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012")
        ]
    )


def mnist():
252
253
    with unittest.mock.patch.object(datasets.MNIST, "mirrors", datasets.MNIST.mirrors[-1:]):
        return collect_download_configs(lambda: datasets.MNIST(ROOT, download=True), name="MNIST")
254
255


Philip Meier's avatar
Philip Meier committed
256
def fashion_mnist():
257
    return collect_download_configs(lambda: datasets.FashionMNIST(ROOT, download=True), name="FashionMNIST")
Philip Meier's avatar
Philip Meier committed
258
259


260
def kmnist():
261
    return collect_download_configs(lambda: datasets.KMNIST(ROOT, download=True), name="KMNIST")
262
263
264
265


def emnist():
    # the 'split' argument can be any valid one, since everything is downloaded anyway
266
    return collect_download_configs(lambda: datasets.EMNIST(ROOT, split="byclass", download=True), name="EMNIST")
267
268
269
270
271
272


def qmnist():
    return itertools.chain(
        *[
            collect_download_configs(
273
                lambda: datasets.QMNIST(ROOT, what=what, download=True),
274
275
276
277
278
279
280
281
282
283
284
285
                name=f"QMNIST, {what}",
                file="mnist",
            )
            for what in ("train", "test", "nist")
        ]
    )


def omniglot():
    return itertools.chain(
        *[
            collect_download_configs(
286
                lambda: datasets.Omniglot(ROOT, background=background, download=True),
287
288
289
290
291
292
293
294
295
296
297
                name=f"Omniglot, {'background' if background else 'evaluation'}",
            )
            for background in (True, False)
        ]
    )


def phototour():
    return itertools.chain(
        *[
            collect_download_configs(
298
                lambda: datasets.PhotoTour(ROOT, name=name, download=True),
299
300
301
302
303
304
305
306
307
308
309
310
                name=f"PhotoTour, {name}",
                file="phototour",
            )
            # The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
            # requests timeout from within CI. They are disabled until this is resolved.
            for name in ("notredame", "yosemite", "liberty")  # "notredame_harris", "yosemite_harris", "liberty_harris"
        ]
    )


def sbdataset():
    return collect_download_configs(
311
        lambda: datasets.SBDataset(ROOT, download=True),
312
313
314
315
316
317
318
        name="SBDataset",
        file="voc",
    )


def sbu():
    return collect_download_configs(
319
        lambda: datasets.SBU(ROOT, download=True),
320
321
322
323
324
325
326
        name="SBU",
        file="sbu",
    )


def semeion():
    return collect_download_configs(
327
        lambda: datasets.SEMEION(ROOT, download=True),
328
329
330
331
332
333
334
        name="SEMEION",
        file="semeion",
    )


def stl10():
    return collect_download_configs(
335
        lambda: datasets.STL10(ROOT, download=True),
336
337
338
339
340
341
342
343
        name="STL10",
    )


def svhn():
    return itertools.chain(
        *[
            collect_download_configs(
344
                lambda: datasets.SVHN(ROOT, split=split, download=True),
345
346
347
348
349
350
351
352
353
354
355
356
                name=f"SVHN, {split}",
                file="svhn",
            )
            for split in ("train", "test", "extra")
        ]
    )


def usps():
    return itertools.chain(
        *[
            collect_download_configs(
357
                lambda: datasets.USPS(ROOT, train=train, download=True),
358
359
360
361
362
363
364
365
366
367
                name=f"USPS, {'train' if train else 'test'}",
                file="usps",
            )
            for train in (True, False)
        ]
    )


def celeba():
    return collect_download_configs(
368
        lambda: datasets.CelebA(ROOT, download=True),
369
370
371
372
373
374
375
        name="CelebA",
        file="celeba",
    )


def widerface():
    return collect_download_configs(
376
        lambda: datasets.WIDERFace(ROOT, download=True),
377
378
379
380
381
        name="WIDERFace",
        file="widerface",
    )


382
383
384
385
386
387
388
389
def make_parametrize_kwargs(download_configs):
    argvalues = []
    ids = []
    for config in download_configs:
        argvalues.append((config.url, config.md5))
        ids.append(config.id)

    return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids)
390
391


392
@pytest.mark.parametrize(
393
394
395
396
397
398
399
    **make_parametrize_kwargs(
        itertools.chain(
            places365(),
            caltech101(),
            caltech256(),
            cifar10(),
            cifar100(),
400
401
            # The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
            # voc(),
402
            mnist(),
Philip Meier's avatar
Philip Meier committed
403
            fashion_mnist(),
404
405
406
407
408
409
410
411
412
413
414
415
416
            kmnist(),
            emnist(),
            qmnist(),
            omniglot(),
            phototour(),
            sbdataset(),
            sbu(),
            semeion(),
            stl10(),
            svhn(),
            usps(),
            celeba(),
            widerface(),
417
418
        )
    )
419
)
420
421
def test_url_is_accessible(url, md5):
    retry(lambda: assert_url_is_accessible(url))
422
423


424
425
426
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain()))
def test_file_downloads_correctly(url, md5):
    retry(lambda: assert_file_downloads_correctly(url, md5))