Unverified Commit 3aea38ce authored by Scruel Tao's avatar Scruel Tao Committed by GitHub
Browse files

fix: suppress `GatedRepoError` to use cache file (fix #28558). (#28566)

* fix: suppress `GatedRepoError` to use cache file (fix #28558).

* move condition_to_return parameter back to outside.
parent 708b19eb
......@@ -786,6 +786,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
"user_agent": user_agent,
"revision": revision,
"subfolder": subfolder,
"_raise_exceptions_for_gated_repo": False,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
......
......@@ -2771,6 +2771,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"user_agent": user_agent,
"revision": revision,
"subfolder": subfolder,
"_raise_exceptions_for_gated_repo": False,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
......
......@@ -2938,6 +2938,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
token=token,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
......@@ -3381,6 +3382,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"user_agent": user_agent,
"revision": revision,
"subfolder": subfolder,
"_raise_exceptions_for_gated_repo": False,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
......
......@@ -488,6 +488,7 @@ class _BaseAutoModelClass:
resolved_config_file = cached_file(
pretrained_model_name_or_path,
CONFIG_NAME,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
**hub_kwargs,
......
......@@ -598,6 +598,7 @@ def get_tokenizer_config(
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
......
......@@ -747,6 +747,7 @@ def pipeline(
resolved_config_file = cached_file(
pretrained_model_name_or_path,
CONFIG_NAME,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
**hub_kwargs,
......
......@@ -1961,6 +1961,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
local_files_only=local_files_only,
subfolder=subfolder,
user_agent=user_agent,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
......@@ -1997,6 +1998,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
......
......@@ -229,6 +229,7 @@ class Tool:
TOOL_CONFIG_FILE,
token=token,
**hub_kwargs,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
......@@ -239,6 +240,7 @@ class Tool:
CONFIG_NAME,
token=token,
**hub_kwargs,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
......
......@@ -146,6 +146,16 @@ HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{
HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"
def _get_cache_file_to_return(
path_or_repo_id: str, full_filename: str, cache_dir: Union[str, Path, None] = None, revision: Optional[str] = None
):
# We try to see if we have a cached version (not up to date):
resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision)
if resolved_file is not None and resolved_file != _CACHED_NO_EXIST:
return resolved_file
return None
def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
......@@ -266,6 +276,7 @@ def cached_file(
subfolder: str = "",
repo_type: Optional[str] = None,
user_agent: Optional[Union[str, Dict[str, str]]] = None,
_raise_exceptions_for_gated_repo: bool = True,
_raise_exceptions_for_missing_entries: bool = True,
_raise_exceptions_for_connection_errors: bool = True,
_commit_hash: Optional[str] = None,
......@@ -335,6 +346,8 @@ def cached_file(
token = use_auth_token
# Private arguments
# _raise_exceptions_for_gated_repo: if False, do not raise an exception for gated repo error but return
# None.
# _raise_exceptions_for_missing_entries: if False, do not raise an exception for missing entries but return
# None.
# _raise_exceptions_for_connection_errors: if False, do not raise an exception for connection errors but return
......@@ -397,6 +410,9 @@ def cached_file(
local_files_only=local_files_only,
)
except GatedRepoError as e:
resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
if resolved_file is not None or not _raise_exceptions_for_gated_repo:
return resolved_file
raise EnvironmentError(
"You are trying to access a gated repo.\nMake sure to request access at "
f"https://huggingface.co/{path_or_repo_id} and pass a token having permission to this repo either "
......@@ -416,12 +432,13 @@ def cached_file(
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
) from e
except LocalEntryNotFoundError as e:
# We try to see if we have a cached version (not up to date):
resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision)
if resolved_file is not None and resolved_file != _CACHED_NO_EXIST:
resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
if (
resolved_file is not None
or not _raise_exceptions_for_missing_entries
or not _raise_exceptions_for_connection_errors
):
return resolved_file
if not _raise_exceptions_for_missing_entries or not _raise_exceptions_for_connection_errors:
return None
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
......@@ -438,13 +455,9 @@ def cached_file(
f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
) from e
except HTTPError as err:
# First we try to see if we have a cached version (not up to date):
resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision)
if resolved_file is not None and resolved_file != _CACHED_NO_EXIST:
resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
if resolved_file is not None or not _raise_exceptions_for_connection_errors:
return resolved_file
if not _raise_exceptions_for_connection_errors:
return None
raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
except HFValidationError as e:
raise EnvironmentError(
......@@ -545,6 +558,7 @@ def get_file_from_repo(
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
......
......@@ -96,6 +96,7 @@ def find_adapter_config_file(
local_files_only=local_files_only,
subfolder=subfolder,
_commit_hash=_commit_hash,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment