Unverified Commit 403d530e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Auto feature extractor (#11097)

* AutoFeatureExtractor

* Init and first tests

* Tests

* Damn you gitignore

* Quality

* Defensive test for when not all backends are here

* Use pattern for Speech2Text models
parent 520198f5
...@@ -9,8 +9,7 @@ __pycache__/ ...@@ -9,8 +9,7 @@ __pycache__/
*.so *.so
# tests and logs # tests and logs
tests/fixtures/* tests/fixtures/cached_*_text.txt
!tests/fixtures/sample_text_no_unicode.txt
logs/ logs/
lightning_logs/ lightning_logs/
lang_code_data/ lang_code_data/
......
...@@ -44,6 +44,13 @@ AutoTokenizer ...@@ -44,6 +44,13 @@ AutoTokenizer
:members: :members:
AutoFeatureExtractor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AutoFeatureExtractor
:members:
AutoModel AutoModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -45,6 +45,7 @@ from .file_utils import ( ...@@ -45,6 +45,7 @@ from .file_utils import (
_BaseLazyModule, _BaseLazyModule,
is_flax_available, is_flax_available,
is_sentencepiece_available, is_sentencepiece_available,
is_speech_available,
is_tf_available, is_tf_available,
is_tokenizers_available, is_tokenizers_available,
is_torch_available, is_torch_available,
...@@ -102,6 +103,7 @@ _import_structure = { ...@@ -102,6 +103,7 @@ _import_structure = {
"is_py3nvml_available", "is_py3nvml_available",
"is_sentencepiece_available", "is_sentencepiece_available",
"is_sklearn_available", "is_sklearn_available",
"is_speech_available",
"is_tf_available", "is_tf_available",
"is_tokenizers_available", "is_tokenizers_available",
"is_torch_available", "is_torch_available",
...@@ -133,9 +135,11 @@ _import_structure = { ...@@ -133,9 +135,11 @@ _import_structure = {
"models.auto": [ "models.auto": [
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
"CONFIG_MAPPING", "CONFIG_MAPPING",
"FEATURE_EXTRACTOR_MAPPING",
"MODEL_NAMES_MAPPING", "MODEL_NAMES_MAPPING",
"TOKENIZER_MAPPING", "TOKENIZER_MAPPING",
"AutoConfig", "AutoConfig",
"AutoFeatureExtractor",
"AutoTokenizer", "AutoTokenizer",
], ],
"models.bart": ["BartConfig", "BartTokenizer"], "models.bart": ["BartConfig", "BartTokenizer"],
...@@ -202,7 +206,6 @@ _import_structure = { ...@@ -202,7 +206,6 @@ _import_structure = {
"models.speech_to_text": [ "models.speech_to_text": [
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Speech2TextConfig", "Speech2TextConfig",
"Speech2TextFeatureExtractor",
], ],
"models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"], "models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"],
"models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"], "models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"],
...@@ -288,7 +291,6 @@ if is_sentencepiece_available(): ...@@ -288,7 +291,6 @@ if is_sentencepiece_available():
_import_structure["models.pegasus"].append("PegasusTokenizer") _import_structure["models.pegasus"].append("PegasusTokenizer")
_import_structure["models.reformer"].append("ReformerTokenizer") _import_structure["models.reformer"].append("ReformerTokenizer")
_import_structure["models.speech_to_text"].append("Speech2TextTokenizer") _import_structure["models.speech_to_text"].append("Speech2TextTokenizer")
_import_structure["models.speech_to_text"].append("Speech2TextProcessor")
_import_structure["models.t5"].append("T5Tokenizer") _import_structure["models.t5"].append("T5Tokenizer")
_import_structure["models.xlm_prophetnet"].append("XLMProphetNetTokenizer") _import_structure["models.xlm_prophetnet"].append("XLMProphetNetTokenizer")
_import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer") _import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer")
...@@ -339,6 +341,7 @@ if is_tokenizers_available(): ...@@ -339,6 +341,7 @@ if is_tokenizers_available():
if is_sentencepiece_available(): if is_sentencepiece_available():
_import_structure["convert_slow_tokenizer"] = ["SLOW_TO_FAST_CONVERTERS", "convert_slow_tokenizer"] _import_structure["convert_slow_tokenizer"] = ["SLOW_TO_FAST_CONVERTERS", "convert_slow_tokenizer"]
else: else:
from .utils import dummy_tokenizers_objects from .utils import dummy_tokenizers_objects
...@@ -346,6 +349,20 @@ else: ...@@ -346,6 +349,20 @@ else:
name for name in dir(dummy_tokenizers_objects) if not name.startswith("_") name for name in dir(dummy_tokenizers_objects) if not name.startswith("_")
] ]
# Speech-specific objects
if is_speech_available():
_import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor")
if is_sentencepiece_available():
_import_structure["models.speech_to_text"].append("Speech2TextProcessor")
else:
from .utils import dummy_speech_objects
_import_structure["utils.dummy_speech_objects"] = [
name for name in dir(dummy_speech_objects) if not name.startswith("_")
]
# Vision-specific objects # Vision-specific objects
if is_vision_available(): if is_vision_available():
_import_structure["image_utils"] = ["ImageFeatureExtractionMixin"] _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
...@@ -1394,6 +1411,7 @@ if TYPE_CHECKING: ...@@ -1394,6 +1411,7 @@ if TYPE_CHECKING:
is_py3nvml_available, is_py3nvml_available,
is_sentencepiece_available, is_sentencepiece_available,
is_sklearn_available, is_sklearn_available,
is_speech_available,
is_tf_available, is_tf_available,
is_tokenizers_available, is_tokenizers_available,
is_torch_available, is_torch_available,
...@@ -1429,9 +1447,11 @@ if TYPE_CHECKING: ...@@ -1429,9 +1447,11 @@ if TYPE_CHECKING:
from .models.auto import ( from .models.auto import (
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, ALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
CONFIG_MAPPING, CONFIG_MAPPING,
FEATURE_EXTRACTOR_MAPPING,
MODEL_NAMES_MAPPING, MODEL_NAMES_MAPPING,
TOKENIZER_MAPPING, TOKENIZER_MAPPING,
AutoConfig, AutoConfig,
AutoFeatureExtractor,
AutoTokenizer, AutoTokenizer,
) )
from .models.bart import BartConfig, BartTokenizer from .models.bart import BartConfig, BartTokenizer
...@@ -1494,11 +1514,7 @@ if TYPE_CHECKING: ...@@ -1494,11 +1514,7 @@ if TYPE_CHECKING:
from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig, RetriBertTokenizer from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig, RetriBertTokenizer
from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer
from .models.speech_to_text import ( from .models.speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Speech2TextConfig,
Speech2TextFeatureExtractor,
)
from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer
from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer
...@@ -1585,7 +1601,7 @@ if TYPE_CHECKING: ...@@ -1585,7 +1601,7 @@ if TYPE_CHECKING:
from .models.mt5 import MT5Tokenizer from .models.mt5 import MT5Tokenizer
from .models.pegasus import PegasusTokenizer from .models.pegasus import PegasusTokenizer
from .models.reformer import ReformerTokenizer from .models.reformer import ReformerTokenizer
from .models.speech_to_text import Speech2TextProcessor, Speech2TextTokenizer from .models.speech_to_text import Speech2TextTokenizer
from .models.t5 import T5Tokenizer from .models.t5 import T5Tokenizer
from .models.xlm_prophetnet import XLMProphetNetTokenizer from .models.xlm_prophetnet import XLMProphetNetTokenizer
from .models.xlm_roberta import XLMRobertaTokenizer from .models.xlm_roberta import XLMRobertaTokenizer
...@@ -1627,9 +1643,19 @@ if TYPE_CHECKING: ...@@ -1627,9 +1643,19 @@ if TYPE_CHECKING:
if is_sentencepiece_available(): if is_sentencepiece_available():
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, convert_slow_tokenizer from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, convert_slow_tokenizer
else: else:
from .utils.dummy_tokenizers_objects import * from .utils.dummy_tokenizers_objects import *
if is_speech_available():
from .models.speech_to_text import Speech2TextFeatureExtractor
if is_sentencepiece_available():
from .models.speech_to_text import Speech2TextProcessor
else:
from .utils.dummy_speech_objects import *
if is_vision_available(): if is_vision_available():
from .image_utils import ImageFeatureExtractionMixin from .image_utils import ImageFeatureExtractionMixin
from .models.vit import ViTFeatureExtractor from .models.vit import ViTFeatureExtractor
......
...@@ -43,6 +43,7 @@ deps = { ...@@ -43,6 +43,7 @@ deps = {
"sphinx-copybutton": "sphinx-copybutton", "sphinx-copybutton": "sphinx-copybutton",
"sphinx-markdown-tables": "sphinx-markdown-tables", "sphinx-markdown-tables": "sphinx-markdown-tables",
"sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3", "sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",
"sphinxext-opengraph": "sphinxext-opengraph==0.4.1",
"sphinx": "sphinx==3.2.1", "sphinx": "sphinx==3.2.1",
"starlette": "starlette", "starlette": "starlette",
"tensorflow-cpu": "tensorflow-cpu>=2.3", "tensorflow-cpu": "tensorflow-cpu>=2.3",
......
...@@ -325,6 +325,13 @@ class FeatureExtractionMixin: ...@@ -325,6 +325,13 @@ class FeatureExtractionMixin:
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
user_agent = {"file_type": "feature extractor", "from_auto_class": from_auto_class}
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline
if is_offline_mode() and not local_files_only: if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True") logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True local_files_only = True
...@@ -349,6 +356,7 @@ class FeatureExtractionMixin: ...@@ -349,6 +356,7 @@ class FeatureExtractionMixin:
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent,
) )
# Load feature_extractor dict # Load feature_extractor dict
with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader: with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader:
...@@ -426,6 +434,7 @@ class FeatureExtractionMixin: ...@@ -426,6 +434,7 @@ class FeatureExtractionMixin:
:obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this feature extractor instance. :obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this feature extractor instance.
""" """
output = copy.deepcopy(self.__dict__) output = copy.deepcopy(self.__dict__)
output["feature_extractor_type"] = self.__class__.__name__
return output return output
......
...@@ -397,6 +397,11 @@ def is_torchaudio_available(): ...@@ -397,6 +397,11 @@ def is_torchaudio_available():
return _torchaudio_available return _torchaudio_available
def is_speech_available():
# For now this depends on torchaudio but the exact dependency might evolve in the future.
return _torchaudio_available
def torch_only_method(fn): def torch_only_method(fn):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if not _torch_available: if not _torch_available:
...@@ -513,6 +518,13 @@ explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/ins ...@@ -513,6 +518,13 @@ explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/ins
""" """
# docstyle-ignore
SPEECH_IMPORT_ERROR = """
{0} requires the torchaudio library but it was not found in your environment. You can install it with pip:
`pip install torchaudio`
"""
# docstyle-ignore # docstyle-ignore
VISION_IMPORT_ERROR = """ VISION_IMPORT_ERROR = """
{0} requires the PIL library but it was not found in your environment. You can install it with pip: {0} requires the PIL library but it was not found in your environment. You can install it with pip:
...@@ -586,6 +598,12 @@ def requires_scatter(obj): ...@@ -586,6 +598,12 @@ def requires_scatter(obj):
raise ImportError(SCATTER_IMPORT_ERROR.format(name)) raise ImportError(SCATTER_IMPORT_ERROR.format(name))
def requires_speech(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_speech_available():
raise ImportError(SPEECH_IMPORT_ERROR.format(name))
def requires_vision(obj): def requires_vision(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_vision_available(): if not is_vision_available():
......
...@@ -23,6 +23,7 @@ from ...file_utils import _BaseLazyModule, is_flax_available, is_tf_available, i ...@@ -23,6 +23,7 @@ from ...file_utils import _BaseLazyModule, is_flax_available, is_tf_available, i
_import_structure = { _import_structure = {
"configuration_auto": ["ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"], "configuration_auto": ["ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"],
"feature_extraction_auto": ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"],
"tokenization_auto": ["TOKENIZER_MAPPING", "AutoTokenizer"], "tokenization_auto": ["TOKENIZER_MAPPING", "AutoTokenizer"],
} }
...@@ -104,6 +105,7 @@ if is_flax_available(): ...@@ -104,6 +105,7 @@ if is_flax_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig
from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
if is_torch_available(): if is_torch_available():
......
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" AutoFeatureExtractor class. """
from collections import OrderedDict
from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import is_speech_available, is_vision_available
from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .configuration_auto import replace_list_option_in_docstrings
if is_speech_available():
from ..speech_to_text.feature_extraction_speech_to_text import Speech2TextFeatureExtractor
else:
Speech2TextFeatureExtractor = None
if is_vision_available():
from ..vit.feature_extraction_vit import ViTFeatureExtractor
else:
ViTFeatureExtractor = None
# Build the list of all feature extractors
FEATURE_EXTRACTOR_MAPPING = OrderedDict(
[
("s2t", Speech2TextFeatureExtractor),
("vit", ViTFeatureExtractor),
("wav2vec2", Wav2Vec2FeatureExtractor),
]
)
def feature_extractor_class_from_name(class_name: str):
for c in FEATURE_EXTRACTOR_MAPPING.values():
if c is not None and c.__name__ == class_name:
return c
class AutoFeatureExtractor:
r"""
This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the
library when created with the :meth:`AutoFeatureExtractor.from_pretrained` class method.
This class cannot be instantiated directly using ``__init__()`` (throws an error).
"""
def __init__(self):
raise EnvironmentError(
"AutoFeatureExtractor is designed to be instantiated "
"using the `AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)` method."
)
@classmethod
@replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING)
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r"""
Instantiate one of the feature extractor classes of the library from a pretrained model vocabulary.
The tokenizer class to instantiate is selected based on the :obj:`model_type` property of the config object
(either passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's
missing, by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:
List options
Params:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
This can be either:
- a string, the `model id` of a pretrained feature_extractor hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a feature extractor file saved using the
:func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` method, e.g.,
``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``.
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
standard cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force to (re-)download the feature extractor files and override the cached versions
if they exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file
exists.
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (:obj:`str` or `bool`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`False`, then this function returns just the final feature extractor object. If :obj:`True`,
then this functions returns a :obj:`Tuple(feature_extractor, unused_kwargs)` where `unused_kwargs` is a
dictionary consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the
part of ``kwargs`` which has not been used to update ``feature_extractor`` and is otherwise ignored.
kwargs (:obj:`Dict[str, Any]`, `optional`):
The values in kwargs of any keys which are feature extractor attributes will be used to override the
loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
controlled by the ``return_unused_kwargs`` keyword parameter.
.. note::
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
Examples::
>>> from transformers import AutoFeatureExtractor
>>> # Download vocabulary from huggingface.co and cache.
>>> feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/wav2vec2-base-960h')
>>> # If vocabulary files are in a directory (e.g. feature extractor was saved using `save_pretrained('./test/saved_model/')`)
>>> feature_extractor = AutoFeatureExtractor.from_pretrained('./test/saved_model/')
"""
kwargs["_from_auto"] = True
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
if "feature_extractor_type" in config_dict:
feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"])
return feature_extractor_class.from_dict(config_dict, **kwargs)
else:
# Fallback: use pattern matching on the string.
for pattern, feature_extractor_class in FEATURE_EXTRACTOR_MAPPING.items():
if pattern in str(pretrained_model_name_or_path):
return feature_extractor_class.from_dict(config_dict, **kwargs)
raise ValueError(
f"Unrecognized model in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in "
"its feature_extraction_config.json, or contain one of the following strings "
f"in its name: {', '.join(FEATURE_EXTRACTOR_MAPPING.keys())}"
)
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_sentencepiece_available, is_torch_available from ...file_utils import _BaseLazyModule, is_sentencepiece_available, is_speech_available, is_torch_available
_import_structure = { _import_structure = {
...@@ -25,13 +25,17 @@ _import_structure = { ...@@ -25,13 +25,17 @@ _import_structure = {
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Speech2TextConfig", "Speech2TextConfig",
], ],
"feature_extraction_speech_to_text": ["Speech2TextFeatureExtractor"],
} }
if is_sentencepiece_available(): if is_sentencepiece_available():
_import_structure["processing_speech_to_text"] = ["Speech2TextProcessor"]
_import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"] _import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"]
if is_speech_available():
_import_structure["feature_extraction_speech_to_text"] = ["Speech2TextFeatureExtractor"]
if is_sentencepiece_available():
_import_structure["processing_speech_to_text"] = ["Speech2TextProcessor"]
if is_torch_available(): if is_torch_available():
_import_structure["modeling_speech_to_text"] = [ _import_structure["modeling_speech_to_text"] = [
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST", "SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -43,12 +47,16 @@ if is_torch_available(): ...@@ -43,12 +47,16 @@ if is_torch_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig from .configuration_speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig
from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor
if is_sentencepiece_available(): if is_sentencepiece_available():
from .processing_speech_to_text import Speech2TextProcessor
from .tokenization_speech_to_text import Speech2TextTokenizer from .tokenization_speech_to_text import Speech2TextTokenizer
if is_speech_available():
from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor
if is_sentencepiece_available():
from .processing_speech_to_text import Speech2TextProcessor
if is_torch_available(): if is_torch_available():
from .modeling_speech_to_text import ( from .modeling_speech_to_text import (
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST, SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
......
...@@ -19,19 +19,15 @@ Feature extractor class for Speech2Text ...@@ -19,19 +19,15 @@ Feature extractor class for Speech2Text
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
import torch
import torchaudio.compliance.kaldi as ta_kaldi
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...file_utils import PaddingStrategy, TensorType, is_torch_available, is_torchaudio_available from ...file_utils import PaddingStrategy, TensorType
from ...utils import logging from ...utils import logging
if is_torch_available():
import torch
if is_torchaudio_available():
import torchaudio.compliance.kaldi as ta_kaldi
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -75,8 +71,6 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor): ...@@ -75,8 +71,6 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
normalize_vars=True, normalize_vars=True,
**kwargs **kwargs
): ):
if not is_torchaudio_available():
raise ImportError("`Speech2TextFeatureExtractor` requires torchaudio: `pip install torchaudio`.")
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
self.num_mel_bins = num_mel_bins self.num_mel_bins = num_mel_bins
self.do_ceptral_normalize = do_ceptral_normalize self.do_ceptral_normalize = do_ceptral_normalize
......
...@@ -110,11 +110,6 @@ class ReformerTokenizer: ...@@ -110,11 +110,6 @@ class ReformerTokenizer:
requires_sentencepiece(self) requires_sentencepiece(self)
class Speech2TextProcessor:
def __init__(self, *args, **kwargs):
requires_sentencepiece(self)
class Speech2TextTokenizer: class Speech2TextTokenizer:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_sentencepiece(self) requires_sentencepiece(self)
......
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..file_utils import requires_speech
class Speech2TextFeatureExtractor:
def __init__(self, *args, **kwargs):
requires_speech(self)
class Speech2TextProcessor:
def __init__(self, *args, **kwargs):
requires_speech(self)
{
"feature_extractor_type": "Wav2Vec2FeatureExtractor"
}
\ No newline at end of file
# coding=utf-8
# Copyright 2021 the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from transformers import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor, Wav2Vec2FeatureExtractor
SAMPLE_FEATURE_EXTRACTION_CONFIG = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
)
class AutoFeatureExtractorTest(unittest.TestCase):
def test_feature_extractor_from_model_shortcut(self):
config = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_feature_extractor_from_local_file(self):
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG)
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
def test_pattern_matching_fallback(self):
"""
In cases where config.json doesn't include a model_type,
perform a few safety checks on the config mapping's order.
"""
# no key string should be included in a later key string (typical failure case)
keys = list(FEATURE_EXTRACTOR_MAPPING.keys())
for i, key in enumerate(keys):
self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))
...@@ -20,12 +20,15 @@ import unittest ...@@ -20,12 +20,15 @@ import unittest
import numpy as np import numpy as np
from transformers import Speech2TextFeatureExtractor from transformers import is_speech_available
from transformers.testing_utils import require_torch, require_torchaudio from transformers.testing_utils import require_torch, require_torchaudio
from .test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin from .test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
if is_speech_available():
from transformers import Speech2TextFeatureExtractor
global_rng = random.Random() global_rng = random.Random()
...@@ -101,7 +104,7 @@ class Speech2TextFeatureExtractionTester(unittest.TestCase): ...@@ -101,7 +104,7 @@ class Speech2TextFeatureExtractionTester(unittest.TestCase):
@require_torchaudio @require_torchaudio
class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = Speech2TextFeatureExtractor feature_extraction_class = Speech2TextFeatureExtractor if is_speech_available() else None
def setUp(self): def setUp(self):
self.feat_extract_tester = Speech2TextFeatureExtractionTester(self) self.feat_extract_tester = Speech2TextFeatureExtractionTester(self)
......
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from transformers import Speech2TextFeatureExtractor, Speech2TextProcessor, Speech2TextTokenizer from transformers import Speech2TextTokenizer, is_speech_available
from transformers.file_utils import FEATURE_EXTRACTOR_NAME from transformers.file_utils import FEATURE_EXTRACTOR_NAME
from transformers.models.speech_to_text.tokenization_speech_to_text import VOCAB_FILES_NAMES, save_json from transformers.models.speech_to_text.tokenization_speech_to_text import VOCAB_FILES_NAMES, save_json
from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio
...@@ -27,6 +27,10 @@ from transformers.testing_utils import require_sentencepiece, require_torch, req ...@@ -27,6 +27,10 @@ from transformers.testing_utils import require_sentencepiece, require_torch, req
from .test_feature_extraction_speech_to_text import floats_list from .test_feature_extraction_speech_to_text import floats_list
if is_speech_available():
from transformers import Speech2TextFeatureExtractor, Speech2TextProcessor
SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
......
...@@ -26,7 +26,7 @@ _re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") ...@@ -26,7 +26,7 @@ _re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
_re_test_backend = re.compile(r"^\s+if\s+is\_([a-z]*)\_available\(\):\s*$") _re_test_backend = re.compile(r"^\s+if\s+is\_([a-z]*)\_available\(\):\s*$")
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "tokenizers", "vision"] BACKENDS = ["torch", "tf", "flax", "sentencepiece", "speech", "tokenizers", "vision"]
DUMMY_CONSTANT = """ DUMMY_CONSTANT = """
......
...@@ -18,7 +18,7 @@ import re ...@@ -18,7 +18,7 @@ import re
PATH_TO_TRANSFORMERS = "src/transformers" PATH_TO_TRANSFORMERS = "src/transformers"
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "tokenizers", "vision"] BACKENDS = ["torch", "tf", "flax", "sentencepiece", "speech", "tokenizers", "vision"]
# Catches a line with a key-values pattern: "bla": ["foo", "bar"] # Catches a line with a key-values pattern: "bla": ["foo", "bar"]
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]') _re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment