Unverified Commit 6ac77534 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Refine errors for pretrained objects (#15261)

* Refine errors for pretrained objects

* PoC to avoid using get_list_of_files

* Adapt tests to use new errors

* Quality + Fix PoC

* Revert "PoC to avoid using get_list_of_files"

This reverts commit cb93b7cae8504ef837c2a7663cb7955e714f323e.

* Revert "Quality + Fix PoC"

This reverts commit 3ba6d0d4ca546708b31d355baa9e68ba9736508f.

* Fix doc

* Revert PoC

* Add feature extractors

* More tests and PT model

* Adapt error message

* Feature extractor tests

* TF model

* Flax model and test

* Merge flax auto tests

* Add tokenization

* Fix test
parent 80af1048
...@@ -25,10 +25,15 @@ from typing import Any, Dict, Optional, Tuple, Union ...@@ -25,10 +25,15 @@ from typing import Any, Dict, Optional, Tuple, Union
from packaging import version from packaging import version
from requests import HTTPError
from . import __version__ from . import __version__
from .file_utils import ( from .file_utils import (
CONFIG_NAME, CONFIG_NAME,
EntryNotFoundError,
PushToHubMixin, PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
cached_path, cached_path,
copy_func, copy_func,
get_list_of_files, get_list_of_files,
...@@ -520,8 +525,6 @@ class PretrainedConfig(PushToHubMixin): ...@@ -520,8 +525,6 @@ class PretrainedConfig(PushToHubMixin):
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
[`PretrainedConfig`] using `from_dict`. [`PretrainedConfig`] using `from_dict`.
Parameters: Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`): pretrained_model_name_or_path (`str` or `os.PathLike`):
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
...@@ -578,30 +581,51 @@ class PretrainedConfig(PushToHubMixin): ...@@ -578,30 +581,51 @@ class PretrainedConfig(PushToHubMixin):
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
) )
# Load config dict
config_dict = cls._dict_from_json_file(resolved_config_file)
except RepositoryNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
"`use_auth_token=True`."
)
except RevisionNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
"available revisions."
)
except EntryNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}."
)
except HTTPError as err:
logger.error(err)
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
f"{pretrained_model_name_or_path} is not the path to a directory conaining a {configuration_file} "
"file.\nCheckout your internet connection or see how to run the library in offline mode at "
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError as err: except EnvironmentError as err:
logger.error(err) logger.error(err)
msg = ( raise EnvironmentError(
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n" "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n" f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n" f"containing a {configuration_file} file"
) )
if revision is not None: try:
msg += f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) that exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n" # Load config dict
config_dict = cls._dict_from_json_file(resolved_config_file)
raise EnvironmentError(msg)
except (json.JSONDecodeError, UnicodeDecodeError): except (json.JSONDecodeError, UnicodeDecodeError):
msg = ( raise EnvironmentError(
f"Couldn't reach server at '{config_file}' to download configuration file or " f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
"configuration file is not a valid JSON file. "
f"Please check network or file content here: {resolved_config_file}."
) )
raise EnvironmentError(msg)
if resolved_config_file == config_file: if resolved_config_file == config_file:
logger.info(f"loading configuration file {config_file}") logger.info(f"loading configuration file {config_file}")
...@@ -842,9 +866,13 @@ def get_configuration_file( ...@@ -842,9 +866,13 @@ def get_configuration_file(
`str`: The configuration file to use. `str`: The configuration file to use.
""" """
# Inspect all files from the repo/folder. # Inspect all files from the repo/folder.
all_files = get_list_of_files( try:
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only all_files = get_list_of_files(
) path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
)
except Exception:
return FULL_CONFIGURATION_FILE
configuration_files_map = {} configuration_files_map = {}
for file_name in all_files: for file_name in all_files:
search = _re_configuration_file.search(file_name) search = _re_configuration_file.search(file_name)
......
...@@ -24,8 +24,13 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union ...@@ -24,8 +24,13 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import numpy as np import numpy as np
from requests import HTTPError
from .file_utils import ( from .file_utils import (
FEATURE_EXTRACTOR_NAME, FEATURE_EXTRACTOR_NAME,
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
TensorType, TensorType,
_is_jax, _is_jax,
_is_numpy, _is_numpy,
...@@ -374,28 +379,54 @@ class FeatureExtractionMixin: ...@@ -374,28 +379,54 @@ class FeatureExtractionMixin:
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
) )
# Load feature_extractor dict
with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader:
text = reader.read()
feature_extractor_dict = json.loads(text)
except RepositoryNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
"`use_auth_token=True`."
)
except RevisionNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
"available revisions."
)
except EntryNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {FEATURE_EXTRACTOR_NAME}."
)
except HTTPError as err:
logger.error(err)
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
f"{pretrained_model_name_or_path} is not the path to a directory conaining a "
f"{FEATURE_EXTRACTOR_NAME} file.\nCheckout your internet connection or see how to run the library in "
"offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError as err: except EnvironmentError as err:
logger.error(err) logger.error(err)
msg = ( raise EnvironmentError(
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load it "
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n" "from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n" f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {FEATURE_EXTRACTOR_NAME} file\n\n" f"containing a {FEATURE_EXTRACTOR_NAME} file"
) )
raise EnvironmentError(msg)
try:
# Load feature_extractor dict
with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader:
text = reader.read()
feature_extractor_dict = json.loads(text)
except json.JSONDecodeError: except json.JSONDecodeError:
msg = ( raise EnvironmentError(
f"Couldn't reach server at '{feature_extractor_file}' to download feature extractor configuration file or " f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
"feature extractor configuration file is not a valid JSON file. "
f"Please check network or file content here: {resolved_feature_extractor_file}."
) )
raise EnvironmentError(msg)
if resolved_feature_extractor_file == feature_extractor_file: if resolved_feature_extractor_file == feature_extractor_file:
logger.info(f"loading feature extractor configuration file {feature_extractor_file}") logger.info(f"loading feature extractor configuration file {feature_extractor_file}")
......
...@@ -1900,6 +1900,37 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: ...@@ -1900,6 +1900,37 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return ua return ua
class RepositoryNotFoundError(HTTPError):
"""
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
not have access to.
"""
class EntryNotFoundError(HTTPError):
"""Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
class RevisionNotFoundError(HTTPError):
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
def _raise_for_status(request):
"""
Internal version of `request.raise_for_status()` that will refine a potential HTTPError.
"""
if "X-Error-Code" in request.headers:
error_code = request.headers["X-Error-Code"]
if error_code == "RepoNotFound":
raise RepositoryNotFoundError(f"404 Client Error: Repository Not Found for url: {request.url}")
elif error_code == "EntryNotFound":
raise EntryNotFoundError(f"404 Client Error: Entry Not Found for url: {request.url}")
elif error_code == "RevisionNotFound":
raise RevisionNotFoundError((f"404 Client Error: Revision Not Found for url: {request.url}"))
request.raise_for_status()
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None): def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
""" """
Download remote file. Do not gobble up errors. Download remote file. Do not gobble up errors.
...@@ -1908,7 +1939,7 @@ def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers ...@@ -1908,7 +1939,7 @@ def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers
if resume_size > 0: if resume_size > 0:
headers["Range"] = f"bytes={resume_size}-" headers["Range"] = f"bytes={resume_size}-"
r = requests.get(url, stream=True, proxies=proxies, headers=headers) r = requests.get(url, stream=True, proxies=proxies, headers=headers)
r.raise_for_status() _raise_for_status(r)
content_length = r.headers.get("Content-Length") content_length = r.headers.get("Content-Length")
total = resume_size + int(content_length) if content_length is not None else None total = resume_size + int(content_length) if content_length is not None else None
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()` # `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
...@@ -1970,7 +2001,7 @@ def get_from_cache( ...@@ -1970,7 +2001,7 @@ def get_from_cache(
if not local_files_only: if not local_files_only:
try: try:
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout) r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
r.raise_for_status() _raise_for_status(r)
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag") etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
# We favor a custom header indicating the etag of the linked resource, and # We favor a custom header indicating the etag of the linked resource, and
# we fallback to the regular etag header. # we fallback to the regular etag header.
...@@ -2081,6 +2112,56 @@ def get_from_cache( ...@@ -2081,6 +2112,56 @@ def get_from_cache(
return cache_path return cache_path
def has_file(
path_or_repo: Union[str, os.PathLike],
filename: str,
revision: Optional[str] = None,
mirror: Optional[str] = None,
proxies: Optional[Dict[str, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
):
"""
Checks if a repo contains a given file wihtout downloading it. Works for remote repos and local folders.
<Tip warning={false}>
This function will raise an error if the repository `path_or_repo` is not valid or if `revision` does not exist for
this repo, but will return False for regular connection errors.
</Tip>
"""
if os.path.isdir(path_or_repo):
return os.path.isfile(os.path.join(path_or_repo, filename))
url = hf_bucket_url(path_or_repo, filename=filename, revision=revision, mirror=mirror)
headers = {"user-agent": http_user_agent()}
if isinstance(use_auth_token, str):
headers["authorization"] = f"Bearer {use_auth_token}"
elif use_auth_token:
token = HfFolder.get_token()
if token is None:
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
headers["authorization"] = f"Bearer {token}"
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10)
try:
_raise_for_status(r)
return True
except RepositoryNotFoundError as e:
logger.error(e)
raise EnvironmentError(f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'.")
except RevisionNotFoundError as e:
logger.error(e)
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
"model name. Check the model page at 'https://huggingface.co/{path_or_repo}' for available revisions."
)
except requests.HTTPError:
# We return false for EntryNotFoundError (logical) as well as any connection error.
return False
def get_list_of_files( def get_list_of_files(
path_or_repo: Union[str, os.PathLike], path_or_repo: Union[str, os.PathLike],
revision: Optional[str] = None, revision: Optional[str] = None,
......
...@@ -26,16 +26,21 @@ from flax.core.frozen_dict import FrozenDict, unfreeze ...@@ -26,16 +26,21 @@ from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.serialization import from_bytes, to_bytes from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey from jax.random import PRNGKey
from requests import HTTPError
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import ( from .file_utils import (
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
EntryNotFoundError,
PushToHubMixin, PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
cached_path, cached_path,
copy_func, copy_func,
has_file,
hf_bucket_url, hf_bucket_url,
is_offline_mode, is_offline_mode,
is_remote_url, is_remote_url,
...@@ -450,17 +455,25 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -450,17 +455,25 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint # Load from a Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
# At this stage we don't have a weight file so we will raise an error.
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
raise EnvironmentError(
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
"weights."
)
else: else:
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory " f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
f"{pretrained_model_name_or_path} or `from_pt` set to False" f"{pretrained_model_name_or_path}."
) )
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path archive_file = pretrained_model_name_or_path
else: else:
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
archive_file = hf_bucket_url( archive_file = hf_bucket_url(
pretrained_model_name_or_path, pretrained_model_name_or_path,
filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME, filename=filename,
revision=revision, revision=revision,
) )
...@@ -476,15 +489,59 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -476,15 +489,59 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
) )
except RepositoryNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError as err:
logger.error(err)
if filename == FLAX_WEIGHTS_NAME:
has_file_kwargs = {"revision": revision, "proxies": proxies, "use_auth_token": use_auth_token}
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} "
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from "
"those weights."
)
else:
logger.error(err)
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} "
f"or {WEIGHTS_NAME}."
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
)
except HTTPError as err:
logger.error(err)
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named "
f"{FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
"Checkout your internet connection or see how to run the library in offline mode at "
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError as err: except EnvironmentError as err:
logger.error(err) logger.error(err)
msg = ( raise EnvironmentError(
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n" "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n" f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n" f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
) )
raise EnvironmentError(msg)
if resolved_archive_file == archive_file: if resolved_archive_file == archive_file:
logger.info(f"loading weights file {archive_file}") logger.info(f"loading weights file {archive_file}")
......
...@@ -32,16 +32,21 @@ from tensorflow.python.keras.engine.keras_tensor import KerasTensor ...@@ -32,16 +32,21 @@ from tensorflow.python.keras.engine.keras_tensor import KerasTensor
from tensorflow.python.keras.saving import hdf5_format from tensorflow.python.keras.saving import hdf5_format
from huggingface_hub import Repository, list_repo_files from huggingface_hub import Repository, list_repo_files
from requests import HTTPError
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import ( from .file_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
TF2_WEIGHTS_NAME, TF2_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
EntryNotFoundError,
ModelOutput, ModelOutput,
PushToHubMixin, PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
cached_path, cached_path,
copy_func, copy_func,
has_file,
hf_bucket_url, hf_bucket_url,
is_offline_mode, is_offline_mode,
is_remote_url, is_remote_url,
...@@ -1542,19 +1547,27 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1542,19 +1547,27 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
# Load from a TF 2.0 checkpoint # Load from a TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
# At this stage we don't have a weight file so we will raise an error.
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
raise EnvironmentError(
f"Error no file named {TF2_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
"weights."
)
else: else:
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME]} found in directory " f"Error no file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
f"{pretrained_model_name_or_path} or `from_pt` set to False" f"{pretrained_model_name_or_path}."
) )
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path archive_file = pretrained_model_name_or_path
elif os.path.isfile(pretrained_model_name_or_path + ".index"): elif os.path.isfile(pretrained_model_name_or_path + ".index"):
archive_file = pretrained_model_name_or_path + ".index" archive_file = pretrained_model_name_or_path + ".index"
else: else:
filename = WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME
archive_file = hf_bucket_url( archive_file = hf_bucket_url(
pretrained_model_name_or_path, pretrained_model_name_or_path,
filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME), filename=filename,
revision=revision, revision=revision,
mirror=mirror, mirror=mirror,
) )
...@@ -1571,15 +1584,65 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1571,15 +1584,65 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
) )
except RepositoryNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError as err:
logger.error(err)
if filename == TF2_WEIGHTS_NAME:
has_file_kwargs = {
"revision": revision,
"mirror": mirror,
"proxies": proxies,
"use_auth_token": use_auth_token,
}
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {TF2_WEIGHTS_NAME} "
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from "
"those weights."
)
else:
logger.error(err)
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {TF2_WEIGHTS_NAME} "
f"or {WEIGHTS_NAME}."
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
)
except HTTPError as err:
logger.error(err)
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named "
f"{TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
"Checkout your internet connection or see how to run the library in offline mode at "
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError as err: except EnvironmentError as err:
logger.error(err) logger.error(err)
msg = ( raise EnvironmentError(
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n" "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n" f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n" f"containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
) )
raise EnvironmentError(msg)
if resolved_archive_file == archive_file: if resolved_archive_file == archive_file:
logger.info(f"loading weights file {archive_file}") logger.info(f"loading weights file {archive_file}")
else: else:
......
...@@ -27,6 +27,8 @@ from packaging import version ...@@ -27,6 +27,8 @@ from packaging import version
from torch import Tensor, device, nn from torch import Tensor, device, nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from requests import HTTPError
from .activations import get_activation from .activations import get_activation
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
...@@ -36,10 +38,14 @@ from .file_utils import ( ...@@ -36,10 +38,14 @@ from .file_utils import (
TF2_WEIGHTS_NAME, TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME, TF_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
EntryNotFoundError,
ModelOutput, ModelOutput,
PushToHubMixin, PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
cached_path, cached_path,
copy_func, copy_func,
has_file,
hf_bucket_url, hf_bucket_url,
is_offline_mode, is_offline_mode,
is_remote_url, is_remote_url,
...@@ -1292,10 +1298,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1292,10 +1298,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
# At this stage we don't have a weight file so we will raise an error.
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
) or os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but "
"there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those "
"weights."
)
elif os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME):
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but "
"there is a file for Flax weights. Use `from_flax=True` to load this model from those "
"weights."
)
else: else:
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + '.index', FLAX_WEIGHTS_NAME]} found in " f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or "
f"directory {pretrained_model_name_or_path} or `from_tf` and `from_flax` set to False." f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
) )
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path archive_file = pretrained_model_name_or_path
...@@ -1334,20 +1355,72 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1334,20 +1355,72 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
) )
except RepositoryNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError as err:
logger.error(err)
if filename == WEIGHTS_NAME:
has_file_kwargs = {
"revision": revision,
"mirror": mirror,
"proxies": proxies,
"use_auth_token": use_auth_token,
}
if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but "
"there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those "
"weights."
)
elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but "
"there is a file for Flax weights. Use `from_flax=True` to load this model from those "
"weights."
)
else:
logger.error(err)
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}, "
f"{TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
)
except HTTPError as err:
logger.error(err)
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named "
f"{WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}.\n"
"Checkout your internet connection or see how to run the library in offline mode at "
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError as err: except EnvironmentError as err:
logger.error(err) logger.error(err)
msg = ( raise EnvironmentError(
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n" "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n" f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}\n\n" f"containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or "
f"{FLAX_WEIGHTS_NAME}."
) )
if revision is not None:
msg += f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) that exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n"
raise EnvironmentError(msg)
if resolved_archive_file == archive_file: if resolved_archive_file == archive_file:
logger.info(f"loading weights file {archive_file}") logger.info(f"loading weights file {archive_file}")
else: else:
......
...@@ -18,13 +18,13 @@ import importlib ...@@ -18,13 +18,13 @@ import importlib
import json import json
import os import os
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import ( from ...file_utils import (
RepositoryNotFoundError,
RevisionNotFoundError,
cached_path, cached_path,
get_list_of_files,
hf_bucket_url, hf_bucket_url,
is_offline_mode, is_offline_mode,
is_sentencepiece_available, is_sentencepiece_available,
...@@ -333,16 +333,6 @@ def get_tokenizer_config( ...@@ -333,16 +333,6 @@ def get_tokenizer_config(
logger.info("Offline mode: forcing local_files_only=True") logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True local_files_only = True
# Will raise a ValueError if `pretrained_model_name_or_path` is not a valid path or model identifier
repo_files = get_list_of_files(
pretrained_model_name_or_path,
revision=revision,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
)
if TOKENIZER_CONFIG_FILE not in [Path(f).name for f in repo_files]:
return {}
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE) config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE)
...@@ -363,6 +353,21 @@ def get_tokenizer_config( ...@@ -363,6 +353,21 @@ def get_tokenizer_config(
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
) )
except RepositoryNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
"pass a token having permission to this repo with `use_auth_token` or log in with "
"`huggingface-cli login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
"for this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EnvironmentError: except EnvironmentError:
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.") logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
return {} return {}
......
...@@ -31,13 +31,16 @@ from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequenc ...@@ -31,13 +31,16 @@ from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequenc
import numpy as np import numpy as np
from packaging import version from packaging import version
import requests from requests import HTTPError
from . import __version__ from . import __version__
from .file_utils import ( from .file_utils import (
EntryNotFoundError,
ExplicitEnum, ExplicitEnum,
PaddingStrategy, PaddingStrategy,
PushToHubMixin, PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
TensorType, TensorType,
_is_jax, _is_jax,
_is_numpy, _is_numpy,
...@@ -1704,9 +1707,28 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1704,9 +1707,28 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
else: else:
raise error raise error
except requests.exceptions.HTTPError as err: except RepositoryNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
"pass a token having permission to this repo with `use_auth_token` or log in with "
"`huggingface-cli login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
"for this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
logger.debug(f"{pretrained_model_name_or_path} does not contain a file named {file_path}.")
resolved_vocab_files[file_id] = None
except HTTPError as err:
if "404 Client Error" in str(err): if "404 Client Error" in str(err):
logger.debug(err) logger.debug(f"Connection problem to access {file_path}.")
resolved_vocab_files[file_id] = None resolved_vocab_files[file_id] = None
else: else:
raise err raise err
...@@ -1718,18 +1740,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1718,18 +1740,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
) )
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()): if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
msg = ( raise EnvironmentError(
f"Can't load tokenizer for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from "
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n" "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n" f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing relevant tokenizer files\n\n" f"containing all relevant tokenizer files."
) )
if revision is not None:
msg += f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) that exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n"
raise EnvironmentError(msg)
for file_id, file_path in vocab_files.items(): for file_id, file_path in vocab_files.items():
if file_id not in resolved_vocab_files: if file_id not in resolved_vocab_files:
continue continue
...@@ -3504,9 +3521,13 @@ def get_fast_tokenizer_file( ...@@ -3504,9 +3521,13 @@ def get_fast_tokenizer_file(
`str`: The tokenizer file to use. `str`: The tokenizer file to use.
""" """
# Inspect all files from the repo/folder. # Inspect all files from the repo/folder.
all_files = get_list_of_files( try:
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only all_files = get_list_of_files(
) path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
)
except Exception:
return FULL_TOKENIZER_FILE
tokenizer_files_map = {} tokenizer_files_map = {}
for file_name in all_files: for file_name in all_files:
search = _re_tokenizer_file.search(file_name) search = _re_tokenizer_file.search(file_name)
......
...@@ -83,3 +83,22 @@ class AutoConfigTest(unittest.TestCase): ...@@ -83,3 +83,22 @@ class AutoConfigTest(unittest.TestCase):
finally: finally:
if "new-model" in CONFIG_MAPPING._extra_content: if "new-model" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["new-model"] del CONFIG_MAPPING._extra_content["new-model"]
def test_repo_not_found(self):
with self.assertRaisesRegex(
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
):
_ = AutoConfig.from_pretrained("bert-base")
def test_revision_not_found(self):
with self.assertRaisesRegex(
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
):
_ = AutoConfig.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
def test_configuration_not_found(self):
with self.assertRaisesRegex(
EnvironmentError,
"hf-internal-testing/no-config-test-repo does not appear to have a file named config.json.",
):
_ = AutoConfig.from_pretrained("hf-internal-testing/no-config-test-repo")
...@@ -19,6 +19,7 @@ import tempfile ...@@ -19,6 +19,7 @@ import tempfile
import unittest import unittest
from transformers import AutoFeatureExtractor, Wav2Vec2Config, Wav2Vec2FeatureExtractor from transformers import AutoFeatureExtractor, Wav2Vec2Config, Wav2Vec2FeatureExtractor
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures") SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
...@@ -62,3 +63,22 @@ class AutoFeatureExtractorTest(unittest.TestCase): ...@@ -62,3 +63,22 @@ class AutoFeatureExtractorTest(unittest.TestCase):
def test_feature_extractor_from_local_file(self): def test_feature_extractor_from_local_file(self):
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG) config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG)
self.assertIsInstance(config, Wav2Vec2FeatureExtractor) self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_repo_not_found(self):
with self.assertRaisesRegex(
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
):
_ = AutoFeatureExtractor.from_pretrained("bert-base")
def test_revision_not_found(self):
with self.assertRaisesRegex(
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
):
_ = AutoFeatureExtractor.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
def test_feature_extractor_not_found(self):
with self.assertRaisesRegex(
EnvironmentError,
"hf-internal-testing/config-no-model does not appear to have a file named preprocessor_config.json.",
):
_ = AutoFeatureExtractor.from_pretrained("hf-internal-testing/config-no-model")
...@@ -17,17 +17,22 @@ import importlib ...@@ -17,17 +17,22 @@ import importlib
import io import io
import unittest import unittest
import requests
import transformers import transformers
# Try to import everything from transformers to ensure every object can be loaded. # Try to import everything from transformers to ensure every object can be loaded.
from transformers import * # noqa F406 from transformers import * # noqa F406
from transformers.file_utils import ( from transformers.file_utils import (
CONFIG_NAME, CONFIG_NAME,
FLAX_WEIGHTS_NAME,
TF2_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
ContextManagers, ContextManagers,
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
filename_to_url, filename_to_url,
get_from_cache, get_from_cache,
has_file,
hf_bucket_url, hf_bucket_url,
) )
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
...@@ -83,13 +88,19 @@ class GetFromCacheTests(unittest.TestCase): ...@@ -83,13 +88,19 @@ class GetFromCacheTests(unittest.TestCase):
def test_file_not_found(self): def test_file_not_found(self):
# Valid revision (None) but missing file. # Valid revision (None) but missing file.
url = hf_bucket_url(MODEL_ID, filename="missing.bin") url = hf_bucket_url(MODEL_ID, filename="missing.bin")
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"): with self.assertRaisesRegex(EntryNotFoundError, "404 Client Error"):
_ = get_from_cache(url)
def test_model_not_found(self):
# Invalid model file.
url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
with self.assertRaisesRegex(RepositoryNotFoundError, "404 Client Error"):
_ = get_from_cache(url) _ = get_from_cache(url)
def test_revision_not_found(self): def test_revision_not_found(self):
# Valid file but missing revision # Valid file but missing revision
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID) url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"): with self.assertRaisesRegex(RevisionNotFoundError, "404 Client Error"):
_ = get_from_cache(url) _ = get_from_cache(url)
def test_standard_object(self): def test_standard_object(self):
...@@ -112,6 +123,11 @@ class GetFromCacheTests(unittest.TestCase): ...@@ -112,6 +123,11 @@ class GetFromCacheTests(unittest.TestCase):
metadata = filename_to_url(filepath) metadata = filename_to_url(filepath)
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"')) self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
def test_has_file(self):
self.assertTrue(has_file("hf-internal-testing/tiny-bert-pt-only", WEIGHTS_NAME))
self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME))
self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", FLAX_WEIGHTS_NAME))
class ContextManagerTests(unittest.TestCase): class ContextManagerTests(unittest.TestCase):
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO) @unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
......
...@@ -389,3 +389,30 @@ class AutoModelTest(unittest.TestCase): ...@@ -389,3 +389,30 @@ class AutoModelTest(unittest.TestCase):
): ):
if NewModelConfig in mapping._extra_content: if NewModelConfig in mapping._extra_content:
del mapping._extra_content[NewModelConfig] del mapping._extra_content[NewModelConfig]
def test_repo_not_found(self):
with self.assertRaisesRegex(
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
):
_ = AutoModel.from_pretrained("bert-base")
def test_revision_not_found(self):
with self.assertRaisesRegex(
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
):
_ = AutoModel.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
def test_model_file_not_found(self):
with self.assertRaisesRegex(
EnvironmentError,
"hf-internal-testing/config-no-model does not appear to have a file named pytorch_model.bin",
):
_ = AutoModel.from_pretrained("hf-internal-testing/config-no-model")
def test_model_from_tf_suggestion(self):
with self.assertRaisesRegex(EnvironmentError, "Use `from_tf=True` to load this model"):
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
def test_model_from_flax_suggestion(self):
with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"):
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
from transformers import AutoConfig, AutoTokenizer, BertConfig, TensorType, is_flax_available from transformers import AutoConfig, AutoTokenizer, BertConfig, TensorType, is_flax_available
from transformers.testing_utils import require_flax, slow from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, require_flax, slow
if is_flax_available(): if is_flax_available():
...@@ -76,3 +76,26 @@ class FlaxAutoModelTest(unittest.TestCase): ...@@ -76,3 +76,26 @@ class FlaxAutoModelTest(unittest.TestCase):
return model(**kwargs) return model(**kwargs)
eval(**tokens).block_until_ready() eval(**tokens).block_until_ready()
def test_repo_not_found(self):
with self.assertRaisesRegex(
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
):
_ = FlaxAutoModel.from_pretrained("bert-base")
def test_revision_not_found(self):
with self.assertRaisesRegex(
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
):
_ = FlaxAutoModel.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
def test_model_file_not_found(self):
with self.assertRaisesRegex(
EnvironmentError,
"hf-internal-testing/config-no-model does not appear to have a file named flax_model.msgpack",
):
_ = FlaxAutoModel.from_pretrained("hf-internal-testing/config-no-model")
def test_model_from_pt_suggestion(self):
with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"):
_ = FlaxAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
...@@ -309,3 +309,26 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -309,3 +309,26 @@ class TFAutoModelTest(unittest.TestCase):
): ):
if NewModelConfig in mapping._extra_content: if NewModelConfig in mapping._extra_content:
del mapping._extra_content[NewModelConfig] del mapping._extra_content[NewModelConfig]
def test_repo_not_found(self):
with self.assertRaisesRegex(
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
):
_ = TFAutoModel.from_pretrained("bert-base")
def test_revision_not_found(self):
with self.assertRaisesRegex(
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
):
_ = TFAutoModel.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
def test_model_file_not_found(self):
with self.assertRaisesRegex(
EnvironmentError,
"hf-internal-testing/config-no-model does not appear to have a file named tf_model.h5",
):
_ = TFAutoModel.from_pretrained("hf-internal-testing/config-no-model")
def test_model_from_pt_suggestion(self):
with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"):
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
...@@ -150,7 +150,8 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -150,7 +150,8 @@ class AutoTokenizerTest(unittest.TestCase):
def test_tokenizer_identifier_non_existent(self): def test_tokenizer_identifier_non_existent(self):
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]: for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ".*is not a local path or a model identifier on the model Hub. Did you make a typo?" EnvironmentError,
"julien-c/herlolip-not-exists is not a local folder and is not a valid model identifier",
): ):
_ = tokenizer_class.from_pretrained("julien-c/herlolip-not-exists") _ = tokenizer_class.from_pretrained("julien-c/herlolip-not-exists")
...@@ -310,3 +311,15 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -310,3 +311,15 @@ class AutoTokenizerTest(unittest.TestCase):
del CONFIG_MAPPING._extra_content["new-model"] del CONFIG_MAPPING._extra_content["new-model"]
if NewConfig in TOKENIZER_MAPPING._extra_content: if NewConfig in TOKENIZER_MAPPING._extra_content:
del TOKENIZER_MAPPING._extra_content[NewConfig] del TOKENIZER_MAPPING._extra_content[NewConfig]
def test_repo_not_found(self):
with self.assertRaisesRegex(
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
):
_ = AutoTokenizer.from_pretrained("bert-base")
def test_revision_not_found(self):
with self.assertRaisesRegex(
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
):
_ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
...@@ -255,7 +255,7 @@ SPECIAL_MODULE_TO_TEST_MAP = { ...@@ -255,7 +255,7 @@ SPECIAL_MODULE_TO_TEST_MAP = {
"modeling_tf_utils.py": ["test_modeling_tf_common.py", "test_modeling_tf_core.py"], "modeling_tf_utils.py": ["test_modeling_tf_common.py", "test_modeling_tf_core.py"],
"modeling_utils.py": ["test_modeling_common.py", "test_offline.py"], "modeling_utils.py": ["test_modeling_common.py", "test_offline.py"],
"models/auto/modeling_auto.py": ["test_modeling_auto.py", "test_modeling_tf_pytorch.py", "test_modeling_bort.py"], "models/auto/modeling_auto.py": ["test_modeling_auto.py", "test_modeling_tf_pytorch.py", "test_modeling_bort.py"],
"models/auto/modeling_flax_auto.py": "test_flax_auto.py", "models/auto/modeling_flax_auto.py": "test_modeling_flax_auto.py",
"models/auto/modeling_tf_auto.py": [ "models/auto/modeling_tf_auto.py": [
"test_modeling_tf_auto.py", "test_modeling_tf_auto.py",
"test_modeling_tf_pytorch.py", "test_modeling_tf_pytorch.py",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment