Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e38348ae
Unverified
Commit
e38348ae
authored
Nov 09, 2023
by
Lucain
Committed by
GitHub
Nov 09, 2023
Browse files
Fix RequestCounter to make it more future-proof (#27406)
* Fix RequestCounter to make it more future-proof * code quality
parent
c8b6052f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
48 additions
and
45 deletions
+48
-45
src/transformers/testing_utils.py
src/transformers/testing_utils.py
+30
-21
tests/models/auto/test_modeling_auto.py
tests/models/auto/test_modeling_auto.py
+6
-9
tests/models/auto/test_modeling_tf_auto.py
tests/models/auto/test_modeling_tf_auto.py
+6
-6
tests/models/auto/test_tokenization_auto.py
tests/models/auto/test_tokenization_auto.py
+3
-6
tests/pipelines/test_pipelines_common.py
tests/pipelines/test_pipelines_common.py
+3
-3
No files found.
src/transformers/testing_utils.py
View file @
e38348ae
...
@@ -29,14 +29,15 @@ import sys
...
@@ -29,14 +29,15 @@ import sys
import
tempfile
import
tempfile
import
time
import
time
import
unittest
import
unittest
from
collections
import
defaultdict
from
collections.abc
import
Mapping
from
collections.abc
import
Mapping
from
io
import
StringIO
from
io
import
StringIO
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Callable
,
Dict
,
Iterable
,
Iterator
,
List
,
Optional
,
Union
from
typing
import
Callable
,
Dict
,
Iterable
,
Iterator
,
List
,
Optional
,
Union
from
unittest
import
mock
from
unittest
import
mock
from
unittest.mock
import
patch
import
huggingface_hub
import
urllib3
import
requests
from
transformers
import
logging
as
transformers_logging
from
transformers
import
logging
as
transformers_logging
...
@@ -1983,32 +1984,40 @@ def run_command(command: List[str], return_stdout=False):
...
@@ -1983,32 +1984,40 @@ def run_command(command: List[str], return_stdout=False):
class
RequestCounter
:
class
RequestCounter
:
"""
"""
Helper class that will count all requests made online.
Helper class that will count all requests made online.
Might not be robust if urllib3 changes its logging format but should be good enough for us.
Usage:
```py
with RequestCounter() as counter:
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
assert counter["GET"] == 0
assert counter["HEAD"] == 1
assert counter.total_calls == 1
```
"""
"""
def
__enter__
(
self
):
def
__enter__
(
self
):
self
.
head_request_count
=
0
self
.
_counter
=
defaultdict
(
int
)
self
.
get_request_count
=
0
self
.
patcher
=
patch
.
object
(
urllib3
.
connectionpool
.
log
,
"debug"
,
wraps
=
urllib3
.
connectionpool
.
log
.
debug
)
self
.
other_request_count
=
0
self
.
mock
=
self
.
patcher
.
start
()
# Mock `get_session` to count HTTP calls.
self
.
old_get_session
=
huggingface_hub
.
utils
.
_http
.
get_session
self
.
session
=
requests
.
Session
()
self
.
session
.
request
=
self
.
new_request
huggingface_hub
.
utils
.
_http
.
get_session
=
lambda
:
self
.
session
return
self
return
self
def
__exit__
(
self
,
*
args
,
**
kwargs
):
def
__exit__
(
self
,
*
args
,
**
kwargs
)
->
None
:
huggingface_hub
.
utils
.
_http
.
get_session
=
self
.
old_get_session
for
call
in
self
.
mock
.
call_args_list
:
log
=
call
.
args
[
0
]
%
call
.
args
[
1
:]
for
method
in
(
"HEAD"
,
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
,
"CONNECT"
,
"OPTIONS"
,
"TRACE"
,
"PATCH"
):
if
method
in
log
:
self
.
_counter
[
method
]
+=
1
break
self
.
patcher
.
stop
()
def
new_request
(
self
,
method
,
**
kwargs
):
def
__getitem__
(
self
,
key
:
str
)
->
int
:
if
method
==
"GET"
:
return
self
.
_counter
[
key
]
self
.
get_request_count
+=
1
elif
method
==
"HEAD"
:
self
.
head_request_count
+=
1
else
:
self
.
other_request_count
+=
1
return
requests
.
request
(
method
=
method
,
**
kwargs
)
@
property
def
total_calls
(
self
)
->
int
:
return
sum
(
self
.
_counter
.
values
())
def
is_flaky
(
max_attempts
:
int
=
5
,
wait_before_retry
:
Optional
[
float
]
=
None
,
description
:
Optional
[
str
]
=
None
):
def
is_flaky
(
max_attempts
:
int
=
5
,
wait_before_retry
:
Optional
[
float
]
=
None
,
description
:
Optional
[
str
]
=
None
):
...
...
tests/models/auto/test_modeling_auto.py
View file @
e38348ae
...
@@ -482,25 +482,22 @@ class AutoModelTest(unittest.TestCase):
...
@@ -482,25 +482,22 @@ class AutoModelTest(unittest.TestCase):
with
self
.
assertRaisesRegex
(
EnvironmentError
,
"Use `from_flax=True` to load this model"
):
with
self
.
assertRaisesRegex
(
EnvironmentError
,
"Use `from_flax=True` to load this model"
):
_
=
AutoModel
.
from_pretrained
(
"hf-internal-testing/tiny-bert-flax-only"
)
_
=
AutoModel
.
from_pretrained
(
"hf-internal-testing/tiny-bert-flax-only"
)
@
unittest
.
skip
(
"Currently failing with new huggingface_hub release. See: https://github.com/huggingface/transformers/pull/27389"
)
def
test_cached_model_has_minimum_calls_to_head
(
self
):
def
test_cached_model_has_minimum_calls_to_head
(
self
):
# Make sure we have cached the model.
# Make sure we have cached the model.
_
=
AutoModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
_
=
AutoModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
with
RequestCounter
()
as
counter
:
with
RequestCounter
()
as
counter
:
_
=
AutoModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
_
=
AutoModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
self
.
assertEqual
(
counter
.
get_request_count
,
0
)
self
.
assertEqual
(
counter
[
"GET"
]
,
0
)
self
.
assertEqual
(
counter
.
head_request_count
,
1
)
self
.
assertEqual
(
counter
[
"HEAD"
]
,
1
)
self
.
assertEqual
(
counter
.
other_request_count
,
0
)
self
.
assertEqual
(
counter
.
total_calls
,
1
)
# With a sharded checkpoint
# With a sharded checkpoint
_
=
AutoModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert-sharded"
)
_
=
AutoModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert-sharded"
)
with
RequestCounter
()
as
counter
:
with
RequestCounter
()
as
counter
:
_
=
AutoModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert-sharded"
)
_
=
AutoModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert-sharded"
)
self
.
assertEqual
(
counter
.
get_request_count
,
0
)
self
.
assertEqual
(
counter
[
"GET"
]
,
0
)
self
.
assertEqual
(
counter
.
head_request_count
,
1
)
self
.
assertEqual
(
counter
[
"HEAD"
]
,
1
)
self
.
assertEqual
(
counter
.
other_request_count
,
0
)
self
.
assertEqual
(
counter
.
total_calls
,
1
)
def
test_attr_not_existing
(
self
):
def
test_attr_not_existing
(
self
):
from
transformers.models.auto.auto_factory
import
_LazyAutoMapping
from
transformers.models.auto.auto_factory
import
_LazyAutoMapping
...
...
tests/models/auto/test_modeling_tf_auto.py
View file @
e38348ae
...
@@ -301,14 +301,14 @@ class TFAutoModelTest(unittest.TestCase):
...
@@ -301,14 +301,14 @@ class TFAutoModelTest(unittest.TestCase):
_
=
TFAutoModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
_
=
TFAutoModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
with
RequestCounter
()
as
counter
:
with
RequestCounter
()
as
counter
:
_
=
TFAutoModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
_
=
TFAutoModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
self
.
assertEqual
(
counter
.
get_request_count
,
0
)
self
.
assertEqual
(
counter
[
"GET"
]
,
0
)
self
.
assertEqual
(
counter
.
head_request_count
,
1
)
self
.
assertEqual
(
counter
[
"HEAD"
]
,
1
)
self
.
assertEqual
(
counter
.
other_request_count
,
0
)
self
.
assertEqual
(
counter
.
total_calls
,
1
)
# With a sharded checkpoint
# With a sharded checkpoint
_
=
TFAutoModel
.
from_pretrained
(
"ArthurZ/tiny-random-bert-sharded"
)
_
=
TFAutoModel
.
from_pretrained
(
"ArthurZ/tiny-random-bert-sharded"
)
with
RequestCounter
()
as
counter
:
with
RequestCounter
()
as
counter
:
_
=
TFAutoModel
.
from_pretrained
(
"ArthurZ/tiny-random-bert-sharded"
)
_
=
TFAutoModel
.
from_pretrained
(
"ArthurZ/tiny-random-bert-sharded"
)
self
.
assertEqual
(
counter
.
get_request_count
,
0
)
self
.
assertEqual
(
counter
[
"GET"
]
,
0
)
self
.
assertEqual
(
counter
.
head_request_count
,
1
)
self
.
assertEqual
(
counter
[
"HEAD"
]
,
1
)
self
.
assertEqual
(
counter
.
other_request_count
,
0
)
self
.
assertEqual
(
counter
.
total_calls
,
1
)
tests/models/auto/test_tokenization_auto.py
View file @
e38348ae
...
@@ -419,14 +419,11 @@ class AutoTokenizerTest(unittest.TestCase):
...
@@ -419,14 +419,11 @@ class AutoTokenizerTest(unittest.TestCase):
):
):
_
=
AutoTokenizer
.
from_pretrained
(
DUMMY_UNKNOWN_IDENTIFIER
,
revision
=
"aaaaaa"
)
_
=
AutoTokenizer
.
from_pretrained
(
DUMMY_UNKNOWN_IDENTIFIER
,
revision
=
"aaaaaa"
)
@
unittest
.
skip
(
"Currently failing with new huggingface_hub release. See: https://github.com/huggingface/transformers/pull/27389"
)
def
test_cached_tokenizer_has_minimum_calls_to_head
(
self
):
def
test_cached_tokenizer_has_minimum_calls_to_head
(
self
):
# Make sure we have cached the tokenizer.
# Make sure we have cached the tokenizer.
_
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
_
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
with
RequestCounter
()
as
counter
:
with
RequestCounter
()
as
counter
:
_
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
_
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
self
.
assertEqual
(
counter
.
get_request_count
,
0
)
self
.
assertEqual
(
counter
[
"GET"
]
,
0
)
self
.
assertEqual
(
counter
.
head_request_count
,
1
)
self
.
assertEqual
(
counter
[
"HEAD"
]
,
1
)
self
.
assertEqual
(
counter
.
other_request_count
,
0
)
self
.
assertEqual
(
counter
.
total_calls
,
1
)
tests/pipelines/test_pipelines_common.py
View file @
e38348ae
...
@@ -763,9 +763,9 @@ class CustomPipelineTest(unittest.TestCase):
...
@@ -763,9 +763,9 @@ class CustomPipelineTest(unittest.TestCase):
_
=
pipeline
(
"text-classification"
,
model
=
"hf-internal-testing/tiny-random-bert"
)
_
=
pipeline
(
"text-classification"
,
model
=
"hf-internal-testing/tiny-random-bert"
)
with
RequestCounter
()
as
counter
:
with
RequestCounter
()
as
counter
:
_
=
pipeline
(
"text-classification"
,
model
=
"hf-internal-testing/tiny-random-bert"
)
_
=
pipeline
(
"text-classification"
,
model
=
"hf-internal-testing/tiny-random-bert"
)
self
.
assertEqual
(
counter
.
get_request_count
,
0
)
self
.
assertEqual
(
counter
[
"GET"
]
,
0
)
self
.
assertEqual
(
counter
.
head_request_count
,
1
)
self
.
assertEqual
(
counter
[
"HEAD"
]
,
1
)
self
.
assertEqual
(
counter
.
other_request_count
,
0
)
self
.
assertEqual
(
counter
.
total_calls
,
1
)
@
require_torch
@
require_torch
def
test_chunk_pipeline_batching_single_file
(
self
):
def
test_chunk_pipeline_batching_single_file
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment