Unverified Commit 0d0aada5 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Use commit hash to look in cache instead of calling head (#18534)



* Use commit hash to look in cache instead of calling head

* Add tests

* Add attr for local configs too

* Stupid typos

* Fix tests

* Update src/transformers/utils/hub.py
Co-authored-by: default avatarJulien Chaumond <julien@huggingface.co>

* Address Julien's comments
Co-authored-by: default avatarJulien Chaumond <julien@huggingface.co>
parent 6eb51450
...@@ -27,7 +27,15 @@ from packaging import version ...@@ -27,7 +27,15 @@ from packaging import version
from . import __version__ from . import __version__
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .utils import CONFIG_NAME, PushToHubMixin, cached_file, copy_func, is_torch_available, logging from .utils import (
CONFIG_NAME,
PushToHubMixin,
cached_file,
copy_func,
extract_commit_hash,
is_torch_available,
logging,
)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -343,6 +351,8 @@ class PretrainedConfig(PushToHubMixin): ...@@ -343,6 +351,8 @@ class PretrainedConfig(PushToHubMixin):
# Name or path to the pretrained checkpoint # Name or path to the pretrained checkpoint
self._name_or_path = str(kwargs.pop("name_or_path", "")) self._name_or_path = str(kwargs.pop("name_or_path", ""))
# Config hash
self._commit_hash = kwargs.pop("_commit_hash", None)
# Drop the transformers version info # Drop the transformers version info
self.transformers_version = kwargs.pop("transformers_version", None) self.transformers_version = kwargs.pop("transformers_version", None)
...@@ -539,6 +549,8 @@ class PretrainedConfig(PushToHubMixin): ...@@ -539,6 +549,8 @@ class PretrainedConfig(PushToHubMixin):
original_kwargs = copy.deepcopy(kwargs) original_kwargs = copy.deepcopy(kwargs)
# Get config dict associated with the base config file # Get config dict associated with the base config file
config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
if "_commit_hash" in config_dict:
original_kwargs["_commit_hash"] = config_dict["_commit_hash"]
# That config file may point us toward another config file to use. # That config file may point us toward another config file to use.
if "configuration_files" in config_dict: if "configuration_files" in config_dict:
...@@ -564,6 +576,7 @@ class PretrainedConfig(PushToHubMixin): ...@@ -564,6 +576,7 @@ class PretrainedConfig(PushToHubMixin):
subfolder = kwargs.pop("subfolder", "") subfolder = kwargs.pop("subfolder", "")
from_pipeline = kwargs.pop("_from_pipeline", None) from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
commit_hash = kwargs.pop("_commit_hash", None)
if trust_remote_code is True: if trust_remote_code is True:
logger.warning( logger.warning(
...@@ -599,7 +612,9 @@ class PretrainedConfig(PushToHubMixin): ...@@ -599,7 +612,9 @@ class PretrainedConfig(PushToHubMixin):
user_agent=user_agent, user_agent=user_agent,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
_commit_hash=commit_hash,
) )
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
except EnvironmentError: except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception. # the original exception.
...@@ -616,6 +631,7 @@ class PretrainedConfig(PushToHubMixin): ...@@ -616,6 +631,7 @@ class PretrainedConfig(PushToHubMixin):
try: try:
# Load config dict # Load config dict
config_dict = cls._dict_from_json_file(resolved_config_file) config_dict = cls._dict_from_json_file(resolved_config_file)
config_dict["_commit_hash"] = commit_hash
except (json.JSONDecodeError, UnicodeDecodeError): except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError( raise EnvironmentError(
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file." f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
...@@ -648,6 +664,9 @@ class PretrainedConfig(PushToHubMixin): ...@@ -648,6 +664,9 @@ class PretrainedConfig(PushToHubMixin):
# We remove them so they don't appear in `return_unused_kwargs`. # We remove them so they don't appear in `return_unused_kwargs`.
kwargs.pop("_from_auto", None) kwargs.pop("_from_auto", None)
kwargs.pop("_from_pipeline", None) kwargs.pop("_from_pipeline", None)
# The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
kwargs["_commit_hash"] = config_dict["_commit_hash"]
config = cls(**config_dict) config = cls(**config_dict)
...@@ -751,6 +770,8 @@ class PretrainedConfig(PushToHubMixin): ...@@ -751,6 +770,8 @@ class PretrainedConfig(PushToHubMixin):
output["model_type"] = self.__class__.model_type output["model_type"] = self.__class__.model_type
if "_auto_class" in output: if "_auto_class" in output:
del output["_auto_class"] del output["_auto_class"]
if "_commit_hash" in output:
del output["_commit_hash"]
# Transformers version when serializing the model # Transformers version when serializing the model
output["transformers_version"] = __version__ output["transformers_version"] = __version__
......
...@@ -595,6 +595,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -595,6 +595,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
_do_init = kwargs.pop("_do_init", True) _do_init = kwargs.pop("_do_init", True)
subfolder = kwargs.pop("subfolder", "") subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
if trust_remote_code is True: if trust_remote_code is True:
logger.warning( logger.warning(
...@@ -625,11 +626,15 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -625,11 +626,15 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
revision=revision, revision=revision,
_from_auto=from_auto_class, _from_auto=from_auto_class,
_from_pipeline=from_pipeline, _from_pipeline=from_pipeline,
_commit_hash=commit_hash,
**kwargs, **kwargs,
) )
else: else:
model_kwargs = kwargs model_kwargs = kwargs
if commit_hash is None:
commit_hash = getattr(config, "_commit_hash", None)
# Add the dtype to model_kwargs # Add the dtype to model_kwargs
model_kwargs["dtype"] = dtype model_kwargs["dtype"] = dtype
...@@ -682,6 +687,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -682,6 +687,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
_raise_exceptions_for_missing_entries=False, _raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
) )
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
...@@ -748,6 +754,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -748,6 +754,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
revision=revision, revision=revision,
_commit_hash=commit_hash,
) )
# init random models # init random models
......
...@@ -2161,6 +2161,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2161,6 +2161,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
from_pipeline = kwargs.pop("_from_pipeline", None) from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
subfolder = kwargs.pop("subfolder", "") subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
if trust_remote_code is True: if trust_remote_code is True:
logger.warning( logger.warning(
...@@ -2191,11 +2192,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2191,11 +2192,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
revision=revision, revision=revision,
_from_auto=from_auto_class, _from_auto=from_auto_class,
_from_pipeline=from_pipeline, _from_pipeline=from_pipeline,
_commit_hash=commit_hash,
**kwargs, **kwargs,
) )
else: else:
model_kwargs = kwargs model_kwargs = kwargs
if commit_hash is None:
commit_hash = getattr(config, "_commit_hash", None)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files. # index of the files.
is_sharded = False is_sharded = False
...@@ -2253,6 +2258,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2253,6 +2258,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
_raise_exceptions_for_missing_entries=False, _raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
) )
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
...@@ -2320,6 +2326,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2320,6 +2326,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
revision=revision, revision=revision,
_commit_hash=commit_hash,
) )
config.name_or_path = pretrained_model_name_or_path config.name_or_path = pretrained_model_name_or_path
......
...@@ -1840,6 +1840,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1840,6 +1840,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
load_in_8bit = kwargs.pop("load_in_8bit", False) load_in_8bit = kwargs.pop("load_in_8bit", False)
int8_threshold = kwargs.pop("int8_threshold", 6.0) int8_threshold = kwargs.pop("int8_threshold", 6.0)
subfolder = kwargs.pop("subfolder", "") subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
if trust_remote_code is True: if trust_remote_code is True:
logger.warning( logger.warning(
...@@ -1918,6 +1919,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1918,6 +1919,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else: else:
model_kwargs = kwargs model_kwargs = kwargs
if commit_hash is None:
commit_hash = getattr(config, "_commit_hash", None)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files. # index of the files.
is_sharded = False is_sharded = False
...@@ -2004,6 +2008,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2004,6 +2008,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
_raise_exceptions_for_missing_entries=False, _raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
) )
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
...@@ -2078,6 +2083,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2078,6 +2083,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
user_agent=user_agent, user_agent=user_agent,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
_commit_hash=commit_hash,
) )
# load pt weights early so that we know which dtype to init the model under # load pt weights early so that we know which dtype to init the model under
......
...@@ -25,7 +25,7 @@ from ...dynamic_module_utils import get_class_from_dynamic_module ...@@ -25,7 +25,7 @@ from ...dynamic_module_utils import get_class_from_dynamic_module
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import get_file_from_repo, is_sentencepiece_available, is_tokenizers_available, logging from ...utils import cached_file, extract_commit_hash, is_sentencepiece_available, is_tokenizers_available, logging
from ..encoder_decoder import EncoderDecoderConfig from ..encoder_decoder import EncoderDecoderConfig
from .auto_factory import _LazyAutoMapping from .auto_factory import _LazyAutoMapping
from .configuration_auto import ( from .configuration_auto import (
...@@ -389,7 +389,8 @@ def get_tokenizer_config( ...@@ -389,7 +389,8 @@ def get_tokenizer_config(
tokenizer.save_pretrained("tokenizer-test") tokenizer.save_pretrained("tokenizer-test")
tokenizer_config = get_tokenizer_config("tokenizer-test") tokenizer_config = get_tokenizer_config("tokenizer-test")
```""" ```"""
resolved_config_file = get_file_from_repo( commit_hash = kwargs.get("_commit_hash", None)
resolved_config_file = cached_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
TOKENIZER_CONFIG_FILE, TOKENIZER_CONFIG_FILE,
cache_dir=cache_dir, cache_dir=cache_dir,
...@@ -399,13 +400,19 @@ def get_tokenizer_config( ...@@ -399,13 +400,19 @@ def get_tokenizer_config(
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
revision=revision, revision=revision,
local_files_only=local_files_only, local_files_only=local_files_only,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
) )
if resolved_config_file is None: if resolved_config_file is None:
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.") logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
return {} return {}
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
with open(resolved_config_file, encoding="utf-8") as reader: with open(resolved_config_file, encoding="utf-8") as reader:
return json.load(reader) result = json.load(reader)
result["_commit_hash"] = commit_hash
return result
class AutoTokenizer: class AutoTokenizer:
...@@ -532,6 +539,8 @@ class AutoTokenizer: ...@@ -532,6 +539,8 @@ class AutoTokenizer:
# Next, let's try to use the tokenizer_config file to get the tokenizer class. # Next, let's try to use the tokenizer_config file to get the tokenizer class.
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs) tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
if "_commit_hash" in tokenizer_config:
kwargs["_commit_hash"] = tokenizer_config["_commit_hash"]
config_tokenizer_class = tokenizer_config.get("tokenizer_class") config_tokenizer_class = tokenizer_config.get("tokenizer_class")
tokenizer_auto_map = None tokenizer_auto_map = None
if "auto_map" in tokenizer_config: if "auto_map" in tokenizer_config:
......
...@@ -557,7 +557,12 @@ def pipeline( ...@@ -557,7 +557,12 @@ def pipeline(
# Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs, # Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
# this is to keep BC). # this is to keep BC).
use_auth_token = model_kwargs.pop("use_auth_token", use_auth_token) use_auth_token = model_kwargs.pop("use_auth_token", use_auth_token)
hub_kwargs = {"revision": revision, "use_auth_token": use_auth_token, "trust_remote_code": trust_remote_code} hub_kwargs = {
"revision": revision,
"use_auth_token": use_auth_token,
"trust_remote_code": trust_remote_code,
"_commit_hash": None,
}
if task is None and model is None: if task is None and model is None:
raise RuntimeError( raise RuntimeError(
...@@ -583,8 +588,10 @@ def pipeline( ...@@ -583,8 +588,10 @@ def pipeline(
# Instantiate config if needed # Instantiate config if needed
if isinstance(config, str): if isinstance(config, str):
config = AutoConfig.from_pretrained(config, _from_pipeline=task, **hub_kwargs, **model_kwargs) config = AutoConfig.from_pretrained(config, _from_pipeline=task, **hub_kwargs, **model_kwargs)
hub_kwargs["_commit_hash"] = config._commit_hash
elif config is None and isinstance(model, str): elif config is None and isinstance(model, str):
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs) config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
hub_kwargs["_commit_hash"] = config._commit_hash
custom_tasks = {} custom_tasks = {}
if config is not None and len(getattr(config, "custom_pipelines", {})) > 0: if config is not None and len(getattr(config, "custom_pipelines", {})) > 0:
...@@ -639,6 +646,7 @@ def pipeline( ...@@ -639,6 +646,7 @@ def pipeline(
) )
if config is None and isinstance(model, str): if config is None and isinstance(model, str):
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs) config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
hub_kwargs["_commit_hash"] = config._commit_hash
if device_map is not None: if device_map is not None:
if "device_map" in model_kwargs: if "device_map" in model_kwargs:
...@@ -672,6 +680,7 @@ def pipeline( ...@@ -672,6 +680,7 @@ def pipeline(
) )
model_config = model.config model_config = model.config
hub_kwargs["_commit_hash"] = model.config._commit_hash
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
......
...@@ -31,6 +31,7 @@ from pathlib import Path ...@@ -31,6 +31,7 @@ from pathlib import Path
from typing import Iterator, List, Union from typing import Iterator, List, Union
from unittest import mock from unittest import mock
import huggingface_hub
from transformers import logging as transformers_logging from transformers import logging as transformers_logging
from .deepspeed import is_deepspeed_available from .deepspeed import is_deepspeed_available
...@@ -1588,3 +1589,30 @@ def run_command(command: List[str], return_stdout=False): ...@@ -1588,3 +1589,30 @@ def run_command(command: List[str], return_stdout=False):
raise SubprocessCallException( raise SubprocessCallException(
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}" f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
) from e ) from e
class RequestCounter:
"""
Helper class that will count all requests made online.
"""
def __enter__(self):
self.head_request_count = 0
self.get_request_count = 0
self.other_request_count = 0
self.old_request = huggingface_hub.file_download.requests.request
huggingface_hub.file_download.requests.request = self.new_request
return self
def __exit__(self, *args, **kwargs):
huggingface_hub.file_download.requests.request = self.old_request
def new_request(self, method, **kwargs):
if method == "GET":
self.get_request_count += 1
elif method == "HEAD":
self.head_request_count += 1
else:
self.other_request_count += 1
return self.old_request(method=method, **kwargs)
...@@ -42,7 +42,7 @@ from .utils import ( ...@@ -42,7 +42,7 @@ from .utils import (
add_end_docstrings, add_end_docstrings,
cached_file, cached_file,
copy_func, copy_func,
get_file_from_repo, extract_commit_hash,
is_flax_available, is_flax_available,
is_offline_mode, is_offline_mode,
is_tf_available, is_tf_available,
...@@ -1651,6 +1651,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1651,6 +1651,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
from_pipeline = kwargs.pop("_from_pipeline", None) from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
commit_hash = kwargs.pop("_commit_hash", None)
user_agent = {"file_type": "tokenizer", "from_auto_class": from_auto_class, "is_fast": "Fast" in cls.__name__} user_agent = {"file_type": "tokenizer", "from_auto_class": from_auto_class, "is_fast": "Fast" in cls.__name__}
if from_pipeline is not None: if from_pipeline is not None:
...@@ -1690,7 +1691,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1690,7 +1691,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if "tokenizer_file" in vocab_files: if "tokenizer_file" in vocab_files:
# Try to get the tokenizer config to see if there are versioned tokenizer files. # Try to get the tokenizer config to see if there are versioned tokenizer files.
fast_tokenizer_file = FULL_TOKENIZER_FILE fast_tokenizer_file = FULL_TOKENIZER_FILE
resolved_config_file = get_file_from_repo( resolved_config_file = cached_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
TOKENIZER_CONFIG_FILE, TOKENIZER_CONFIG_FILE,
cache_dir=cache_dir, cache_dir=cache_dir,
...@@ -1701,7 +1702,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1701,7 +1702,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
revision=revision, revision=revision,
local_files_only=local_files_only, local_files_only=local_files_only,
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
) )
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
if resolved_config_file is not None: if resolved_config_file is not None:
with open(resolved_config_file, encoding="utf-8") as reader: with open(resolved_config_file, encoding="utf-8") as reader:
tokenizer_config = json.load(reader) tokenizer_config = json.load(reader)
...@@ -1730,7 +1736,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1730,7 +1736,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
subfolder=subfolder, subfolder=subfolder,
_raise_exceptions_for_missing_entries=False, _raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False, _raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
) )
commit_hash = extract_commit_hash(resolved_vocab_files[file_id], commit_hash)
if len(unresolved_files) > 0: if len(unresolved_files) > 0:
logger.info( logger.info(
...@@ -1763,6 +1771,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1763,6 +1771,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
cache_dir=cache_dir, cache_dir=cache_dir,
local_files_only=local_files_only, local_files_only=local_files_only,
_commit_hash=commit_hash,
**kwargs, **kwargs,
) )
...@@ -1776,6 +1785,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1776,6 +1785,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
use_auth_token=None, use_auth_token=None,
cache_dir=None, cache_dir=None,
local_files_only=False, local_files_only=False,
_commit_hash=None,
**kwargs **kwargs
): ):
# We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json # We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json
...@@ -1791,6 +1801,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1791,6 +1801,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
cache_dir=cache_dir, cache_dir=cache_dir,
local_files_only=local_files_only, local_files_only=local_files_only,
_commit_hash=_commit_hash,
**(copy.deepcopy(kwargs)), **(copy.deepcopy(kwargs)),
) )
else: else:
...@@ -1823,6 +1834,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1823,6 +1834,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
cache_dir=cache_dir, cache_dir=cache_dir,
local_files_only=local_files_only, local_files_only=local_files_only,
_commit_hash=_commit_hash,
) )
config_tokenizer_class = config.tokenizer_class config_tokenizer_class = config.tokenizer_class
except (OSError, ValueError, KeyError): except (OSError, ValueError, KeyError):
......
...@@ -63,6 +63,7 @@ from .hub import ( ...@@ -63,6 +63,7 @@ from .hub import (
cached_file, cached_file,
default_cache_path, default_cache_path,
define_sagemaker_information, define_sagemaker_information,
extract_commit_hash,
get_cached_models, get_cached_models,
get_file_from_repo, get_file_from_repo,
get_full_repo_name, get_full_repo_name,
......
...@@ -38,6 +38,7 @@ from huggingface_hub import ( ...@@ -38,6 +38,7 @@ from huggingface_hub import (
whoami, whoami,
) )
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers.utils.logging import tqdm from transformers.utils.logging import tqdm
...@@ -200,11 +201,27 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: ...@@ -200,11 +201,27 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return ua return ua
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None): def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]):
"""
Extracts the commit hash from a resolved filename toward a cache file.
"""
if resolved_file is None or commit_hash is not None:
return commit_hash
search = re.search(r"snapshots/([^/]+)/", resolved_file)
if search is None:
return None
commit_hash = search.groups()[0]
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_hash=None):
""" """
Explores the cache to return the latest cached file for a given revision. Explores the cache to return the latest cached file for a given revision.
""" """
if revision is None: if commit_hash is not None and revision is not None:
raise ValueError("`commit_hash` and `revision` are mutually exclusive, pick one only.")
if revision is None and commit_hash is None:
revision = "main" revision = "main"
model_id = repo_id.replace("/", "--") model_id = repo_id.replace("/", "--")
...@@ -216,18 +233,19 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None): ...@@ -216,18 +233,19 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None):
if not os.path.isdir(os.path.join(model_cache, subfolder)): if not os.path.isdir(os.path.join(model_cache, subfolder)):
return None return None
# Resolve refs (for instance to convert main to the associated commit sha) if commit_hash is None:
cached_refs = os.listdir(os.path.join(model_cache, "refs")) # Resolve refs (for instance to convert main to the associated commit sha)
if revision in cached_refs: cached_refs = os.listdir(os.path.join(model_cache, "refs"))
with open(os.path.join(model_cache, "refs", revision)) as f: if revision in cached_refs:
revision = f.read() with open(os.path.join(model_cache, "refs", revision)) as f:
commit_hash = f.read()
cached_shas = os.listdir(os.path.join(model_cache, "snapshots")) cached_shas = os.listdir(os.path.join(model_cache, "snapshots"))
if revision not in cached_shas: if commit_hash not in cached_shas:
# No cache for this revision and we won't try to return a random revision # No cache for this revision and we won't try to return a random revision
return None return None
cached_file = os.path.join(model_cache, "snapshots", revision, filename) cached_file = os.path.join(model_cache, "snapshots", commit_hash, filename)
return cached_file if os.path.isfile(cached_file) else None return cached_file if os.path.isfile(cached_file) else None
...@@ -265,8 +283,9 @@ def cached_file( ...@@ -265,8 +283,9 @@ def cached_file(
local_files_only: bool = False, local_files_only: bool = False,
subfolder: str = "", subfolder: str = "",
user_agent: Optional[Union[str, Dict[str, str]]] = None, user_agent: Optional[Union[str, Dict[str, str]]] = None,
_raise_exceptions_for_missing_entries=True, _raise_exceptions_for_missing_entries: bool = True,
_raise_exceptions_for_connection_errors=True, _raise_exceptions_for_connection_errors: bool = True,
_commit_hash: Optional[str] = None,
): ):
""" """
Tries to locate a file in a local folder and repo, downloads and cache it if necessary. Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
...@@ -318,6 +337,13 @@ def cached_file( ...@@ -318,6 +337,13 @@ def cached_file(
# Download a model weight from the Hub and cache it. # Download a model weight from the Hub and cache it.
model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin") model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin")
```""" ```"""
# Private arguments
# _raise_exceptions_for_missing_entries: if False, do not raise an exception for missing entries but return
# None.
# _raise_exceptions_for_connection_errors: if False, do not raise an exception for connection errors but return
# None.
# _commit_hash: passed when we are chaining several calls to various files (e.g. when loading a tokenizer or
# a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache.
if is_offline_mode() and not local_files_only: if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True") logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True local_files_only = True
...@@ -339,6 +365,13 @@ def cached_file( ...@@ -339,6 +365,13 @@ def cached_file(
cache_dir = TRANSFORMERS_CACHE cache_dir = TRANSFORMERS_CACHE
if isinstance(cache_dir, Path): if isinstance(cache_dir, Path):
cache_dir = str(cache_dir) cache_dir = str(cache_dir)
if _commit_hash is not None:
# If the file is cached under that commit hash, we return it directly.
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, commit_hash=_commit_hash)
if resolved_file is not None:
return resolved_file
user_agent = http_user_agent(user_agent) user_agent = http_user_agent(user_agent)
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
...@@ -803,6 +836,7 @@ def get_checkpoint_shard_files( ...@@ -803,6 +836,7 @@ def get_checkpoint_shard_files(
user_agent=None, user_agent=None,
revision=None, revision=None,
subfolder="", subfolder="",
_commit_hash=None,
): ):
""" """
For a given model: For a given model:
...@@ -848,6 +882,7 @@ def get_checkpoint_shard_files( ...@@ -848,6 +882,7 @@ def get_checkpoint_shard_files(
user_agent=user_agent, user_agent=user_agent,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
_commit_hash=_commit_hash,
) )
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. # we don't have to catch them here.
......
...@@ -24,6 +24,7 @@ from transformers.models.auto.configuration_auto import CONFIG_MAPPING ...@@ -24,6 +24,7 @@ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.testing_utils import ( from transformers.testing_utils import (
DUMMY_UNKNOWN_IDENTIFIER, DUMMY_UNKNOWN_IDENTIFIER,
SMALL_MODEL_IDENTIFIER, SMALL_MODEL_IDENTIFIER,
RequestCounter,
require_scatter, require_scatter,
require_torch, require_torch,
slow, slow,
...@@ -354,3 +355,21 @@ class AutoModelTest(unittest.TestCase): ...@@ -354,3 +355,21 @@ class AutoModelTest(unittest.TestCase):
def test_model_from_flax_suggestion(self): def test_model_from_flax_suggestion(self):
with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"): with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"):
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") _ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
def test_cached_model_has_minimum_calls_to_head(self):
# Make sure we have cached the model.
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
# With a sharded checkpoint
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
with RequestCounter() as counter:
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0)
# There is no pytorch_model.bin so we still get one call for this one.
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0)
...@@ -21,6 +21,7 @@ from transformers import CONFIG_MAPPING, AutoConfig, BertConfig, GPT2Config, T5C ...@@ -21,6 +21,7 @@ from transformers import CONFIG_MAPPING, AutoConfig, BertConfig, GPT2Config, T5C
from transformers.testing_utils import ( from transformers.testing_utils import (
DUMMY_UNKNOWN_IDENTIFIER, DUMMY_UNKNOWN_IDENTIFIER,
SMALL_MODEL_IDENTIFIER, SMALL_MODEL_IDENTIFIER,
RequestCounter,
require_tensorflow_probability, require_tensorflow_probability,
require_tf, require_tf,
slow, slow,
...@@ -287,3 +288,21 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -287,3 +288,21 @@ class TFAutoModelTest(unittest.TestCase):
def test_model_from_pt_suggestion(self): def test_model_from_pt_suggestion(self):
with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"): with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"):
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only") _ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
def test_cached_model_has_minimum_calls_to_head(self):
# Make sure we have cached the model.
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
# With a sharded checkpoint
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
with RequestCounter() as counter:
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0)
# There is no pytorch_model.bin so we still get one call for this one.
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0)
...@@ -48,6 +48,7 @@ from transformers.testing_utils import ( ...@@ -48,6 +48,7 @@ from transformers.testing_utils import (
DUMMY_DIFF_TOKENIZER_IDENTIFIER, DUMMY_DIFF_TOKENIZER_IDENTIFIER,
DUMMY_UNKNOWN_IDENTIFIER, DUMMY_UNKNOWN_IDENTIFIER,
SMALL_MODEL_IDENTIFIER, SMALL_MODEL_IDENTIFIER,
RequestCounter,
require_tokenizers, require_tokenizers,
slow, slow,
) )
...@@ -213,6 +214,7 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -213,6 +214,7 @@ class AutoTokenizerTest(unittest.TestCase):
def test_get_tokenizer_config(self): def test_get_tokenizer_config(self):
# Check we can load the tokenizer config of an online model. # Check we can load the tokenizer config of an online model.
config = get_tokenizer_config("bert-base-cased") config = get_tokenizer_config("bert-base-cased")
_ = config.pop("_commit_hash", None)
# If we ever update bert-base-cased tokenizer config, this dict here will need to be updated. # If we ever update bert-base-cased tokenizer config, this dict here will need to be updated.
self.assertEqual(config, {"do_lower_case": False}) self.assertEqual(config, {"do_lower_case": False})
...@@ -340,3 +342,13 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -340,3 +342,13 @@ class AutoTokenizerTest(unittest.TestCase):
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)" EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
): ):
_ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa") _ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
def test_cached_tokenizer_has_minimum_calls_to_head(self):
# Make sure we have cached the tokenizer.
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
# We still have one extra call because the model does not have a added_tokens.json file
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0)
...@@ -49,6 +49,7 @@ from transformers.testing_utils import ( ...@@ -49,6 +49,7 @@ from transformers.testing_utils import (
TOKEN, TOKEN,
USER, USER,
CaptureLogger, CaptureLogger,
RequestCounter,
is_pipeline_test, is_pipeline_test,
is_staging_test, is_staging_test,
nested_simplify, nested_simplify,
...@@ -877,6 +878,16 @@ class CustomPipelineTest(unittest.TestCase): ...@@ -877,6 +878,16 @@ class CustomPipelineTest(unittest.TestCase):
[{"label": "LABEL_0", "score": 0.505}], [{"label": "LABEL_0", "score": 0.505}],
) )
def test_cached_pipeline_has_minimum_calls_to_head(self):
# Make sure we have cached the pipeline.
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
# We still have one extra call because the model does not have a added_tokens.json file
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0)
@require_torch @require_torch
@is_staging_test @is_staging_test
......
...@@ -246,7 +246,7 @@ class ConfigPushToHubTester(unittest.TestCase): ...@@ -246,7 +246,7 @@ class ConfigPushToHubTester(unittest.TestCase):
config.push_to_hub("test-config", use_auth_token=self._token) config.push_to_hub("test-config", use_auth_token=self._token)
new_config = BertConfig.from_pretrained(f"{USER}/test-config") new_config = BertConfig.from_pretrained(f"{USER}/test-config")
for k, v in config.__dict__.items(): for k, v in config.to_dict().items():
if k != "transformers_version": if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k)) self.assertEqual(v, getattr(new_config, k))
...@@ -258,7 +258,7 @@ class ConfigPushToHubTester(unittest.TestCase): ...@@ -258,7 +258,7 @@ class ConfigPushToHubTester(unittest.TestCase):
config.save_pretrained(tmp_dir, repo_id="test-config", push_to_hub=True, use_auth_token=self._token) config.save_pretrained(tmp_dir, repo_id="test-config", push_to_hub=True, use_auth_token=self._token)
new_config = BertConfig.from_pretrained(f"{USER}/test-config") new_config = BertConfig.from_pretrained(f"{USER}/test-config")
for k, v in config.__dict__.items(): for k, v in config.to_dict().items():
if k != "transformers_version": if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k)) self.assertEqual(v, getattr(new_config, k))
...@@ -269,7 +269,7 @@ class ConfigPushToHubTester(unittest.TestCase): ...@@ -269,7 +269,7 @@ class ConfigPushToHubTester(unittest.TestCase):
config.push_to_hub("valid_org/test-config-org", use_auth_token=self._token) config.push_to_hub("valid_org/test-config-org", use_auth_token=self._token)
new_config = BertConfig.from_pretrained("valid_org/test-config-org") new_config = BertConfig.from_pretrained("valid_org/test-config-org")
for k, v in config.__dict__.items(): for k, v in config.to_dict().items():
if k != "transformers_version": if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k)) self.assertEqual(v, getattr(new_config, k))
...@@ -283,7 +283,7 @@ class ConfigPushToHubTester(unittest.TestCase): ...@@ -283,7 +283,7 @@ class ConfigPushToHubTester(unittest.TestCase):
) )
new_config = BertConfig.from_pretrained("valid_org/test-config-org") new_config = BertConfig.from_pretrained("valid_org/test-config-org")
for k, v in config.__dict__.items(): for k, v in config.to_dict().items():
if k != "transformers_version": if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k)) self.assertEqual(v, getattr(new_config, k))
...@@ -323,7 +323,9 @@ class ConfigTestUtils(unittest.TestCase): ...@@ -323,7 +323,9 @@ class ConfigTestUtils(unittest.TestCase):
base_config = PretrainedConfig() base_config = PretrainedConfig()
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs] missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
# If this part of the test fails, you have arguments to addin config_common_kwargs above. # If this part of the test fails, you have arguments to addin config_common_kwargs above.
self.assertListEqual(missing_keys, ["is_encoder_decoder", "_name_or_path", "transformers_version"]) self.assertListEqual(
missing_keys, ["is_encoder_decoder", "_name_or_path", "_commit_hash", "transformers_version"]
)
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)] keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
if len(keys_with_defaults) > 0: if len(keys_with_defaults) > 0:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment