"src/vscode:/vscode.git/clone" did not exist on "f0fd29955bd1e7e39318a052c3259e4f51a706e7"
test_datasets_download.py 11.4 KB
Newer Older
1
2
import contextlib
import itertools
limm's avatar
limm committed
3
4
import shutil
import tempfile
5
import time
limm's avatar
limm committed
6
import traceback
7
import unittest.mock
limm's avatar
limm committed
8
import warnings
9
from datetime import datetime
10
from os import path
11
from urllib.error import HTTPError, URLError
12
from urllib.parse import urlparse
limm's avatar
limm committed
13
from urllib.request import Request, urlopen
14

15
import pytest
16
from torchvision import datasets
limm's avatar
limm committed
17
from torchvision.datasets.utils import _get_redirect_url, USER_AGENT
18
19


20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def limit_requests_per_time(min_secs_between_requests=2.0):
    last_requests = {}

    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
            url = request.full_url if isinstance(request, Request) else request

            netloc = urlparse(url).netloc
            last_request = last_requests.get(netloc)
            if last_request is not None:
                elapsed_secs = (datetime.now() - last_request).total_seconds()
                delta = min_secs_between_requests - elapsed_secs
                if delta > 0:
                    time.sleep(delta)

            response = fn(request, *args, **kwargs)
            last_requests[netloc] = datetime.now()

            return response

        return inner_wrapper

    return outer_wrapper


urlopen = limit_requests_per_time()(urlopen)


48
def resolve_redirects(max_hops=3):
Philip Meier's avatar
Philip Meier committed
49
50
    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
51
52
            initial_url = request.full_url if isinstance(request, Request) else request
            url = _get_redirect_url(initial_url, max_hops=max_hops)
Philip Meier's avatar
Philip Meier committed
53

54
55
            if url == initial_url:
                return fn(request, *args, **kwargs)
Philip Meier's avatar
Philip Meier committed
56

57
            warnings.warn(f"The URL {initial_url} ultimately redirects to {url}.")
Philip Meier's avatar
Philip Meier committed
58

59
60
61
62
63
64
65
66
67
68
69
            if not isinstance(request, Request):
                return fn(url, *args, **kwargs)

            request_attrs = {
                attr: getattr(request, attr) for attr in ("data", "headers", "origin_req_host", "unverifiable")
            }
            # the 'method' attribute does only exist if the request was created with it
            if hasattr(request, "method"):
                request_attrs["method"] = request.method

            return fn(Request(url, **request_attrs), *args, **kwargs)
Philip Meier's avatar
Philip Meier committed
70
71
72
73
74
75
76
77
78

        return inner_wrapper

    return outer_wrapper


urlopen = resolve_redirects()(urlopen)


79
@contextlib.contextmanager
80
def log_download_attempts(
limm's avatar
limm committed
81
82
83
    urls,
    *,
    dataset_module,
84
):
limm's avatar
limm committed
85
86
87
    def maybe_add_mock(*, module, name, stack, lst=None):
        patcher = unittest.mock.patch(f"torchvision.datasets.{module}.{name}")

88
        try:
limm's avatar
limm committed
89
90
91
            mock = stack.enter_context(patcher)
        except AttributeError:
            return
92

limm's avatar
limm committed
93
94
        if lst is not None:
            lst.append(mock)
95

limm's avatar
limm committed
96
97
98
99
100
101
102
103
104
105
106
107
    with contextlib.ExitStack() as stack:
        download_url_mocks = []
        download_file_from_google_drive_mocks = []
        for module in [dataset_module, "utils"]:
            maybe_add_mock(module=module, name="download_url", stack=stack, lst=download_url_mocks)
            maybe_add_mock(
                module=module,
                name="download_file_from_google_drive",
                stack=stack,
                lst=download_file_from_google_drive_mocks,
            )
            maybe_add_mock(module=module, name="extract_archive", stack=stack)
108

109
        try:
limm's avatar
limm committed
110
            yield
111
        finally:
limm's avatar
limm committed
112
113
114
            for download_url_mock in download_url_mocks:
                for args, kwargs in download_url_mock.call_args_list:
                    urls.append(args[0] if args else kwargs["url"])
115

limm's avatar
limm committed
116
117
118
119
            for download_file_from_google_drive_mock in download_file_from_google_drive_mocks:
                for args, kwargs in download_file_from_google_drive_mock.call_args_list:
                    file_id = args[0] if args else kwargs["file_id"]
                    urls.append(f"https://drive.google.com/file/d/{file_id}")
120

121
122

def retry(fn, times=1, wait=5.0):
limm's avatar
limm committed
123
    tbs = []
124
125
126
127
    for _ in range(times + 1):
        try:
            return fn()
        except AssertionError as error:
limm's avatar
limm committed
128
            tbs.append("".join(traceback.format_exception(type(error), error, error.__traceback__)))
129
130
131
132
133
            time.sleep(wait)
    else:
        raise AssertionError(
            "\n".join(
                (
limm's avatar
limm committed
134
135
136
137
138
139
                    "\n",
                    *[f"{'_' * 40}  {idx:2d}  {'_' * 40}\n\n{tb}" for idx, tb in enumerate(tbs, 1)],
                    (
                        f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time. "
                        f"You can find the the full tracebacks above."
                    ),
140
141
142
143
144
                )
            )
        )


145
146
147
148
149
150
@contextlib.contextmanager
def assert_server_response_ok():
    try:
        yield
    except HTTPError as error:
        raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
limm's avatar
limm committed
151
152
153
154
    except URLError as error:
        raise AssertionError(
            "Connection not possible due to SSL." if "SSL" in str(error) else "The request timed out."
        ) from error
Philip Meier's avatar
Philip Meier committed
155
156
    except RecursionError as error:
        raise AssertionError(str(error)) from error
157
158


Philip Meier's avatar
Philip Meier committed
159
def assert_url_is_accessible(url, timeout=5.0):
160
    request = Request(url, headers={"User-Agent": USER_AGENT}, method="HEAD")
161
    with assert_server_response_ok():
Philip Meier's avatar
Philip Meier committed
162
        urlopen(request, timeout=timeout)
163
164


limm's avatar
limm committed
165
166
167
168
169
170
def collect_urls(dataset_cls, *args, **kwargs):
    urls = []
    with contextlib.suppress(Exception), log_download_attempts(
        urls, dataset_module=dataset_cls.__module__.split(".")[-1]
    ):
        dataset_cls(*args, **kwargs)
171

limm's avatar
limm committed
172
    return [(url, f"{dataset_cls.__name__}, {url}") for url in urls]
173
174


175
176
177
178
179
180
181
182
# This is a workaround since fixtures, such as the built-in tmp_dir, can only be used within a test but not within a
# parametrization. Thus, we use a single root directory for all datasets and remove it when all download tests are run.
ROOT = tempfile.mkdtemp()


@pytest.fixture(scope="module", autouse=True)
def root():
    yield ROOT
limm's avatar
limm committed
183
    shutil.rmtree(ROOT)
184
185


186
def places365():
limm's avatar
limm committed
187
188
189
190
191
192
193
194
    return itertools.chain.from_iterable(
        [
            collect_urls(
                datasets.Places365,
                ROOT,
                split=split,
                small=small,
                download=True,
195
196
197
198
            )
            for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True))
        ]
    )
199
200
201


def caltech101():
limm's avatar
limm committed
202
    return collect_urls(datasets.Caltech101, ROOT, download=True)
203
204
205


def caltech256():
limm's avatar
limm committed
206
    return collect_urls(datasets.Caltech256, ROOT, download=True)
207
208
209


def cifar10():
limm's avatar
limm committed
210
    return collect_urls(datasets.CIFAR10, ROOT, download=True)
211

212

213
def cifar100():
limm's avatar
limm committed
214
    return collect_urls(datasets.CIFAR100, ROOT, download=True)
215
216


217
def voc():
limm's avatar
limm committed
218
219
220
221
222
    # TODO: Also test the "2007-test" key
    return itertools.chain.from_iterable(
        [
            collect_urls(datasets.VOCSegmentation, ROOT, year=year, download=True)
            for year in ("2007", "2008", "2009", "2010", "2011", "2012")
223
224
225
226
227
        ]
    )


def mnist():
228
    with unittest.mock.patch.object(datasets.MNIST, "mirrors", datasets.MNIST.mirrors[-1:]):
limm's avatar
limm committed
229
        return collect_urls(datasets.MNIST, ROOT, download=True)
230
231


Philip Meier's avatar
Philip Meier committed
232
def fashion_mnist():
limm's avatar
limm committed
233
    return collect_urls(datasets.FashionMNIST, ROOT, download=True)
Philip Meier's avatar
Philip Meier committed
234
235


236
def kmnist():
limm's avatar
limm committed
237
    return collect_urls(datasets.KMNIST, ROOT, download=True)
238
239
240
241


def emnist():
    # the 'split' argument can be any valid one, since everything is downloaded anyway
limm's avatar
limm committed
242
    return collect_urls(datasets.EMNIST, ROOT, split="byclass", download=True)
243
244
245


def qmnist():
limm's avatar
limm committed
246
247
    return itertools.chain.from_iterable(
        [collect_urls(datasets.QMNIST, ROOT, what=what, download=True) for what in ("train", "test", "nist")]
248
249
250
    )


limm's avatar
limm committed
251
252
253
254
def moving_mnist():
    return collect_urls(datasets.MovingMNIST, ROOT, download=True)


255
def omniglot():
limm's avatar
limm committed
256
257
    return itertools.chain.from_iterable(
        [collect_urls(datasets.Omniglot, ROOT, background=background, download=True) for background in (True, False)]
258
259
260
261
    )


def phototour():
limm's avatar
limm committed
262
263
264
    return itertools.chain.from_iterable(
        [
            collect_urls(datasets.PhotoTour, ROOT, name=name, download=True)
265
266
267
268
269
270
271
272
            # The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
            # requests timeout from within CI. They are disabled until this is resolved.
            for name in ("notredame", "yosemite", "liberty")  # "notredame_harris", "yosemite_harris", "liberty_harris"
        ]
    )


def sbdataset():
limm's avatar
limm committed
273
    return collect_urls(datasets.SBDataset, ROOT, download=True)
274
275
276


def sbu():
limm's avatar
limm committed
277
    return collect_urls(datasets.SBU, ROOT, download=True)
278
279
280


def semeion():
limm's avatar
limm committed
281
    return collect_urls(datasets.SEMEION, ROOT, download=True)
282
283
284


def stl10():
limm's avatar
limm committed
285
    return collect_urls(datasets.STL10, ROOT, download=True)
286
287
288


def svhn():
limm's avatar
limm committed
289
290
    return itertools.chain.from_iterable(
        [collect_urls(datasets.SVHN, ROOT, split=split, download=True) for split in ("train", "test", "extra")]
291
292
293
294
    )


def usps():
limm's avatar
limm committed
295
296
    return itertools.chain.from_iterable(
        [collect_urls(datasets.USPS, ROOT, train=train, download=True) for train in (True, False)]
297
298
299
300
    )


def celeba():
limm's avatar
limm committed
301
    return collect_urls(datasets.CelebA, ROOT, download=True)
302
303
304


def widerface():
limm's avatar
limm committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    return collect_urls(datasets.WIDERFace, ROOT, download=True)


def kinetics():
    return itertools.chain.from_iterable(
        [
            collect_urls(
                datasets.Kinetics,
                path.join(ROOT, f"Kinetics{num_classes}"),
                frames_per_clip=1,
                num_classes=num_classes,
                split=split,
                download=True,
            )
            for num_classes, split in itertools.product(("400", "600", "700"), ("train", "val"))
        ]
321
322
323
    )


324
def kitti():
limm's avatar
limm committed
325
326
    return itertools.chain.from_iterable(
        [collect_urls(datasets.Kitti, ROOT, train=train, download=True) for train in (True, False)]
327
328
329
    )


limm's avatar
limm committed
330
331
332
333
334
335
336
337
def url_parametrization(*dataset_urls_and_ids_fns):
    return pytest.mark.parametrize(
        "url",
        [
            pytest.param(url, id=id)
            for dataset_urls_and_ids_fn in dataset_urls_and_ids_fns
            for url, id in sorted(set(dataset_urls_and_ids_fn()))
        ],
338
    )
limm's avatar
limm committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365


@url_parametrization(
    caltech101,
    caltech256,
    cifar10,
    cifar100,
    # The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
    # voc,
    mnist,
    fashion_mnist,
    kmnist,
    emnist,
    qmnist,
    omniglot,
    phototour,
    sbdataset,
    semeion,
    stl10,
    svhn,
    usps,
    celeba,
    widerface,
    kinetics,
    kitti,
    places365,
    sbu,
366
)
limm's avatar
limm committed
367
368
369
370
371
def test_url_is_accessible(url):
    """
    If you see this test failing, find the offending dataset in the parametrization and move it to
    ``test_url_is_not_accessible`` and link an issue detailing the problem.
    """
372
    retry(lambda: assert_url_is_accessible(url))
373
374


limm's avatar
limm committed
375
376
377
378
379
380
381
382
383
384
385
386
387
388
# TODO: if e.g. caltech101 starts failing, remove the pytest.mark.parametrize below and use
# @url_parametrization(caltech101)
@pytest.mark.parametrize("url", ("http://url_that_doesnt_exist.com",))  # here until we actually have a failing dataset
@pytest.mark.xfail
def test_url_is_not_accessible(url):
    """
    As the name implies, this test is the 'inverse' of ``test_url_is_accessible``. Since the download servers are
    beyond our control, some files might not be accessible for longer stretches of time. Still, we want to know if they
    come back up, or if we need to remove the download functionality of the dataset for good.

    If you see this test failing, find the offending dataset in the parametrization and move it to
    ``test_url_is_accessible``.
    """
    assert_url_is_accessible(url)