Unverified Commit 9a06b6b1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[FeatureExtractorSavingUtils] Refactor PretrainedFeatureExtractor (#10594)



* save first version

* finish refactor

* finish refactor

* correct naming

* correct naming

* shorter names

* Update src/transformers/feature_extraction_common_utils.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* change name

* finish
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent b6a28e9a
...@@ -14,16 +14,24 @@ ...@@ -14,16 +14,24 @@
Feature Extractor Feature Extractor
----------------------------------------------------------------------------------------------------------------------- -----------------------------------------------------------------------------------------------------------------------
A feature extractor is in charge of preparing read-in audio files for a speech model. This includes feature extraction, A feature extractor is in charge of preparing input features for a multi-modal model. This includes feature extraction
such as processing audio files to, *e.g.*, Log-Mel Spectrogram features, but also padding, normalization, and from sequences, *e.g.*, pre-processing audio files to Log-Mel Spectrogram features, feature extraction from images
conversion to Numpy, PyTorch, and TensorFlow tensors. *e.g.* cropping image image files, but also padding, normalization, and conversion to Numpy, PyTorch, and TensorFlow
tensors.
PreTrainedFeatureExtractor FeatureExtractionMixin
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.PreTrainedFeatureExtractor .. autoclass:: transformers.feature_extraction_utils.FeatureExtractionMixin
:members: from_pretrained, save_pretrained, pad :members: from_pretrained, save_pretrained
SequenceFeatureExtractor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.SequenceFeatureExtractor
:members: pad
BatchFeature BatchFeature
......
...@@ -246,7 +246,7 @@ _import_structure = { ...@@ -246,7 +246,7 @@ _import_structure = {
"SpecialTokensMixin", "SpecialTokensMixin",
"TokenSpan", "TokenSpan",
], ],
"feature_extraction_utils": ["PreTrainedFeatureExtractor", "BatchFeature"], "feature_extraction_sequence_utils": ["SequenceFeatureExtractor", "BatchFeature"],
"trainer_callback": [ "trainer_callback": [
"DefaultFlowCallback", "DefaultFlowCallback",
"EarlyStoppingCallback", "EarlyStoppingCallback",
...@@ -1257,7 +1257,7 @@ if TYPE_CHECKING: ...@@ -1257,7 +1257,7 @@ if TYPE_CHECKING:
) )
# Feature Extractor # Feature Extractor
from .feature_extraction_utils import BatchFeature, PreTrainedFeatureExtractor from .feature_extraction_utils import BatchFeature, SequenceFeatureExtractor
# Files and general utilities # Files and general utilities
from .file_utils import ( from .file_utils import (
......
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Sequence feature extraction class for common feature extrcactors to preprocess sequences.
"""
from typing import Dict, List, Optional, Union
import numpy as np
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from .file_utils import (
PaddingStrategy,
TensorType,
_is_tensorflow,
_is_torch,
is_tf_available,
is_torch_available,
to_py_obj,
)
from .utils import logging
logger = logging.get_logger(__name__)
class SequenceFeatureExtractor(FeatureExtractionMixin):
"""
This is a general feature extraction class for speech recognition.
Args:
feature_size (:obj:`int`):
The feature dimension of the extracted features.
sampling_rate (:obj:`int`):
The sampling rate at which the audio files should be digitalized expressed in Hertz per second (Hz).
padding_value (:obj:`float`):
The value that is used to fill the padding values / vectors.
"""
def __init__(self, feature_size: int, sampling_rate: int, padding_value: float, **kwargs):
self.feature_size = feature_size
self.sampling_rate = sampling_rate
self.padding_value = padding_value
self.padding_side = kwargs.pop("padding_side", "right")
self.return_attention_mask = kwargs.pop("return_attention_mask", True)
# Additional attributes without default values
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
def pad(
self,
processed_features: Union[
BatchFeature,
List[BatchFeature],
Dict[str, BatchFeature],
Dict[str, List[BatchFeature]],
List[Dict[str, BatchFeature]],
],
padding: Union[bool, str, PaddingStrategy] = True,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchFeature:
"""
Pad input values / input vectors or a batch of input values / input vectors up to predefined length or to the
max sequence length in the batch.
Padding side (left/right) padding values are defined at the feature extractor level (with
``self.padding_side``, ``self.padding_value``)
.. note::
If the ``processed_features`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors,
the result will use the same type unless you provide a different tensor type with ``return_tensors``. In
the case of PyTorch tensors, you will lose the specific device of your tensors however.
Args:
processed_features (:class:`~transformers.BatchFeature`, list of :class:`~transformers.BatchFeature`, :obj:`Dict[str, List[float]]`, :obj:`Dict[str, List[List[float]]` or :obj:`List[Dict[str, List[float]]]`):
Processed inputs. Can represent one input (:class:`~transformers.BatchFeature` or :obj:`Dict[str,
List[float]]`) or a batch of input values / vectors (list of :class:`~transformers.BatchFeature`,
`Dict[str, List[List[float]]]` or `List[Dict[str, List[float]]]`) so you can use this method during
preprocessing as well as in a PyTorch Dataloader collate function.
Instead of :obj:`List[float]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow
tensors), see the note above for the return type.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
>= 7.5 (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
return_attention_mask (:obj:`bool`, `optional`):
Whether to return the attention mask. If left to the default, will return the attention mask according
to the specific feature_extractor's default.
`What are attention masks? <../glossary.html#attention-mask>`__
return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
"""
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(processed_features, (list, tuple)) and isinstance(processed_features[0], (dict, BatchFeature)):
processed_features = {
key: [example[key] for example in processed_features] for key in processed_features[0].keys()
}
# The model's main input name, usually `input_values`, has be passed for padding
if self.model_input_names[0] not in processed_features:
raise ValueError(
"You should supply an instance of :class:`~transformers.BatchFeature` or list of :class:`~transformers.BatchFeature` to this method"
f"that includes {self.model_input_names[0]}, but you provided {list(processed_features.keys())}"
)
required_input = processed_features[self.model_input_names[0]]
return_attention_mask = (
return_attention_mask if return_attention_mask is not None else self.return_attention_mask
)
if not required_input:
if return_attention_mask:
processed_features["attention_mask"] = []
return processed_features
# If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
# and rebuild them afterwards if no return_tensors is specified
# Note that we lose the specific device the tensor may be on for PyTorch
first_element = required_input[0]
if isinstance(first_element, (list, tuple)):
# first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
index = 0
while len(required_input[index]) == 0:
index += 1
if index < len(required_input):
first_element = required_input[index][0]
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
if not isinstance(first_element, (float, int, list, tuple)):
if is_tf_available() and _is_tensorflow(first_element):
return_tensors = "tf" if return_tensors is None else return_tensors
elif is_torch_available() and _is_torch(first_element):
return_tensors = "pt" if return_tensors is None else return_tensors
elif isinstance(first_element, np.ndarray):
return_tensors = "np" if return_tensors is None else return_tensors
else:
raise ValueError(
f"type of {first_element} unknown: {type(first_element)}. "
f"Should be one of a python, numpy, pytorch or tensorflow object."
)
for key, value in processed_features.items():
processed_features[key] = to_py_obj(value)
# Convert padding_strategy in PaddingStrategy
padding_strategy, max_length, _ = self._get_padding_strategies(padding=padding, max_length=max_length)
required_input = processed_features[self.model_input_names[0]]
if required_input and not isinstance(required_input[0], (list, tuple)):
processed_features = self._pad(
processed_features,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
return BatchFeature(processed_features, tensor_type=return_tensors)
batch_size = len(required_input)
assert all(
len(v) == batch_size for v in processed_features.values()
), "Some items in the output dictionary have a different batch size than others."
if padding_strategy == PaddingStrategy.LONGEST:
max_length = max(len(inputs) for inputs in required_input)
padding_strategy = PaddingStrategy.MAX_LENGTH
batch_outputs = {}
for i in range(batch_size):
inputs = dict((k, v[i]) for k, v in processed_features.items())
outputs = self._pad(
inputs,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
for key, value in outputs.items():
if key not in batch_outputs:
batch_outputs[key] = []
batch_outputs[key].append(value)
return BatchFeature(batch_outputs, tensor_type=return_tensors)
def _pad(
self,
processed_features: Union[Dict[str, List[float]], BatchFeature],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Pad inputs (on left/right and up to predefined length or max length in the batch)
Args:
processed_features: Dictionary of input values (`List[float]`) / input vectors (`List[List[float]]`) or batch of inputs values (`List[List[int]]`) / input vectors (`List[List[List[int]]]`)
max_length: maximum length of the returned list and optionally padding length (see below)
padding_strategy: PaddingStrategy to use for padding.
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The feature_extractor padding sides are defined in self.padding_side:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
>= 7.5 (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
required_input = processed_features[self.model_input_names[0]]
if padding_strategy == PaddingStrategy.LONGEST:
max_length = len(required_input)
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
if needs_to_be_padded:
difference = max_length - len(required_input)
padding_vector = self.feature_size * [self.padding_value] if self.feature_size > 1 else self.padding_value
if self.padding_side == "right":
if return_attention_mask:
processed_features["attention_mask"] = [1] * len(required_input) + [0] * difference
processed_features[self.model_input_names[0]] = required_input + [
padding_vector for _ in range(difference)
]
elif self.padding_side == "left":
if return_attention_mask:
processed_features["attention_mask"] = [0] * difference + [1] * len(required_input)
processed_features[self.model_input_names[0]] = [
padding_vector for _ in range(difference)
] + required_input
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
elif return_attention_mask and "attention_mask" not in processed_features:
processed_features["attention_mask"] = [1] * len(required_input)
return processed_features
def _get_padding_strategies(self, padding=False, max_length=None, pad_to_multiple_of=None, **kwargs):
"""
Find the correct padding strategy
"""
# Get padding strategy
if padding is not False:
if padding is True:
padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
elif not isinstance(padding, PaddingStrategy):
padding_strategy = PaddingStrategy(padding)
elif isinstance(padding, PaddingStrategy):
padding_strategy = padding
else:
padding_strategy = PaddingStrategy.DO_NOT_PAD
# Set max length if needed
if max_length is None:
if padding_strategy == PaddingStrategy.MAX_LENGTH:
raise ValueError(
f"When setting ``padding={PaddingStrategy.MAX_LENGTH}``, make sure that" f" max_length is defined"
)
# Test if we have a padding value
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.padding_value is None):
raise ValueError(
"Asking to pad but the feature_extractor does not have a padding value. "
"Please select a value to use as `padding_value`. For example: `feature_extractor.padding_value = 0.0`."
)
return padding_strategy, max_length, kwargs
...@@ -13,24 +13,22 @@ ...@@ -13,24 +13,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Feature extraction common class for python feature extractors. Feature extraction saving/loading class for common feature extractors.
""" """
import copy import copy
import json import json
import os import os
from collections import UserDict from collections import UserDict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
import numpy as np import numpy as np
from .file_utils import ( from .file_utils import (
FEATURE_EXTRACTOR_NAME, FEATURE_EXTRACTOR_NAME,
PaddingStrategy,
TensorType, TensorType,
_is_jax, _is_jax,
_is_numpy, _is_numpy,
_is_tensorflow,
_is_torch,
_is_torch_device, _is_torch_device,
cached_path, cached_path,
hf_bucket_url, hf_bucket_url,
...@@ -39,23 +37,24 @@ from .file_utils import ( ...@@ -39,23 +37,24 @@ from .file_utils import (
is_remote_url, is_remote_url,
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
to_py_obj,
torch_required, torch_required,
) )
from .utils import logging from .utils import logging
logger = logging.get_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
if is_torch_available(): if is_torch_available():
import torch import torch
logger = logging.get_logger(__name__)
PreTrainedFeatureExtractor = Union["SequenceFeatureExtractor"] # noqa: F821
class BatchFeature(UserDict): class BatchFeature(UserDict):
r""" r"""
Holds the output of the :meth:`~transformers.PreTrainedFeatureExtractor.pad` and feature extractor specific Holds the output of the :meth:`~transformers.SequenceFeatureExtractor.pad` and feature extractor specific
``__call__`` methods. ``__call__`` methods.
This class is derived from a python dictionary and can be used as a dictionary. This class is derived from a python dictionary and can be used as a dictionary.
...@@ -179,8 +178,7 @@ class BatchFeature(UserDict): ...@@ -179,8 +178,7 @@ class BatchFeature(UserDict):
device (:obj:`str` or :obj:`torch.device`): The device to put the tensors on. device (:obj:`str` or :obj:`torch.device`): The device to put the tensors on.
Returns: Returns:
:class:`~transformers.BatchFeature`: The same instance of :class:`~transformers.BatchFeature` after :class:`~transformers.BatchFeature`: The same instance after modification.
modification.
""" """
# This check catches things like APEX blindly calling "to" on all inputs to a module # This check catches things like APEX blindly calling "to" on all inputs to a module
...@@ -193,42 +191,19 @@ class BatchFeature(UserDict): ...@@ -193,42 +191,19 @@ class BatchFeature(UserDict):
return self return self
class PreTrainedFeatureExtractor: class FeatureExtractionMixin:
""" """
This is a general feature extraction class for speech recognition. This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature
extractors.
Args:
feature_size (:obj:`int`):
The feature dimension of the extracted features.
sampling_rate (:obj:`int`):
The sampling rate at which the audio files should be digitalized expressed in Hertz per second (Hz).
padding_value (:obj:`float`):
The value that is used to fill the padding values / vectors.
""" """
def __init__(self, feature_size: int, sampling_rate: int, padding_value: float, **kwargs):
self.feature_size = feature_size
self.sampling_rate = sampling_rate
self.padding_value = padding_value
self.padding_side = kwargs.pop("padding_side", "right")
self.return_attention_mask = kwargs.pop("return_attention_mask", True)
# Additional attributes without default values
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PreTrainedFeatureExtractor": ) -> PreTrainedFeatureExtractor:
r""" r"""
Instantiate a :class:`~transformers.PreTrainedFeatureExtractor` (or a derived class) from a pretrained feature Instantiate a type of :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin` from a feature
extractor. extractor, *e.g.* a derived class of :class:`~transformers.SequenceFeatureExtractor`.
Args: Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
...@@ -238,7 +213,7 @@ class PreTrainedFeatureExtractor: ...@@ -238,7 +213,7 @@ class PreTrainedFeatureExtractor:
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a feature extractor file saved using the - a path to a `directory` containing a feature extractor file saved using the
:func:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g., :func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` method, e.g.,
``./my_model_directory/``. ``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g., - a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``. ``./my_model_directory/feature_extraction_config.json``.
...@@ -262,12 +237,10 @@ class PreTrainedFeatureExtractor: ...@@ -262,12 +237,10 @@ class PreTrainedFeatureExtractor:
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git. identifier allowed by git.
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`): return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`False`, then this function returns just the final feature extractor object. If :obj:`False`, then this function returns just the final feature extractor object. If :obj:`True`,
then this functions returns a :obj:`Tuple(feature_extractor, unused_kwargs)` where `unused_kwargs` is a
If :obj:`True`, then this functions returns a :obj:`Tuple(feature_extractor, unused_kwargs)` where dictionary consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the
`unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not feature extractor part of ``kwargs`` which has not been used to update ``feature_extractor`` and is otherwise ignored.
attributes: i.e., the part of ``kwargs`` which has not been used to update ``feature_extractor`` and is
otherwise ignored.
kwargs (:obj:`Dict[str, Any]`, `optional`): kwargs (:obj:`Dict[str, Any]`, `optional`):
The values in kwargs of any keys which are feature extractor attributes will be used to override the The values in kwargs of any keys which are feature extractor attributes will be used to override the
loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
...@@ -279,13 +252,12 @@ class PreTrainedFeatureExtractor: ...@@ -279,13 +252,12 @@ class PreTrainedFeatureExtractor:
Returns: Returns:
:class:`~transformers.PreTrainedFeatureExtractor`: The feature extractor object instantiated from this A feature extractor of type :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin`.
pretrained model.
Examples:: Examples::
# We can't instantiate directly the base class `PreTrainedFeatureExtractor` so let's show the examples on a # We can't instantiate directly the base class `FeatureExtractionMixin` nor `SequenceFeatureExtractor` so let's show the examples on a
# derived class: Wav2Vec2FeatureExtractor # derived class: `Wav2Vec2FeatureExtractor`
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('facebook/wav2vec2-base-960h') # Download feature_extraction_config from huggingface.co and cache. feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('facebook/wav2vec2-base-960h') # Download feature_extraction_config from huggingface.co and cache.
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('./test/saved_model/') # E.g. feature_extractor (or model) was saved using `save_pretrained('./test/saved_model/')` feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('./test/saved_model/') # E.g. feature_extractor (or model) was saved using `save_pretrained('./test/saved_model/')`
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('./test/saved_model/preprocessor_config.json') feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained('./test/saved_model/preprocessor_config.json')
...@@ -295,7 +267,6 @@ class PreTrainedFeatureExtractor: ...@@ -295,7 +267,6 @@ class PreTrainedFeatureExtractor:
foo=False, return_unused_kwargs=True) foo=False, return_unused_kwargs=True)
assert feature_extractor.return_attention_mask is False assert feature_extractor.return_attention_mask is False
assert unused_kwargs == {'foo': False} assert unused_kwargs == {'foo': False}
""" """
feature_extractor_dict, kwargs = cls.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) feature_extractor_dict, kwargs = cls.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
...@@ -304,7 +275,7 @@ class PreTrainedFeatureExtractor: ...@@ -304,7 +275,7 @@ class PreTrainedFeatureExtractor:
def save_pretrained(self, save_directory: Union[str, os.PathLike]): def save_pretrained(self, save_directory: Union[str, os.PathLike]):
""" """
Save a feature_extractor object to the directory ``save_directory``, so that it can be re-loaded using the Save a feature_extractor object to the directory ``save_directory``, so that it can be re-loaded using the
:func:`~transformers.PreTrainedFeatureExtractor.from_pretrained` class method. :func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.from_pretrained` class method.
Args: Args:
save_directory (:obj:`str` or :obj:`os.PathLike`): save_directory (:obj:`str` or :obj:`os.PathLike`):
...@@ -325,7 +296,8 @@ class PreTrainedFeatureExtractor: ...@@ -325,7 +296,8 @@ class PreTrainedFeatureExtractor:
) -> Tuple[Dict[str, Any], Dict[str, Any]]: ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
""" """
From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a
:class:`~transformers.PreTrainedFeatureExtractor` using ``from_dict``. feature extractor of type :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin` using
``from_dict``.
Parameters: Parameters:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
...@@ -400,21 +372,22 @@ class PreTrainedFeatureExtractor: ...@@ -400,21 +372,22 @@ class PreTrainedFeatureExtractor:
return feature_extractor_dict, kwargs return feature_extractor_dict, kwargs
@classmethod @classmethod
def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> "PreTrainedFeatureExtractor": def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> PreTrainedFeatureExtractor:
""" """
Instantiates a :class:`~transformers.PreTrainedFeatureExtractor` from a Python dictionary of parameters. Instantiates a type of :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin` from a Python
dictionary of parameters.
Args: Args:
feature_extractor_dict (:obj:`Dict[str, Any]`): feature_extractor_dict (:obj:`Dict[str, Any]`):
Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be
retrieved from a pretrained checkpoint by leveraging the retrieved from a pretrained checkpoint by leveraging the
:func:`~transformers.PreTrainedFeatureExtractor.to_dict` method. :func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.to_dict` method.
kwargs (:obj:`Dict[str, Any]`): kwargs (:obj:`Dict[str, Any]`):
Additional parameters from which to initialize the feature extractor object. Additional parameters from which to initialize the feature extractor object.
Returns: Returns:
:class:`~transformers.PreTrainedFeatureExtractor`: The feature extractor object instantiated from those :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin`: The feature extractor object
parameters. instantiated from those parameters.
""" """
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
...@@ -447,18 +420,18 @@ class PreTrainedFeatureExtractor: ...@@ -447,18 +420,18 @@ class PreTrainedFeatureExtractor:
return output return output
@classmethod @classmethod
def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PreTrainedFeatureExtractor": def from_json_file(cls, json_file: Union[str, os.PathLike]) -> PreTrainedFeatureExtractor:
""" """
Instantiates a :class:`~transformers.PreTrainedFeatureExtractor` from the path to a JSON file of parameters. Instantiates a feature extractor of type :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin`
from the path to a JSON file of parameters.
Args: Args:
json_file (:obj:`str` or :obj:`os.PathLike`): json_file (:obj:`str` or :obj:`os.PathLike`):
Path to the JSON file containing the parameters. Path to the JSON file containing the parameters.
Returns: Returns:
:class:`~transformers.PreTrainedFeatureExtractor`: The feature_extractor object instantiated from that JSON A feature extractor of type :class:`~transformers.feature_extraction_utils.FeatureExtractionMixin`: The
file. feature_extractor object instantiated from that JSON file.
""" """
with open(json_file, "r", encoding="utf-8") as reader: with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read() text = reader.read()
...@@ -488,255 +461,3 @@ class PreTrainedFeatureExtractor: ...@@ -488,255 +461,3 @@ class PreTrainedFeatureExtractor:
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}" return f"{self.__class__.__name__} {self.to_json_string()}"
def pad(
self,
processed_features: Union[
BatchFeature,
List[BatchFeature],
Dict[str, BatchFeature],
Dict[str, List[BatchFeature]],
List[Dict[str, BatchFeature]],
],
padding: Union[bool, str, PaddingStrategy] = True,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchFeature:
"""
Pad input values / input vectors or a batch of input values / input vectors up to predefined length or to the
max sequence length in the batch.
Padding side (left/right) padding values are defined at the feature extractor level (with
``self.padding_side``, ``self.padding_value``)
.. note::
If the ``processed_features`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors,
the result will use the same type unless you provide a different tensor type with ``return_tensors``. In
the case of PyTorch tensors, you will lose the specific device of your tensors however.
Args:
processed_features (:class:`~transformers.BatchFeature`, list of :class:`~transformers.BatchFeature`, :obj:`Dict[str, List[float]]`, :obj:`Dict[str, List[List[float]]` or :obj:`List[Dict[str, List[float]]]`):
Processed inputs. Can represent one input (:class:`~transformers.BatchFeature` or :obj:`Dict[str,
List[float]]`) or a batch of input values / vectors (list of :class:`~transformers.BatchFeature`,
`Dict[str, List[List[float]]]` or `List[Dict[str, List[float]]]`) so you can use this method during
preprocessing as well as in a PyTorch Dataloader collate function.
Instead of :obj:`List[float]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow
tensors), see the note above for the return type.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
>= 7.5 (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
return_attention_mask (:obj:`bool`, `optional`):
Whether to return the attention mask. If left to the default, will return the attention mask according
to the specific feature_extractor's default.
`What are attention masks? <../glossary.html#attention-mask>`__
return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
"""
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(processed_features, (list, tuple)) and isinstance(processed_features[0], (dict, BatchFeature)):
processed_features = {
key: [example[key] for example in processed_features] for key in processed_features[0].keys()
}
# The model's main input name, usually `input_values`, has be passed for padding
if self.model_input_names[0] not in processed_features:
raise ValueError(
"You should supply an instance of :class:`~transformers.BatchFeature` or list of :class:`~transformers.BatchFeature` to this method"
f"that includes {self.model_input_names[0]}, but you provided {list(processed_features.keys())}"
)
required_input = processed_features[self.model_input_names[0]]
return_attention_mask = (
return_attention_mask if return_attention_mask is not None else self.return_attention_mask
)
if not required_input:
if return_attention_mask:
processed_features["attention_mask"] = []
return processed_features
# If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
# and rebuild them afterwards if no return_tensors is specified
# Note that we lose the specific device the tensor may be on for PyTorch
first_element = required_input[0]
if isinstance(first_element, (list, tuple)):
# first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
index = 0
while len(required_input[index]) == 0:
index += 1
if index < len(required_input):
first_element = required_input[index][0]
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
if not isinstance(first_element, (float, int, list, tuple)):
if is_tf_available() and _is_tensorflow(first_element):
return_tensors = "tf" if return_tensors is None else return_tensors
elif is_torch_available() and _is_torch(first_element):
return_tensors = "pt" if return_tensors is None else return_tensors
elif isinstance(first_element, np.ndarray):
return_tensors = "np" if return_tensors is None else return_tensors
else:
raise ValueError(
f"type of {first_element} unknown: {type(first_element)}. "
f"Should be one of a python, numpy, pytorch or tensorflow object."
)
for key, value in processed_features.items():
processed_features[key] = to_py_obj(value)
# Convert padding_strategy in PaddingStrategy
padding_strategy, max_length, _ = self._get_padding_strategies(padding=padding, max_length=max_length)
required_input = processed_features[self.model_input_names[0]]
if required_input and not isinstance(required_input[0], (list, tuple)):
processed_features = self._pad(
processed_features,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
return BatchFeature(processed_features, tensor_type=return_tensors)
batch_size = len(required_input)
assert all(
len(v) == batch_size for v in processed_features.values()
), "Some items in the output dictionary have a different batch size than others."
if padding_strategy == PaddingStrategy.LONGEST:
max_length = max(len(inputs) for inputs in required_input)
padding_strategy = PaddingStrategy.MAX_LENGTH
batch_outputs = {}
for i in range(batch_size):
inputs = dict((k, v[i]) for k, v in processed_features.items())
outputs = self._pad(
inputs,
max_length=max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
for key, value in outputs.items():
if key not in batch_outputs:
batch_outputs[key] = []
batch_outputs[key].append(value)
return BatchFeature(batch_outputs, tensor_type=return_tensors)
def _pad(
self,
processed_features: Union[Dict[str, List[float]], BatchFeature],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Pad inputs (on left/right and up to predefined length or max length in the batch)
Args:
processed_features: Dictionary of input values (`List[float]`) / input vectors (`List[List[float]]`) or batch of inputs values (`List[List[int]]`) / input vectors (`List[List[List[int]]]`)
max_length: maximum length of the returned list and optionally padding length (see below)
padding_strategy: PaddingStrategy to use for padding.
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The feature_extractor padding sides are defined in self.padding_side:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
>= 7.5 (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
required_input = processed_features[self.model_input_names[0]]
if padding_strategy == PaddingStrategy.LONGEST:
max_length = len(required_input)
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
if needs_to_be_padded:
difference = max_length - len(required_input)
padding_vector = self.feature_size * [self.padding_value] if self.feature_size > 1 else self.padding_value
if self.padding_side == "right":
if return_attention_mask:
processed_features["attention_mask"] = [1] * len(required_input) + [0] * difference
processed_features[self.model_input_names[0]] = required_input + [
padding_vector for _ in range(difference)
]
elif self.padding_side == "left":
if return_attention_mask:
processed_features["attention_mask"] = [0] * difference + [1] * len(required_input)
processed_features[self.model_input_names[0]] = [
padding_vector for _ in range(difference)
] + required_input
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
elif return_attention_mask and "attention_mask" not in processed_features:
processed_features["attention_mask"] = [1] * len(required_input)
return processed_features
def _get_padding_strategies(self, padding=False, max_length=None, pad_to_multiple_of=None, **kwargs):
"""
Find the correct padding strategy
"""
# Get padding strategy
if padding is not False:
if padding is True:
padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
elif not isinstance(padding, PaddingStrategy):
padding_strategy = PaddingStrategy(padding)
elif isinstance(padding, PaddingStrategy):
padding_strategy = padding
else:
padding_strategy = PaddingStrategy.DO_NOT_PAD
# Set max length if needed
if max_length is None:
if padding_strategy == PaddingStrategy.MAX_LENGTH:
raise ValueError(
f"When setting ``padding={PaddingStrategy.MAX_LENGTH}``, make sure that" f" max_length is defined"
)
# Test if we have a padding value
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.padding_value is None):
raise ValueError(
"Asking to pad but the feature_extractor does not have a padding value. "
"Please select a value to use as `padding_value`. For example: `feature_extractor.padding_value = 0.0`."
)
return padding_strategy, max_length, kwargs
...@@ -20,7 +20,8 @@ from typing import List, Optional, Union ...@@ -20,7 +20,8 @@ from typing import List, Optional, Union
import numpy as np import numpy as np
from ...feature_extraction_utils import BatchFeature, PreTrainedFeatureExtractor from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...file_utils import PaddingStrategy, TensorType from ...file_utils import PaddingStrategy, TensorType
from ...utils import logging from ...utils import logging
...@@ -28,7 +29,7 @@ from ...utils import logging ...@@ -28,7 +29,7 @@ from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class Wav2Vec2FeatureExtractor(PreTrainedFeatureExtractor): class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
r""" r"""
Constructs a Wav2Vec2 feature extractor. Constructs a Wav2Vec2 feature extractor.
......
...@@ -59,7 +59,8 @@ class Wav2Vec2Processor: ...@@ -59,7 +59,8 @@ class Wav2Vec2Processor:
.. note:: .. note::
This class method is simply calling :meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` and This class method is simply calling
:meth:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` and
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.save_pretrained`. Please refer to the :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.save_pretrained`. Please refer to the
docstrings of the methods above for more information. docstrings of the methods above for more information.
...@@ -80,9 +81,9 @@ class Wav2Vec2Processor: ...@@ -80,9 +81,9 @@ class Wav2Vec2Processor:
.. note:: .. note::
This class method is simply calling Wav2Vec2FeatureExtractor's This class method is simply calling Wav2Vec2FeatureExtractor's
:meth:`~transformers.PreTrainedFeatureExtractor.from_pretrained` and Wav2Vec2CTCTokenizer's :meth:`~transformers.feature_extraction_utils.FeatureExtractionMixin.from_pretrained` and
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`. Please refer to the Wav2Vec2CTCTokenizer's :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`.
docstrings of the methods above for more information. Please refer to the docstrings of the methods above for more information.
Args: Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
...@@ -92,12 +93,12 @@ class Wav2Vec2Processor: ...@@ -92,12 +93,12 @@ class Wav2Vec2Processor:
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a feature extractor file saved using the - a path to a `directory` containing a feature extractor file saved using the
:meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g., :meth:`~transformers.SequenceFeatureExtractor.save_pretrained` method, e.g.,
``./my_model_directory/``. ``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g., - a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``. ``./my_model_directory/feature_extraction_config.json``.
**kwargs **kwargs
Additional keyword arguments passed along to both :class:`~transformers.PreTrainedFeatureExtractor` and Additional keyword arguments passed along to both :class:`~transformers.SequenceFeatureExtractor` and
:class:`~transformers.PreTrainedTokenizer` :class:`~transformers.PreTrainedTokenizer`
""" """
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
......
...@@ -727,8 +727,7 @@ class BatchEncoding(UserDict): ...@@ -727,8 +727,7 @@ class BatchEncoding(UserDict):
device (:obj:`str` or :obj:`torch.device`): The device to put the tensors on. device (:obj:`str` or :obj:`torch.device`): The device to put the tensors on.
Returns: Returns:
:class:`~transformers.BatchEncoding`: The same instance of :class:`~transformers.BatchEncoding` after :class:`~transformers.BatchEncoding`: The same instance after modification.
modification.
""" """
# This check catches things like APEX blindly calling "to" on all inputs to a module # This check catches things like APEX blindly calling "to" on all inputs to a module
......
...@@ -18,28 +18,8 @@ import json ...@@ -18,28 +18,8 @@ import json
import os import os
import tempfile import tempfile
import numpy as np
from transformers import BatchFeature
from transformers.testing_utils import require_tf, require_torch
class FeatureExtractionMixin:
# to overwrite at feature extractactor specific tests
feat_extract_tester = None
feature_extraction_class = None
@property
def feat_extract_dict(self):
return self.feat_extract_tester.prepare_feat_extract_dict()
def test_feat_extract_common_properties(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
self.assertTrue(hasattr(feat_extract, "feature_size"))
self.assertTrue(hasattr(feat_extract, "sampling_rate"))
self.assertTrue(hasattr(feat_extract, "padding_value"))
class FeatureExtractionSavingTestMixin:
def test_feat_extract_to_json_string(self): def test_feat_extract_to_json_string(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict) feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
obj = json.loads(feat_extract.to_json_string()) obj = json.loads(feat_extract.to_json_string())
...@@ -68,217 +48,3 @@ class FeatureExtractionMixin: ...@@ -68,217 +48,3 @@ class FeatureExtractionMixin:
def test_init_without_params(self): def test_init_without_params(self):
feat_extract = self.feature_extraction_class() feat_extract = self.feature_extraction_class()
self.assertIsNotNone(feat_extract) self.assertIsNotNone(feat_extract)
def test_batch_feature(self):
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs})
self.assertTrue(all(len(x) == len(y) for x, y in zip(speech_inputs, processed_features[input_name])))
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="np")
batch_features_input = processed_features[input_name]
if len(batch_features_input.shape) < 3:
batch_features_input = batch_features_input[:, :, None]
self.assertTrue(
batch_features_input.shape
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
)
@require_torch
def test_batch_feature_pt(self):
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="pt")
batch_features_input = processed_features[input_name]
if len(batch_features_input.shape) < 3:
batch_features_input = batch_features_input[:, :, None]
self.assertTrue(
batch_features_input.shape
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
)
@require_tf
def test_batch_feature_tf(self):
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="tf")
batch_features_input = processed_features[input_name]
if len(batch_features_input.shape) < 3:
batch_features_input = batch_features_input[:, :, None]
self.assertTrue(
batch_features_input.shape
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
)
def _check_padding(self, numpify=False):
def _inputs_have_equal_length(input):
length = len(input[0])
for input_slice in input[1:]:
if len(input_slice) != length:
return False
return True
def _inputs_are_equal(input_1, input_2):
if len(input_1) != len(input_2):
return False
for input_slice_1, input_slice_2 in zip(input_1, input_2):
if not np.allclose(np.asarray(input_slice_1), np.asarray(input_slice_2), atol=1e-3):
return False
return True
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(numpify=numpify)
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs})
pad_diff = self.feat_extract_tester.seq_length_diff
pad_max_length = self.feat_extract_tester.max_seq_length + pad_diff
pad_min_length = self.feat_extract_tester.min_seq_length
batch_size = self.feat_extract_tester.batch_size
feature_size = self.feat_extract_tester.feature_size
# test padding for List[int] + numpy
input_1 = feat_extract.pad(processed_features, padding=False)[input_name]
input_2 = feat_extract.pad(processed_features, padding="longest")[input_name]
input_3 = feat_extract.pad(processed_features, padding="max_length", max_length=len(speech_inputs[-1]))[
input_name
]
input_4 = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
# max_length parameter has to be provided when setting `padding="max_length"`
with self.assertRaises(ValueError):
feat_extract.pad(processed_features, padding="max_length")[input_name]
input_5 = feat_extract.pad(
processed_features, padding="max_length", max_length=pad_max_length, return_tensors="np"
)[input_name]
self.assertFalse(_inputs_have_equal_length(input_1))
self.assertTrue(_inputs_have_equal_length(input_2))
self.assertTrue(_inputs_have_equal_length(input_3))
self.assertTrue(_inputs_are_equal(input_2, input_3))
self.assertTrue(len(input_1[0]) == pad_min_length)
self.assertTrue(len(input_1[1]) == pad_min_length + pad_diff)
self.assertTrue(input_4.shape[:2] == (batch_size, len(input_3[0])))
self.assertTrue(input_5.shape[:2] == (batch_size, pad_max_length))
if feature_size > 1:
self.assertTrue(input_4.shape[2] == input_5.shape[2] == feature_size)
# test padding for `pad_to_multiple_of` for List[int] + numpy
input_6 = feat_extract.pad(processed_features, pad_to_multiple_of=10)[input_name]
input_7 = feat_extract.pad(processed_features, padding="longest", pad_to_multiple_of=10)[input_name]
input_8 = feat_extract.pad(
processed_features, padding="max_length", pad_to_multiple_of=10, max_length=pad_max_length
)[input_name]
input_9 = feat_extract.pad(
processed_features,
padding="max_length",
pad_to_multiple_of=10,
max_length=pad_max_length,
return_tensors="np",
)[input_name]
self.assertTrue(all(len(x) % 10 == 0 for x in input_6))
self.assertTrue(_inputs_are_equal(input_6, input_7))
expected_mult_pad_length = pad_max_length if pad_max_length % 10 == 0 else (pad_max_length // 10 + 1) * 10
self.assertTrue(all(len(x) == expected_mult_pad_length for x in input_8))
self.assertTrue(input_9.shape[:2], (batch_size, expected_mult_pad_length))
if feature_size > 1:
self.assertTrue(input_9.shape[2] == feature_size)
# Check padding value is correct
padding_vector_sum = (np.ones(self.feat_extract_tester.feature_size) * feat_extract.padding_value).sum()
self.assertTrue(
abs(np.asarray(input_2[0])[pad_min_length:].sum() - padding_vector_sum * (pad_max_length - pad_min_length))
< 1e-3
)
self.assertTrue(
abs(
np.asarray(input_2[1])[pad_min_length + pad_diff :].sum()
- padding_vector_sum * (pad_max_length - pad_min_length - pad_diff)
)
< 1e-3
)
self.assertTrue(
abs(
np.asarray(input_2[2])[pad_min_length + 2 * pad_diff :].sum()
- padding_vector_sum * (pad_max_length - pad_min_length - 2 * pad_diff)
)
< 1e-3
)
self.assertTrue(
abs(input_5[0, pad_min_length:].sum() - padding_vector_sum * (pad_max_length - pad_min_length)) < 1e-3
)
self.assertTrue(
abs(input_9[0, pad_min_length:].sum() - padding_vector_sum * (expected_mult_pad_length - pad_min_length))
< 1e-3
)
def test_padding_from_list(self):
self._check_padding(numpify=False)
def test_padding_from_array(self):
self._check_padding(numpify=True)
@require_torch
def test_padding_accepts_tensors_pt(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs})
input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
input_pt = feat_extract.pad(processed_features, padding="longest", return_tensors="pt")[input_name]
self.assertTrue(abs(input_np.sum() - input_pt.numpy().sum()) < 1e-2)
@require_tf
def test_padding_accepts_tensors_tf(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs})
input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
input_tf = feat_extract.pad(processed_features, padding="longest", return_tensors="tf")[input_name]
self.assertTrue(abs(input_np.sum() - input_tf.numpy().sum()) < 1e-2)
def test_attention_mask(self):
feat_dict = self.feat_extract_dict
feat_dict["return_attention_mask"] = True
feat_extract = self.feature_extraction_class(**feat_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_lenghts = [len(x) for x in speech_inputs]
input_name = feat_extract.model_input_names[0]
processed = BatchFeature({input_name: speech_inputs})
processed = feat_extract.pad(processed, padding="longest", return_tensors="np")
self.assertIn("attention_mask", processed)
self.assertListEqual(list(processed.attention_mask.shape), list(processed[input_name].shape[:2]))
self.assertListEqual(processed.attention_mask.sum(-1).tolist(), input_lenghts)
...@@ -23,7 +23,7 @@ import numpy as np ...@@ -23,7 +23,7 @@ import numpy as np
from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, Wav2Vec2Config, Wav2Vec2FeatureExtractor from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, Wav2Vec2Config, Wav2Vec2FeatureExtractor
from transformers.testing_utils import slow from transformers.testing_utils import slow
from .test_feature_extraction_common import FeatureExtractionMixin from .test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
global_rng = random.Random() global_rng = random.Random()
...@@ -94,7 +94,7 @@ class Wav2Vec2FeatureExtractionTester(unittest.TestCase): ...@@ -94,7 +94,7 @@ class Wav2Vec2FeatureExtractionTester(unittest.TestCase):
return speech_inputs return speech_inputs
class Wav2Vec2FeatureExtractionTest(FeatureExtractionMixin, unittest.TestCase): class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = Wav2Vec2FeatureExtractor feature_extraction_class = Wav2Vec2FeatureExtractor
......
# coding=utf-8
# Copyright 2021 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from transformers import BatchFeature
from transformers.testing_utils import require_tf, require_torch
from .test_feature_extraction_common import FeatureExtractionSavingTestMixin
class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
# to overwrite at feature extractactor specific tests
feat_extract_tester = None
feature_extraction_class = None
@property
def feat_extract_dict(self):
return self.feat_extract_tester.prepare_feat_extract_dict()
def test_feat_extract_common_properties(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
self.assertTrue(hasattr(feat_extract, "feature_size"))
self.assertTrue(hasattr(feat_extract, "sampling_rate"))
self.assertTrue(hasattr(feat_extract, "padding_value"))
def test_batch_feature(self):
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs})
self.assertTrue(all(len(x) == len(y) for x, y in zip(speech_inputs, processed_features[input_name])))
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="np")
batch_features_input = processed_features[input_name]
if len(batch_features_input.shape) < 3:
batch_features_input = batch_features_input[:, :, None]
self.assertTrue(
batch_features_input.shape
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
)
@require_torch
def test_batch_feature_pt(self):
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="pt")
batch_features_input = processed_features[input_name]
if len(batch_features_input.shape) < 3:
batch_features_input = batch_features_input[:, :, None]
self.assertTrue(
batch_features_input.shape
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
)
@require_tf
def test_batch_feature_tf(self):
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(equal_length=True)
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="tf")
batch_features_input = processed_features[input_name]
if len(batch_features_input.shape) < 3:
batch_features_input = batch_features_input[:, :, None]
self.assertTrue(
batch_features_input.shape
== (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.feature_size)
)
def _check_padding(self, numpify=False):
def _inputs_have_equal_length(input):
length = len(input[0])
for input_slice in input[1:]:
if len(input_slice) != length:
return False
return True
def _inputs_are_equal(input_1, input_2):
if len(input_1) != len(input_2):
return False
for input_slice_1, input_slice_2 in zip(input_1, input_2):
if not np.allclose(np.asarray(input_slice_1), np.asarray(input_slice_2), atol=1e-3):
return False
return True
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(numpify=numpify)
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs})
pad_diff = self.feat_extract_tester.seq_length_diff
pad_max_length = self.feat_extract_tester.max_seq_length + pad_diff
pad_min_length = self.feat_extract_tester.min_seq_length
batch_size = self.feat_extract_tester.batch_size
feature_size = self.feat_extract_tester.feature_size
# test padding for List[int] + numpy
input_1 = feat_extract.pad(processed_features, padding=False)[input_name]
input_2 = feat_extract.pad(processed_features, padding="longest")[input_name]
input_3 = feat_extract.pad(processed_features, padding="max_length", max_length=len(speech_inputs[-1]))[
input_name
]
input_4 = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
# max_length parameter has to be provided when setting `padding="max_length"`
with self.assertRaises(ValueError):
feat_extract.pad(processed_features, padding="max_length")[input_name]
input_5 = feat_extract.pad(
processed_features, padding="max_length", max_length=pad_max_length, return_tensors="np"
)[input_name]
self.assertFalse(_inputs_have_equal_length(input_1))
self.assertTrue(_inputs_have_equal_length(input_2))
self.assertTrue(_inputs_have_equal_length(input_3))
self.assertTrue(_inputs_are_equal(input_2, input_3))
self.assertTrue(len(input_1[0]) == pad_min_length)
self.assertTrue(len(input_1[1]) == pad_min_length + pad_diff)
self.assertTrue(input_4.shape[:2] == (batch_size, len(input_3[0])))
self.assertTrue(input_5.shape[:2] == (batch_size, pad_max_length))
if feature_size > 1:
self.assertTrue(input_4.shape[2] == input_5.shape[2] == feature_size)
# test padding for `pad_to_multiple_of` for List[int] + numpy
input_6 = feat_extract.pad(processed_features, pad_to_multiple_of=10)[input_name]
input_7 = feat_extract.pad(processed_features, padding="longest", pad_to_multiple_of=10)[input_name]
input_8 = feat_extract.pad(
processed_features, padding="max_length", pad_to_multiple_of=10, max_length=pad_max_length
)[input_name]
input_9 = feat_extract.pad(
processed_features,
padding="max_length",
pad_to_multiple_of=10,
max_length=pad_max_length,
return_tensors="np",
)[input_name]
self.assertTrue(all(len(x) % 10 == 0 for x in input_6))
self.assertTrue(_inputs_are_equal(input_6, input_7))
expected_mult_pad_length = pad_max_length if pad_max_length % 10 == 0 else (pad_max_length // 10 + 1) * 10
self.assertTrue(all(len(x) == expected_mult_pad_length for x in input_8))
self.assertTrue(input_9.shape[:2], (batch_size, expected_mult_pad_length))
if feature_size > 1:
self.assertTrue(input_9.shape[2] == feature_size)
# Check padding value is correct
padding_vector_sum = (np.ones(self.feat_extract_tester.feature_size) * feat_extract.padding_value).sum()
self.assertTrue(
abs(np.asarray(input_2[0])[pad_min_length:].sum() - padding_vector_sum * (pad_max_length - pad_min_length))
< 1e-3
)
self.assertTrue(
abs(
np.asarray(input_2[1])[pad_min_length + pad_diff :].sum()
- padding_vector_sum * (pad_max_length - pad_min_length - pad_diff)
)
< 1e-3
)
self.assertTrue(
abs(
np.asarray(input_2[2])[pad_min_length + 2 * pad_diff :].sum()
- padding_vector_sum * (pad_max_length - pad_min_length - 2 * pad_diff)
)
< 1e-3
)
self.assertTrue(
abs(input_5[0, pad_min_length:].sum() - padding_vector_sum * (pad_max_length - pad_min_length)) < 1e-3
)
self.assertTrue(
abs(input_9[0, pad_min_length:].sum() - padding_vector_sum * (expected_mult_pad_length - pad_min_length))
< 1e-3
)
def test_padding_from_list(self):
self._check_padding(numpify=False)
def test_padding_from_array(self):
self._check_padding(numpify=True)
@require_torch
def test_padding_accepts_tensors_pt(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs})
input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
input_pt = feat_extract.pad(processed_features, padding="longest", return_tensors="pt")[input_name]
self.assertTrue(abs(input_np.sum() - input_pt.numpy().sum()) < 1e-2)
@require_tf
def test_padding_accepts_tensors_tf(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs})
input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
input_tf = feat_extract.pad(processed_features, padding="longest", return_tensors="tf")[input_name]
self.assertTrue(abs(input_np.sum() - input_tf.numpy().sum()) < 1e-2)
def test_attention_mask(self):
feat_dict = self.feat_extract_dict
feat_dict["return_attention_mask"] = True
feat_extract = self.feature_extraction_class(**feat_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_lenghts = [len(x) for x in speech_inputs]
input_name = feat_extract.model_input_names[0]
processed = BatchFeature({input_name: speech_inputs})
processed = feat_extract.pad(processed, padding="longest", return_tensors="np")
self.assertIn("attention_mask", processed)
self.assertListEqual(list(processed.attention_mask.shape), list(processed[input_name].shape[:2]))
self.assertListEqual(processed.attention_mask.sum(-1).tolist(), input_lenghts)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment