Unverified Commit e38348ae authored by Lucain's avatar Lucain Committed by GitHub
Browse files

Fix RequestCounter to make it more future-proof (#27406)

* Fix RequestCounter to make it more future-proof

* code quality
parent c8b6052f
...@@ -29,14 +29,15 @@ import sys ...@@ -29,14 +29,15 @@ import sys
import tempfile import tempfile
import time import time
import unittest import unittest
from collections import defaultdict
from collections.abc import Mapping from collections.abc import Mapping
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
from unittest import mock from unittest import mock
from unittest.mock import patch
import huggingface_hub import urllib3
import requests
from transformers import logging as transformers_logging from transformers import logging as transformers_logging
...@@ -1983,32 +1984,40 @@ def run_command(command: List[str], return_stdout=False): ...@@ -1983,32 +1984,40 @@ def run_command(command: List[str], return_stdout=False):
class RequestCounter: class RequestCounter:
""" """
Helper class that will count all requests made online. Helper class that will count all requests made online.
Might not be robust if urllib3 changes its logging format but should be good enough for us.
Usage:
```py
with RequestCounter() as counter:
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
assert counter["GET"] == 0
assert counter["HEAD"] == 1
assert counter.total_calls == 1
```
""" """
def __enter__(self): def __enter__(self):
self.head_request_count = 0 self._counter = defaultdict(int)
self.get_request_count = 0 self.patcher = patch.object(urllib3.connectionpool.log, "debug", wraps=urllib3.connectionpool.log.debug)
self.other_request_count = 0 self.mock = self.patcher.start()
# Mock `get_session` to count HTTP calls.
self.old_get_session = huggingface_hub.utils._http.get_session
self.session = requests.Session()
self.session.request = self.new_request
huggingface_hub.utils._http.get_session = lambda: self.session
return self return self
def __exit__(self, *args, **kwargs): def __exit__(self, *args, **kwargs) -> None:
huggingface_hub.utils._http.get_session = self.old_get_session for call in self.mock.call_args_list:
log = call.args[0] % call.args[1:]
for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"):
if method in log:
self._counter[method] += 1
break
self.patcher.stop()
def new_request(self, method, **kwargs): def __getitem__(self, key: str) -> int:
if method == "GET": return self._counter[key]
self.get_request_count += 1
elif method == "HEAD":
self.head_request_count += 1
else:
self.other_request_count += 1
return requests.request(method=method, **kwargs) @property
def total_calls(self) -> int:
return sum(self._counter.values())
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None): def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
......
...@@ -482,25 +482,22 @@ class AutoModelTest(unittest.TestCase): ...@@ -482,25 +482,22 @@ class AutoModelTest(unittest.TestCase):
with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"): with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"):
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") _ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
@unittest.skip(
"Currently failing with new huggingface_hub release. See: https://github.com/huggingface/transformers/pull/27389"
)
def test_cached_model_has_minimum_calls_to_head(self): def test_cached_model_has_minimum_calls_to_head(self):
# Make sure we have cached the model. # Make sure we have cached the model.
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert") _ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter: with RequestCounter() as counter:
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert") _ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0) self.assertEqual(counter["GET"], 0)
self.assertEqual(counter.head_request_count, 1) self.assertEqual(counter["HEAD"], 1)
self.assertEqual(counter.other_request_count, 0) self.assertEqual(counter.total_calls, 1)
# With a sharded checkpoint # With a sharded checkpoint
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded") _ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
with RequestCounter() as counter: with RequestCounter() as counter:
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded") _ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0) self.assertEqual(counter["GET"], 0)
self.assertEqual(counter.head_request_count, 1) self.assertEqual(counter["HEAD"], 1)
self.assertEqual(counter.other_request_count, 0) self.assertEqual(counter.total_calls, 1)
def test_attr_not_existing(self): def test_attr_not_existing(self):
from transformers.models.auto.auto_factory import _LazyAutoMapping from transformers.models.auto.auto_factory import _LazyAutoMapping
......
...@@ -301,14 +301,14 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -301,14 +301,14 @@ class TFAutoModelTest(unittest.TestCase):
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert") _ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter: with RequestCounter() as counter:
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert") _ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0) self.assertEqual(counter["GET"], 0)
self.assertEqual(counter.head_request_count, 1) self.assertEqual(counter["HEAD"], 1)
self.assertEqual(counter.other_request_count, 0) self.assertEqual(counter.total_calls, 1)
# With a sharded checkpoint # With a sharded checkpoint
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded") _ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
with RequestCounter() as counter: with RequestCounter() as counter:
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded") _ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0) self.assertEqual(counter["GET"], 0)
self.assertEqual(counter.head_request_count, 1) self.assertEqual(counter["HEAD"], 1)
self.assertEqual(counter.other_request_count, 0) self.assertEqual(counter.total_calls, 1)
...@@ -419,14 +419,11 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -419,14 +419,11 @@ class AutoTokenizerTest(unittest.TestCase):
): ):
_ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa") _ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
@unittest.skip(
"Currently failing with new huggingface_hub release. See: https://github.com/huggingface/transformers/pull/27389"
)
def test_cached_tokenizer_has_minimum_calls_to_head(self): def test_cached_tokenizer_has_minimum_calls_to_head(self):
# Make sure we have cached the tokenizer. # Make sure we have cached the tokenizer.
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter: with RequestCounter() as counter:
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0) self.assertEqual(counter["GET"], 0)
self.assertEqual(counter.head_request_count, 1) self.assertEqual(counter["HEAD"], 1)
self.assertEqual(counter.other_request_count, 0) self.assertEqual(counter.total_calls, 1)
...@@ -763,9 +763,9 @@ class CustomPipelineTest(unittest.TestCase): ...@@ -763,9 +763,9 @@ class CustomPipelineTest(unittest.TestCase):
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert") _ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter: with RequestCounter() as counter:
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert") _ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0) self.assertEqual(counter["GET"], 0)
self.assertEqual(counter.head_request_count, 1) self.assertEqual(counter["HEAD"], 1)
self.assertEqual(counter.other_request_count, 0) self.assertEqual(counter.total_calls, 1)
@require_torch @require_torch
def test_chunk_pipeline_batching_single_file(self): def test_chunk_pipeline_batching_single_file(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment