Unverified Commit 7a32e472 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Custom feature extractor (#15630)

* Rework AutoFeatureExtractor.from_pretrained internal

* Custom feature extractor

* Add more tests

* Add support for custom feature extractor code

* Clean up
parent fcb0f743
...@@ -26,6 +26,7 @@ import numpy as np ...@@ -26,6 +26,7 @@ import numpy as np
from requests import HTTPError from requests import HTTPError
from .dynamic_module_utils import custom_object_save
from .file_utils import ( from .file_utils import (
FEATURE_EXTRACTOR_NAME, FEATURE_EXTRACTOR_NAME,
EntryNotFoundError, EntryNotFoundError,
...@@ -205,6 +206,8 @@ class FeatureExtractionMixin: ...@@ -205,6 +206,8 @@ class FeatureExtractionMixin:
extractors. extractors.
""" """
_auto_class = None
def __init__(self, **kwargs): def __init__(self, **kwargs):
"""Set elements of `kwargs` as attributes.""" """Set elements of `kwargs` as attributes."""
# Pop "processor_class" as it should be saved as private attribute # Pop "processor_class" as it should be saved as private attribute
...@@ -316,6 +319,12 @@ class FeatureExtractionMixin: ...@@ -316,6 +319,12 @@ class FeatureExtractionMixin:
""" """
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
if self._auto_class is not None:
custom_object_save(self, save_directory, config=self)
os.makedirs(save_directory, exist_ok=True) os.makedirs(save_directory, exist_ok=True)
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME) output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME)
...@@ -539,3 +548,29 @@ class FeatureExtractionMixin: ...@@ -539,3 +548,29 @@ class FeatureExtractionMixin:
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}" return f"{self.__class__.__name__} {self.to_json_string()}"
@classmethod
def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"):
"""
Register this class with a given auto class. This should only be used for custom feature extractors as the ones
in the library are already mapped with `AutoFeatureExtractor`.
<Tip warning={true}>
This API is experimental and may have some slight breaking changes in the next releases.
</Tip>
Args:
auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`):
The auto class to register this new feature extractor with.
"""
if not isinstance(auto_class, str):
auto_class = auto_class.__name__
import transformers.models.auto as auto_module
if not hasattr(auto_module, auto_class):
raise ValueError(f"{auto_class} is not a valid auto class.")
cls._auto_class = auto_class
...@@ -14,23 +14,28 @@ ...@@ -14,23 +14,28 @@
# limitations under the License. # limitations under the License.
""" AutoFeatureExtractor class.""" """ AutoFeatureExtractor class."""
import importlib import importlib
import json
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, Optional, Union
# Build the list of all feature extractors # Build the list of all feature extractors
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module
from ...feature_extraction_utils import FeatureExtractionMixin from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo
from ...utils import logging
from .auto_factory import _LazyAutoMapping from .auto_factory import _LazyAutoMapping
from .configuration_auto import ( from .configuration_auto import (
CONFIG_MAPPING_NAMES, CONFIG_MAPPING_NAMES,
AutoConfig, AutoConfig,
config_class_to_model_type,
model_type_to_module_name, model_type_to_module_name,
replace_list_option_in_docstrings, replace_list_option_in_docstrings,
) )
logger = logging.get_logger(__name__)
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict( FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
[ [
("beit", "BeitFeatureExtractor"), ("beit", "BeitFeatureExtractor"),
...@@ -66,6 +71,96 @@ def feature_extractor_class_from_name(class_name: str): ...@@ -66,6 +71,96 @@ def feature_extractor_class_from_name(class_name: str):
return None return None
def get_feature_extractor_config(
pretrained_model_name_or_path: Union[str, os.PathLike],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
**kwargs,
):
"""
Loads the tokenizer configuration from a pretrained model tokenizer configuration.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
under a user or organization name, like `dbmdz/bert-base-german-cased`.
- a path to a *directory* containing a configuration file saved using the
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision(`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
<Tip>
Passing `use_auth_token=True` is required when you want to use a private model.
</Tip>
Returns:
`Dict`: The configuration of the tokenizer.
Examples:
```python
# Download configuration from huggingface.co and cache.
tokenizer_config = get_tokenizer_config("bert-base-uncased")
# This model does not have a tokenizer config so the result will be an empty dict.
tokenizer_config = get_tokenizer_config("xlm-roberta-base")
# Save a pretrained tokenizer locally and you can reload its config
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
tokenizer.save_pretrained("tokenizer-test")
tokenizer_config = get_tokenizer_config("tokenizer-test")
```"""
resolved_config_file = get_file_from_repo(
pretrained_model_name_or_path,
FEATURE_EXTRACTOR_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
)
if resolved_config_file is None:
logger.info(
"Could not locate the feature extractor configuration file, will try to use the model config instead."
)
return {}
with open(resolved_config_file, encoding="utf-8") as reader:
return json.load(reader)
class AutoFeatureExtractor: class AutoFeatureExtractor:
r""" r"""
This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the
...@@ -128,6 +223,10 @@ class AutoFeatureExtractor: ...@@ -128,6 +223,10 @@ class AutoFeatureExtractor:
functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
`kwargs` which has not been used to update `feature_extractor` and is otherwise ignored. `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
trust_remote_code (`bool`, *optional*, defaults to `False`):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to `True` for repositories you trust and in which you have read the code, as it will
execute code present on the Hub on your local machine.
kwargs (`Dict[str, Any]`, *optional*): kwargs (`Dict[str, Any]`, *optional*):
The values in kwargs of any keys which are feature extractor attributes will be used to override the The values in kwargs of any keys which are feature extractor attributes will be used to override the
loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
...@@ -151,35 +250,54 @@ class AutoFeatureExtractor: ...@@ -151,35 +250,54 @@ class AutoFeatureExtractor:
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/") >>> feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/")
```""" ```"""
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
trust_remote_code = kwargs.pop("trust_remote_code", False)
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
is_feature_extraction_file = os.path.isfile(pretrained_model_name_or_path) config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
is_directory = os.path.isdir(pretrained_model_name_or_path) and os.path.exists( feature_extractor_class = config_dict.get("feature_extractor_type", None)
os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME) feature_extractor_auto_map = None
) if "AutoFeatureExtractor" in config_dict.get("auto_map", {}):
feature_extractor_auto_map = config_dict["auto_map"]["AutoFeatureExtractor"]
has_local_config = (
os.path.exists(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)) if is_directory else False
)
# load config, if it can be loaded # If we don't find the feature extractor class in the feature extractor config, let's try the model config.
if not is_feature_extraction_file and (has_local_config or not is_directory): if feature_extractor_class is None and feature_extractor_auto_map is None:
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
# It could be in `config.feature_extractor_type``
feature_extractor_class = getattr(config, "feature_extractor_type", None)
if hasattr(config, "auto_map") and "AutoFeatureExtractor" in config.auto_map:
feature_extractor_auto_map = config.auto_map["AutoFeatureExtractor"]
kwargs["_from_auto"] = True if feature_extractor_class is not None:
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) # If we have custom code for a feature extractor, we get the proper class.
if feature_extractor_auto_map is not None:
if not trust_remote_code:
raise ValueError(
f"Loading {pretrained_model_name_or_path} requires you to execute the feature extractor file "
"in that repo on your local machine. Make sure you have read the code there to avoid "
"malicious use, then set the option `trust_remote_code=True` to remove this error."
)
if kwargs.get("revision", None) is None:
logger.warning(
"Explicitly passing a `revision` is encouraged when loading a feature extractor with custom "
"code to ensure no malicious code has been contributed in a newer revision."
)
model_type = config_class_to_model_type(type(config).__name__) module_file, class_name = feature_extractor_auto_map.split(".")
feature_extractor_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
)
else:
feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class)
if "feature_extractor_type" in config_dict:
feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"])
return feature_extractor_class.from_dict(config_dict, **kwargs) return feature_extractor_class.from_dict(config_dict, **kwargs)
elif model_type is not None: # Last try: we use the FEATURE_EXTRACTOR_MAPPING.
return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs) elif type(config) in FEATURE_EXTRACTOR_MAPPING:
feature_extractor_class = FEATURE_EXTRACTOR_MAPPING[type(config)]
return feature_extractor_class.from_dict(config_dict, **kwargs)
raise ValueError( raise ValueError(
f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in " f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a "
f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: " f"`feature_extractor_type` key in its {FEATURE_EXTRACTOR_NAME} of {CONFIG_NAME}, or one of the following "
f"{', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES.keys())}" "`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES.keys())}"
) )
...@@ -82,3 +82,9 @@ class AutoFeatureExtractorTest(unittest.TestCase): ...@@ -82,3 +82,9 @@ class AutoFeatureExtractorTest(unittest.TestCase):
"hf-internal-testing/config-no-model does not appear to have a file named preprocessor_config.json.", "hf-internal-testing/config-no-model does not appear to have a file named preprocessor_config.json.",
): ):
_ = AutoFeatureExtractor.from_pretrained("hf-internal-testing/config-no-model") _ = AutoFeatureExtractor.from_pretrained("hf-internal-testing/config-no-model")
def test_from_pretrained_dynamic_feature_extractor(self):
model = AutoFeatureExtractor.from_pretrained(
"hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True
)
self.assertEqual(model.__class__.__name__, "NewFeatureExtractor")
...@@ -16,9 +16,21 @@ ...@@ -16,9 +16,21 @@
import json import json
import os import os
import sys
import tempfile import tempfile
import unittest
from pathlib import Path
from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError
from transformers import AutoFeatureExtractor
from transformers.file_utils import is_torch_available, is_vision_available from transformers.file_utils import is_torch_available, is_vision_available
from transformers.testing_utils import PASS, USER, is_staging_test
sys.path.append(str(Path(__file__).parent.parent / "utils"))
from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402
if is_torch_available(): if is_torch_available():
...@@ -29,6 +41,9 @@ if is_vision_available(): ...@@ -29,6 +41,9 @@ if is_vision_available():
from PIL import Image from PIL import Image
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
def prepare_image_inputs(feature_extract_tester, equal_resolution=False, numpify=False, torchify=False): def prepare_image_inputs(feature_extract_tester, equal_resolution=False, numpify=False, torchify=False):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True. or a list of PyTorch tensors if one specifies torchify=True.
...@@ -99,3 +114,41 @@ class FeatureExtractionSavingTestMixin: ...@@ -99,3 +114,41 @@ class FeatureExtractionSavingTestMixin:
def test_init_without_params(self): def test_init_without_params(self):
feat_extract = self.feature_extraction_class() feat_extract = self.feature_extraction_class()
self.assertIsNotNone(feat_extract) self.assertIsNotNone(feat_extract)
@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._token = login(username=USER, password=PASS)
@classmethod
def tearDownClass(cls):
try:
delete_repo(token=cls._token, name="test-dynamic-feature-extractor")
except HTTPError:
pass
def test_push_to_hub_dynamic_feature_extractor(self):
CustomFeatureExtractor.register_for_auto_class()
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from=f"{USER}/test-dynamic-feature-extractor", use_auth_token=self._token)
feature_extractor.save_pretrained(tmp_dir)
# This has added the proper auto_map field to the config
self.assertDictEqual(
feature_extractor.auto_map,
{"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"},
)
# The code has been copied from fixtures
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "custom_feature_extraction.py")))
repo.push_to_hub()
new_feature_extractor = AutoFeatureExtractor.from_pretrained(
f"{USER}/test-dynamic-feature-extractor", trust_remote_code=True
)
# Can't make an isinstance check because the new_feature_extractor is from the CustomFeatureExtractor class of a dynamic module
self.assertEqual(new_feature_extractor.__class__.__name__, "CustomFeatureExtractor")
from transformers import Wav2Vec2FeatureExtractor
class CustomFeatureExtractor(Wav2Vec2FeatureExtractor):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment