"sgl-kernel/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "0a56b721d5531d674bd332b4d051a75857033e6c"
test_datasets_download.py 11.6 KB
Newer Older
1
2
import contextlib
import itertools
3
import time
4
import unittest.mock
5
from datetime import datetime
6
from os import path
7
from urllib.error import HTTPError, URLError
8
from urllib.parse import urlparse
9
from urllib.request import urlopen, Request
Philip Meier's avatar
Philip Meier committed
10
import warnings
11

12
13
import pytest

14
from torchvision import datasets
15
from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive
16
17
18
19
20

from common_utils import get_tmp_dir
from fakedata_generation import places365_root


21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def limit_requests_per_time(min_secs_between_requests=2.0):
    last_requests = {}

    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
            url = request.full_url if isinstance(request, Request) else request

            netloc = urlparse(url).netloc
            last_request = last_requests.get(netloc)
            if last_request is not None:
                elapsed_secs = (datetime.now() - last_request).total_seconds()
                delta = min_secs_between_requests - elapsed_secs
                if delta > 0:
                    time.sleep(delta)

            response = fn(request, *args, **kwargs)
            last_requests[netloc] = datetime.now()

            return response

        return inner_wrapper

    return outer_wrapper


urlopen = limit_requests_per_time()(urlopen)


Philip Meier's avatar
Philip Meier committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def resolve_redirects(max_redirects=3):
    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
            url = initial_url = request.full_url if isinstance(request, Request) else request

            for _ in range(max_redirects + 1):
                response = fn(request, *args, **kwargs)

                if response.url == url or response.url is None:
                    if url != initial_url:
                        warnings.warn(f"The URL {initial_url} ultimately redirects to {url}.")
                    return response

                url = response.url
            else:
                raise RecursionError(f"Request to {initial_url} exceeded {max_redirects} redirects.")

        return inner_wrapper

    return outer_wrapper


urlopen = resolve_redirects()(urlopen)


74
@contextlib.contextmanager
75
76
def log_download_attempts(
    urls_and_md5s=None,
77
    file="utils",
78
    patch=True,
79
    mock_auxiliaries=None,
80
):
81
82
83
84
85
86
87
88
89
    def add_mock(stack, name, file, **kwargs):
        try:
            return stack.enter_context(unittest.mock.patch(f"torchvision.datasets.{file}.{name}", **kwargs))
        except AttributeError as error:
            if file != "utils":
                return add_mock(stack, name, "utils", **kwargs)
            else:
                raise pytest.UsageError from error

90
91
    if urls_and_md5s is None:
        urls_and_md5s = set()
92
93
    if mock_auxiliaries is None:
        mock_auxiliaries = patch
94
95

    with contextlib.ExitStack() as stack:
96
97
98
        url_mock = add_mock(stack, "download_url", file, wraps=None if patch else download_url)
        google_drive_mock = add_mock(
            stack, "download_file_from_google_drive", file, wraps=None if patch else download_file_from_google_drive
99
        )
100
101
102
103

        if mock_auxiliaries:
            add_mock(stack, "extract_archive", file)

104
105
106
        try:
            yield urls_and_md5s
        finally:
107
            for args, kwargs in url_mock.call_args_list:
108
109
110
111
                url = args[0]
                md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
                urls_and_md5s.add((url, md5))

112
113
114
115
116
117
            for args, kwargs in google_drive_mock.call_args_list:
                id = args[0]
                url = f"https://drive.google.com/file/d/{id}"
                md5 = args[3] if len(args) == 4 else kwargs.get("md5")
                urls_and_md5s.add((url, md5))

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

def retry(fn, times=1, wait=5.0):
    msgs = []
    for _ in range(times + 1):
        try:
            return fn()
        except AssertionError as error:
            msgs.append(str(error))
            time.sleep(wait)
    else:
        raise AssertionError(
            "\n".join(
                (
                    f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n",
                    *(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)),
133
134
135
136
137
                )
            )
        )


138
139
140
141
@contextlib.contextmanager
def assert_server_response_ok():
    try:
        yield
142
143
    except URLError as error:
        raise AssertionError("The request timed out.") from error
144
145
    except HTTPError as error:
        raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
Philip Meier's avatar
Philip Meier committed
146
147
    except RecursionError as error:
        raise AssertionError(str(error)) from error
148
149


Philip Meier's avatar
Philip Meier committed
150
def assert_url_is_accessible(url, timeout=5.0):
151
    request = Request(url, headers=dict(method="HEAD"))
152
    with assert_server_response_ok():
Philip Meier's avatar
Philip Meier committed
153
        urlopen(request, timeout=timeout)
154
155


Philip Meier's avatar
Philip Meier committed
156
def assert_file_downloads_correctly(url, md5, timeout=5.0):
157
158
    with get_tmp_dir() as root:
        file = path.join(root, path.basename(url))
159
        with assert_server_response_ok():
Philip Meier's avatar
Philip Meier committed
160
161
            with open(file, "wb") as fh:
                response = urlopen(url, timeout=timeout)
162
                fh.write(response.read())
163
164
165
166
167
168
169
170
171
172

        assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"


class DownloadConfig:
    def __init__(self, url, md5=None, id=None):
        self.url = url
        self.md5 = md5
        self.id = id or url

173
174
    def __repr__(self):
        return self.id
175
176


177
178
179
180
def make_download_configs(urls_and_md5s, name=None):
    return [
        DownloadConfig(url, md5=md5, id=f"{name}, {url}" if name is not None else None) for url, md5 in urls_and_md5s
    ]
181
182


183
184
185
186
187
188
189
190
191
192
193
def collect_download_configs(dataset_loader, name=None, **kwargs):
    urls_and_md5s = set()
    try:
        with log_download_attempts(urls_and_md5s=urls_and_md5s, **kwargs):
            dataset = dataset_loader()
    except Exception:
        dataset = None

    if name is None and dataset is not None:
        name = type(dataset).__name__

194
195
196
    return make_download_configs(urls_and_md5s, name)


197
198
199
200
201
202
203
204
def places365():
    with log_download_attempts(patch=False) as urls_and_md5s:
        for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
            with places365_root(split=split, small=small) as places365:
                root, data = places365

                datasets.Places365(root, split=split, small=small, download=True)

205
    return make_download_configs(urls_and_md5s, name="Places365")
206
207
208


def caltech101():
209
    return collect_download_configs(lambda: datasets.Caltech101(".", download=True), name="Caltech101")
210
211
212


def caltech256():
213
    return collect_download_configs(lambda: datasets.Caltech256(".", download=True), name="Caltech256")
214
215
216


def cifar10():
217
    return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR10")
218

219

220
def cifar100():
221
    return collect_download_configs(lambda: datasets.CIFAR100(".", download=True), name="CIFAR100")
222
223


224
def voc():
225
226
227
228
229
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.VOCSegmentation(".", year=year, download=True),
                name=f"VOC, {year}",
230
                file="voc",
231
232
233
234
235
236
237
238
            )
            for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012")
        ]
    )


def mnist():
    return collect_download_configs(lambda: datasets.MNIST(".", download=True), name="MNIST")
239
240


Philip Meier's avatar
Philip Meier committed
241
242
243
244
def fashion_mnist():
    return collect_download_configs(lambda: datasets.FashionMNIST(".", download=True), name="FashionMNIST")


245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def kmnist():
    return collect_download_configs(lambda: datasets.KMNIST(".", download=True), name="KMNIST")


def emnist():
    # the 'split' argument can be any valid one, since everything is downloaded anyway
    return collect_download_configs(lambda: datasets.EMNIST(".", split="byclass", download=True), name="EMNIST")


def qmnist():
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.QMNIST(".", what=what, download=True),
                name=f"QMNIST, {what}",
                file="mnist",
            )
            for what in ("train", "test", "nist")
        ]
    )


def omniglot():
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.Omniglot(".", background=background, download=True),
                name=f"Omniglot, {'background' if background else 'evaluation'}",
            )
            for background in (True, False)
        ]
    )


def phototour():
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.PhotoTour(".", name=name, download=True),
                name=f"PhotoTour, {name}",
                file="phototour",
            )
            # The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
            # requests timeout from within CI. They are disabled until this is resolved.
            for name in ("notredame", "yosemite", "liberty")  # "notredame_harris", "yosemite_harris", "liberty_harris"
        ]
    )


def sbdataset():
    return collect_download_configs(
        lambda: datasets.SBDataset(".", download=True),
        name="SBDataset",
        file="voc",
    )


def sbu():
    return collect_download_configs(
        lambda: datasets.SBU(".", download=True),
        name="SBU",
        file="sbu",
    )


def semeion():
    return collect_download_configs(
        lambda: datasets.SEMEION(".", download=True),
        name="SEMEION",
        file="semeion",
    )


def stl10():
    return collect_download_configs(
        lambda: datasets.STL10(".", download=True),
        name="STL10",
    )


def svhn():
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.SVHN(".", split=split, download=True),
                name=f"SVHN, {split}",
                file="svhn",
            )
            for split in ("train", "test", "extra")
        ]
    )


def usps():
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.USPS(".", train=train, download=True),
                name=f"USPS, {'train' if train else 'test'}",
                file="usps",
            )
            for train in (True, False)
        ]
    )


def celeba():
    return collect_download_configs(
        lambda: datasets.CelebA(".", download=True),
        name="CelebA",
        file="celeba",
    )


def widerface():
    return collect_download_configs(
        lambda: datasets.WIDERFace(".", download=True),
        name="WIDERFace",
        file="widerface",
    )


367
368
369
370
371
372
373
374
def make_parametrize_kwargs(download_configs):
    argvalues = []
    ids = []
    for config in download_configs:
        argvalues.append((config.url, config.md5))
        ids.append(config.id)

    return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids)
375
376


377
@pytest.mark.parametrize(
378
379
380
381
382
383
384
    **make_parametrize_kwargs(
        itertools.chain(
            places365(),
            caltech101(),
            caltech256(),
            cifar10(),
            cifar100(),
385
386
            # The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
            # voc(),
387
            mnist(),
Philip Meier's avatar
Philip Meier committed
388
            fashion_mnist(),
389
390
391
392
393
394
395
396
397
398
399
400
401
            kmnist(),
            emnist(),
            qmnist(),
            omniglot(),
            phototour(),
            sbdataset(),
            sbu(),
            semeion(),
            stl10(),
            svhn(),
            usps(),
            celeba(),
            widerface(),
402
403
        )
    )
404
)
405
406
def test_url_is_accessible(url, md5):
    retry(lambda: assert_url_is_accessible(url))
407
408


409
410
411
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain()))
def test_file_downloads_correctly(url, md5):
    retry(lambda: assert_file_downloads_correctly(url, md5))