test_datasets_download.py 3.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import contextlib
import itertools
import unittest
import unittest.mock
from os import path
from time import sleep
from urllib.request import urlopen, Request

from torchvision import datasets
from torchvision.datasets.utils import download_url, check_integrity

from common_utils import get_tmp_dir
from fakedata_generation import places365_root


class DownloadTester(unittest.TestCase):
    @staticmethod
    @contextlib.contextmanager
    def log_download_attempts(patch=True):
        urls_and_md5s = set()
        with unittest.mock.patch(
            "torchvision.datasets.utils.download_url", wraps=None if patch else download_url
        ) as mock:
            try:
                yield urls_and_md5s
            finally:
                for args, kwargs in mock.call_args_list:
                    url = args[0]
                    md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
                    urls_and_md5s.add((url, md5))

    @staticmethod
    def retry(fn, times=1, wait=5.0):
        msgs = []
        for _ in range(times + 1):
            try:
                return fn()
            except AssertionError as error:
                msgs.append(str(error))
                sleep(wait)
        else:
            raise AssertionError(
                "\n".join(
                    (
                        f"Assertion failed {times + 1} times with {wait:.1f} seconds intermediate wait time.\n",
                        *(f"{idx}: {error}" for idx, error in enumerate(msgs, 1)),
                    )
                )
            )

    @staticmethod
    def assert_response_ok(response, url=None, ok=200):
        msg = f"The server returned status code {response.code}"
        if url is not None:
            msg += f"for the the URL {url}"
        assert response.code == ok, msg

    @staticmethod
    def assert_is_downloadable(url):
        request = Request(url, headers=dict(method="HEAD"))
        response = urlopen(request)
        DownloadTester.assert_response_ok(response, url)

    @staticmethod
    def assert_downloads_correctly(url, md5):
        with get_tmp_dir() as root:
            file = path.join(root, path.basename(url))
            with urlopen(url) as response, open(file, "wb") as fh:
                DownloadTester.assert_response_ok(response, url)
                fh.write(response.read())

            assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"

    def test_download(self):
        assert_fn = (
            lambda url, _: self.assert_is_downloadable(url)
            if self.only_test_downloadability
            else self.assert_downloads_correctly
        )
        for url, md5 in self.collect_urls_and_md5s():
            with self.subTest(url=url, md5=md5):
                self.retry(lambda: assert_fn(url, md5))
                sleep(2.0)

    def collect_urls_and_md5s(self):
        raise NotImplementedError

    @property
    def only_test_downloadability(self):
        return True


class Places365Tester(DownloadTester):
    def collect_urls_and_md5s(self):
        with self.log_download_attempts(patch=False) as urls_and_md5s:
            for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
                with places365_root(split=split, small=small) as places365:
                    root, data = places365

                    datasets.Places365(root, split=split, small=small, download=True)

        return urls_and_md5s