Unverified Commit 0d0aada5 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Use commit hash to look in cache instead of calling head (#18534)



* Use commit hash to look in cache instead of calling head

* Add tests

* Add attr for local configs too

* Stupid typos

* Fix tests

* Update src/transformers/utils/hub.py
Co-authored-by: default avatarJulien Chaumond <julien@huggingface.co>

* Address Julien's comments
Co-authored-by: default avatarJulien Chaumond <julien@huggingface.co>
parent 6eb51450
......@@ -27,7 +27,15 @@ from packaging import version
from . import __version__
from .dynamic_module_utils import custom_object_save
from .utils import CONFIG_NAME, PushToHubMixin, cached_file, copy_func, is_torch_available, logging
from .utils import (
CONFIG_NAME,
PushToHubMixin,
cached_file,
copy_func,
extract_commit_hash,
is_torch_available,
logging,
)
logger = logging.get_logger(__name__)
......@@ -343,6 +351,8 @@ class PretrainedConfig(PushToHubMixin):
# Name or path to the pretrained checkpoint
self._name_or_path = str(kwargs.pop("name_or_path", ""))
# Config hash
self._commit_hash = kwargs.pop("_commit_hash", None)
# Drop the transformers version info
self.transformers_version = kwargs.pop("transformers_version", None)
......@@ -539,6 +549,8 @@ class PretrainedConfig(PushToHubMixin):
original_kwargs = copy.deepcopy(kwargs)
# Get config dict associated with the base config file
config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
if "_commit_hash" in config_dict:
original_kwargs["_commit_hash"] = config_dict["_commit_hash"]
# That config file may point us toward another config file to use.
if "configuration_files" in config_dict:
......@@ -564,6 +576,7 @@ class PretrainedConfig(PushToHubMixin):
subfolder = kwargs.pop("subfolder", "")
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
commit_hash = kwargs.pop("_commit_hash", None)
if trust_remote_code is True:
logger.warning(
......@@ -599,7 +612,9 @@ class PretrainedConfig(PushToHubMixin):
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
)
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
......@@ -616,6 +631,7 @@ class PretrainedConfig(PushToHubMixin):
try:
# Load config dict
config_dict = cls._dict_from_json_file(resolved_config_file)
config_dict["_commit_hash"] = commit_hash
except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError(
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
......@@ -648,6 +664,9 @@ class PretrainedConfig(PushToHubMixin):
# We remove them so they don't appear in `return_unused_kwargs`.
kwargs.pop("_from_auto", None)
kwargs.pop("_from_pipeline", None)
# The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
kwargs["_commit_hash"] = config_dict["_commit_hash"]
config = cls(**config_dict)
......@@ -751,6 +770,8 @@ class PretrainedConfig(PushToHubMixin):
output["model_type"] = self.__class__.model_type
if "_auto_class" in output:
del output["_auto_class"]
if "_commit_hash" in output:
del output["_commit_hash"]
# Transformers version when serializing the model
output["transformers_version"] = __version__
......
......@@ -595,6 +595,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
from_auto_class = kwargs.pop("_from_auto", False)
_do_init = kwargs.pop("_do_init", True)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
if trust_remote_code is True:
logger.warning(
......@@ -625,11 +626,15 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
revision=revision,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
_commit_hash=commit_hash,
**kwargs,
)
else:
model_kwargs = kwargs
if commit_hash is None:
commit_hash = getattr(config, "_commit_hash", None)
# Add the dtype to model_kwargs
model_kwargs["dtype"] = dtype
......@@ -682,6 +687,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
)
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
......@@ -748,6 +754,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
_commit_hash=commit_hash,
)
# init random models
......
......@@ -2161,6 +2161,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
if trust_remote_code is True:
logger.warning(
......@@ -2191,11 +2192,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
revision=revision,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
_commit_hash=commit_hash,
**kwargs,
)
else:
model_kwargs = kwargs
if commit_hash is None:
commit_hash = getattr(config, "_commit_hash", None)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
is_sharded = False
......@@ -2253,6 +2258,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
)
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
......@@ -2320,6 +2326,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
_commit_hash=commit_hash,
)
config.name_or_path = pretrained_model_name_or_path
......
......@@ -1840,6 +1840,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
load_in_8bit = kwargs.pop("load_in_8bit", False)
int8_threshold = kwargs.pop("int8_threshold", 6.0)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
if trust_remote_code is True:
logger.warning(
......@@ -1918,6 +1919,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else:
model_kwargs = kwargs
if commit_hash is None:
commit_hash = getattr(config, "_commit_hash", None)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
is_sharded = False
......@@ -2004,6 +2008,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
)
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
......@@ -2078,6 +2083,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
)
# load pt weights early so that we know which dtype to init the model under
......
......@@ -25,7 +25,7 @@ from ...dynamic_module_utils import get_class_from_dynamic_module
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import get_file_from_repo, is_sentencepiece_available, is_tokenizers_available, logging
from ...utils import cached_file, extract_commit_hash, is_sentencepiece_available, is_tokenizers_available, logging
from ..encoder_decoder import EncoderDecoderConfig
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
......@@ -389,7 +389,8 @@ def get_tokenizer_config(
tokenizer.save_pretrained("tokenizer-test")
tokenizer_config = get_tokenizer_config("tokenizer-test")
```"""
resolved_config_file = get_file_from_repo(
commit_hash = kwargs.get("_commit_hash", None)
resolved_config_file = cached_file(
pretrained_model_name_or_path,
TOKENIZER_CONFIG_FILE,
cache_dir=cache_dir,
......@@ -399,13 +400,19 @@ def get_tokenizer_config(
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
)
if resolved_config_file is None:
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
return {}
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
with open(resolved_config_file, encoding="utf-8") as reader:
return json.load(reader)
result = json.load(reader)
result["_commit_hash"] = commit_hash
return result
class AutoTokenizer:
......@@ -532,6 +539,8 @@ class AutoTokenizer:
# Next, let's try to use the tokenizer_config file to get the tokenizer class.
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
if "_commit_hash" in tokenizer_config:
kwargs["_commit_hash"] = tokenizer_config["_commit_hash"]
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
tokenizer_auto_map = None
if "auto_map" in tokenizer_config:
......
......@@ -557,7 +557,12 @@ def pipeline(
# Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
# this is to keep BC).
use_auth_token = model_kwargs.pop("use_auth_token", use_auth_token)
hub_kwargs = {"revision": revision, "use_auth_token": use_auth_token, "trust_remote_code": trust_remote_code}
hub_kwargs = {
"revision": revision,
"use_auth_token": use_auth_token,
"trust_remote_code": trust_remote_code,
"_commit_hash": None,
}
if task is None and model is None:
raise RuntimeError(
......@@ -583,8 +588,10 @@ def pipeline(
# Instantiate config if needed
if isinstance(config, str):
config = AutoConfig.from_pretrained(config, _from_pipeline=task, **hub_kwargs, **model_kwargs)
hub_kwargs["_commit_hash"] = config._commit_hash
elif config is None and isinstance(model, str):
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
hub_kwargs["_commit_hash"] = config._commit_hash
custom_tasks = {}
if config is not None and len(getattr(config, "custom_pipelines", {})) > 0:
......@@ -639,6 +646,7 @@ def pipeline(
)
if config is None and isinstance(model, str):
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
hub_kwargs["_commit_hash"] = config._commit_hash
if device_map is not None:
if "device_map" in model_kwargs:
......@@ -672,6 +680,7 @@ def pipeline(
)
model_config = model.config
hub_kwargs["_commit_hash"] = model.config._commit_hash
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
......
......@@ -31,6 +31,7 @@ from pathlib import Path
from typing import Iterator, List, Union
from unittest import mock
import huggingface_hub
from transformers import logging as transformers_logging
from .deepspeed import is_deepspeed_available
......@@ -1588,3 +1589,30 @@ def run_command(command: List[str], return_stdout=False):
raise SubprocessCallException(
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
) from e
class RequestCounter:
"""
Helper class that will count all requests made online.
"""
def __enter__(self):
self.head_request_count = 0
self.get_request_count = 0
self.other_request_count = 0
self.old_request = huggingface_hub.file_download.requests.request
huggingface_hub.file_download.requests.request = self.new_request
return self
def __exit__(self, *args, **kwargs):
huggingface_hub.file_download.requests.request = self.old_request
def new_request(self, method, **kwargs):
if method == "GET":
self.get_request_count += 1
elif method == "HEAD":
self.head_request_count += 1
else:
self.other_request_count += 1
return self.old_request(method=method, **kwargs)
......@@ -42,7 +42,7 @@ from .utils import (
add_end_docstrings,
cached_file,
copy_func,
get_file_from_repo,
extract_commit_hash,
is_flax_available,
is_offline_mode,
is_tf_available,
......@@ -1651,6 +1651,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
subfolder = kwargs.pop("subfolder", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
commit_hash = kwargs.pop("_commit_hash", None)
user_agent = {"file_type": "tokenizer", "from_auto_class": from_auto_class, "is_fast": "Fast" in cls.__name__}
if from_pipeline is not None:
......@@ -1690,7 +1691,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if "tokenizer_file" in vocab_files:
# Try to get the tokenizer config to see if there are versioned tokenizer files.
fast_tokenizer_file = FULL_TOKENIZER_FILE
resolved_config_file = get_file_from_repo(
resolved_config_file = cached_file(
pretrained_model_name_or_path,
TOKENIZER_CONFIG_FILE,
cache_dir=cache_dir,
......@@ -1701,7 +1702,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
user_agent=user_agent,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
)
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
if resolved_config_file is not None:
with open(resolved_config_file, encoding="utf-8") as reader:
tokenizer_config = json.load(reader)
......@@ -1730,7 +1736,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
)
commit_hash = extract_commit_hash(resolved_vocab_files[file_id], commit_hash)
if len(unresolved_files) > 0:
logger.info(
......@@ -1763,6 +1771,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
use_auth_token=use_auth_token,
cache_dir=cache_dir,
local_files_only=local_files_only,
_commit_hash=commit_hash,
**kwargs,
)
......@@ -1776,6 +1785,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
use_auth_token=None,
cache_dir=None,
local_files_only=False,
_commit_hash=None,
**kwargs
):
# We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json
......@@ -1791,6 +1801,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
use_auth_token=use_auth_token,
cache_dir=cache_dir,
local_files_only=local_files_only,
_commit_hash=_commit_hash,
**(copy.deepcopy(kwargs)),
)
else:
......@@ -1823,6 +1834,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
use_auth_token=use_auth_token,
cache_dir=cache_dir,
local_files_only=local_files_only,
_commit_hash=_commit_hash,
)
config_tokenizer_class = config.tokenizer_class
except (OSError, ValueError, KeyError):
......
......@@ -63,6 +63,7 @@ from .hub import (
cached_file,
default_cache_path,
define_sagemaker_information,
extract_commit_hash,
get_cached_models,
get_file_from_repo,
get_full_repo_name,
......
......@@ -38,6 +38,7 @@ from huggingface_hub import (
whoami,
)
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests.exceptions import HTTPError
from transformers.utils.logging import tqdm
......@@ -200,11 +201,27 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return ua
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None):
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]):
"""
Extracts the commit hash from a resolved filename toward a cache file.
"""
if resolved_file is None or commit_hash is not None:
return commit_hash
search = re.search(r"snapshots/([^/]+)/", resolved_file)
if search is None:
return None
commit_hash = search.groups()[0]
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_hash=None):
"""
Explores the cache to return the latest cached file for a given revision.
"""
if revision is None:
if commit_hash is not None and revision is not None:
raise ValueError("`commit_hash` and `revision` are mutually exclusive, pick one only.")
if revision is None and commit_hash is None:
revision = "main"
model_id = repo_id.replace("/", "--")
......@@ -216,18 +233,19 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None):
if not os.path.isdir(os.path.join(model_cache, subfolder)):
return None
if commit_hash is None:
# Resolve refs (for instance to convert main to the associated commit sha)
cached_refs = os.listdir(os.path.join(model_cache, "refs"))
if revision in cached_refs:
with open(os.path.join(model_cache, "refs", revision)) as f:
revision = f.read()
commit_hash = f.read()
cached_shas = os.listdir(os.path.join(model_cache, "snapshots"))
if revision not in cached_shas:
if commit_hash not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None
cached_file = os.path.join(model_cache, "snapshots", revision, filename)
cached_file = os.path.join(model_cache, "snapshots", commit_hash, filename)
return cached_file if os.path.isfile(cached_file) else None
......@@ -265,8 +283,9 @@ def cached_file(
local_files_only: bool = False,
subfolder: str = "",
user_agent: Optional[Union[str, Dict[str, str]]] = None,
_raise_exceptions_for_missing_entries=True,
_raise_exceptions_for_connection_errors=True,
_raise_exceptions_for_missing_entries: bool = True,
_raise_exceptions_for_connection_errors: bool = True,
_commit_hash: Optional[str] = None,
):
"""
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
......@@ -318,6 +337,13 @@ def cached_file(
# Download a model weight from the Hub and cache it.
model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin")
```"""
# Private arguments
# _raise_exceptions_for_missing_entries: if False, do not raise an exception for missing entries but return
# None.
# _raise_exceptions_for_connection_errors: if False, do not raise an exception for connection errors but return
# None.
# _commit_hash: passed when we are chaining several calls to various files (e.g. when loading a tokenizer or
# a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache.
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
......@@ -339,6 +365,13 @@ def cached_file(
cache_dir = TRANSFORMERS_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if _commit_hash is not None:
# If the file is cached under that commit hash, we return it directly.
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, commit_hash=_commit_hash)
if resolved_file is not None:
return resolved_file
user_agent = http_user_agent(user_agent)
try:
# Load from URL or cache if already cached
......@@ -803,6 +836,7 @@ def get_checkpoint_shard_files(
user_agent=None,
revision=None,
subfolder="",
_commit_hash=None,
):
"""
For a given model:
......@@ -848,6 +882,7 @@ def get_checkpoint_shard_files(
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=_commit_hash,
)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here.
......
......@@ -24,6 +24,7 @@ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.testing_utils import (
DUMMY_UNKNOWN_IDENTIFIER,
SMALL_MODEL_IDENTIFIER,
RequestCounter,
require_scatter,
require_torch,
slow,
......@@ -354,3 +355,21 @@ class AutoModelTest(unittest.TestCase):
def test_model_from_flax_suggestion(self):
with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"):
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
def test_cached_model_has_minimum_calls_to_head(self):
# Make sure we have cached the model.
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
# With a sharded checkpoint
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
with RequestCounter() as counter:
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0)
# There is no pytorch_model.bin so we still get one call for this one.
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0)
......@@ -21,6 +21,7 @@ from transformers import CONFIG_MAPPING, AutoConfig, BertConfig, GPT2Config, T5C
from transformers.testing_utils import (
DUMMY_UNKNOWN_IDENTIFIER,
SMALL_MODEL_IDENTIFIER,
RequestCounter,
require_tensorflow_probability,
require_tf,
slow,
......@@ -287,3 +288,21 @@ class TFAutoModelTest(unittest.TestCase):
def test_model_from_pt_suggestion(self):
with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"):
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
def test_cached_model_has_minimum_calls_to_head(self):
# Make sure we have cached the model.
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
# With a sharded checkpoint
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
with RequestCounter() as counter:
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0)
# There is no pytorch_model.bin so we still get one call for this one.
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0)
......@@ -48,6 +48,7 @@ from transformers.testing_utils import (
DUMMY_DIFF_TOKENIZER_IDENTIFIER,
DUMMY_UNKNOWN_IDENTIFIER,
SMALL_MODEL_IDENTIFIER,
RequestCounter,
require_tokenizers,
slow,
)
......@@ -213,6 +214,7 @@ class AutoTokenizerTest(unittest.TestCase):
def test_get_tokenizer_config(self):
# Check we can load the tokenizer config of an online model.
config = get_tokenizer_config("bert-base-cased")
_ = config.pop("_commit_hash", None)
# If we ever update bert-base-cased tokenizer config, this dict here will need to be updated.
self.assertEqual(config, {"do_lower_case": False})
......@@ -340,3 +342,13 @@ class AutoTokenizerTest(unittest.TestCase):
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
):
_ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
def test_cached_tokenizer_has_minimum_calls_to_head(self):
# Make sure we have cached the tokenizer.
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
# We still have one extra call because the model does not have a added_tokens.json file
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0)
......@@ -49,6 +49,7 @@ from transformers.testing_utils import (
TOKEN,
USER,
CaptureLogger,
RequestCounter,
is_pipeline_test,
is_staging_test,
nested_simplify,
......@@ -877,6 +878,16 @@ class CustomPipelineTest(unittest.TestCase):
[{"label": "LABEL_0", "score": 0.505}],
)
def test_cached_pipeline_has_minimum_calls_to_head(self):
# Make sure we have cached the pipeline.
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
# We still have one extra call because the model does not have a added_tokens.json file
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0)
@require_torch
@is_staging_test
......
......@@ -246,7 +246,7 @@ class ConfigPushToHubTester(unittest.TestCase):
config.push_to_hub("test-config", use_auth_token=self._token)
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
for k, v in config.__dict__.items():
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
......@@ -258,7 +258,7 @@ class ConfigPushToHubTester(unittest.TestCase):
config.save_pretrained(tmp_dir, repo_id="test-config", push_to_hub=True, use_auth_token=self._token)
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
for k, v in config.__dict__.items():
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
......@@ -269,7 +269,7 @@ class ConfigPushToHubTester(unittest.TestCase):
config.push_to_hub("valid_org/test-config-org", use_auth_token=self._token)
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
for k, v in config.__dict__.items():
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
......@@ -283,7 +283,7 @@ class ConfigPushToHubTester(unittest.TestCase):
)
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
for k, v in config.__dict__.items():
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
......@@ -323,7 +323,9 @@ class ConfigTestUtils(unittest.TestCase):
base_config = PretrainedConfig()
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
# If this part of the test fails, you have arguments to addin config_common_kwargs above.
self.assertListEqual(missing_keys, ["is_encoder_decoder", "_name_or_path", "transformers_version"])
self.assertListEqual(
missing_keys, ["is_encoder_decoder", "_name_or_path", "_commit_hash", "transformers_version"]
)
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
if len(keys_with_defaults) > 0:
raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment