"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "268c6cc160ba046d6a91747c5f281f82bd88a4d8"
Unverified Commit c99ddcc4 authored by Simon Brandeis's avatar Simon Brandeis Committed by GitHub
Browse files

🐛 Properly raise `RepoNotFoundError` when not authenticated (#17651)

* Raise RepoNotFoundError in case of 401

* Include changes from revert-17646-skip_repo_not_found

* Add a comment

* 💄 Code quality

* 💚 Update `get_from_cache` test

* 💚 Code quality & skip failing test
parent 35b16032
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"#%pip install-r requirements.txt" "# %pip install-r requirements.txt"
] ]
}, },
{ {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 1,
"source": [ "source": [
"#%pip install-r requirements.txt" "# %pip install-r requirements.txt"
], ],
"outputs": [], "outputs": [],
"metadata": {} "metadata": {}
......
...@@ -38,6 +38,7 @@ import requests ...@@ -38,6 +38,7 @@ import requests
from filelock import FileLock from filelock import FileLock
from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from requests.models import Response
from transformers.utils.logging import tqdm from transformers.utils.logging import tqdm
from . import __version__, logging from . import __version__, logging
...@@ -398,20 +399,27 @@ class RevisionNotFoundError(HTTPError): ...@@ -398,20 +399,27 @@ class RevisionNotFoundError(HTTPError):
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision.""" """Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
def _raise_for_status(request): def _raise_for_status(response: Response):
""" """
Internal version of `request.raise_for_status()` that will refine a potential HTTPError. Internal version of `request.raise_for_status()` that will refine a potential HTTPError.
""" """
if "X-Error-Code" in request.headers: if "X-Error-Code" in response.headers:
error_code = request.headers["X-Error-Code"] error_code = response.headers["X-Error-Code"]
if error_code == "RepoNotFound": if error_code == "RepoNotFound":
raise RepositoryNotFoundError(f"404 Client Error: Repository Not Found for url: {request.url}") raise RepositoryNotFoundError(f"404 Client Error: Repository Not Found for url: {response.url}")
elif error_code == "EntryNotFound": elif error_code == "EntryNotFound":
raise EntryNotFoundError(f"404 Client Error: Entry Not Found for url: {request.url}") raise EntryNotFoundError(f"404 Client Error: Entry Not Found for url: {response.url}")
elif error_code == "RevisionNotFound": elif error_code == "RevisionNotFound":
raise RevisionNotFoundError(f"404 Client Error: Revision Not Found for url: {request.url}") raise RevisionNotFoundError(f"404 Client Error: Revision Not Found for url: {response.url}")
request.raise_for_status() if response.status_code == 401:
# The repo was not found and the user is not Authenticated
raise RepositoryNotFoundError(
f"401 Client Error: Repository not found for url: {response.url}. "
"If the repo is private, make sure you are authenticated."
)
response.raise_for_status()
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None): def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
......
...@@ -88,7 +88,6 @@ class AutoConfigTest(unittest.TestCase): ...@@ -88,7 +88,6 @@ class AutoConfigTest(unittest.TestCase):
if "custom" in CONFIG_MAPPING._extra_content: if "custom" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["custom"] del CONFIG_MAPPING._extra_content["custom"]
@unittest.skip("Temp bug in the Hub not returning RepoNotFound errors.")
def test_repo_not_found(self): def test_repo_not_found(self):
with self.assertRaisesRegex( with self.assertRaisesRegex(
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier" EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
......
...@@ -76,7 +76,6 @@ class AutoFeatureExtractorTest(unittest.TestCase): ...@@ -76,7 +76,6 @@ class AutoFeatureExtractorTest(unittest.TestCase):
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG) config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG)
self.assertIsInstance(config, Wav2Vec2FeatureExtractor) self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
@unittest.skip("Temp bug in the Hub not returning RepoNotFound errors.")
def test_repo_not_found(self): def test_repo_not_found(self):
with self.assertRaisesRegex( with self.assertRaisesRegex(
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier" EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
......
...@@ -328,7 +328,6 @@ class AutoModelTest(unittest.TestCase): ...@@ -328,7 +328,6 @@ class AutoModelTest(unittest.TestCase):
if CustomConfig in mapping._extra_content: if CustomConfig in mapping._extra_content:
del mapping._extra_content[CustomConfig] del mapping._extra_content[CustomConfig]
@unittest.skip("Temp bug in the Hub not returning RepoNotFound errors.")
def test_repo_not_found(self): def test_repo_not_found(self):
with self.assertRaisesRegex( with self.assertRaisesRegex(
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier" EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
......
...@@ -77,7 +77,6 @@ class FlaxAutoModelTest(unittest.TestCase): ...@@ -77,7 +77,6 @@ class FlaxAutoModelTest(unittest.TestCase):
eval(**tokens).block_until_ready() eval(**tokens).block_until_ready()
@unittest.skip("Temp bug in the Hub not returning RepoNotFound errors.")
def test_repo_not_found(self): def test_repo_not_found(self):
with self.assertRaisesRegex( with self.assertRaisesRegex(
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier" EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
......
...@@ -265,7 +265,6 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -265,7 +265,6 @@ class TFAutoModelTest(unittest.TestCase):
if NewModelConfig in mapping._extra_content: if NewModelConfig in mapping._extra_content:
del mapping._extra_content[NewModelConfig] del mapping._extra_content[NewModelConfig]
@unittest.skip("Temp bug in the Hub not returning RepoNotFound errors.")
def test_repo_not_found(self): def test_repo_not_found(self):
with self.assertRaisesRegex( with self.assertRaisesRegex(
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier" EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
......
...@@ -142,7 +142,6 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -142,7 +142,6 @@ class AutoTokenizerTest(unittest.TestCase):
self.assertEqual(tokenizer.model_max_length, 512) self.assertEqual(tokenizer.model_max_length, 512)
@unittest.skip("Temp bug in the Hub not returning RepoNotFound errors.")
@require_tokenizers @require_tokenizers
def test_tokenizer_identifier_non_existent(self): def test_tokenizer_identifier_non_existent(self):
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]: for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
...@@ -330,7 +329,6 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -330,7 +329,6 @@ class AutoTokenizerTest(unittest.TestCase):
else: else:
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer") self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
@unittest.skip("Temp bug in the Hub not returning RepoNotFound errors.")
def test_repo_not_found(self): def test_repo_not_found(self):
with self.assertRaisesRegex( with self.assertRaisesRegex(
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier" EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
......
...@@ -99,13 +99,20 @@ class GetFromCacheTests(unittest.TestCase): ...@@ -99,13 +99,20 @@ class GetFromCacheTests(unittest.TestCase):
with self.assertRaisesRegex(EntryNotFoundError, "404 Client Error"): with self.assertRaisesRegex(EntryNotFoundError, "404 Client Error"):
_ = get_from_cache(url) _ = get_from_cache(url)
@unittest.skip("Temp bug in the Hub not returning RepoNotFound errors.") def test_model_not_found_not_authenticated(self):
def test_model_not_found(self): # Invalid model id.
# Invalid model file.
url = hf_bucket_url("bert-base", filename="pytorch_model.bin") url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
with self.assertRaisesRegex(RepositoryNotFoundError, "404 Client Error"): with self.assertRaisesRegex(RepositoryNotFoundError, "401 Client Error"):
_ = get_from_cache(url) _ = get_from_cache(url)
@unittest.skip("No authentication when testing against prod")
def test_model_not_found_authenticated(self):
# Invalid model id.
url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
with self.assertRaisesRegex(RepositoryNotFoundError, "404 Client Error"):
_ = get_from_cache(url, use_auth_token="hf_sometoken")
# ^ TODO - if we decide to unskip this: use a real / functional token
def test_revision_not_found(self): def test_revision_not_found(self):
# Valid file but missing revision # Valid file but missing revision
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID) url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
...@@ -142,9 +149,8 @@ class GetFromCacheTests(unittest.TestCase): ...@@ -142,9 +149,8 @@ class GetFromCacheTests(unittest.TestCase):
self.assertIsNone(get_file_from_repo("bert-base-cased", "ahah.txt")) self.assertIsNone(get_file_from_repo("bert-base-cased", "ahah.txt"))
# The function raises if the repository does not exist. # The function raises if the repository does not exist.
# Uncomment when bug is fixed. with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"):
# with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"): get_file_from_repo("bert-base-case", "config.json")
# get_file_from_repo("bert-base-case", "config.json")
# The function raises if the revision does not exist. # The function raises if the revision does not exist.
with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"): with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment