Unverified Commit 5cd40323 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Use new huggingface_hub tools for download models (#18438)

* Draft new cached_file

* Initial draft for config and model

* Small fixes

* Fix first batch of tests

* Look in cache when internet is down

* Fix last tests

* Bad black, not fixing all quality errors

* Make diff less

* Implement change for TF and Flax models

* Add tokenizer and feature extractor

* For compatibility with main

* Add utils to move the cache and auto-do it at first use.

* Quality

* Deal with empty commit shas

* Deal with empty etag

* Address review comments
parent 70fa1a8d
......@@ -25,25 +25,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from packaging import version
from requests import HTTPError
from . import __version__
from .dynamic_module_utils import custom_object_save
from .utils import (
CONFIG_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
cached_path,
copy_func,
hf_bucket_url,
is_offline_mode,
is_remote_url,
is_torch_available,
logging,
)
from .utils import CONFIG_NAME, PushToHubMixin, cached_file, copy_func, is_torch_available, logging
logger = logging.get_logger(__name__)
......@@ -591,33 +575,21 @@ class PretrainedConfig(PushToHubMixin):
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
pretrained_model_name_or_path
):
config_file = pretrained_model_name_or_path
else:
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
if os.path.isdir(os.path.join(pretrained_model_name_or_path, subfolder)):
config_file = os.path.join(pretrained_model_name_or_path, subfolder, configuration_file)
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
# Soecial case when pretrained_model_name_or_path is a local file
resolved_config_file = pretrained_model_name_or_path
is_local = True
else:
config_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=configuration_file,
revision=revision,
subfolder=subfolder if len(subfolder) > 0 else None,
mirror=None,
)
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
try:
# Load from URL or cache if already cached
resolved_config_file = cached_path(
config_file,
# Load from local folder or from cache or download from model Hub and cache
resolved_config_file = cached_file(
pretrained_model_name_or_path,
configuration_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
......@@ -625,42 +597,20 @@ class PretrainedConfig(PushToHubMixin):
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
"`use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
"available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
f" containing a {configuration_file} file.\nCheckout your internet connection or see how to run the"
" library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
revision=revision,
subfolder=subfolder,
)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError(
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a {configuration_file} file"
f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
f" containing a {configuration_file} file"
)
try:
......@@ -671,10 +621,10 @@ class PretrainedConfig(PushToHubMixin):
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
)
if resolved_config_file == config_file:
logger.info(f"loading configuration file {config_file}")
if is_local:
logger.info(f"loading configuration file {resolved_config_file}")
else:
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
return config_dict, kwargs
......
......@@ -24,23 +24,15 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import numpy as np
from requests import HTTPError
from .dynamic_module_utils import custom_object_save
from .utils import (
FEATURE_EXTRACTOR_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
TensorType,
cached_path,
cached_file,
copy_func,
hf_bucket_url,
is_flax_available,
is_offline_mode,
is_remote_url,
is_tf_available,
is_torch_available,
logging,
......@@ -388,18 +380,18 @@ class FeatureExtractionMixin(PushToHubMixin):
local_files_only = True
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
feature_extractor_file = pretrained_model_name_or_path
if os.path.isfile(pretrained_model_name_or_path):
resolved_feature_extractor_file = pretrained_model_name_or_path
is_local = True
else:
feature_extractor_file = hf_bucket_url(
pretrained_model_name_or_path, filename=FEATURE_EXTRACTOR_NAME, revision=revision, mirror=None
)
feature_extractor_file = FEATURE_EXTRACTOR_NAME
try:
# Load from URL or cache if already cached
resolved_feature_extractor_file = cached_path(
# Load from local folder or from cache or download from model Hub and cache
resolved_feature_extractor_file = cached_file(
pretrained_model_name_or_path,
feature_extractor_file,
cache_dir=cache_dir,
force_download=force_download,
......@@ -408,43 +400,19 @@ class FeatureExtractionMixin(PushToHubMixin):
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
"`use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
"available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {FEATURE_EXTRACTOR_NAME}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
f" containing a {FEATURE_EXTRACTOR_NAME} file.\nCheckout your internet connection or see how to run"
" the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
revision=revision,
)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError(
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load it "
"from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a {FEATURE_EXTRACTOR_NAME} file"
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a {FEATURE_EXTRACTOR_NAME} file"
)
try:
......@@ -458,12 +426,11 @@ class FeatureExtractionMixin(PushToHubMixin):
f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
)
if resolved_feature_extractor_file == feature_extractor_file:
logger.info(f"loading feature extractor configuration file {feature_extractor_file}")
if is_local:
logger.info(f"loading configuration file {resolved_feature_extractor_file}")
else:
logger.info(
f"loading feature extractor configuration file {feature_extractor_file} from cache at"
f" {resolved_feature_extractor_file}"
f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}"
)
return feature_extractor_dict, kwargs
......
......@@ -32,7 +32,6 @@ from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
from requests import HTTPError
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
......@@ -41,20 +40,14 @@ from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_d
from .utils import (
FLAX_WEIGHTS_INDEX_NAME,
FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_NAME,
EntryNotFoundError,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
add_code_sample_docstrings,
add_start_docstrings_to_model_forward,
cached_path,
cached_file,
copy_func,
has_file,
hf_bucket_url,
is_offline_mode,
is_remote_url,
logging,
replace_return_docstrings,
)
......@@ -557,6 +550,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
......@@ -598,6 +594,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
_do_init = kwargs.pop("_do_init", True)
subfolder = kwargs.pop("subfolder", "")
if trust_remote_code is True:
logger.warning(
......@@ -642,6 +639,8 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# Load model
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
......@@ -665,21 +664,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
f"{pretrained_model_name_or_path}."
)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
is_local = True
else:
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=filename,
revision=revision,
)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(
archive_file,
# Load from URL or cache if already cached
cached_file_kwargs = dict(
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
......@@ -687,43 +679,29 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
)
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
if filename == FLAX_WEIGHTS_NAME:
try:
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
# result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=FLAX_WEIGHTS_INDEX_NAME,
revision=revision,
)
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
resolved_archive_file = cached_file(
pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
if resolved_archive_file is not None:
is_sharded = True
except EntryNotFoundError:
has_file_kwargs = {"revision": revision, "proxies": proxies, "use_auth_token": use_auth_token}
if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
has_file_kwargs = {
"revision": revision,
"proxies": proxies,
"use_auth_token": use_auth_token,
}
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
......@@ -735,35 +713,24 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
" internet connection or see how to run the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
# to the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
if resolved_archive_file == archive_file:
if is_local:
logger.info(f"loading weights file {archive_file}")
resolved_archive_file = archive_file
else:
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
else:
resolved_archive_file = None
......
......@@ -37,7 +37,6 @@ from tensorflow.python.keras.saving import hdf5_format
from huggingface_hub import Repository, list_repo_files
from keras.saving.hdf5_format import save_attributes_to_hdf5_group
from requests import HTTPError
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from . import DataCollatorWithPadding, DefaultDataCollator
......@@ -48,22 +47,16 @@ from .generation_tf_utils import TFGenerationMixin
from .tf_utils import shape_list
from .utils import (
DUMMY_INPUTS,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
TF2_WEIGHTS_INDEX_NAME,
TF2_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
EntryNotFoundError,
ModelOutput,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
cached_path,
cached_file,
find_labels,
has_file,
hf_bucket_url,
is_offline_mode,
is_remote_url,
logging,
requires_backends,
working_or_temp_dir,
......@@ -2112,6 +2105,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
......@@ -2164,6 +2160,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
load_weight_prefix = kwargs.pop("load_weight_prefix", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
subfolder = kwargs.pop("subfolder", "")
if trust_remote_code is True:
logger.warning(
......@@ -2202,9 +2199,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
is_sharded = False
sharded_metadata = None
# Load model
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint in priority if from_pt
......@@ -2232,23 +2230,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
f"Error no file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
f"{pretrained_model_name_or_path}."
)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
is_local = True
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
archive_file = pretrained_model_name_or_path + ".index"
is_local = True
else:
# set correct filename
filename = WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=filename,
revision=revision,
mirror=mirror,
)
try:
# Load from URL or cache if already cached
resolved_archive_file = cached_path(
archive_file,
cached_file_kwargs = dict(
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
......@@ -2256,44 +2250,23 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
)
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
if filename == TF2_WEIGHTS_NAME:
try:
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
# result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=TF2_WEIGHTS_INDEX_NAME,
revision=revision,
mirror=mirror,
)
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
resolved_archive_file = cached_file(
pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
if resolved_archive_file is not None:
is_sharded = True
except EntryNotFoundError:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
if resolved_archive_file is None:
# Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error
# message.
has_file_kwargs = {
"revision": revision,
......@@ -2312,42 +2285,32 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your internet"
" connection or see how to run the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
# to the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
if resolved_archive_file == archive_file:
if is_local:
logger.info(f"loading weights file {archive_file}")
resolved_archive_file = archive_file
else:
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
else:
resolved_archive_file = None
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
if is_sharded:
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
resolved_archive_file, _ = get_checkpoint_shard_files(
pretrained_model_name_or_path,
resolved_archive_file,
cache_dir=cache_dir,
......
......@@ -31,7 +31,6 @@ from packaging import version
from torch import Tensor, device, nn
from torch.nn import CrossEntropyLoss
from requests import HTTPError
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from transformers.utils.import_utils import is_sagemaker_mp_enabled
......@@ -51,24 +50,18 @@ from .pytorch_utils import ( # noqa: F401
from .utils import (
DUMMY_INPUTS,
FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
ContextManagers,
EntryNotFoundError,
ModelOutput,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
cached_path,
cached_file,
copy_func,
has_file,
hf_bucket_url,
is_accelerate_available,
is_offline_mode,
is_remote_url,
logging,
replace_return_docstrings,
)
......@@ -1868,7 +1861,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local:
if from_tf and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
):
......@@ -1911,10 +1905,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or "
f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
)
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
pretrained_model_name_or_path
):
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
archive_file = pretrained_model_name_or_path
is_local = True
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")):
if not from_tf:
raise ValueError(
......@@ -1922,6 +1915,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"from_tf to True to load from this checkpoint."
)
archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
is_local = True
else:
# set correct filename
if from_tf:
......@@ -1931,18 +1925,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else:
filename = WEIGHTS_NAME
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=filename,
revision=revision,
mirror=mirror,
subfolder=subfolder if len(subfolder) > 0 else None,
)
try:
# Load from URL or cache if already cached
resolved_archive_file = cached_path(
archive_file,
cached_file_kwargs = dict(
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
......@@ -1950,44 +1935,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
)
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
if filename == WEIGHTS_NAME:
try:
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
# result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=WEIGHTS_INDEX_NAME,
revision=revision,
mirror=mirror,
subfolder=subfolder if len(subfolder) > 0 else None,
)
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
resolved_archive_file = cached_file(
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
if resolved_archive_file is not None:
is_sharded = True
except EntryNotFoundError:
if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
has_file_kwargs = {
......@@ -2013,42 +1976,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME},"
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
f" {FLAX_WEIGHTS_NAME}.\nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
# to the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or "
f"{FLAX_WEIGHTS_NAME}."
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
f" {FLAX_WEIGHTS_NAME}."
)
if resolved_archive_file == archive_file:
if is_local:
logger.info(f"loading weights file {archive_file}")
resolved_archive_file = archive_file
else:
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
else:
resolved_archive_file = None
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
if is_sharded:
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
# rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
pretrained_model_name_or_path,
resolved_archive_file,
......
......@@ -35,21 +35,16 @@ from packaging import version
from . import __version__
from .dynamic_module_utils import custom_object_save
from .utils import (
EntryNotFoundError,
ExplicitEnum,
PaddingStrategy,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
TensorType,
add_end_docstrings,
cached_path,
cached_file,
copy_func,
get_file_from_repo,
hf_bucket_url,
is_flax_available,
is_offline_mode,
is_remote_url,
is_tf_available,
is_tokenizers_available,
is_torch_available,
......@@ -1669,7 +1664,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
vocab_files = {}
init_configuration = {}
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path):
if len(cls.vocab_files_names) > 1:
raise ValueError(
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not "
......@@ -1689,9 +1685,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
}
vocab_files_target = {**cls.vocab_files_names, **additional_files_names}
vocab_files = {**cls.vocab_files_names, **additional_files_names}
if "tokenizer_file" in vocab_files_target:
if "tokenizer_file" in vocab_files:
# Try to get the tokenizer config to see if there are versioned tokenizer files.
fast_tokenizer_file = FULL_TOKENIZER_FILE
resolved_config_file = get_file_from_repo(
......@@ -1704,44 +1700,25 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
)
if resolved_config_file is not None:
with open(resolved_config_file, encoding="utf-8") as reader:
tokenizer_config = json.load(reader)
if "fast_tokenizer_files" in tokenizer_config:
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
vocab_files_target["tokenizer_file"] = fast_tokenizer_file
# Look for the tokenizer files
for file_id, file_name in vocab_files_target.items():
if os.path.isdir(pretrained_model_name_or_path):
if subfolder is not None:
full_file_name = os.path.join(pretrained_model_name_or_path, subfolder, file_name)
else:
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
if not os.path.exists(full_file_name):
logger.info(f"Didn't find file {full_file_name}. We won't load it.")
full_file_name = None
else:
full_file_name = hf_bucket_url(
pretrained_model_name_or_path,
filename=file_name,
subfolder=subfolder,
revision=revision,
mirror=None,
)
vocab_files[file_id] = full_file_name
vocab_files["tokenizer_file"] = fast_tokenizer_file
# Get files from url, cache, or disk depending on the case
resolved_vocab_files = {}
unresolved_files = []
for file_id, file_path in vocab_files.items():
print(file_id, file_path)
if file_path is None:
resolved_vocab_files[file_id] = None
else:
try:
resolved_vocab_files[file_id] = cached_path(
resolved_vocab_files[file_id] = cached_file(
pretrained_model_name_or_path,
file_path,
cache_dir=cache_dir,
force_download=force_download,
......@@ -1750,35 +1727,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
except FileNotFoundError as error:
if local_files_only:
unresolved_files.append(file_id)
else:
raise error
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
"pass a token having permission to this repo with `use_auth_token` or log in with "
"`huggingface-cli login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
"for this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
logger.debug(f"{pretrained_model_name_or_path} does not contain a file named {file_path}.")
resolved_vocab_files[file_id] = None
except ValueError:
logger.debug(f"Connection problem to access {file_path} and it wasn't found in the cache.")
resolved_vocab_files[file_id] = None
if len(unresolved_files) > 0:
logger.info(
f"Can't load following files from cache: {unresolved_files} and cannot check if these "
......@@ -1797,7 +1751,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if file_id not in resolved_vocab_files:
continue
if file_path == resolved_vocab_files[file_id]:
if is_local:
logger.info(f"loading file {file_path}")
else:
logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}")
......
......@@ -60,6 +60,7 @@ from .hub import (
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
cached_file,
cached_path,
default_cache_path,
define_sagemaker_information,
......@@ -76,6 +77,7 @@ from .hub import (
is_local_clone,
is_offline_mode,
is_remote_url,
move_cache,
send_example_telemetry,
url_to_filename,
)
......
This diff is collapsed.
......@@ -345,14 +345,14 @@ class ConfigTestUtils(unittest.TestCase):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
......
......@@ -170,13 +170,13 @@ class FeatureExtractorUtilTester(unittest.TestCase):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
# This check we did call the fake head request
mock_head.assert_called()
......
......@@ -2925,14 +2925,14 @@ class ModelUtilsTest(TestCasePlus):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
......
......@@ -1922,14 +1922,14 @@ class UtilsFunctionsTest(unittest.TestCase):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
......
......@@ -3829,14 +3829,14 @@ class TokenizerUtilTester(unittest.TestCase):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment