Unverified Commit 34be08ef authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

More tests for regression in cached non existence (#19216)

* More tests for regression in cached non existence

* Style
parent e3a30e2b
...@@ -40,6 +40,7 @@ from transformers import ( ...@@ -40,6 +40,7 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
BertTokenizer, BertTokenizer,
BertTokenizerFast, BertTokenizerFast,
GPT2TokenizerFast,
PreTrainedTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
PreTrainedTokenizerFast, PreTrainedTokenizerFast,
...@@ -3884,12 +3885,30 @@ class TokenizerUtilTester(unittest.TestCase): ...@@ -3884,12 +3885,30 @@ class TokenizerUtilTester(unittest.TestCase):
# Download this model to make sure it's in the cache. # Download this model to make sure it's in the cache.
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") _ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model. # Under the mock environment we get a 500 error when trying to reach the tokenizer.
with mock.patch("requests.request", return_value=response_mock) as mock_head: with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") _ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request # This check we did call the fake head request
mock_head.assert_called() mock_head.assert_called()
@require_tokenizers
def test_cached_files_are_used_when_internet_is_down_missing_files(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
_ = GPT2TokenizerFast.from_pretrained("gpt2")
# Under the mock environment we get a 500 error when trying to reach the tokenizer.
with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = GPT2TokenizerFast.from_pretrained("gpt2")
# This check we did call the fake head request
mock_head.assert_called()
def test_legacy_load_from_one_file(self): def test_legacy_load_from_one_file(self):
# This test is for deprecated behavior and can be removed in v5 # This test is for deprecated behavior and can be removed in v5
try: try:
......
...@@ -15,8 +15,10 @@ import json ...@@ -15,8 +15,10 @@ import json
import os import os
import tempfile import tempfile
import unittest import unittest
import unittest.mock as mock
from pathlib import Path from pathlib import Path
from requests.exceptions import HTTPError
from transformers.utils import ( from transformers.utils import (
CONFIG_NAME, CONFIG_NAME,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
...@@ -79,6 +81,19 @@ class GetFromCacheTests(unittest.TestCase): ...@@ -79,6 +81,19 @@ class GetFromCacheTests(unittest.TestCase):
path = cached_file(RANDOM_BERT, "conf", local_files_only=True, _raise_exceptions_for_missing_entries=False) path = cached_file(RANDOM_BERT, "conf", local_files_only=True, _raise_exceptions_for_missing_entries=False)
self.assertIsNone(path) self.assertIsNone(path)
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
response_mock.json.return_value = {}
# Under the mock environment we get a 500 error when trying to reach the tokenizer.
with mock.patch("requests.request", return_value=response_mock) as mock_head:
path = cached_file(RANDOM_BERT, "conf", _raise_exceptions_for_connection_errors=False)
self.assertIsNone(path)
# This check we did call the fake head request
mock_head.assert_called()
def test_has_file(self): def test_has_file(self):
self.assertTrue(has_file("hf-internal-testing/tiny-bert-pt-only", WEIGHTS_NAME)) self.assertTrue(has_file("hf-internal-testing/tiny-bert-pt-only", WEIGHTS_NAME))
self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME)) self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment