test_datasets_download.py 11.5 KB
Newer Older
1
2
import contextlib
import itertools
3
import tempfile
4
import time
5
import traceback
6
import unittest.mock
7
import warnings
8
from datetime import datetime
9
from distutils import dir_util
10
from os import path
11
from urllib.error import HTTPError, URLError
12
from urllib.parse import urlparse
13
from urllib.request import Request, urlopen
14

15
import pytest
16
from torchvision import datasets
Philip Meier's avatar
Philip Meier committed
17
from torchvision.datasets.utils import _get_redirect_url, USER_AGENT
18
19


20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def limit_requests_per_time(min_secs_between_requests=2.0):
    last_requests = {}

    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
            url = request.full_url if isinstance(request, Request) else request

            netloc = urlparse(url).netloc
            last_request = last_requests.get(netloc)
            if last_request is not None:
                elapsed_secs = (datetime.now() - last_request).total_seconds()
                delta = min_secs_between_requests - elapsed_secs
                if delta > 0:
                    time.sleep(delta)

            response = fn(request, *args, **kwargs)
            last_requests[netloc] = datetime.now()

            return response

        return inner_wrapper

    return outer_wrapper


urlopen = limit_requests_per_time()(urlopen)


48
def resolve_redirects(max_hops=3):
Philip Meier's avatar
Philip Meier committed
49
50
    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
51
52
            initial_url = request.full_url if isinstance(request, Request) else request
            url = _get_redirect_url(initial_url, max_hops=max_hops)
Philip Meier's avatar
Philip Meier committed
53

54
55
            if url == initial_url:
                return fn(request, *args, **kwargs)
Philip Meier's avatar
Philip Meier committed
56

57
            warnings.warn(f"The URL {initial_url} ultimately redirects to {url}.")
Philip Meier's avatar
Philip Meier committed
58

59
60
61
62
63
64
65
66
67
68
69
            if not isinstance(request, Request):
                return fn(url, *args, **kwargs)

            request_attrs = {
                attr: getattr(request, attr) for attr in ("data", "headers", "origin_req_host", "unverifiable")
            }
            # the 'method' attribute does only exist if the request was created with it
            if hasattr(request, "method"):
                request_attrs["method"] = request.method

            return fn(Request(url, **request_attrs), *args, **kwargs)
Philip Meier's avatar
Philip Meier committed
70
71
72
73
74
75
76
77
78

        return inner_wrapper

    return outer_wrapper


urlopen = resolve_redirects()(urlopen)


79
@contextlib.contextmanager
80
def log_download_attempts(
Philip Meier's avatar
Philip Meier committed
81
82
83
    urls,
    *,
    dataset_module,
84
):
Philip Meier's avatar
Philip Meier committed
85
86
87
    def maybe_add_mock(*, module, name, stack, lst=None):
        patcher = unittest.mock.patch(f"torchvision.datasets.{module}.{name}")

88
        try:
Philip Meier's avatar
Philip Meier committed
89
90
91
            mock = stack.enter_context(patcher)
        except AttributeError:
            return
92

Philip Meier's avatar
Philip Meier committed
93
94
        if lst is not None:
            lst.append(mock)
95

Philip Meier's avatar
Philip Meier committed
96
97
98
99
100
101
102
103
104
105
106
107
    with contextlib.ExitStack() as stack:
        download_url_mocks = []
        download_file_from_google_drive_mocks = []
        for module in [dataset_module, "utils"]:
            maybe_add_mock(module=module, name="download_url", stack=stack, lst=download_url_mocks)
            maybe_add_mock(
                module=module,
                name="download_file_from_google_drive",
                stack=stack,
                lst=download_file_from_google_drive_mocks,
            )
            maybe_add_mock(module=module, name="extract_archive", stack=stack)
108

109
        try:
Philip Meier's avatar
Philip Meier committed
110
            yield
111
        finally:
Philip Meier's avatar
Philip Meier committed
112
113
114
            for download_url_mock in download_url_mocks:
                for args, kwargs in download_url_mock.call_args_list:
                    urls.append(args[0] if args else kwargs["url"])
115

Philip Meier's avatar
Philip Meier committed
116
117
118
119
            for download_file_from_google_drive_mock in download_file_from_google_drive_mocks:
                for args, kwargs in download_file_from_google_drive_mock.call_args_list:
                    file_id = args[0] if args else kwargs["file_id"]
                    urls.append(f"https://drive.google.com/file/d/{file_id}")
120

121
122

def retry(fn, times=1, wait=5.0):
123
    tbs = []
124
125
126
127
    for _ in range(times + 1):
        try:
            return fn()
        except AssertionError as error:
128
            tbs.append("".join(traceback.format_exception(type(error), error, error.__traceback__)))
129
130
131
132
133
            time.sleep(wait)
    else:
        raise AssertionError(
            "\n".join(
                (
134
135
136
137
138
139
                    "\n",
                    *[f"{'_' * 40}  {idx:2d}  {'_' * 40}\n\n{tb}" for idx, tb in enumerate(tbs, 1)],
                    (
                        f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time. "
                        f"You can find the the full tracebacks above."
                    ),
140
141
142
143
144
                )
            )
        )


145
146
147
148
149
150
@contextlib.contextmanager
def assert_server_response_ok():
    try:
        yield
    except HTTPError as error:
        raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
151
152
153
154
    except URLError as error:
        raise AssertionError(
            "Connection not possible due to SSL." if "SSL" in str(error) else "The request timed out."
        ) from error
Philip Meier's avatar
Philip Meier committed
155
156
    except RecursionError as error:
        raise AssertionError(str(error)) from error
157
158


Philip Meier's avatar
Philip Meier committed
159
def assert_url_is_accessible(url, timeout=5.0):
160
    request = Request(url, headers={"User-Agent": USER_AGENT}, method="HEAD")
161
    with assert_server_response_ok():
Philip Meier's avatar
Philip Meier committed
162
        urlopen(request, timeout=timeout)
163
164


Philip Meier's avatar
Philip Meier committed
165
166
167
168
169
170
def collect_urls(dataset_cls, *args, **kwargs):
    urls = []
    with contextlib.suppress(Exception), log_download_attempts(
        urls, dataset_module=dataset_cls.__module__.split(".")[-1]
    ):
        dataset_cls(*args, **kwargs)
171

Philip Meier's avatar
Philip Meier committed
172
    return [(url, f"{dataset_cls.__name__}, {url}") for url in urls]
173
174


175
176
177
178
179
180
181
182
183
184
185
# This is a workaround since fixtures, such as the built-in tmp_dir, can only be used within a test but not within a
# parametrization. Thus, we use a single root directory for all datasets and remove it when all download tests are run.
ROOT = tempfile.mkdtemp()


@pytest.fixture(scope="module", autouse=True)
def root():
    yield ROOT
    dir_util.remove_tree(ROOT)


186
def places365():
Philip Meier's avatar
Philip Meier committed
187
188
189
190
191
192
193
194
    return itertools.chain.from_iterable(
        [
            collect_urls(
                datasets.Places365,
                ROOT,
                split=split,
                small=small,
                download=True,
195
196
197
198
            )
            for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True))
        ]
    )
199
200
201


def caltech101():
Philip Meier's avatar
Philip Meier committed
202
    return collect_urls(datasets.Caltech101, ROOT, download=True)
203
204
205


def caltech256():
Philip Meier's avatar
Philip Meier committed
206
    return collect_urls(datasets.Caltech256, ROOT, download=True)
207
208
209


def cifar10():
Philip Meier's avatar
Philip Meier committed
210
    return collect_urls(datasets.CIFAR10, ROOT, download=True)
211

212

213
def cifar100():
Philip Meier's avatar
Philip Meier committed
214
    return collect_urls(datasets.CIFAR100, ROOT, download=True)
215
216


217
def voc():
218
    # TODO: Also test the "2007-test" key
Philip Meier's avatar
Philip Meier committed
219
220
221
    return itertools.chain.from_iterable(
        [
            collect_urls(datasets.VOCSegmentation, ROOT, year=year, download=True)
222
            for year in ("2007", "2008", "2009", "2010", "2011", "2012")
223
224
225
226
227
        ]
    )


def mnist():
228
    with unittest.mock.patch.object(datasets.MNIST, "mirrors", datasets.MNIST.mirrors[-1:]):
Philip Meier's avatar
Philip Meier committed
229
        return collect_urls(datasets.MNIST, ROOT, download=True)
230
231


Philip Meier's avatar
Philip Meier committed
232
def fashion_mnist():
Philip Meier's avatar
Philip Meier committed
233
    return collect_urls(datasets.FashionMNIST, ROOT, download=True)
Philip Meier's avatar
Philip Meier committed
234
235


236
def kmnist():
Philip Meier's avatar
Philip Meier committed
237
    return collect_urls(datasets.KMNIST, ROOT, download=True)
238
239
240
241


def emnist():
    # the 'split' argument can be any valid one, since everything is downloaded anyway
Philip Meier's avatar
Philip Meier committed
242
    return collect_urls(datasets.EMNIST, ROOT, split="byclass", download=True)
243
244
245


def qmnist():
Philip Meier's avatar
Philip Meier committed
246
247
    return itertools.chain.from_iterable(
        [collect_urls(datasets.QMNIST, ROOT, what=what, download=True) for what in ("train", "test", "nist")]
248
249
250
    )


Akira Noda's avatar
Akira Noda committed
251
def moving_mnist():
Philip Meier's avatar
Philip Meier committed
252
    return collect_urls(datasets.MovingMNIST, ROOT, download=True)
Akira Noda's avatar
Akira Noda committed
253
254


255
def omniglot():
Philip Meier's avatar
Philip Meier committed
256
257
    return itertools.chain.from_iterable(
        [collect_urls(datasets.Omniglot, ROOT, background=background, download=True) for background in (True, False)]
258
259
260
261
    )


def phototour():
Philip Meier's avatar
Philip Meier committed
262
263
264
    return itertools.chain.from_iterable(
        [
            collect_urls(datasets.PhotoTour, ROOT, name=name, download=True)
265
266
267
268
269
270
271
272
            # The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
            # requests timeout from within CI. They are disabled until this is resolved.
            for name in ("notredame", "yosemite", "liberty")  # "notredame_harris", "yosemite_harris", "liberty_harris"
        ]
    )


def sbdataset():
Philip Meier's avatar
Philip Meier committed
273
    return collect_urls(datasets.SBDataset, ROOT, download=True)
274
275
276


def sbu():
Philip Meier's avatar
Philip Meier committed
277
    return collect_urls(datasets.SBU, ROOT, download=True)
278
279
280


def semeion():
Philip Meier's avatar
Philip Meier committed
281
    return collect_urls(datasets.SEMEION, ROOT, download=True)
282
283
284


def stl10():
Philip Meier's avatar
Philip Meier committed
285
    return collect_urls(datasets.STL10, ROOT, download=True)
286
287
288


def svhn():
Philip Meier's avatar
Philip Meier committed
289
290
    return itertools.chain.from_iterable(
        [collect_urls(datasets.SVHN, ROOT, split=split, download=True) for split in ("train", "test", "extra")]
291
292
293
294
    )


def usps():
Philip Meier's avatar
Philip Meier committed
295
296
    return itertools.chain.from_iterable(
        [collect_urls(datasets.USPS, ROOT, train=train, download=True) for train in (True, False)]
297
298
299
300
    )


def celeba():
Philip Meier's avatar
Philip Meier committed
301
    return collect_urls(datasets.CelebA, ROOT, download=True)
302
303
304


def widerface():
Philip Meier's avatar
Philip Meier committed
305
    return collect_urls(datasets.WIDERFace, ROOT, download=True)
306
307


308
def kinetics():
Philip Meier's avatar
Philip Meier committed
309
310
311
312
313
314
315
316
317
    return itertools.chain.from_iterable(
        [
            collect_urls(
                datasets.Kinetics,
                path.join(ROOT, f"Kinetics{num_classes}"),
                frames_per_clip=1,
                num_classes=num_classes,
                split=split,
                download=True,
318
319
320
321
322
323
            )
            for num_classes, split in itertools.product(("400", "600", "700"), ("train", "val"))
        ]
    )


324
def kitti():
Philip Meier's avatar
Philip Meier committed
325
326
    return itertools.chain.from_iterable(
        [collect_urls(datasets.Kitti, ROOT, train=train, download=True) for train in (True, False)]
327
328
329
    )


Philip Meier's avatar
Philip Meier committed
330
331
332
333
334
335
336
337
338
339
340
341
342
343
def stanford_cars():
    return itertools.chain.from_iterable(
        [collect_urls(datasets.StanfordCars, ROOT, split=split, download=True) for split in ["train", "test"]]
    )


def url_parametrization(*dataset_urls_and_ids_fns):
    return pytest.mark.parametrize(
        "url",
        [
            pytest.param(url, id=id)
            for dataset_urls_and_ids_fn in dataset_urls_and_ids_fns
            for url, id in sorted(set(dataset_urls_and_ids_fn()))
        ],
344
    )
Philip Meier's avatar
Philip Meier committed
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371


@url_parametrization(
    caltech101,
    caltech256,
    cifar10,
    cifar100,
    # The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
    # voc,
    mnist,
    fashion_mnist,
    kmnist,
    emnist,
    qmnist,
    omniglot,
    phototour,
    sbdataset,
    semeion,
    stl10,
    svhn,
    usps,
    celeba,
    widerface,
    kinetics,
    kitti,
    places365,
    sbu,
372
)
Philip Meier's avatar
Philip Meier committed
373
def test_url_is_accessible(url):
374
375
376
377
378
379
380
    """
    If you see this test failing, find the offending dataset in the parametrization and move it to
    ``test_url_is_not_accessible`` and link an issue detailing the problem.
    """
    retry(lambda: assert_url_is_accessible(url))


Philip Meier's avatar
Philip Meier committed
381
382
@url_parametrization(
    stanford_cars,  # https://github.com/pytorch/vision/issues/7545
383
384
)
@pytest.mark.xfail
Philip Meier's avatar
Philip Meier committed
385
def test_url_is_not_accessible(url):
386
387
388
389
390
391
392
393
    """
    As the name implies, this test is the 'inverse' of ``test_url_is_accessible``. Since the download servers are
    beyond our control, some files might not be accessible for longer stretches of time. Still, we want to know if they
    come back up, or if we need to remove the download functionality of the dataset for good.

    If you see this test failing, find the offending dataset in the parametrization and move it to
    ``test_url_is_accessible``.
    """
394
    retry(lambda: assert_url_is_accessible(url))