test_datasets_download.py 4.24 KB
Newer Older
1
2
import contextlib
import itertools
3
import time
4
5
import unittest
import unittest.mock
6
from datetime import datetime
7
from os import path
8
from urllib.parse import urlparse
9
10
11
12
13
14
15
16
17
from urllib.request import urlopen, Request

from torchvision import datasets
from torchvision.datasets.utils import download_url, check_integrity

from common_utils import get_tmp_dir
from fakedata_generation import places365_root


18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def limit_requests_per_time(min_secs_between_requests=2.0):
    last_requests = {}

    def outer_wrapper(fn):
        def inner_wrapper(request, *args, **kwargs):
            url = request.full_url if isinstance(request, Request) else request

            netloc = urlparse(url).netloc
            last_request = last_requests.get(netloc)
            if last_request is not None:
                elapsed_secs = (datetime.now() - last_request).total_seconds()
                delta = min_secs_between_requests - elapsed_secs
                if delta > 0:
                    time.sleep(delta)

            response = fn(request, *args, **kwargs)
            last_requests[netloc] = datetime.now()

            return response

        return inner_wrapper

    return outer_wrapper


urlopen = limit_requests_per_time()(urlopen)


46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class DownloadTester(unittest.TestCase):
    @staticmethod
    @contextlib.contextmanager
    def log_download_attempts(patch=True):
        urls_and_md5s = set()
        with unittest.mock.patch(
            "torchvision.datasets.utils.download_url", wraps=None if patch else download_url
        ) as mock:
            try:
                yield urls_and_md5s
            finally:
                for args, kwargs in mock.call_args_list:
                    url = args[0]
                    md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
                    urls_and_md5s.add((url, md5))

    @staticmethod
    def retry(fn, times=1, wait=5.0):
        msgs = []
        for _ in range(times + 1):
            try:
                return fn()
            except AssertionError as error:
                msgs.append(str(error))
70
                time.sleep(wait)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        else:
            raise AssertionError(
                "\n".join(
                    (
                        f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n",
                        *(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)),
                    )
                )
            )

    @staticmethod
    def assert_response_ok(response, url=None, ok=200):
        msg = f"The server returned status code {response.code}"
        if url is not None:
            msg += f"for the the URL {url}"
        assert response.code == ok, msg

    @staticmethod
    def assert_is_downloadable(url):
        request = Request(url, headers=dict(method="HEAD"))
        response = urlopen(request)
        DownloadTester.assert_response_ok(response, url)

    @staticmethod
    def assert_downloads_correctly(url, md5):
        with get_tmp_dir() as root:
            file = path.join(root, path.basename(url))
            with urlopen(url) as response, open(file, "wb") as fh:
                DownloadTester.assert_response_ok(response, url)
                fh.write(response.read())

            assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"

    def test_download(self):
        assert_fn = (
            lambda url, _: self.assert_is_downloadable(url)
            if self.only_test_downloadability
            else self.assert_downloads_correctly
        )
        for url, md5 in self.collect_urls_and_md5s():
            with self.subTest(url=url, md5=md5):
                self.retry(lambda: assert_fn(url, md5))

    def collect_urls_and_md5s(self):
        raise NotImplementedError

    @property
    def only_test_downloadability(self):
        return True


class Places365Tester(DownloadTester):
    def collect_urls_and_md5s(self):
        with self.log_download_attempts(patch=False) as urls_and_md5s:
            for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
                with places365_root(split=split, small=small) as places365:
                    root, data = places365

                    datasets.Places365(root, split=split, small=small, download=True)

        return urls_and_md5s