"docs/source/en/model_doc/wav2vec2.md" did not exist on "c44d3675c285278406722b0fa9eb7afff2a3d434"
Unverified Commit 44b21f11 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Save code of registered custom models (#15379)



* Allow dynamic modules to use relative imports

* Work for configs

* Fix last merge conflict

* Save code of registered custom objects

* Map strings to strings

* Fix test

* Add tokenizer

* Rework tests

* Tests

* Ignore fixtures py files for tests

* Tokenizer test + fix collection

* With full path

* Rework integration

* Fix typo

* Remove changes in conftest

* Test for tokenizers

* Add documentation

* Update docs/source/custom_models.mdx
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Add file structure and file content

* Add more doc

* Style

* Update docs/source/custom_models.mdx
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Address review comments
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 623d8cb4
......@@ -67,6 +67,8 @@
title: Debugging
- local: serialization
title: Exporting 🤗 Transformers models
- local: custom_models
title: Sharing custom models
- local: pr_checks
title: Checks on a Pull Request
title: Advanced guides
......
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Sharing custom models
The 🤗 Transformers library is designed to be easily extensible. Every model is fully coded in a given subfolder
of the repository with no abstraction, so you can easily copy a modeling file and tweak it to your needs.
Once you are happy with those tweaks and trained a model you want to share with the community, there are simple steps
to push on the Model Hub not only the weights of your model, but also the code it relies on, so that anyone in the
community can use it, even if it's not present in the 🤗 Transformers library.
This also applies to configurations and tokenizers (support for feature extractors and processors is coming soon).
## Sending the code to the Hub
First, make sure your model is fully defined in a `.py` file. It can rely on relative imports to some other files as
long as all the files are in the same directory (we don't support submodules for this feature yet). For instance,
let's say you have a `modeling.py` file and a `configuration.py` file in a folder of the current working directory
named `awesome_model`, and that the modeling file defines an `AwesomeModel`, the configuration file a `AwesomeConfig`.
```
.
└── awesome_model
├── __init__.py
├── configuration.py
└── modeling.py
```
The `__init__.py` can be empty, it's just there so that Python detects `awesome_model` can be use as a module.
Here is an example of what the configuration file could look like:
```py
from transformers import PretrainedConfig
class AwesomeConfig(PretrainedConfig):
model_type = "awesome"
def __init__(self, attribute=1, hidden_size=42, **kwargs):
self.attribute = attribute
self.hidden_size = hidden_size
super().__init__(**kwargs)
```
and the modeling file could have content like this:
```py
import torch
from transformers import PreTrainedModel
from .configuration import AwesomeConfig
class AwesomeModel(PreTrainedModel):
config_class = AwesomeConfig
base_model_prefix = "base"
def __init__(self, config):
super().__init__(config)
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, x):
return self.linear(x)
```
`AwesomeModel` should subclass [`PreTrainedModel`] and `AwesomeConfig` should subclass [`PretrainedConfig`]. The
easiest way to achieve this is to copy the modeling and configuration files of the model closest to the one you're
coding, and then tweaking them.
<Tip warning={true}>
If copying a modeling files from the library, you will need to replace all the relative imports at the top of the file
to import from the `transformers` package.
</Tip>
Note that you can re-use (or subclass) an existing configuration/model.
To share your model with the community, follow those steps: first import the custom objects.
```py
from awesome_model.configuration import AwesomeConfig
from awesome_model.modeling import AwesomeModel
```
Then you have to tell the library you want to copy the code files of those objects when using the `save_pretrained`
method and properly register them with a given Auto class (especially for models), just run:
```py
AwesomeConfig.register_for_auto_class()
AwesomeModel.register_for_auto_class("AutoModel")
```
Note that there is no need to specify an auto class for the configuration (there is only one auto class for them,
[`AutoConfig`]) but it's different for models. Your custom model could be suitable for sequence classification (in
which case you should do `AwesomeModel.register_for_auto_class("AutoModelForSequenceClassification")`) or any other
task, so you have to specify which one of the auto classes is the correct one for your model.
Next, just create the config and models as you would any other Transformer models:
```py
config = AwesomeConfig()
model = AwesomeModel(config)
```
then train your model. Alternatively, you could load a pretrained checkpoint you have already trained in your model.
Once everything is ready, you just have to do:
```py
model.save_pretrained("save_dir")
```
which will not only save the model weights and the configuration in json format, but also copy the modeling and
configuration `.py` files in this folder, so you can directly upload the result to the Hub.
If you have already logged in to Hugging face with
```bash
huggingface-cli login
```
or in a notebook with
```py
from huggingface_hub import notebook_login
notebook_login()
```
you can push your model and its code to the Hub with the following:
```py
model.push_to_hub("model-identifier")
```
See the [sharing tutorial](model_sharing) for more information on the push to Hub method.
## Using a model with custom code
You can use any configuration, model or tokenizer with custom code files in its repository with the auto-classes and
the `from_pretrained` method. The only thing is that you have to add an extra argument to make sure you have read the
online code and trust the author of that model, to avoid executing malicious code on your machine:
```py
from transformers import AutoModel
model = AutoModel.from_pretrained("model-checkpoint", trust_remote_code=True)
```
It is also strongly encouraged to pass a commit hash as a `revision` to make sure the author of the models did not
update the code with some malicious new lines (unless you fully trust the authors of the models).
```py
commit_hash = "b731e5fae6d80a4a775461251c4388886fb7a249"
model = AutoModel.from_pretrained("model-checkpoint", trust_remote_code=True, revision=commit_hash)
```
Note that when browsing the commit history of the model repo on the Hub, there is a button to easily copy the commit
hash of any commit.
......@@ -93,6 +93,7 @@ _import_structure = {
"debug_utils": [],
"dependency_versions_check": [],
"dependency_versions_table": [],
"dynamic_module_utils": [],
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
"feature_extraction_utils": ["BatchFeature"],
"file_utils": [
......
......@@ -21,13 +21,14 @@ import json
import os
import re
import warnings
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from packaging import version
from requests import HTTPError
from . import __version__
from .dynamic_module_utils import custom_object_save
from .file_utils import (
CONFIG_NAME,
EntryNotFoundError,
......@@ -238,6 +239,7 @@ class PretrainedConfig(PushToHubMixin):
model_type: str = ""
is_composition: bool = False
attribute_map: Dict[str, str] = {}
_auto_class: Optional[str] = None
def __setattr__(self, key, value):
if key in super().__getattribute__("attribute_map"):
......@@ -423,6 +425,12 @@ class PretrainedConfig(PushToHubMixin):
repo = self._create_or_get_repo(save_directory, **kwargs)
os.makedirs(save_directory, exist_ok=True)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
if self._auto_class is not None:
custom_object_save(self, save_directory, config=self)
# If we save using the predefined names, we can load using `from_pretrained`
output_config_file = os.path.join(save_directory, CONFIG_NAME)
......@@ -753,6 +761,8 @@ class PretrainedConfig(PushToHubMixin):
output = copy.deepcopy(self.__dict__)
if hasattr(self.__class__, "model_type"):
output["model_type"] = self.__class__.model_type
if "_auto_class" in output:
del output["_auto_class"]
# Transformers version when serializing the model
output["transformers_version"] = __version__
......@@ -850,6 +860,26 @@ class PretrainedConfig(PushToHubMixin):
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
@classmethod
def register_for_auto_class(cls, auto_class="AutoConfig"):
"""
Register this class with a given auto class. This should only be used for custom configurations as the ones in
the library are already mapped with `AutoConfig`.
Args:
auto_class (`str` or `type`, *optional*, defaults to `"AutoConfig"`):
The auto class to register this new configuration with.
"""
if not isinstance(auto_class, str):
auto_class = auto_class.__name__
import transformers.models.auto as auto_module
if not hasattr(auto_module, auto_class):
raise ValueError(f"{auto_class} is not a valid auto class.")
cls._auto_class = auto_class
def get_configuration_file(configuration_files: List[str]) -> str:
"""
......
......@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities to dynamically load model and tokenizer from the Hub."""
"""Utilities to dynamically load objects from the Hub."""
import importlib
import os
......@@ -24,14 +24,8 @@ from typing import Dict, Optional, Union
from huggingface_hub import HfFolder, model_info
from ...file_utils import (
HF_MODULES_CACHE,
TRANSFORMERS_DYNAMIC_MODULE_NAME,
cached_path,
hf_bucket_url,
is_offline_mode,
)
from ...utils import logging
from .file_utils import HF_MODULES_CACHE, TRANSFORMERS_DYNAMIC_MODULE_NAME, cached_path, hf_bucket_url, is_offline_mode
from .utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -67,6 +61,53 @@ def create_dynamic_module(name: Union[str, os.PathLike]):
init_path.touch()
def get_relative_imports(module_file):
"""
Get the list of modules that are relatively imported in a module file.
Args:
module_file (`str` or `os.PathLike`): The module file to inspect.
"""
with open(module_file, "r", encoding="utf-8") as f:
content = f.read()
# Imports of the form `import .xxx`
relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
# Imports of the form `from .xxx import yyy`
relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
# Unique-ify
return list(set(relative_imports))
def get_relative_import_files(module_file):
"""
Get the list of all files that are needed for a given module. Note that this function recurses through the relative
imports (if a imports b and b imports c, it will return module files for b and c).
Args:
module_file (`str` or `os.PathLike`): The module file to inspect.
"""
no_change = False
files_to_check = [module_file]
all_relative_imports = []
# Let's recurse through all relative imports
while not no_change:
new_imports = []
for f in files_to_check:
new_imports.extend(get_relative_imports(f))
module_path = Path(module_file).parent
new_import_files = [str(module_path / m) for m in new_imports]
new_import_files = [f for f in new_import_files if f not in all_relative_imports]
files_to_check = [f"{f}.py" for f in new_import_files]
no_change = len(new_import_files) == 0
all_relative_imports.extend(files_to_check)
return all_relative_imports
def check_imports(filename):
"""
Check if the current Python environment contains all the libraries that are imported in a file.
......@@ -81,12 +122,6 @@ def check_imports(filename):
# Only keep the top-level module
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
# Imports of the form `import .xxx`
relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
# Imports of the form `from .xxx import yyy`
relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
relative_imports = list(set(relative_imports))
# Unique-ify and test we got them all
imports = list(set(imports))
missing_packages = []
......@@ -102,7 +137,7 @@ def check_imports(filename):
f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
)
return relative_imports
return get_relative_imports(filename)
def get_class_in_module(class_name, module_path):
......@@ -169,7 +204,8 @@ def get_cached_module_file(
</Tip>
Returns:
`str`: The path to the module inside the cache."""
`str`: The path to the module inside the cache.
"""
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
......@@ -218,7 +254,7 @@ def get_cached_module_file(
shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
else:
# Get the commit hash
# TODO: we will get this info in the etag soon, so retrieve it from there.
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.
if isinstance(use_auth_token, str):
token = use_auth_token
elif use_auth_token is True:
......@@ -301,7 +337,7 @@ def get_class_from_dynamic_module(
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or *bool*, *optional*):
use_auth_token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision(`str`, *optional*, defaults to `"main"`):
......@@ -323,7 +359,7 @@ def get_class_from_dynamic_module(
Examples:
```python
# Download module *modeling.py* from huggingface.co and cache then extract the class *MyBertModel* from this
# Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
# module.
cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
```"""
......@@ -340,3 +376,61 @@ def get_class_from_dynamic_module(
local_files_only=local_files_only,
)
return get_class_in_module(class_name, final_module.replace(".py", ""))
def custom_object_save(obj, folder, config=None):
"""
Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
adds the proper fields in a config.
Args:
obj (`Any`): The object for which to save the module files.
folder (`str` or `os.PathLike`): The folder where to save.
config (`PretrainedConfig` or dictionary, `optional`):
A config in which to register the auto_map corresponding to this custom object.
"""
if obj.__module__ == "__main__":
logger.warning(
f"We can't save the code defining {obj} in {folder} as it's been defined in __main__. You should put "
"this code in a separate module so we can include it in the saved folder and make it easier to share via "
"the Hub."
)
# Add object class to the config auto_map
if config is not None:
module_name = obj.__class__.__module__
last_module = module_name.split(".")[-1]
full_name = f"{last_module}.{obj.__class__.__name__}"
# Special handling for tokenizers
if "Tokenizer" in full_name:
slow_tokenizer_class = None
fast_tokenizer_class = None
if obj.__class__.__name__.endswith("Fast"):
# Fast tokenizer: we have the fast tokenizer class and we may have the slow one has an attribute.
fast_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
if getattr(obj, "slow_tokenizer_class", None) is not None:
slow_tokenizer = getattr(obj, "slow_tokenizer_class")
slow_tok_module_name = slow_tokenizer.__module__
last_slow_tok_module = slow_tok_module_name.split(".")[-1]
slow_tokenizer_class = f"{last_slow_tok_module}.{slow_tokenizer.__name__}"
else:
# Slow tokenizer: no way to have the fast class
slow_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
full_name = (slow_tokenizer_class, fast_tokenizer_class)
if isinstance(config, dict):
config["auto_map"] = full_name
elif getattr(config, "auto_map", None) is not None:
config.auto_map[obj._auto_class] = full_name
else:
config.auto_map = {obj._auto_class: full_name}
# Copy module file to the output folder.
object_file = sys.modules[obj.__module__].__file__
dest_file = Path(folder) / (Path(object_file).name)
shutil.copy(object_file, dest_file)
# Gather all relative imports recursively and make sure they are copied as well.
for needed_file in get_relative_import_files(object_file):
dest_file = Path(folder) / (Path(needed_file).name)
shutil.copy(needed_file, dest_file)
......@@ -29,6 +29,7 @@ from jax.random import PRNGKey
from requests import HTTPError
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .file_utils import (
FLAX_WEIGHTS_NAME,
WEIGHTS_NAME,
......@@ -87,6 +88,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
config_class = None
base_model_prefix = ""
main_input_name = "input_ids"
_auto_class = None
def __init__(
self,
......@@ -696,6 +698,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
save_directory = os.path.abspath(save_directory)
# save config as well
self.config.architectures = [self.__class__.__name__[4:]]
# If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
if self._auto_class is not None:
custom_object_save(self, save_directory, config=self.config)
self.config.save_pretrained(save_directory)
# save model
......@@ -711,6 +719,26 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
url = self._push_to_hub(repo, commit_message=commit_message)
logger.info(f"Model pushed to the hub in this commit: {url}")
@classmethod
def register_for_auto_class(cls, auto_class="FlaxAutoModel"):
"""
Register this class with a given auto class. This should only be used for custom models as the ones in the
library are already mapped with an auto class.
Args:
auto_class (`str` or `type`, *optional*, defaults to `"FlaxAutoModel"`):
The auto class to register this new model with.
"""
if not isinstance(auto_class, str):
auto_class = auto_class.__name__
import transformers.models.auto as auto_module
if not hasattr(auto_module, auto_class):
raise ValueError(f"{auto_class} is not a valid auto class.")
cls._auto_class = auto_class
# To update the docstring, we need to copy the method, otherwise we change the original docstring.
FlaxPreTrainedModel.push_to_hub = copy_func(FlaxPreTrainedModel.push_to_hub)
......
......@@ -35,6 +35,7 @@ from huggingface_hub import Repository, list_repo_files
from requests import HTTPError
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .file_utils import (
DUMMY_INPUTS,
TF2_WEIGHTS_NAME,
......@@ -661,6 +662,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
config_class = None
base_model_prefix = ""
main_input_name = "input_ids"
_auto_class = None
# a list of re pattern of tensor names to ignore from the model when loading the model weights
# (and avoid unnecessary warnings).
......@@ -1359,6 +1361,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# Save configuration file
self.config.architectures = [self.__class__.__name__[2:]]
# If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
if self._auto_class is not None:
custom_object_save(self, save_directory, config=self.config)
self.config.save_pretrained(save_directory)
# If we save using the predefined names, we can load using `from_pretrained`
......@@ -2007,6 +2015,26 @@ class TFSequenceSummary(tf.keras.layers.Layer):
return output
@classmethod
def register_for_auto_class(cls, auto_class="TFAutoModel"):
"""
Register this class with a given auto class. This should only be used for custom models as the ones in the
library are already mapped with an auto class.
Args:
auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`):
The auto class to register this new model with.
"""
if not isinstance(auto_class, str):
auto_class = auto_class.__name__
import transformers.models.auto as auto_module
if not hasattr(auto_module, auto_class):
raise ValueError(f"{auto_class} is not a valid auto class.")
cls._auto_class = auto_class
def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:
"""
......
......@@ -32,6 +32,7 @@ from requests import HTTPError
from .activations import get_activation
from .configuration_utils import PretrainedConfig
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from .dynamic_module_utils import custom_object_save
from .file_utils import (
DUMMY_INPUTS,
FLAX_WEIGHTS_NAME,
......@@ -446,6 +447,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config_class = None
base_model_prefix = ""
main_input_name = "input_ids"
_auto_class = None
# a list of re pattern of tensor names to ignore from the model when loading the model weights
# (and avoid unnecessary warnings).
......@@ -1053,6 +1055,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Attach architecture to the config
model_to_save.config.architectures = [model_to_save.__class__.__name__]
# If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
if self._auto_class is not None:
custom_object_save(self, save_directory, config=self.config)
# Save the config
if save_config:
model_to_save.config.save_pretrained(save_directory)
......@@ -1805,6 +1812,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
del state_dict
@classmethod
def register_for_auto_class(cls, auto_class="AutoModel"):
"""
Register this class with a given auto class. This should only be used for custom models as the ones in the
library are already mapped with an auto class.
Args:
auto_class (`str` or `type`, *optional*, defaults to `"AutoModel"`):
The auto class to register this new model with.
"""
if not isinstance(auto_class, str):
auto_class = auto_class.__name__
import transformers.models.auto as auto_module
if not hasattr(auto_module, auto_class):
raise ValueError(f"{auto_class} is not a valid auto class.")
cls._auto_class = auto_class
# To update the docstring, we need to copy the method, otherwise we change the original docstring.
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
......
......@@ -17,10 +17,10 @@ import importlib
from collections import OrderedDict
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module
from ...file_utils import copy_func
from ...utils import logging
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
from .dynamic import get_class_from_dynamic_module
logger = logging.get_logger(__name__)
......
......@@ -20,9 +20,9 @@ from collections import OrderedDict
from typing import List, Union
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module
from ...file_utils import CONFIG_NAME
from ...utils import logging
from .dynamic import get_class_from_dynamic_module
logger = logging.get_logger(__name__)
......
......@@ -21,6 +21,7 @@ from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module
from ...file_utils import get_file_from_repo, is_sentencepiece_available, is_tokenizers_available
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
......@@ -35,7 +36,6 @@ from .configuration_auto import (
model_type_to_module_name,
replace_list_option_in_docstrings,
)
from .dynamic import get_class_from_dynamic_module
logger = logging.get_logger(__name__)
......
......@@ -34,6 +34,7 @@ from packaging import version
from requests import HTTPError
from . import __version__
from .dynamic_module_utils import custom_object_save
from .file_utils import (
EntryNotFoundError,
ExplicitEnum,
......@@ -1435,6 +1436,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
pretrained_vocab_files_map: Dict[str, Dict[str, str]] = {}
pretrained_init_configuration: Dict[str, Dict[str, Any]] = {}
max_model_input_sizes: Dict[str, Optional[int]] = {}
_auto_class: Optional[str] = None
# first name has to correspond to main model input name
# to make sure `tokenizer.pad(...)` works correctly
......@@ -2071,6 +2073,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if getattr(self, "_processor_class", None) is not None:
tokenizer_config["processor_class"] = self._processor_class
# If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
if self._auto_class is not None:
custom_object_save(self, save_directory, config=tokenizer_config)
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
logger.info(f"tokenizer config file saved in {tokenizer_config_file}")
......@@ -3391,6 +3398,26 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
"""
yield
@classmethod
def register_for_auto_class(cls, auto_class="AutoTokenizer"):
"""
Register this class with a given auto class. This should only be used for custom tokenizers as the ones in the
library are already mapped with `AutoTokenizer`.
Args:
auto_class (`str` or `type`, *optional*, defaults to `"AutoTokenizer"`):
The auto class to register this new tokenizer with.
"""
if not isinstance(auto_class, str):
auto_class = auto_class.__name__
import transformers.models.auto as auto_module
if not hasattr(auto_module, auto_class):
raise ValueError(f"{auto_class} is not a valid auto class.")
cls._auto_class = auto_class
def prepare_seq2seq_batch(
self,
src_texts: List[str],
......
......@@ -15,8 +15,10 @@
import importlib
import os
import sys
import tempfile
import unittest
from pathlib import Path
import transformers.models.auto
from transformers.models.auto.configuration_auto import CONFIG_MAPPING, AutoConfig
......@@ -25,11 +27,12 @@ from transformers.models.roberta.configuration_roberta import RobertaConfig
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
sys.path.append(str(Path(__file__).parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig # noqa E402
class NewModelConfig(BertConfig):
model_type = "new-model"
SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
class AutoConfigTest(unittest.TestCase):
......@@ -65,24 +68,24 @@ class AutoConfigTest(unittest.TestCase):
def test_new_config_registration(self):
try:
AutoConfig.register("new-model", NewModelConfig)
AutoConfig.register("custom", CustomConfig)
# Wrong model type will raise an error
with self.assertRaises(ValueError):
AutoConfig.register("model", NewModelConfig)
AutoConfig.register("model", CustomConfig)
# Trying to register something existing in the Transformers library will raise an error
with self.assertRaises(ValueError):
AutoConfig.register("bert", BertConfig)
# Now that the config is registered, it can be used as any other config with the auto-API
config = NewModelConfig()
config = CustomConfig()
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir)
new_config = AutoConfig.from_pretrained(tmp_dir)
self.assertIsInstance(new_config, NewModelConfig)
self.assertIsInstance(new_config, CustomConfig)
finally:
if "new-model" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["new-model"]
if "custom" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["custom"]
def test_repo_not_found(self):
with self.assertRaisesRegex(
......
......@@ -17,9 +17,11 @@ import copy
import json
import os
import shutil
import sys
import tempfile
import unittest
import unittest.mock
from pathlib import Path
from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError
......@@ -28,6 +30,11 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.testing_utils import PASS, USER, is_staging_test
sys.path.append(str(Path(__file__).parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig # noqa E402
config_common_kwargs = {
"return_dict": False,
"output_hidden_states": True,
......@@ -192,23 +199,6 @@ class ConfigTester(object):
self.check_config_arguments_init()
class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
# Make sure this is synchronized with the config above.
FAKE_CONFIG_CODE = """
from transformers import PretrainedConfig
class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
"""
@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
@classmethod
......@@ -263,20 +253,23 @@ class ConfigPushToHubTester(unittest.TestCase):
self.assertEqual(v, getattr(new_config, k))
def test_push_to_hub_dynamic_config(self):
config = FakeConfig(attribute=42)
config.auto_map = {"AutoConfig": "configuration.FakeConfig"}
CustomConfig.register_for_auto_class()
config = CustomConfig(attribute=42)
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-config", use_auth_token=self._token)
config.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "configuration.py"), "w") as f:
f.write(FAKE_CONFIG_CODE)
# This has added the proper auto_map field to the config
self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"})
# The code has been copied from fixtures
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_configuration.py")))
repo.push_to_hub()
new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True)
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
self.assertEqual(new_config.__class__.__name__, "FakeConfig")
self.assertEqual(new_config.__class__.__name__, "CustomConfig")
self.assertEqual(new_config.attribute, 42)
......
......@@ -14,9 +14,10 @@
# limitations under the License.
import copy
import os
import sys
import tempfile
import unittest
from pathlib import Path
from transformers import BertConfig, is_torch_available
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
......@@ -31,9 +32,15 @@ from transformers.testing_utils import (
from .test_modeling_bert import BertModelTester
sys.path.append(str(Path(__file__).parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig # noqa E402
if is_torch_available():
import torch
from test_module.custom_modeling import CustomModel
from transformers import (
AutoConfig,
AutoModel,
......@@ -56,7 +63,6 @@ if is_torch_available():
FunnelModel,
GPT2Config,
GPT2LMHeadModel,
PreTrainedModel,
RobertaForMaskedLM,
T5Config,
T5ForConditionalGeneration,
......@@ -81,51 +87,6 @@ if is_torch_available():
from transformers.models.tapas.modeling_tapas import TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST
class NewModelConfig(BertConfig):
model_type = "new-model"
if is_torch_available():
class NewModel(BertModel):
config_class = NewModelConfig
class FakeModel(PreTrainedModel):
config_class = BertConfig
base_model_prefix = "fake"
def __init__(self, config):
super().__init__(config)
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, x):
return self.linear(x)
def _init_weights(self, module):
pass
# Make sure this is synchronized with the model above.
FAKE_MODEL_CODE = """
import torch
from transformers import BertConfig, PreTrainedModel
class FakeModel(PreTrainedModel):
config_class = BertConfig
base_model_prefix = "fake"
def __init__(self, config):
super().__init__(config)
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, x):
return self.linear(x)
def _init_weights(self, module):
pass
"""
@require_torch
class AutoModelTest(unittest.TestCase):
@slow
......@@ -325,21 +286,26 @@ class AutoModelTest(unittest.TestCase):
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
def test_from_pretrained_dynamic_model_local(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
config.auto_map = {"AutoModel": "modeling.FakeModel"}
model = FakeModel(config)
try:
AutoConfig.register("custom", CustomConfig)
AutoModel.register(CustomConfig, CustomModel)
config = CustomConfig(hidden_size=32)
model = CustomModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "modeling.py"), "w") as f:
f.write(FAKE_MODEL_CODE)
new_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
finally:
if "custom" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["custom"]
if CustomConfig in MODEL_MAPPING._extra_content:
del MODEL_MAPPING._extra_content[CustomConfig]
def test_from_pretrained_dynamic_model_distant(self):
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
self.assertEqual(model.__class__.__name__, "NewModel")
......@@ -349,7 +315,7 @@ class AutoModelTest(unittest.TestCase):
self.assertEqual(model.__class__.__name__, "NewModel")
def test_new_model_registration(self):
AutoConfig.register("new-model", NewModelConfig)
AutoConfig.register("custom", CustomConfig)
auto_classes = [
AutoModel,
......@@ -366,26 +332,27 @@ class AutoModelTest(unittest.TestCase):
with self.subTest(auto_class.__name__):
# Wrong config class will raise an error
with self.assertRaises(ValueError):
auto_class.register(BertConfig, NewModel)
auto_class.register(NewModelConfig, NewModel)
auto_class.register(BertConfig, CustomModel)
auto_class.register(CustomConfig, CustomModel)
# Trying to register something existing in the Transformers library will raise an error
with self.assertRaises(ValueError):
auto_class.register(BertConfig, BertModel)
# Now that the config is registered, it can be used as any other config with the auto-API
tiny_config = BertModelTester(self).get_config()
config = NewModelConfig(**tiny_config.to_dict())
config = CustomConfig(**tiny_config.to_dict())
model = auto_class.from_config(config)
self.assertIsInstance(model, NewModel)
self.assertIsInstance(model, CustomModel)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
new_model = auto_class.from_pretrained(tmp_dir)
self.assertIsInstance(new_model, NewModel)
# The model is a CustomModel but from the new dynamically imported class.
self.assertIsInstance(new_model, CustomModel)
finally:
if "new-model" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["new-model"]
if "custom" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["custom"]
for mapping in (
MODEL_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
......@@ -395,8 +362,8 @@ class AutoModelTest(unittest.TestCase):
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
):
if NewModelConfig in mapping._extra_content:
del mapping._extra_content[NewModelConfig]
if CustomConfig in mapping._extra_content:
del mapping._extra_content[CustomConfig]
def test_repo_not_found(self):
with self.assertRaisesRegex(
......
......@@ -20,9 +20,11 @@ import json
import os
import os.path
import random
import sys
import tempfile
import unittest
import warnings
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
......@@ -55,10 +57,16 @@ from transformers.testing_utils import (
)
sys.path.append(str(Path(__file__).parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig # noqa E402
if is_torch_available():
import torch
from torch import nn
from test_module.custom_modeling import CustomModel
from transformers import (
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
......@@ -2109,61 +2117,6 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(model.dtype, torch.float16)
class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
# Make sure this is synchronized with the config above.
FAKE_CONFIG_CODE = """
from transformers import PretrainedConfig
class FakeConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
"""
if is_torch_available():
class FakeModel(PreTrainedModel):
config_class = BertConfig
base_model_prefix = "fake"
def __init__(self, config):
super().__init__(config)
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, x):
return self.linear(x)
def _init_weights(self, module):
pass
# Make sure this is synchronized with the model above.
FAKE_MODEL_CODE = """
import torch
from transformers import BertConfig, PreTrainedModel
class FakeModel(PreTrainedModel):
config_class = BertConfig
base_model_prefix = "fake"
def __init__(self, config):
super().__init__(config)
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, x):
return self.linear(x)
def _init_weights(self, module):
pass
"""
@require_torch
@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
......@@ -2223,62 +2176,29 @@ class ModelPushToHubTester(unittest.TestCase):
self.assertTrue(torch.equal(p1, p2))
def test_push_to_hub_dynamic_model(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
config.auto_map = {"AutoModel": "modeling.FakeModel"}
model = FakeModel(config)
CustomConfig.register_for_auto_class()
CustomModel.register_for_auto_class()
config = CustomConfig(hidden_size=32)
model = CustomModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-model", use_auth_token=self._token)
model.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "modeling.py"), "w") as f:
f.write(FAKE_MODEL_CODE)
repo.push_to_hub()
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
# Can't make an isinstance check because the new_model is from the FakeModel class of a dynamic module
self.assertEqual(new_model.__class__.__name__, "FakeModel")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model")
new_model = AutoModel.from_config(config, trust_remote_code=True)
self.assertEqual(new_model.__class__.__name__, "FakeModel")
def test_push_to_hub_dynamic_model_and_config(self):
config = FakeConfig(
attribute=42,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
# checks
self.assertDictEqual(
config.auto_map,
{"AutoConfig": "custom_configuration.CustomConfig", "AutoModel": "custom_modeling.CustomModel"},
)
config.auto_map = {"AutoConfig": "configuration.FakeConfig", "AutoModel": "modeling.FakeModel"}
model = FakeModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-model-config", use_auth_token=self._token)
model.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "configuration.py"), "w") as f:
f.write(FAKE_CONFIG_CODE)
with open(os.path.join(tmp_dir, "modeling.py"), "w") as f:
f.write(FAKE_MODEL_CODE)
repo.push_to_hub()
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model-config", trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(new_model.config.__class__.__name__, "FakeConfig")
self.assertEqual(new_model.config.attribute, 42)
# Can't make an isinstance check because the new_model is from the FakeModel class of a dynamic module
self.assertEqual(new_model.__class__.__name__, "FakeModel")
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
# Can't make an isinstance check because the new_model is from the CustomModel class of a dynamic module
self.assertEqual(new_model.__class__.__name__, "CustomModel")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model")
config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
new_model = AutoModel.from_config(config, trust_remote_code=True)
self.assertEqual(new_model.__class__.__name__, "FakeModel")
self.assertEqual(new_model.__class__.__name__, "CustomModel")
......@@ -15,8 +15,10 @@
import os
import shutil
import sys
import tempfile
import unittest
from pathlib import Path
import pytest
......@@ -30,7 +32,6 @@ from transformers import (
CTRLTokenizer,
GPT2Tokenizer,
GPT2TokenizerFast,
PretrainedConfig,
PreTrainedTokenizerFast,
RobertaTokenizer,
RobertaTokenizerFast,
......@@ -52,19 +53,14 @@ from transformers.testing_utils import (
)
class NewConfig(PretrainedConfig):
model_type = "new-model"
sys.path.append(str(Path(__file__).parent.parent / "utils"))
class NewTokenizer(BertTokenizer):
pass
from test_module.custom_configuration import CustomConfig # noqa E402
from test_module.custom_tokenization import CustomTokenizer # noqa E402
if is_tokenizers_available():
class NewTokenizerFast(BertTokenizerFast):
slow_tokenizer_class = NewTokenizer
pass
from test_module.custom_tokenization_fast import CustomTokenizerFast
class AutoTokenizerTest(unittest.TestCase):
......@@ -250,41 +246,43 @@ class AutoTokenizerTest(unittest.TestCase):
def test_new_tokenizer_registration(self):
try:
AutoConfig.register("new-model", NewConfig)
AutoConfig.register("custom", CustomConfig)
AutoTokenizer.register(NewConfig, slow_tokenizer_class=NewTokenizer)
AutoTokenizer.register(CustomConfig, slow_tokenizer_class=CustomTokenizer)
# Trying to register something existing in the Transformers library will raise an error
with self.assertRaises(ValueError):
AutoTokenizer.register(BertConfig, slow_tokenizer_class=BertTokenizer)
tokenizer = NewTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
tokenizer = CustomTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained(tmp_dir)
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir)
self.assertIsInstance(new_tokenizer, NewTokenizer)
self.assertIsInstance(new_tokenizer, CustomTokenizer)
finally:
if "new-model" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["new-model"]
if NewConfig in TOKENIZER_MAPPING._extra_content:
del TOKENIZER_MAPPING._extra_content[NewConfig]
if "custom" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["custom"]
if CustomConfig in TOKENIZER_MAPPING._extra_content:
del TOKENIZER_MAPPING._extra_content[CustomConfig]
@require_tokenizers
def test_new_tokenizer_fast_registration(self):
try:
AutoConfig.register("new-model", NewConfig)
AutoConfig.register("custom", CustomConfig)
# Can register in two steps
AutoTokenizer.register(NewConfig, slow_tokenizer_class=NewTokenizer)
self.assertEqual(TOKENIZER_MAPPING[NewConfig], (NewTokenizer, None))
AutoTokenizer.register(NewConfig, fast_tokenizer_class=NewTokenizerFast)
self.assertEqual(TOKENIZER_MAPPING[NewConfig], (NewTokenizer, NewTokenizerFast))
AutoTokenizer.register(CustomConfig, slow_tokenizer_class=CustomTokenizer)
self.assertEqual(TOKENIZER_MAPPING[CustomConfig], (CustomTokenizer, None))
AutoTokenizer.register(CustomConfig, fast_tokenizer_class=CustomTokenizerFast)
self.assertEqual(TOKENIZER_MAPPING[CustomConfig], (CustomTokenizer, CustomTokenizerFast))
del TOKENIZER_MAPPING._extra_content[NewConfig]
del TOKENIZER_MAPPING._extra_content[CustomConfig]
# Can register in one step
AutoTokenizer.register(NewConfig, slow_tokenizer_class=NewTokenizer, fast_tokenizer_class=NewTokenizerFast)
self.assertEqual(TOKENIZER_MAPPING[NewConfig], (NewTokenizer, NewTokenizerFast))
AutoTokenizer.register(
CustomConfig, slow_tokenizer_class=CustomTokenizer, fast_tokenizer_class=CustomTokenizerFast
)
self.assertEqual(TOKENIZER_MAPPING[CustomConfig], (CustomTokenizer, CustomTokenizerFast))
# Trying to register something existing in the Transformers library will raise an error
with self.assertRaises(ValueError):
......@@ -295,22 +293,22 @@ class AutoTokenizerTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir:
bert_tokenizer = BertTokenizerFast.from_pretrained(SMALL_MODEL_IDENTIFIER)
bert_tokenizer.save_pretrained(tmp_dir)
tokenizer = NewTokenizerFast.from_pretrained(tmp_dir)
tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)
with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained(tmp_dir)
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir)
self.assertIsInstance(new_tokenizer, NewTokenizerFast)
self.assertIsInstance(new_tokenizer, CustomTokenizerFast)
new_tokenizer = AutoTokenizer.from_pretrained(tmp_dir, use_fast=False)
self.assertIsInstance(new_tokenizer, NewTokenizer)
self.assertIsInstance(new_tokenizer, CustomTokenizer)
finally:
if "new-model" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["new-model"]
if NewConfig in TOKENIZER_MAPPING._extra_content:
del TOKENIZER_MAPPING._extra_content[NewConfig]
if "custom" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["custom"]
if CustomConfig in TOKENIZER_MAPPING._extra_content:
del TOKENIZER_MAPPING._extra_content[CustomConfig]
def test_repo_not_found(self):
with self.assertRaisesRegex(
......
......@@ -21,10 +21,12 @@ import os
import pickle
import re
import shutil
import sys
import tempfile
import unittest
from collections import OrderedDict
from itertools import takewhile
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
from huggingface_hub import Repository, delete_repo, login
......@@ -67,6 +69,15 @@ if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel
sys.path.append(str(Path(__file__).parent.parent / "utils"))
from test_module.custom_tokenization import CustomTokenizer # noqa E402
if is_tokenizers_available():
from test_module.custom_tokenization_fast import CustomTokenizerFast
NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilingual"]
SMALL_TRAINING_CORPUS = [
......@@ -3690,28 +3701,6 @@ class TokenizerTesterMixin:
self.rust_tokenizer_class.from_pretrained(tmp_dir_2)
class FakeTokenizer(BertTokenizer):
pass
if is_tokenizers_available():
class FakeTokenizerFast(BertTokenizerFast):
pass
# Make sure this is synchronized with the tokenizers above.
FAKE_TOKENIZER_CODE = """
from transformers import BertTokenizer, BertTokenizerFast
class FakeTokenizer(BertTokenizer):
pass
class FakeTokenizerFast(BertTokenizerFast):
pass
"""
@is_staging_test
class TokenizerPushToHubTester(unittest.TestCase):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]
......@@ -3766,47 +3755,62 @@ class TokenizerPushToHubTester(unittest.TestCase):
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
@require_tokenizers
def test_push_to_hub_dynamic_tokenizer(self):
CustomTokenizer.register_for_auto_class()
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = FakeTokenizer(vocab_file)
tokenizer = CustomTokenizer(vocab_file)
# No fast custom tokenizer
tokenizer._auto_map = ("tokenizer.FakeTokenizer", None)
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-tokenizer", use_auth_token=self._token)
print(os.listdir((tmp_dir)))
tokenizer.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "tokenizer.py"), "w") as f:
f.write(FAKE_TOKENIZER_CODE)
with open(os.path.join(tmp_dir, "tokenizer_config.json")) as f:
tokenizer_config = json.load(f)
self.assertEqual(tokenizer_config["auto_map"], ["custom_tokenization.CustomTokenizer", None])
repo.push_to_hub()
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "FakeTokenizer")
# Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
# Fast and slow custom tokenizer
tokenizer._auto_map = ("tokenizer.FakeTokenizer", "tokenizer.FakeTokenizerFast")
CustomTokenizerFast.register_for_auto_class()
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
bert_tokenizer = BertTokenizerFast.from_pretrained(tmp_dir)
bert_tokenizer.save_pretrained(tmp_dir)
tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-tokenizer", use_auth_token=self._token)
print(os.listdir((tmp_dir)))
tokenizer.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "tokenizer.py"), "w") as f:
f.write(FAKE_TOKENIZER_CODE)
with open(os.path.join(tmp_dir, "tokenizer_config.json")) as f:
tokenizer_config = json.load(f)
self.assertEqual(
tokenizer_config["auto_map"],
["custom_tokenization.CustomTokenizer", "custom_tokenization_fast.CustomTokenizerFast"],
)
repo.push_to_hub()
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "FakeTokenizerFast")
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast")
tokenizer = AutoTokenizer.from_pretrained(
f"{USER}/test-dynamic-tokenizer", use_fast=False, trust_remote_code=True
)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "FakeTokenizer")
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
class TrieTest(unittest.TestCase):
......
from transformers import PretrainedConfig
class CustomConfig(PretrainedConfig):
model_type = "custom"
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment