Unverified Commit 6a43a1f8 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

limit requests per time in download tests (#2699)

parent 1b415254
import contextlib import contextlib
import itertools import itertools
import time
import unittest import unittest
import unittest.mock import unittest.mock
from datetime import datetime
from os import path from os import path
from time import sleep from urllib.parse import urlparse
from urllib.request import urlopen, Request from urllib.request import urlopen, Request
from torchvision import datasets from torchvision import datasets
...@@ -13,6 +15,34 @@ from common_utils import get_tmp_dir ...@@ -13,6 +15,34 @@ from common_utils import get_tmp_dir
from fakedata_generation import places365_root from fakedata_generation import places365_root
def limit_requests_per_time(min_secs_between_requests=2.0):
last_requests = {}
def outer_wrapper(fn):
def inner_wrapper(request, *args, **kwargs):
url = request.full_url if isinstance(request, Request) else request
netloc = urlparse(url).netloc
last_request = last_requests.get(netloc)
if last_request is not None:
elapsed_secs = (datetime.now() - last_request).total_seconds()
delta = min_secs_between_requests - elapsed_secs
if delta > 0:
time.sleep(delta)
response = fn(request, *args, **kwargs)
last_requests[netloc] = datetime.now()
return response
return inner_wrapper
return outer_wrapper
urlopen = limit_requests_per_time()(urlopen)
class DownloadTester(unittest.TestCase): class DownloadTester(unittest.TestCase):
@staticmethod @staticmethod
@contextlib.contextmanager @contextlib.contextmanager
...@@ -37,7 +67,7 @@ class DownloadTester(unittest.TestCase): ...@@ -37,7 +67,7 @@ class DownloadTester(unittest.TestCase):
return fn() return fn()
except AssertionError as error: except AssertionError as error:
msgs.append(str(error)) msgs.append(str(error))
sleep(wait) time.sleep(wait)
else: else:
raise AssertionError( raise AssertionError(
"\n".join( "\n".join(
...@@ -80,7 +110,6 @@ class DownloadTester(unittest.TestCase): ...@@ -80,7 +110,6 @@ class DownloadTester(unittest.TestCase):
for url, md5 in self.collect_urls_and_md5s(): for url, md5 in self.collect_urls_and_md5s():
with self.subTest(url=url, md5=md5): with self.subTest(url=url, md5=md5):
self.retry(lambda: assert_fn(url, md5)) self.retry(lambda: assert_fn(url, md5))
sleep(2.0)
def collect_urls_and_md5s(self): def collect_urls_and_md5s(self):
raise NotImplementedError raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment