Unverified Commit e6954707 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Avoid using get_list_of_files (#15287)

* Avoid using get_list_of_files in config

* Wip, change tokenizer file getter

* Remove call in tokenizer files

* Remove last call to get_list_model_files

* Better tests

* Unit tests for new function

* Document bad API
parent e65bfc09
...@@ -21,7 +21,7 @@ import json ...@@ -21,7 +21,7 @@ import json
import os import os
import re import re
import warnings import warnings
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, List, Tuple, Union
from packaging import version from packaging import version
...@@ -36,7 +36,6 @@ from .file_utils import ( ...@@ -36,7 +36,6 @@ from .file_utils import (
RevisionNotFoundError, RevisionNotFoundError,
cached_path, cached_path,
copy_func, copy_func,
get_list_of_files,
hf_bucket_url, hf_bucket_url,
is_offline_mode, is_offline_mode,
is_remote_url, is_remote_url,
...@@ -46,7 +45,7 @@ from .utils import logging ...@@ -46,7 +45,7 @@ from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
FULL_CONFIGURATION_FILE = "config.json"
_re_configuration_file = re.compile(r"config\.(.*)\.json") _re_configuration_file = re.compile(r"config\.(.*)\.json")
...@@ -533,6 +532,23 @@ class PretrainedConfig(PushToHubMixin): ...@@ -533,6 +532,23 @@ class PretrainedConfig(PushToHubMixin):
`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object. `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
""" """
original_kwargs = copy.deepcopy(kwargs)
# Get config dict associated with the base config file
config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
# That config file may point us toward another config file to use.
if "configuration_files" in config_dict:
configuration_file = get_configuration_file(config_dict["configuration_files"])
config_dict, kwargs = cls._get_config_dict(
pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs
)
return config_dict, kwargs
@classmethod
def _get_config_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
...@@ -555,12 +571,7 @@ class PretrainedConfig(PushToHubMixin): ...@@ -555,12 +571,7 @@ class PretrainedConfig(PushToHubMixin):
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path config_file = pretrained_model_name_or_path
else: else:
configuration_file = get_configuration_file( configuration_file = kwargs.get("_configuration_file", CONFIG_NAME)
pretrained_model_name_or_path,
revision=revision,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, configuration_file) config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
...@@ -840,41 +851,18 @@ class PretrainedConfig(PushToHubMixin): ...@@ -840,41 +851,18 @@ class PretrainedConfig(PushToHubMixin):
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1] d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
def get_configuration_file( def get_configuration_file(configuration_files: List[str]) -> str:
path_or_repo: Union[str, os.PathLike],
revision: Optional[str] = None,
use_auth_token: Optional[Union[bool, str]] = None,
local_files_only: bool = False,
) -> str:
""" """
Get the configuration file to use for this version of transformers. Get the configuration file to use for this version of transformers.
Args: Args:
path_or_repo (`str` or `os.PathLike`): configuration_files (`List[str]`): The list of available configuration files.
Can be either the id of a repo on huggingface.co or a path to a *directory*.
revision(`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only rely on local files and not to attempt to download any files.
Returns: Returns:
`str`: The configuration file to use. `str`: The configuration file to use.
""" """
# Inspect all files from the repo/folder.
try:
all_files = get_list_of_files(
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
)
except Exception:
return FULL_CONFIGURATION_FILE
configuration_files_map = {} configuration_files_map = {}
for file_name in all_files: for file_name in configuration_files:
search = _re_configuration_file.search(file_name) search = _re_configuration_file.search(file_name)
if search is not None: if search is not None:
v = search.groups()[0] v = search.groups()[0]
...@@ -882,7 +870,7 @@ def get_configuration_file( ...@@ -882,7 +870,7 @@ def get_configuration_file(
available_versions = sorted(configuration_files_map.keys()) available_versions = sorted(configuration_files_map.keys())
# Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions. # Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
configuration_file = FULL_CONFIGURATION_FILE configuration_file = CONFIG_NAME
transformers_version = version.parse(__version__) transformers_version = version.parse(__version__)
for v in available_versions: for v in available_versions:
if version.parse(v) <= transformers_version: if version.parse(v) <= transformers_version:
......
...@@ -2112,6 +2112,112 @@ def get_from_cache( ...@@ -2112,6 +2112,112 @@ def get_from_cache(
return cache_path return cache_path
def get_file_from_repo(
path_or_repo: Union[str, os.PathLike],
filename: str,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
):
"""
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
Args:
path_or_repo (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a model repo on huggingface.co.
- a path to a *directory* potentially containing the file.
filename (`str`):
The name of the file to locate in `path_or_repo`.
cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision(`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
<Tip>
Passing `use_auth_token=True` is required when you want to use a private model.
</Tip>
Returns:
`Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo) or `None` if the
file does not exist.
Examples:
```python
# Download a tokenizer configuration from huggingface.co and cache.
tokenizer_config = get_file_from_repo("bert-base-uncased", "tokenizer_config.json")
# This model does not have a tokenizer config so the result will be None.
tokenizer_config = get_file_from_repo("xlm-roberta-base", "tokenizer_config.json")
```"""
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
path_or_repo = str(path_or_repo)
if os.path.isdir(path_or_repo):
resolved_file = os.path.join(path_or_repo, filename)
return resolved_file if os.path.isfile(resolved_file) else None
else:
resolved_file = hf_bucket_url(path_or_repo, filename=filename, revision=revision, mirror=None)
try:
# Load from URL or cache if already cached
resolved_file = cached_path(
resolved_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
)
except RepositoryNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{path_or_repo} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
"pass a token having permission to this repo with `use_auth_token` or log in with "
"`huggingface-cli login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
"for this model name. Check the model page at "
f"'https://huggingface.co/{path_or_repo}' for available revisions."
)
except EnvironmentError:
# The repo and revision exist, but the file does not or there was a connection error fetching it.
return None
return resolved_file
def has_file( def has_file(
path_or_repo: Union[str, os.PathLike], path_or_repo: Union[str, os.PathLike],
filename: str, filename: str,
...@@ -2184,6 +2290,12 @@ def get_list_of_files( ...@@ -2184,6 +2290,12 @@ def get_list_of_files(
local_files_only (`bool`, *optional*, defaults to `False`): local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only rely on local files and not to attempt to download any files. Whether or not to only rely on local files and not to attempt to download any files.
<Tip warning={true}>
This API is not optimized, so calling it a lot may result in connection errors.
</Tip>
Returns: Returns:
`List[str]`: The list of files available in `path_or_repo`. `List[str]`: The list of files available in `path_or_repo`.
""" """
......
...@@ -14,12 +14,14 @@ ...@@ -14,12 +14,14 @@
# limitations under the License. # limitations under the License.
""" AutoProcessor class.""" """ AutoProcessor class."""
import importlib import importlib
import inspect
import json
from collections import OrderedDict from collections import OrderedDict
# Build the list of all feature extractors # Build the list of all feature extractors
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...feature_extraction_utils import FeatureExtractionMixin from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_list_of_files from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo
from ...tokenization_utils import TOKENIZER_CONFIG_FILE from ...tokenization_utils import TOKENIZER_CONFIG_FILE
from .auto_factory import _LazyAutoMapping from .auto_factory import _LazyAutoMapping
from .configuration_auto import ( from .configuration_auto import (
...@@ -29,7 +31,6 @@ from .configuration_auto import ( ...@@ -29,7 +31,6 @@ from .configuration_auto import (
model_type_to_module_name, model_type_to_module_name,
replace_list_option_in_docstrings, replace_list_option_in_docstrings,
) )
from .tokenization_auto import get_tokenizer_config
PROCESSOR_MAPPING_NAMES = OrderedDict( PROCESSOR_MAPPING_NAMES = OrderedDict(
...@@ -145,24 +146,29 @@ class AutoProcessor: ...@@ -145,24 +146,29 @@ class AutoProcessor:
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
# First, let's see if we have a preprocessor config. # First, let's see if we have a preprocessor config.
# get_list_of_files only takes three of the kwargs we have, so we filter them. # Filter the kwargs for `get_file_from_repo``.
get_list_of_files_kwargs = { get_file_from_repo_kwargs = {
key: kwargs[key] for key in ["revision", "use_auth_token", "local_files_only"] if key in kwargs key: kwargs[key] for key in inspect.signature(get_file_from_repo).parameters.keys() if key in kwargs
} }
model_files = get_list_of_files(pretrained_model_name_or_path, **get_list_of_files_kwargs)
# strip to file name
model_files = [f.split("/")[-1] for f in model_files]
# Let's start by checking whether the processor class is saved in a feature extractor # Let's start by checking whether the processor class is saved in a feature extractor
if FEATURE_EXTRACTOR_NAME in model_files: preprocessor_config_file = get_file_from_repo(
pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **get_file_from_repo_kwargs
)
if preprocessor_config_file is not None:
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
if "processor_class" in config_dict: if "processor_class" in config_dict:
processor_class = processor_class_from_name(config_dict["processor_class"]) processor_class = processor_class_from_name(config_dict["processor_class"])
return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
# Next, let's check whether the processor class is saved in a tokenizer # Next, let's check whether the processor class is saved in a tokenizer
if TOKENIZER_CONFIG_FILE in model_files: # Let's start by checking whether the processor class is saved in a feature extractor
config_dict = get_tokenizer_config(pretrained_model_name_or_path, **kwargs) tokenizer_config_file = get_file_from_repo(
pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **get_file_from_repo_kwargs
)
if tokenizer_config_file is not None:
with open(tokenizer_config_file, encoding="utf-8") as reader:
config_dict = json.load(reader)
if "processor_class" in config_dict: if "processor_class" in config_dict:
processor_class = processor_class_from_name(config_dict["processor_class"]) processor_class = processor_class_from_name(config_dict["processor_class"])
return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
......
...@@ -21,15 +21,7 @@ from collections import OrderedDict ...@@ -21,15 +21,7 @@ from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import ( from ...file_utils import get_file_from_repo, is_sentencepiece_available, is_tokenizers_available
RepositoryNotFoundError,
RevisionNotFoundError,
cached_path,
hf_bucket_url,
is_offline_mode,
is_sentencepiece_available,
is_tokenizers_available,
)
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
...@@ -329,46 +321,18 @@ def get_tokenizer_config( ...@@ -329,46 +321,18 @@ def get_tokenizer_config(
tokenizer.save_pretrained("tokenizer-test") tokenizer.save_pretrained("tokenizer-test")
tokenizer_config = get_tokenizer_config("tokenizer-test") tokenizer_config = get_tokenizer_config("tokenizer-test")
```""" ```"""
if is_offline_mode() and not local_files_only: resolved_config_file = get_file_from_repo(
logger.info("Offline mode: forcing local_files_only=True") pretrained_model_name_or_path,
local_files_only = True TOKENIZER_CONFIG_FILE,
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE)
else:
config_file = hf_bucket_url(
pretrained_model_name_or_path, filename=TOKENIZER_CONFIG_FILE, revision=revision, mirror=None
)
try:
# Load from URL or cache if already cached
resolved_config_file = cached_path(
config_file,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, proxies=proxies,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
) )
if resolved_config_file is None:
except RepositoryNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
"pass a token having permission to this repo with `use_auth_token` or log in with "
"`huggingface-cli login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError as err:
logger.error(err)
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
"for this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EnvironmentError:
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.") logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
return {} return {}
......
...@@ -50,7 +50,7 @@ from .file_utils import ( ...@@ -50,7 +50,7 @@ from .file_utils import (
add_end_docstrings, add_end_docstrings,
cached_path, cached_path,
copy_func, copy_func,
get_list_of_files, get_file_from_repo,
hf_bucket_url, hf_bucket_url,
is_flax_available, is_flax_available,
is_offline_mode, is_offline_mode,
...@@ -1649,12 +1649,26 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1649,12 +1649,26 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
vocab_files[file_id] = pretrained_model_name_or_path vocab_files[file_id] = pretrained_model_name_or_path
else: else:
# At this point pretrained_model_name_or_path is either a directory or a model identifier name # At this point pretrained_model_name_or_path is either a directory or a model identifier name
fast_tokenizer_file = get_fast_tokenizer_file(
# Try to get the tokenizer config to see if there are versioned tokenizer files.
fast_tokenizer_file = FULL_TOKENIZER_FILE
resolved_config_file = get_file_from_repo(
pretrained_model_name_or_path, pretrained_model_name_or_path,
revision=revision, TOKENIZER_CONFIG_FILE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only, local_files_only=local_files_only,
) )
if resolved_config_file is not None:
with open(resolved_config_file, encoding="utf-8") as reader:
tokenizer_config = json.load(reader)
if "fast_tokenizer_files" in tokenizer_config:
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
additional_files_names = { additional_files_names = {
"added_tokens_file": ADDED_TOKENS_FILE, "added_tokens_file": ADDED_TOKENS_FILE,
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
...@@ -3495,41 +3509,18 @@ For a more complete example, see the implementation of `prepare_seq2seq_batch`. ...@@ -3495,41 +3509,18 @@ For a more complete example, see the implementation of `prepare_seq2seq_batch`.
return model_inputs return model_inputs
def get_fast_tokenizer_file( def get_fast_tokenizer_file(tokenization_files: List[str]) -> str:
path_or_repo: Union[str, os.PathLike],
revision: Optional[str] = None,
use_auth_token: Optional[Union[bool, str]] = None,
local_files_only: bool = False,
) -> str:
""" """
Get the tokenizer file to use for this version of transformers. Get the tokenization file to use for this version of transformers.
Args: Args:
path_or_repo (`str` or `os.PathLike`): tokenization_files (`List[str]`): The list of available configuration files.
Can be either the id of a repo on huggingface.co or a path to a *directory*.
revision(`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only rely on local files and not to attempt to download any files.
Returns: Returns:
`str`: The tokenizer file to use. `str`: The tokenization file to use.
""" """
# Inspect all files from the repo/folder.
try:
all_files = get_list_of_files(
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
)
except Exception:
return FULL_TOKENIZER_FILE
tokenizer_files_map = {} tokenizer_files_map = {}
for file_name in all_files: for file_name in tokenization_files:
search = _re_tokenizer_file.search(file_name) search = _re_tokenizer_file.search(file_name)
if search is not None: if search is not None:
v = search.groups()[0] v = search.groups()[0]
......
...@@ -313,6 +313,7 @@ class ConfigTestUtils(unittest.TestCase): ...@@ -313,6 +313,7 @@ class ConfigTestUtils(unittest.TestCase):
class ConfigurationVersioningTest(unittest.TestCase): class ConfigurationVersioningTest(unittest.TestCase):
def test_local_versioning(self): def test_local_versioning(self):
configuration = AutoConfig.from_pretrained("bert-base-cased") configuration = AutoConfig.from_pretrained("bert-base-cased")
configuration.configuration_files = ["config.4.0.0.json"]
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
configuration.save_pretrained(tmp_dir) configuration.save_pretrained(tmp_dir)
...@@ -325,23 +326,26 @@ class ConfigurationVersioningTest(unittest.TestCase): ...@@ -325,23 +326,26 @@ class ConfigurationVersioningTest(unittest.TestCase):
# Will need to be adjusted if we reach v42 and this test is still here. # Will need to be adjusted if we reach v42 and this test is still here.
# Should pick the old configuration file as the version of Transformers is < 4.42.0 # Should pick the old configuration file as the version of Transformers is < 4.42.0
configuration.configuration_files = ["config.42.0.0.json"]
configuration.hidden_size = 768
configuration.save_pretrained(tmp_dir)
shutil.move(os.path.join(tmp_dir, "config.4.0.0.json"), os.path.join(tmp_dir, "config.42.0.0.json")) shutil.move(os.path.join(tmp_dir, "config.4.0.0.json"), os.path.join(tmp_dir, "config.42.0.0.json"))
new_configuration = AutoConfig.from_pretrained(tmp_dir) new_configuration = AutoConfig.from_pretrained(tmp_dir)
self.assertEqual(new_configuration.hidden_size, 768) self.assertEqual(new_configuration.hidden_size, 768)
def test_repo_versioning_before(self): def test_repo_versioning_before(self):
# This repo has two configuration files, one for v5.0.0 and above with an added token, one for versions lower. # This repo has two configuration files, one for v4.0.0 and above with a different hidden size.
repo = "microsoft/layoutxlm-base" repo = "hf-internal-testing/test-two-configs"
import transformers as new_transformers import transformers as new_transformers
new_transformers.configuration_utils.__version__ = "v5.0.0" new_transformers.configuration_utils.__version__ = "v4.0.0"
new_configuration = new_transformers.models.auto.AutoConfig.from_pretrained(repo) new_configuration = new_transformers.models.auto.AutoConfig.from_pretrained(repo)
self.assertEqual(new_configuration.tokenizer_class, None) self.assertEqual(new_configuration.hidden_size, 2)
# Testing an older version by monkey-patching the version in the module it's used. # Testing an older version by monkey-patching the version in the module it's used.
import transformers as old_transformers import transformers as old_transformers
old_transformers.configuration_utils.__version__ = "v3.0.0" old_transformers.configuration_utils.__version__ = "v3.0.0"
old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo) old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
self.assertEqual(old_configuration.tokenizer_class, "XLMRobertaTokenizer") self.assertEqual(old_configuration.hidden_size, 768)
...@@ -15,7 +15,10 @@ ...@@ -15,7 +15,10 @@
import contextlib import contextlib
import importlib import importlib
import io import io
import json
import tempfile
import unittest import unittest
from pathlib import Path
import transformers import transformers
...@@ -31,6 +34,7 @@ from transformers.file_utils import ( ...@@ -31,6 +34,7 @@ from transformers.file_utils import (
RepositoryNotFoundError, RepositoryNotFoundError,
RevisionNotFoundError, RevisionNotFoundError,
filename_to_url, filename_to_url,
get_file_from_repo,
get_from_cache, get_from_cache,
has_file, has_file,
hf_bucket_url, hf_bucket_url,
...@@ -128,6 +132,31 @@ class GetFromCacheTests(unittest.TestCase): ...@@ -128,6 +132,31 @@ class GetFromCacheTests(unittest.TestCase):
self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME)) self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME))
self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", FLAX_WEIGHTS_NAME)) self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", FLAX_WEIGHTS_NAME))
def test_get_file_from_repo_distant(self):
# `get_file_from_repo` returns None if the file does not exist
self.assertIsNone(get_file_from_repo("bert-base-cased", "ahah.txt"))
# The function raises if the repository does not exist.
with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"):
get_file_from_repo("bert-base-case", "config.json")
# The function raises if the revision does not exist.
with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"):
get_file_from_repo("bert-base-cased", "config.json", revision="ahaha")
resolved_file = get_file_from_repo("bert-base-cased", "config.json")
# The name is the cached name which is not very easy to test, so instead we load the content.
config = json.loads(open(resolved_file, "r").read())
self.assertEqual(config["hidden_size"], 768)
def test_get_file_from_repo_local(self):
with tempfile.TemporaryDirectory() as tmp_dir:
filename = Path(tmp_dir) / "a.txt"
filename.touch()
self.assertEqual(get_file_from_repo(tmp_dir, "a.txt"), str(filename))
self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt"))
class ContextManagerTests(unittest.TestCase): class ContextManagerTests(unittest.TestCase):
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO) @unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
......
...@@ -108,6 +108,8 @@ class TokenizerVersioningTest(unittest.TestCase): ...@@ -108,6 +108,8 @@ class TokenizerVersioningTest(unittest.TestCase):
json_tokenizer["model"]["vocab"]["huggingface"] = len(tokenizer) json_tokenizer["model"]["vocab"]["huggingface"] = len(tokenizer)
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
# Hack to save this in the tokenizer_config.json
tokenizer.init_kwargs["fast_tokenizer_files"] = ["tokenizer.4.0.0.json"]
tokenizer.save_pretrained(tmp_dir) tokenizer.save_pretrained(tmp_dir)
json.dump(json_tokenizer, open(os.path.join(tmp_dir, "tokenizer.4.0.0.json"), "w")) json.dump(json_tokenizer, open(os.path.join(tmp_dir, "tokenizer.4.0.0.json"), "w"))
...@@ -120,6 +122,8 @@ class TokenizerVersioningTest(unittest.TestCase): ...@@ -120,6 +122,8 @@ class TokenizerVersioningTest(unittest.TestCase):
# Will need to be adjusted if we reach v42 and this test is still here. # Will need to be adjusted if we reach v42 and this test is still here.
# Should pick the old tokenizer file as the version of Transformers is < 4.0.0 # Should pick the old tokenizer file as the version of Transformers is < 4.0.0
shutil.move(os.path.join(tmp_dir, "tokenizer.4.0.0.json"), os.path.join(tmp_dir, "tokenizer.42.0.0.json")) shutil.move(os.path.join(tmp_dir, "tokenizer.4.0.0.json"), os.path.join(tmp_dir, "tokenizer.42.0.0.json"))
tokenizer.init_kwargs["fast_tokenizer_files"] = ["tokenizer.42.0.0.json"]
tokenizer.save_pretrained(tmp_dir)
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir) new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir)
self.assertEqual(len(new_tokenizer), len(tokenizer)) self.assertEqual(len(new_tokenizer), len(tokenizer))
json_tokenizer = json.loads(new_tokenizer._tokenizer.to_str()) json_tokenizer = json.loads(new_tokenizer._tokenizer.to_str())
...@@ -127,7 +131,7 @@ class TokenizerVersioningTest(unittest.TestCase): ...@@ -127,7 +131,7 @@ class TokenizerVersioningTest(unittest.TestCase):
def test_repo_versioning(self): def test_repo_versioning(self):
# This repo has two tokenizer files, one for v4.0.0 and above with an added token, one for versions lower. # This repo has two tokenizer files, one for v4.0.0 and above with an added token, one for versions lower.
repo = "sgugger/finetuned-bert-mrpc" repo = "hf-internal-testing/test-two-tokenizers"
# This should pick the new tokenizer file as the version of Transformers is > 4.0.0 # This should pick the new tokenizer file as the version of Transformers is > 4.0.0
tokenizer = AutoTokenizer.from_pretrained(repo) tokenizer = AutoTokenizer.from_pretrained(repo)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment