Unverified Commit 5cd40323 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Use new huggingface_hub tools for download models (#18438)

* Draft new cached_file

* Initial draft for config and model

* Small fixes

* Fix first batch of tests

* Look in cache when internet is down

* Fix last tests

* Bad black, not fixing all quality errors

* Make diff less

* Implement change for TF and Flax models

* Add tokenizer and feature extractor

* For compatibility with main

* Add utils to move the cache and auto-do it at first use.

* Quality

* Deal with empty commit shas

* Deal with empty etag

* Address review comments
parent 70fa1a8d
...@@ -25,25 +25,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -25,25 +25,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from packaging import version from packaging import version
from requests import HTTPError
from . import __version__ from . import __version__
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .utils import ( from .utils import CONFIG_NAME, PushToHubMixin, cached_file, copy_func, is_torch_available, logging
CONFIG_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
cached_path,
copy_func,
hf_bucket_url,
is_offline_mode,
is_remote_url,
is_torch_available,
logging,
)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -591,33 +575,21 @@ class PretrainedConfig(PushToHubMixin): ...@@ -591,33 +575,21 @@ class PretrainedConfig(PushToHubMixin):
if from_pipeline is not None: if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline user_agent["using_pipeline"] = from_pipeline
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
pretrained_model_name_or_path
):
config_file = pretrained_model_name_or_path
else:
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
if os.path.isdir(os.path.join(pretrained_model_name_or_path, subfolder)): is_local = os.path.isdir(pretrained_model_name_or_path)
config_file = os.path.join(pretrained_model_name_or_path, subfolder, configuration_file) if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
# Soecial case when pretrained_model_name_or_path is a local file
resolved_config_file = pretrained_model_name_or_path
is_local = True
else: else:
config_file = hf_bucket_url( configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
pretrained_model_name_or_path,
filename=configuration_file,
revision=revision,
subfolder=subfolder if len(subfolder) > 0 else None,
mirror=None,
)
try: try:
# Load from URL or cache if already cached # Load from local folder or from cache or download from model Hub and cache
resolved_config_file = cached_path( resolved_config_file = cached_file(
config_file, pretrained_model_name_or_path,
configuration_file,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
proxies=proxies, proxies=proxies,
...@@ -625,42 +597,20 @@ class PretrainedConfig(PushToHubMixin): ...@@ -625,42 +597,20 @@ class PretrainedConfig(PushToHubMixin):
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
) revision=revision,
subfolder=subfolder,
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
"`use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
"available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
f" containing a {configuration_file} file.\nCheckout your internet connection or see how to run the"
" library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
) )
except EnvironmentError: except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError( raise EnvironmentError(
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. " " from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
f"containing a {configuration_file} file" f" containing a {configuration_file} file"
) )
try: try:
...@@ -671,10 +621,10 @@ class PretrainedConfig(PushToHubMixin): ...@@ -671,10 +621,10 @@ class PretrainedConfig(PushToHubMixin):
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file." f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
) )
if resolved_config_file == config_file: if is_local:
logger.info(f"loading configuration file {config_file}") logger.info(f"loading configuration file {resolved_config_file}")
else: else:
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}") logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
return config_dict, kwargs return config_dict, kwargs
......
...@@ -24,23 +24,15 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union ...@@ -24,23 +24,15 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import numpy as np import numpy as np
from requests import HTTPError
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .utils import ( from .utils import (
FEATURE_EXTRACTOR_NAME, FEATURE_EXTRACTOR_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError,
PushToHubMixin, PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
TensorType, TensorType,
cached_path, cached_file,
copy_func, copy_func,
hf_bucket_url,
is_flax_available, is_flax_available,
is_offline_mode, is_offline_mode,
is_remote_url,
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
logging, logging,
...@@ -388,18 +380,18 @@ class FeatureExtractionMixin(PushToHubMixin): ...@@ -388,18 +380,18 @@ class FeatureExtractionMixin(PushToHubMixin):
local_files_only = True local_files_only = True
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME) feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): if os.path.isfile(pretrained_model_name_or_path):
feature_extractor_file = pretrained_model_name_or_path resolved_feature_extractor_file = pretrained_model_name_or_path
is_local = True
else: else:
feature_extractor_file = hf_bucket_url( feature_extractor_file = FEATURE_EXTRACTOR_NAME
pretrained_model_name_or_path, filename=FEATURE_EXTRACTOR_NAME, revision=revision, mirror=None
)
try: try:
# Load from URL or cache if already cached # Load from local folder or from cache or download from model Hub and cache
resolved_feature_extractor_file = cached_path( resolved_feature_extractor_file = cached_file(
pretrained_model_name_or_path,
feature_extractor_file, feature_extractor_file,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
...@@ -408,43 +400,19 @@ class FeatureExtractionMixin(PushToHubMixin): ...@@ -408,43 +400,19 @@ class FeatureExtractionMixin(PushToHubMixin):
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
) revision=revision,
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
"`use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
"available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {FEATURE_EXTRACTOR_NAME}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
f" containing a {FEATURE_EXTRACTOR_NAME} file.\nCheckout your internet connection or see how to run"
" the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
) )
except EnvironmentError: except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError( raise EnvironmentError(
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load it " f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
"from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. " " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f"containing a {FEATURE_EXTRACTOR_NAME} file" f" directory containing a {FEATURE_EXTRACTOR_NAME} file"
) )
try: try:
...@@ -458,12 +426,11 @@ class FeatureExtractionMixin(PushToHubMixin): ...@@ -458,12 +426,11 @@ class FeatureExtractionMixin(PushToHubMixin):
f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file." f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
) )
if resolved_feature_extractor_file == feature_extractor_file: if is_local:
logger.info(f"loading feature extractor configuration file {feature_extractor_file}") logger.info(f"loading configuration file {resolved_feature_extractor_file}")
else: else:
logger.info( logger.info(
f"loading feature extractor configuration file {feature_extractor_file} from cache at" f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}"
f" {resolved_feature_extractor_file}"
) )
return feature_extractor_dict, kwargs return feature_extractor_dict, kwargs
......
...@@ -32,7 +32,6 @@ from flax.core.frozen_dict import FrozenDict, unfreeze ...@@ -32,7 +32,6 @@ from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.serialization import from_bytes, to_bytes from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey from jax.random import PRNGKey
from requests import HTTPError
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
...@@ -41,20 +40,14 @@ from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_d ...@@ -41,20 +40,14 @@ from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_d
from .utils import ( from .utils import (
FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_INDEX_NAME,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_NAME, WEIGHTS_NAME,
EntryNotFoundError,
PushToHubMixin, PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
cached_path, cached_file,
copy_func, copy_func,
has_file, has_file,
hf_bucket_url,
is_offline_mode, is_offline_mode,
is_remote_url,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
...@@ -557,6 +550,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -557,6 +550,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git. identifier allowed by git.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
...@@ -598,6 +594,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -598,6 +594,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
from_pipeline = kwargs.pop("_from_pipeline", None) from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
_do_init = kwargs.pop("_do_init", True) _do_init = kwargs.pop("_do_init", True)
subfolder = kwargs.pop("subfolder", "")
if trust_remote_code is True: if trust_remote_code is True:
logger.warning( logger.warning(
...@@ -642,6 +639,8 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -642,6 +639,8 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# Load model # Load model
if pretrained_model_name_or_path is not None: if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
...@@ -665,21 +664,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -665,21 +664,14 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
f"{pretrained_model_name_or_path}." f"{pretrained_model_name_or_path}."
) )
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path archive_file = pretrained_model_name_or_path
is_local = True
else: else:
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=filename,
revision=revision,
)
# redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path( # Load from URL or cache if already cached
archive_file, cached_file_kwargs = dict(
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
proxies=proxies, proxies=proxies,
...@@ -687,43 +679,29 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -687,43 +679,29 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
) )
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
except RepositoryNotFoundError: # Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
raise EnvironmentError( # result when internet is up, the repo and revision exist, but the file does not.
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
if filename == FLAX_WEIGHTS_NAME:
try:
# Maybe the checkpoint is sharded, we try to grab the index name in this case. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
archive_file = hf_bucket_url( resolved_archive_file = cached_file(
pretrained_model_name_or_path, pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs
filename=FLAX_WEIGHTS_INDEX_NAME,
revision=revision,
)
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
) )
if resolved_archive_file is not None:
is_sharded = True is_sharded = True
except EntryNotFoundError: if resolved_archive_file is None:
has_file_kwargs = {"revision": revision, "proxies": proxies, "use_auth_token": use_auth_token} # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
has_file_kwargs = {
"revision": revision,
"proxies": proxies,
"use_auth_token": use_auth_token,
}
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named" f"{pretrained_model_name_or_path} does not appear to have a file named"
...@@ -735,35 +713,24 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -735,35 +713,24 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
f"{pretrained_model_name_or_path} does not appear to have a file named" f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
) )
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
" internet connection or see how to run the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError: except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
# to the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError( raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. " " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
) )
if resolved_archive_file == archive_file: if is_local:
logger.info(f"loading weights file {archive_file}") logger.info(f"loading weights file {archive_file}")
resolved_archive_file = archive_file
else: else:
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}") logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
else: else:
resolved_archive_file = None resolved_archive_file = None
......
...@@ -37,7 +37,6 @@ from tensorflow.python.keras.saving import hdf5_format ...@@ -37,7 +37,6 @@ from tensorflow.python.keras.saving import hdf5_format
from huggingface_hub import Repository, list_repo_files from huggingface_hub import Repository, list_repo_files
from keras.saving.hdf5_format import save_attributes_to_hdf5_group from keras.saving.hdf5_format import save_attributes_to_hdf5_group
from requests import HTTPError
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from . import DataCollatorWithPadding, DefaultDataCollator from . import DataCollatorWithPadding, DefaultDataCollator
...@@ -48,22 +47,16 @@ from .generation_tf_utils import TFGenerationMixin ...@@ -48,22 +47,16 @@ from .generation_tf_utils import TFGenerationMixin
from .tf_utils import shape_list from .tf_utils import shape_list
from .utils import ( from .utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_INDEX_NAME,
TF2_WEIGHTS_NAME, TF2_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
EntryNotFoundError,
ModelOutput, ModelOutput,
PushToHubMixin, PushToHubMixin,
RepositoryNotFoundError, cached_file,
RevisionNotFoundError,
cached_path,
find_labels, find_labels,
has_file, has_file,
hf_bucket_url,
is_offline_mode, is_offline_mode,
is_remote_url,
logging, logging,
requires_backends, requires_backends,
working_or_temp_dir, working_or_temp_dir,
...@@ -2112,6 +2105,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2112,6 +2105,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
Mirror source to accelerate downloads in China. If you are from China and have an accessibility Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information. Please refer to the mirror site for more information.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
...@@ -2164,6 +2160,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2164,6 +2160,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
load_weight_prefix = kwargs.pop("load_weight_prefix", None) load_weight_prefix = kwargs.pop("load_weight_prefix", None)
from_pipeline = kwargs.pop("_from_pipeline", None) from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
subfolder = kwargs.pop("subfolder", "")
if trust_remote_code is True: if trust_remote_code is True:
logger.warning( logger.warning(
...@@ -2202,9 +2199,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2202,9 +2199,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files. # index of the files.
is_sharded = False is_sharded = False
sharded_metadata = None
# Load model # Load model
if pretrained_model_name_or_path is not None: if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint in priority if from_pt # Load from a PyTorch checkpoint in priority if from_pt
...@@ -2232,23 +2230,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2232,23 +2230,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
f"Error no file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " f"Error no file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
f"{pretrained_model_name_or_path}." f"{pretrained_model_name_or_path}."
) )
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path archive_file = pretrained_model_name_or_path
is_local = True
elif os.path.isfile(pretrained_model_name_or_path + ".index"): elif os.path.isfile(pretrained_model_name_or_path + ".index"):
archive_file = pretrained_model_name_or_path + ".index" archive_file = pretrained_model_name_or_path + ".index"
is_local = True
else: else:
# set correct filename
filename = WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME filename = WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=filename,
revision=revision,
mirror=mirror,
)
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
resolved_archive_file = cached_path( cached_file_kwargs = dict(
archive_file,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
proxies=proxies, proxies=proxies,
...@@ -2256,44 +2250,23 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2256,44 +2250,23 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
) )
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
except RepositoryNotFoundError: # Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
raise EnvironmentError( # result when internet is up, the repo and revision exist, but the file does not.
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME:
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
if filename == TF2_WEIGHTS_NAME:
try:
# Maybe the checkpoint is sharded, we try to grab the index name in this case. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
archive_file = hf_bucket_url( resolved_archive_file = cached_file(
pretrained_model_name_or_path, pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME, **cached_file_kwargs
filename=TF2_WEIGHTS_INDEX_NAME,
revision=revision,
mirror=mirror,
)
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
) )
if resolved_archive_file is not None:
is_sharded = True is_sharded = True
except EntryNotFoundError: if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error # Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error
# message. # message.
has_file_kwargs = { has_file_kwargs = {
"revision": revision, "revision": revision,
...@@ -2312,42 +2285,32 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2312,42 +2285,32 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
f"{pretrained_model_name_or_path} does not appear to have a file named" f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}." f" {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
) )
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your internet"
" connection or see how to run the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError: except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
# to the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError( raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. " " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f"containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}." f" directory containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
) )
if is_local:
if resolved_archive_file == archive_file:
logger.info(f"loading weights file {archive_file}") logger.info(f"loading weights file {archive_file}")
resolved_archive_file = archive_file
else: else:
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}") logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
else: else:
resolved_archive_file = None resolved_archive_file = None
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded. # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
if is_sharded: if is_sharded:
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( resolved_archive_file, _ = get_checkpoint_shard_files(
pretrained_model_name_or_path, pretrained_model_name_or_path,
resolved_archive_file, resolved_archive_file,
cache_dir=cache_dir, cache_dir=cache_dir,
......
...@@ -31,7 +31,6 @@ from packaging import version ...@@ -31,7 +31,6 @@ from packaging import version
from torch import Tensor, device, nn from torch import Tensor, device, nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from requests import HTTPError
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from transformers.utils.import_utils import is_sagemaker_mp_enabled from transformers.utils.import_utils import is_sagemaker_mp_enabled
...@@ -51,24 +50,18 @@ from .pytorch_utils import ( # noqa: F401 ...@@ -51,24 +50,18 @@ from .pytorch_utils import ( # noqa: F401
from .utils import ( from .utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
TF2_WEIGHTS_NAME, TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME, TF_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
ContextManagers, ContextManagers,
EntryNotFoundError,
ModelOutput, ModelOutput,
PushToHubMixin, PushToHubMixin,
RepositoryNotFoundError, cached_file,
RevisionNotFoundError,
cached_path,
copy_func, copy_func,
has_file, has_file,
hf_bucket_url,
is_accelerate_available, is_accelerate_available,
is_offline_mode, is_offline_mode,
is_remote_url,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
...@@ -1868,7 +1861,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1868,7 +1861,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if pretrained_model_name_or_path is not None: if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path): is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local:
if from_tf and os.path.isfile( if from_tf and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
): ):
...@@ -1911,10 +1905,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1911,10 +1905,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or " f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or "
f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}." f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
) )
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url( elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
pretrained_model_name_or_path
):
archive_file = pretrained_model_name_or_path archive_file = pretrained_model_name_or_path
is_local = True
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")): elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")):
if not from_tf: if not from_tf:
raise ValueError( raise ValueError(
...@@ -1922,6 +1915,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1922,6 +1915,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"from_tf to True to load from this checkpoint." "from_tf to True to load from this checkpoint."
) )
archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index") archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
is_local = True
else: else:
# set correct filename # set correct filename
if from_tf: if from_tf:
...@@ -1931,18 +1925,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1931,18 +1925,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else: else:
filename = WEIGHTS_NAME filename = WEIGHTS_NAME
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=filename,
revision=revision,
mirror=mirror,
subfolder=subfolder if len(subfolder) > 0 else None,
)
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
resolved_archive_file = cached_path( cached_file_kwargs = dict(
archive_file,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
proxies=proxies, proxies=proxies,
...@@ -1950,44 +1935,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1950,44 +1935,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
) )
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
except RepositoryNotFoundError: # Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
raise EnvironmentError( # result when internet is up, the repo and revision exist, but the file does not.
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " if resolved_archive_file is None and filename == WEIGHTS_NAME:
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
if filename == WEIGHTS_NAME:
try:
# Maybe the checkpoint is sharded, we try to grab the index name in this case. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
archive_file = hf_bucket_url( resolved_archive_file = cached_file(
pretrained_model_name_or_path, pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
filename=WEIGHTS_INDEX_NAME,
revision=revision,
mirror=mirror,
subfolder=subfolder if len(subfolder) > 0 else None,
)
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
) )
if resolved_archive_file is not None:
is_sharded = True is_sharded = True
except EntryNotFoundError: if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message. # message.
has_file_kwargs = { has_file_kwargs = {
...@@ -2013,42 +1976,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2013,42 +1976,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}," f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME},"
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
) )
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
f" {FLAX_WEIGHTS_NAME}.\nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError: except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
# to the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError( raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. " " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f"containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or " f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
f"{FLAX_WEIGHTS_NAME}." f" {FLAX_WEIGHTS_NAME}."
) )
if resolved_archive_file == archive_file: if is_local:
logger.info(f"loading weights file {archive_file}") logger.info(f"loading weights file {archive_file}")
resolved_archive_file = archive_file
else: else:
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}") logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
else: else:
resolved_archive_file = None resolved_archive_file = None
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded. # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
if is_sharded: if is_sharded:
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. # rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
pretrained_model_name_or_path, pretrained_model_name_or_path,
resolved_archive_file, resolved_archive_file,
......
...@@ -35,21 +35,16 @@ from packaging import version ...@@ -35,21 +35,16 @@ from packaging import version
from . import __version__ from . import __version__
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .utils import ( from .utils import (
EntryNotFoundError,
ExplicitEnum, ExplicitEnum,
PaddingStrategy, PaddingStrategy,
PushToHubMixin, PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
TensorType, TensorType,
add_end_docstrings, add_end_docstrings,
cached_path, cached_file,
copy_func, copy_func,
get_file_from_repo, get_file_from_repo,
hf_bucket_url,
is_flax_available, is_flax_available,
is_offline_mode, is_offline_mode,
is_remote_url,
is_tf_available, is_tf_available,
is_tokenizers_available, is_tokenizers_available,
is_torch_available, is_torch_available,
...@@ -1669,7 +1664,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1669,7 +1664,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
vocab_files = {} vocab_files = {}
init_configuration = {} init_configuration = {}
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path):
if len(cls.vocab_files_names) > 1: if len(cls.vocab_files_names) > 1:
raise ValueError( raise ValueError(
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not " f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not "
...@@ -1689,9 +1685,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1689,9 +1685,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
"tokenizer_config_file": TOKENIZER_CONFIG_FILE, "tokenizer_config_file": TOKENIZER_CONFIG_FILE,
} }
vocab_files_target = {**cls.vocab_files_names, **additional_files_names} vocab_files = {**cls.vocab_files_names, **additional_files_names}
if "tokenizer_file" in vocab_files_target: if "tokenizer_file" in vocab_files:
# Try to get the tokenizer config to see if there are versioned tokenizer files. # Try to get the tokenizer config to see if there are versioned tokenizer files.
fast_tokenizer_file = FULL_TOKENIZER_FILE fast_tokenizer_file = FULL_TOKENIZER_FILE
resolved_config_file = get_file_from_repo( resolved_config_file = get_file_from_repo(
...@@ -1704,44 +1700,25 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1704,44 +1700,25 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
revision=revision, revision=revision,
local_files_only=local_files_only, local_files_only=local_files_only,
subfolder=subfolder,
) )
if resolved_config_file is not None: if resolved_config_file is not None:
with open(resolved_config_file, encoding="utf-8") as reader: with open(resolved_config_file, encoding="utf-8") as reader:
tokenizer_config = json.load(reader) tokenizer_config = json.load(reader)
if "fast_tokenizer_files" in tokenizer_config: if "fast_tokenizer_files" in tokenizer_config:
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"]) fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
vocab_files_target["tokenizer_file"] = fast_tokenizer_file vocab_files["tokenizer_file"] = fast_tokenizer_file
# Look for the tokenizer files
for file_id, file_name in vocab_files_target.items():
if os.path.isdir(pretrained_model_name_or_path):
if subfolder is not None:
full_file_name = os.path.join(pretrained_model_name_or_path, subfolder, file_name)
else:
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
if not os.path.exists(full_file_name):
logger.info(f"Didn't find file {full_file_name}. We won't load it.")
full_file_name = None
else:
full_file_name = hf_bucket_url(
pretrained_model_name_or_path,
filename=file_name,
subfolder=subfolder,
revision=revision,
mirror=None,
)
vocab_files[file_id] = full_file_name
# Get files from url, cache, or disk depending on the case # Get files from url, cache, or disk depending on the case
resolved_vocab_files = {} resolved_vocab_files = {}
unresolved_files = [] unresolved_files = []
for file_id, file_path in vocab_files.items(): for file_id, file_path in vocab_files.items():
print(file_id, file_path)
if file_path is None: if file_path is None:
resolved_vocab_files[file_id] = None resolved_vocab_files[file_id] = None
else: else:
try: resolved_vocab_files[file_id] = cached_file(
resolved_vocab_files[file_id] = cached_path( pretrained_model_name_or_path,
file_path, file_path,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
...@@ -1750,35 +1727,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1750,35 +1727,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
) )
except FileNotFoundError as error:
if local_files_only:
unresolved_files.append(file_id)
else:
raise error
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
"pass a token having permission to this repo with `use_auth_token` or log in with "
"`huggingface-cli login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
"for this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
logger.debug(f"{pretrained_model_name_or_path} does not contain a file named {file_path}.")
resolved_vocab_files[file_id] = None
except ValueError:
logger.debug(f"Connection problem to access {file_path} and it wasn't found in the cache.")
resolved_vocab_files[file_id] = None
if len(unresolved_files) > 0: if len(unresolved_files) > 0:
logger.info( logger.info(
f"Can't load following files from cache: {unresolved_files} and cannot check if these " f"Can't load following files from cache: {unresolved_files} and cannot check if these "
...@@ -1797,7 +1751,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1797,7 +1751,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if file_id not in resolved_vocab_files: if file_id not in resolved_vocab_files:
continue continue
if file_path == resolved_vocab_files[file_id]: if is_local:
logger.info(f"loading file {file_path}") logger.info(f"loading file {file_path}")
else: else:
logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}") logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}")
......
...@@ -60,6 +60,7 @@ from .hub import ( ...@@ -60,6 +60,7 @@ from .hub import (
PushToHubMixin, PushToHubMixin,
RepositoryNotFoundError, RepositoryNotFoundError,
RevisionNotFoundError, RevisionNotFoundError,
cached_file,
cached_path, cached_path,
default_cache_path, default_cache_path,
define_sagemaker_information, define_sagemaker_information,
...@@ -76,6 +77,7 @@ from .hub import ( ...@@ -76,6 +77,7 @@ from .hub import (
is_local_clone, is_local_clone,
is_offline_mode, is_offline_mode,
is_remote_url, is_remote_url,
move_cache,
send_example_telemetry, send_example_telemetry,
url_to_filename, url_to_filename,
) )
......
This diff is collapsed.
...@@ -345,14 +345,14 @@ class ConfigTestUtils(unittest.TestCase): ...@@ -345,14 +345,14 @@ class ConfigTestUtils(unittest.TestCase):
# A mock response for an HTTP head request to emulate server down # A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock() response_mock = mock.Mock()
response_mock.status_code = 500 response_mock.status_code = 500
response_mock.headers = [] response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache. # Download this model to make sure it's in the cache.
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") _ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model. # Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head: with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") _ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request # This check we did call the fake head request
mock_head.assert_called() mock_head.assert_called()
......
...@@ -170,13 +170,13 @@ class FeatureExtractorUtilTester(unittest.TestCase): ...@@ -170,13 +170,13 @@ class FeatureExtractorUtilTester(unittest.TestCase):
# A mock response for an HTTP head request to emulate server down # A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock() response_mock = mock.Mock()
response_mock.status_code = 500 response_mock.status_code = 500
response_mock.headers = [] response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache. # Download this model to make sure it's in the cache.
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2") _ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
# Under the mock environment we get a 500 error when trying to reach the model. # Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head: with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2") _ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
# This check we did call the fake head request # This check we did call the fake head request
mock_head.assert_called() mock_head.assert_called()
......
...@@ -2925,14 +2925,14 @@ class ModelUtilsTest(TestCasePlus): ...@@ -2925,14 +2925,14 @@ class ModelUtilsTest(TestCasePlus):
# A mock response for an HTTP head request to emulate server down # A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock() response_mock = mock.Mock()
response_mock.status_code = 500 response_mock.status_code = 500
response_mock.headers = [] response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache. # Download this model to make sure it's in the cache.
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") _ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model. # Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head: with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") _ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request # This check we did call the fake head request
mock_head.assert_called() mock_head.assert_called()
......
...@@ -1922,14 +1922,14 @@ class UtilsFunctionsTest(unittest.TestCase): ...@@ -1922,14 +1922,14 @@ class UtilsFunctionsTest(unittest.TestCase):
# A mock response for an HTTP head request to emulate server down # A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock() response_mock = mock.Mock()
response_mock.status_code = 500 response_mock.status_code = 500
response_mock.headers = [] response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache. # Download this model to make sure it's in the cache.
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") _ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model. # Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head: with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") _ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request # This check we did call the fake head request
mock_head.assert_called() mock_head.assert_called()
......
...@@ -3829,14 +3829,14 @@ class TokenizerUtilTester(unittest.TestCase): ...@@ -3829,14 +3829,14 @@ class TokenizerUtilTester(unittest.TestCase):
# A mock response for an HTTP head request to emulate server down # A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock() response_mock = mock.Mock()
response_mock.status_code = 500 response_mock.status_code = 500
response_mock.headers = [] response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache. # Download this model to make sure it's in the cache.
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") _ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model. # Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head: with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") _ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request # This check we did call the fake head request
mock_head.assert_called() mock_head.assert_called()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment