test_datasets_download.py 6.76 KB
Newer Older
1
2
import contextlib
import itertools
3
import time
4
import unittest.mock
5
from datetime import datetime
6
from os import path
7
from urllib.error import HTTPError
8
from urllib.parse import urlparse
9
10
from urllib.request import urlopen, Request

11
12
import pytest

13
14
15
16
17
18
19
from torchvision import datasets
from torchvision.datasets.utils import download_url, check_integrity

from common_utils import get_tmp_dir
from fakedata_generation import places365_root


20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def limit_requests_per_time(min_secs_between_requests=2.0):
    last_requests = {}

    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
            url = request.full_url if isinstance(request, Request) else request

            netloc = urlparse(url).netloc
            last_request = last_requests.get(netloc)
            if last_request is not None:
                elapsed_secs = (datetime.now() - last_request).total_seconds()
                delta = min_secs_between_requests - elapsed_secs
                if delta > 0:
                    time.sleep(delta)

            response = fn(request, *args, **kwargs)
            last_requests[netloc] = datetime.now()

            return response

        return inner_wrapper

    return outer_wrapper


urlopen = limit_requests_per_time()(urlopen)


48
@contextlib.contextmanager
49
50
51
def log_download_attempts(
    urls_and_md5s=None,
    patch=True,
52
    download_url_location=".utils",
53
54
    patch_auxiliaries=None,
):
55
56
    if urls_and_md5s is None:
        urls_and_md5s = set()
57
58
    if download_url_location.startswith("."):
        download_url_location = f"torchvision.datasets{download_url_location}"
59
60
61
62
63
    if patch_auxiliaries is None:
        patch_auxiliaries = patch

    with contextlib.ExitStack() as stack:
        download_url_mock = stack.enter_context(
64
65
66
67
            unittest.mock.patch(
                f"{download_url_location}.download_url",
                wraps=None if patch else download_url,
            )
68
69
70
71
        )
        if patch_auxiliaries:
            # download_and_extract_archive
            stack.enter_context(unittest.mock.patch("torchvision.datasets.utils.extract_archive"))
72
73
74
        try:
            yield urls_and_md5s
        finally:
75
            for args, kwargs in download_url_mock.call_args_list:
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
                url = args[0]
                md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
                urls_and_md5s.add((url, md5))


def retry(fn, times=1, wait=5.0):
    msgs = []
    for _ in range(times + 1):
        try:
            return fn()
        except AssertionError as error:
            msgs.append(str(error))
            time.sleep(wait)
    else:
        raise AssertionError(
            "\n".join(
                (
                    f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n",
                    *(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)),
95
96
97
98
99
                )
            )
        )


100
101
102
103
104
105
@contextlib.contextmanager
def assert_server_response_ok():
    try:
        yield
    except HTTPError as error:
        raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
106
107
108
109


def assert_url_is_accessible(url):
    request = Request(url, headers=dict(method="HEAD"))
110
111
    with assert_server_response_ok():
        urlopen(request)
112
113
114
115
116


def assert_file_downloads_correctly(url, md5):
    with get_tmp_dir() as root:
        file = path.join(root, path.basename(url))
117
118
119
        with assert_server_response_ok():
            with urlopen(url) as response, open(file, "wb") as fh:
                fh.write(response.read())
120
121
122
123
124
125
126
127
128
129

        assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"


class DownloadConfig:
    def __init__(self, url, md5=None, id=None):
        self.url = url
        self.md5 = md5
        self.id = id or url

130
131
    def __repr__(self):
        return self.id
132
133


134
135
136
137
def make_download_configs(urls_and_md5s, name=None):
    return [
        DownloadConfig(url, md5=md5, id=f"{name}, {url}" if name is not None else None) for url, md5 in urls_and_md5s
    ]
138
139


140
141
142
143
144
145
146
147
148
149
150
def collect_download_configs(dataset_loader, name=None, **kwargs):
    urls_and_md5s = set()
    try:
        with log_download_attempts(urls_and_md5s=urls_and_md5s, **kwargs):
            dataset = dataset_loader()
    except Exception:
        dataset = None

    if name is None and dataset is not None:
        name = type(dataset).__name__

151
152
153
    return make_download_configs(urls_and_md5s, name)


154
155
156
157
158
159
160
161
def places365():
    with log_download_attempts(patch=False) as urls_and_md5s:
        for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
            with places365_root(split=split, small=small) as places365:
                root, data = places365

                datasets.Places365(root, split=split, small=small, download=True)

162
    return make_download_configs(urls_and_md5s, name="Places365")
163
164
165


def caltech101():
166
    return collect_download_configs(lambda: datasets.Caltech101(".", download=True), name="Caltech101")
167
168
169


def caltech256():
170
    return collect_download_configs(lambda: datasets.Caltech256(".", download=True), name="Caltech256")
171
172
173


def cifar10():
174
    return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR10")
175

176

177
def cifar100():
178
    return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR100")
179
180


181
def voc():
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    return itertools.chain(
        *[
            collect_download_configs(
                lambda: datasets.VOCSegmentation(".", year=year, download=True),
                name=f"VOC, {year}",
                download_url_location=".voc",
            )
            for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012")
        ]
    )


def mnist():
    return collect_download_configs(lambda: datasets.MNIST(".", download=True), name="MNIST")
196
197


Philip Meier's avatar
Philip Meier committed
198
199
200
201
def fashion_mnist():
    return collect_download_configs(lambda: datasets.FashionMNIST(".", download=True), name="FashionMNIST")


202
203
204
205
206
207
208
209
def make_parametrize_kwargs(download_configs):
    argvalues = []
    ids = []
    for config in download_configs:
        argvalues.append((config.url, config.md5))
        ids.append(config.id)

    return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids)
210
211


212
@pytest.mark.parametrize(
213
214
215
216
217
218
219
    **make_parametrize_kwargs(
        itertools.chain(
            places365(),
            caltech101(),
            caltech256(),
            cifar10(),
            cifar100(),
220
221
            # The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
            # voc(),
222
            mnist(),
Philip Meier's avatar
Philip Meier committed
223
            fashion_mnist(),
224
225
        )
    )
226
)
227
228
def test_url_is_accessible(url, md5):
    retry(lambda: assert_url_is_accessible(url))
229
230


231
232
233
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain()))
def test_file_downloads_correctly(url, md5):
    retry(lambda: assert_file_downloads_correctly(url, md5))