test_datasets_download.py 5.24 KB
Newer Older
1
2
import contextlib
import itertools
3
import time
4
import unittest.mock
5
from datetime import datetime
6
from os import path
7
from urllib.parse import urlparse
8
9
from urllib.request import urlopen, Request

10
11
import pytest

12
13
14
15
16
17
18
from torchvision import datasets
from torchvision.datasets.utils import download_url, check_integrity

from common_utils import get_tmp_dir
from fakedata_generation import places365_root


19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def limit_requests_per_time(min_secs_between_requests=2.0):
    last_requests = {}

    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
            url = request.full_url if isinstance(request, Request) else request

            netloc = urlparse(url).netloc
            last_request = last_requests.get(netloc)
            if last_request is not None:
                elapsed_secs = (datetime.now() - last_request).total_seconds()
                delta = min_secs_between_requests - elapsed_secs
                if delta > 0:
                    time.sleep(delta)

            response = fn(request, *args, **kwargs)
            last_requests[netloc] = datetime.now()

            return response

        return inner_wrapper

    return outer_wrapper


urlopen = limit_requests_per_time()(urlopen)


47
@contextlib.contextmanager
48
49
50
51
52
53
54
55
56
57
58
59
60
def log_download_attempts(urls_and_md5s=None, patch=True, patch_auxiliaries=None):
    if urls_and_md5s is None:
        urls_and_md5s = set()
    if patch_auxiliaries is None:
        patch_auxiliaries = patch

    with contextlib.ExitStack() as stack:
        download_url_mock = stack.enter_context(
            unittest.mock.patch("torchvision.datasets.utils.download_url", wraps=None if patch else download_url)
        )
        if patch_auxiliaries:
            # download_and_extract_archive
            stack.enter_context(unittest.mock.patch("torchvision.datasets.utils.extract_archive"))
61
62
63
        try:
            yield urls_and_md5s
        finally:
64
            for args, kwargs in download_url_mock.call_args_list:
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
                url = args[0]
                md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
                urls_and_md5s.add((url, md5))


def retry(fn, times=1, wait=5.0):
    msgs = []
    for _ in range(times + 1):
        try:
            return fn()
        except AssertionError as error:
            msgs.append(str(error))
            time.sleep(wait)
    else:
        raise AssertionError(
            "\n".join(
                (
                    f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n",
                    *(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)),
84
85
86
87
88
                )
            )
        )


89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def assert_server_response_ok(response, url=None):
    msg = f"The server returned status code {response.code}"
    if url is not None:
        msg += f"for the the URL {url}"
    assert 200 <= response.code < 300, msg


def assert_url_is_accessible(url):
    request = Request(url, headers=dict(method="HEAD"))
    response = urlopen(request)
    assert_server_response_ok(response, url)


def assert_file_downloads_correctly(url, md5):
    with get_tmp_dir() as root:
        file = path.join(root, path.basename(url))
        with urlopen(url) as response, open(file, "wb") as fh:
            assert_server_response_ok(response, url)
            fh.write(response.read())

        assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"


class DownloadConfig:
    def __init__(self, url, md5=None, id=None):
        self.url = url
        self.md5 = md5
        self.id = id or url

118
119
    def __repr__(self):
        return self.id
120
121


122
123
124
125
def make_download_configs(urls_and_md5s, name=None):
    return [
        DownloadConfig(url, md5=md5, id=f"{name}, {url}" if name is not None else None) for url, md5 in urls_and_md5s
    ]
126
127
128
129
130
131
132
133
134
135


def places365():
    with log_download_attempts(patch=False) as urls_and_md5s:
        for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
            with places365_root(split=split, small=small) as places365:
                root, data = places365

                datasets.Places365(root, split=split, small=small, download=True)

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    return make_download_configs(urls_and_md5s, "Places365")


def caltech101():
    try:
        with log_download_attempts() as urls_and_md5s:
            datasets.Caltech101(".", download=True)
    except Exception:
        pass

    return make_download_configs(urls_and_md5s, "Caltech101")


def caltech256():
    try:
        with log_download_attempts() as urls_and_md5s:
            datasets.Caltech256(".", download=True)
    except Exception:
        pass

    return make_download_configs(urls_and_md5s, "Caltech256")


def make_parametrize_kwargs(download_configs):
    argvalues = []
    ids = []
    for config in download_configs:
        argvalues.append((config.url, config.md5))
        ids.append(config.id)

    return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids)
167
168


169
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain(places365(), caltech101(), caltech256())))
170
171
def test_url_is_accessible(url, md5):
    retry(lambda: assert_url_is_accessible(url))
172
173


174
175
176
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain()))
def test_file_downloads_correctly(url, md5):
    retry(lambda: assert_file_downloads_correctly(url, md5))