Unverified Commit d438eee0 authored by Will Rice's avatar Will Rice Committed by GitHub
Browse files

Adding TFWav2Vec2Model (#11617)



* [WIP] Add TFWav2Vec2Model

Work in progress for adding a tensorflow version of Wav2Vec2

* feedback changes

* small fix

* Test Feedback Round 1

* Add SpecAugment and CTC Loss

* correct spec augment mask creation

* docstring and correct copyright

* correct bugs

* remove bogus file

* finish tests correction

* del unnecessary layers

* Update src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* make style

* correct final bug

* Feedback Changes
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 1ed2ebf6
...@@ -399,7 +399,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -399,7 +399,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ | | VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Wav2Vec2 | ✅ | ❌ | ✅ | | ❌ | | Wav2Vec2 | ✅ | ❌ | ✅ | | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ | | XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
...@@ -80,9 +80,22 @@ Wav2Vec2ForCTC ...@@ -80,9 +80,22 @@ Wav2Vec2ForCTC
.. autoclass:: transformers.Wav2Vec2ForCTC .. autoclass:: transformers.Wav2Vec2ForCTC
:members: forward :members: forward
Wav2Vec2ForPreTraining Wav2Vec2ForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Wav2Vec2ForPreTraining .. autoclass:: transformers.Wav2Vec2ForPreTraining
:members: forward :members: forward
TFWav2Vec2Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFWav2Vec2Model
:members: call
TFWav2Vec2ForCTC
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFWav2Vec2ForCTC
:members: call
...@@ -1430,6 +1430,14 @@ if is_tf_available(): ...@@ -1430,6 +1430,14 @@ if is_tf_available():
"TFTransfoXLPreTrainedModel", "TFTransfoXLPreTrainedModel",
] ]
) )
_import_structure["models.wav2vec2"].extend(
[
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFWav2Vec2ForCTC",
"TFWav2Vec2Model",
"TFWav2Vec2PreTrainedModel",
]
)
_import_structure["models.xlm"].extend( _import_structure["models.xlm"].extend(
[ [
"TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -2743,6 +2751,12 @@ if TYPE_CHECKING: ...@@ -2743,6 +2751,12 @@ if TYPE_CHECKING:
TFTransfoXLModel, TFTransfoXLModel,
TFTransfoXLPreTrainedModel, TFTransfoXLPreTrainedModel,
) )
from .models.wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWav2Vec2ForCTC,
TFWav2Vec2Model,
TFWav2Vec2PreTrainedModel,
)
from .models.xlm import ( from .models.xlm import (
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST, TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLMForMultipleChoice, TFXLMForMultipleChoice,
......
...@@ -37,6 +37,7 @@ from . import ( ...@@ -37,6 +37,7 @@ from . import (
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
WEIGHTS_NAME, WEIGHTS_NAME,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
...@@ -79,10 +80,13 @@ from . import ( ...@@ -79,10 +80,13 @@ from . import (
TFRobertaForSequenceClassification, TFRobertaForSequenceClassification,
TFT5ForConditionalGeneration, TFT5ForConditionalGeneration,
TFTransfoXLLMHeadModel, TFTransfoXLLMHeadModel,
TFWav2Vec2Model,
TFXLMRobertaForMaskedLM, TFXLMRobertaForMaskedLM,
TFXLMWithLMHeadModel, TFXLMWithLMHeadModel,
TFXLNetLMHeadModel, TFXLNetLMHeadModel,
TransfoXLConfig, TransfoXLConfig,
Wav2Vec2Config,
Wav2Vec2Model,
XLMConfig, XLMConfig,
XLMRobertaConfig, XLMRobertaConfig,
XLNetConfig, XLNetConfig,
...@@ -287,6 +291,12 @@ MODEL_CLASSES = { ...@@ -287,6 +291,12 @@ MODEL_CLASSES = {
ElectraForPreTraining, ElectraForPreTraining,
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
), ),
"wav2vec2": (
Wav2Vec2Config,
TFWav2Vec2Model,
Wav2Vec2Model,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
),
} }
......
...@@ -163,6 +163,7 @@ from ..transfo_xl.modeling_tf_transfo_xl import ( ...@@ -163,6 +163,7 @@ from ..transfo_xl.modeling_tf_transfo_xl import (
TFTransfoXLLMHeadModel, TFTransfoXLLMHeadModel,
TFTransfoXLModel, TFTransfoXLModel,
) )
from ..wav2vec2.modeling_tf_wav2vec2 import TFWav2Vec2Model
from ..xlm.modeling_tf_xlm import ( from ..xlm.modeling_tf_xlm import (
TFXLMForMultipleChoice, TFXLMForMultipleChoice,
TFXLMForQuestionAnsweringSimple, TFXLMForQuestionAnsweringSimple,
...@@ -218,6 +219,7 @@ from .configuration_auto import ( ...@@ -218,6 +219,7 @@ from .configuration_auto import (
RoFormerConfig, RoFormerConfig,
T5Config, T5Config,
TransfoXLConfig, TransfoXLConfig,
Wav2Vec2Config,
XLMConfig, XLMConfig,
XLMRobertaConfig, XLMRobertaConfig,
XLNetConfig, XLNetConfig,
...@@ -263,6 +265,7 @@ TF_MODEL_MAPPING = OrderedDict( ...@@ -263,6 +265,7 @@ TF_MODEL_MAPPING = OrderedDict(
(PegasusConfig, TFPegasusModel), (PegasusConfig, TFPegasusModel),
(BlenderbotConfig, TFBlenderbotModel), (BlenderbotConfig, TFBlenderbotModel),
(BlenderbotSmallConfig, TFBlenderbotSmallModel), (BlenderbotSmallConfig, TFBlenderbotSmallModel),
(Wav2Vec2Config, TFWav2Vec2Model),
] ]
) )
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available
_import_structure = { _import_structure = {
...@@ -38,6 +38,15 @@ if is_torch_available(): ...@@ -38,6 +38,15 @@ if is_torch_available():
] ]
if is_tf_available():
_import_structure["modeling_tf_wav2vec2"] = [
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFWav2Vec2ForCTC",
"TFWav2Vec2Model",
"TFWav2Vec2PreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config from .configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
...@@ -54,6 +63,14 @@ if TYPE_CHECKING: ...@@ -54,6 +63,14 @@ if TYPE_CHECKING:
Wav2Vec2PreTrainedModel, Wav2Vec2PreTrainedModel,
) )
if is_tf_available():
from .modeling_tf_wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWav2Vec2ForCTC,
TFWav2Vec2Model,
TFWav2Vec2PreTrainedModel,
)
else: else:
import importlib import importlib
......
# coding=utf-8
# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TensorFlow Wav2Vec2 model. """
import inspect
import warnings
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...file_utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import (
TFPreTrainedModel,
booleans_processing,
get_initializer,
keras_serializable,
shape_list,
)
from ...tokenization_utils_base import BatchEncoding
from ...utils import logging
from .configuration_wav2vec2 import Wav2Vec2Config
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
_CONFIG_FOR_DOC = "Wav2Vec2Config"
_TOKENIZER_FOR_DOC = "Wav2Vec2Tokenizer"
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/wav2vec2-base-960h",
"facebook/wav2vec2-large-960h",
"facebook/wav2vec2-large-960h-lv60",
"facebook/wav2vec2-large-960h-lv60-self",
# See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
]
LARGE_NEGATIVE = -1e8
def input_values_processing(func, config, input_values, **kwargs):
"""
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
has to be named accordingly to the parameters name, i.e. :obj:`input_values = tf.keras.Input(shape=(128,),
dtype='float32', name="input_values")` otherwise the order of the tensors will not be guaranteed during the
training.
Args:
func (:obj:`callable`):
The callable function of the TensorFlow model.
config (:class:`~transformers.PretrainedConfig`):
The config of the running model.
**kwargs:
The inputs of the model.
Returns:
Two lists, one for the missing layers, and another one for the unexpected layers.
"""
signature = dict(inspect.signature(func).parameters)
signature.pop("kwargs", None)
signature.pop("self", None)
parameter_names = list(signature.keys())
output = {}
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
if len(kwargs["kwargs_call"]) > 0:
raise ValueError(
f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
)
kwargs.pop("kwargs_call")
for k, v in kwargs.items():
if isinstance(v, allowed_types) or v is None:
output[k] = v
else:
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
if isinstance(input_values, (tuple, list)):
for i, input in enumerate(input_values):
# EagerTensors don't allow to use the .name property so we check for a real Tensor
if type(input) == tf.Tensor:
# Tensor names have always the pattern `name:id` then we check only the
# `name` part
tensor_name = input.name.split(":")[0]
if tensor_name in parameter_names:
output[tensor_name] = input
else:
output[parameter_names[i]] = input
elif isinstance(input, allowed_types) or input is None:
output[parameter_names[i]] = input
else:
raise ValueError(
f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for {parameter_names[i]}."
)
elif isinstance(input_values, (dict, BatchEncoding)):
if "inputs" in input_values:
warnings.warn(
"The `inputs` argument is deprecated and will be removed in a future version, use `input_values` instead.",
FutureWarning,
)
output["input_values"] = input_values.pop("inputs")
if "decoder_cached_states" in input_values:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
output["past_key_values"] = input_values.pop("decoder_cached_states")
for k, v in dict(input_values).items():
if isinstance(v, allowed_types) or v is None:
output[k] = v
elif k not in parameter_names and "args" not in parameter_names:
logger.warning(
f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored."
)
continue
else:
raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
else:
if isinstance(input_values, tf.Tensor) or input_values is None:
output[parameter_names[0]] = input_values
else:
raise ValueError(
f"Data of type {type(input_values)} is not allowed only {allowed_types} is accepted for {parameter_names[0]}."
)
for name in parameter_names:
if name not in list(output.keys()) and name != "args":
output[name] = kwargs.pop(name, signature[name].default)
# When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
# So to respect the proper output we have to add this exception
if "args" in output:
if output["args"] is not None and type(output["args"]) == tf.Tensor:
tensor_name = output["args"].name.split(":")[0]
output[tensor_name] = output["args"]
else:
# `args` in this case is always the first parameter, then `input_values`
output["input_values"] = output["args"]
del output["args"]
if "kwargs" in output:
del output["kwargs"]
boolean_dict = {
k: v
for k, v in output.items()
if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
}
output.update(booleans_processing(config=config, **boolean_dict))
return output
def _sample_without_replacement(distribution, num_samples):
"""
Categorical sampling without replacement is currently not implemented. The gumbel-max trick will do for now - see
https://github.com/tensorflow/tensorflow/issues/9260 for more info
"""
z = -tf.math.log(tf.random.uniform(shape_list(distribution), 0, 1))
_, indices = tf.nn.top_k(distribution + z, num_samples)
return indices
def _scatter_values_on_batch_indices(values, batch_indices, output_shape):
"""
Scatter function as in PyTorch with indices in format (batch_dim, indixes)
"""
indices_shape = shape_list(batch_indices)
# broadcast batch dim to indices_shape
broad_casted_batch_dims = tf.reshape(
tf.broadcast_to(tf.expand_dims(tf.range(indices_shape[0]), axis=-1), indices_shape), [1, -1]
)
# transform batch_indices to pair_indices
pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0))
# scatter values to pair indices
return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), output_shape)
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
min_masks: int = 0,
) -> tf.Tensor:
"""
Computes random mask spans for a given shape
Args:
shape: the the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
attention_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_length: size of the mask
min_masks: minimum number of masked spans
Adapted from `fairseq's data_utils.py
<https://github.com/pytorch/fairseq/blob/e0788f7007a8473a76db573985031f3c94201e79/fairseq/data/data_utils.py#L376>`__.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
)
# compute number of masked spans in batch
num_masked_spans = int(mask_prob * sequence_length / mask_length + tf.random.uniform((1,)))
num_masked_spans = max(num_masked_spans, min_masks)
# make sure num masked indices <= sequence_length
if num_masked_spans * mask_length > sequence_length:
num_masked_spans = sequence_length // mask_length
# SpecAugment mask to fill
spec_aug_mask = tf.zeros((batch_size, sequence_length), dtype=tf.int32)
# uniform distribution to sample from, make sure that offset samples are < sequence_length
uniform_dist = tf.ones((batch_size, sequence_length - (mask_length - 1)))
# get random indices to mask
spec_aug_mask_idxs = _sample_without_replacement(uniform_dist, num_masked_spans)
# expand masked indices to masked spans
spec_aug_mask_idxs = tf.expand_dims(spec_aug_mask_idxs, -1)
spec_aug_mask_idxs = tf.tile(spec_aug_mask_idxs, (1, 1, mask_length))
spec_aug_mask_idxs = tf.reshape(spec_aug_mask_idxs, (batch_size, num_masked_spans * mask_length))
offsets = tf.range(mask_length)[tf.newaxis, tf.newaxis, :]
offsets = tf.tile(offsets, (batch_size, num_masked_spans, 1))
offsets = tf.reshape(offsets, (batch_size, num_masked_spans * mask_length))
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# scatter indices to mask
spec_aug_mask = _scatter_values_on_batch_indices(
tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, spec_aug_mask.shape
)
return tf.cast(spec_aug_mask, tf.float32)
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values_length: int = 0):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len
one_cst = tf.constant(1.0)
mask = tf.cast(mask, dtype=one_cst.dtype)
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
return (one_cst - expanded_mask) * LARGE_NEGATIVE
class TFWav2Vec2GroupNorm(tf.keras.layers.Layer):
"""
From tensorflow-addons https://www.tensorflow.org/addons/api_docs/python/tfa/layers/GroupNormalization
"""
def __init__(
self,
groups: int = 32,
axis: int = -1,
epsilon: float = 1e-3,
center: bool = True,
scale: bool = True,
beta_initializer: tf.keras.initializers.Initializer = "zeros",
gamma_initializer: tf.keras.initializers.Initializer = "ones",
beta_regularizer: tf.keras.regularizers.Regularizer = None,
gamma_regularizer: tf.keras.regularizers.Regularizer = None,
beta_constraint: tf.keras.constraints.Constraint = None,
gamma_constraint: tf.keras.constraints.Constraint = None,
**kwargs,
):
super().__init__(**kwargs)
self.supports_masking = True
self.groups = groups
self.axis = axis
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = tf.keras.initializers.get(beta_initializer)
self.gamma_initializer = tf.keras.initializers.get(gamma_initializer)
self.beta_regularizer = tf.keras.regularizers.get(beta_regularizer)
self.gamma_regularizer = tf.keras.regularizers.get(gamma_regularizer)
self.beta_constraint = tf.keras.constraints.get(beta_constraint)
self.gamma_constraint = tf.keras.constraints.get(gamma_constraint)
self._check_axis()
def build(self, input_shape):
self._check_if_input_shape_is_none(input_shape)
self._set_number_of_groups_for_instance_norm(input_shape)
self._check_size_of_dimensions(input_shape)
self._create_input_spec(input_shape)
self._add_gamma_weight(input_shape)
self._add_beta_weight(input_shape)
self.built = True
super().build(input_shape)
def call(self, inputs):
input_shape = tf.keras.backend.int_shape(inputs)
tensor_input_shape = tf.shape(inputs)
reshaped_inputs, group_shape = self._reshape_into_groups(inputs, input_shape, tensor_input_shape)
normalized_inputs = self._apply_normalization(reshaped_inputs, input_shape)
is_instance_norm = (input_shape[self.axis] // self.groups) == 1
if not is_instance_norm:
outputs = tf.reshape(normalized_inputs, tensor_input_shape)
else:
outputs = normalized_inputs
return outputs
def get_config(self):
config = {
"groups": self.groups,
"axis": self.axis,
"epsilon": self.epsilon,
"center": self.center,
"scale": self.scale,
"beta_initializer": tf.keras.initializers.serialize(self.beta_initializer),
"gamma_initializer": tf.keras.initializers.serialize(self.gamma_initializer),
"beta_regularizer": tf.keras.regularizers.serialize(self.beta_regularizer),
"gamma_regularizer": tf.keras.regularizers.serialize(self.gamma_regularizer),
"beta_constraint": tf.keras.constraints.serialize(self.beta_constraint),
"gamma_constraint": tf.keras.constraints.serialize(self.gamma_constraint),
}
base_config = super().get_config()
return {**base_config, **config}
def compute_output_shape(self, input_shape):
return input_shape
def _reshape_into_groups(self, inputs, input_shape, tensor_input_shape):
group_shape = [tensor_input_shape[i] for i in range(len(input_shape))]
is_instance_norm = (input_shape[self.axis] // self.groups) == 1
if not is_instance_norm:
group_shape[self.axis] = input_shape[self.axis] // self.groups
group_shape.insert(self.axis, self.groups)
group_shape = tf.stack(group_shape)
reshaped_inputs = tf.reshape(inputs, group_shape)
return reshaped_inputs, group_shape
else:
return inputs, group_shape
def _apply_normalization(self, reshaped_inputs, input_shape):
group_shape = tf.keras.backend.int_shape(reshaped_inputs)
group_reduction_axes = list(range(1, len(group_shape)))
is_instance_norm = (input_shape[self.axis] // self.groups) == 1
if not is_instance_norm:
axis = -2 if self.axis == -1 else self.axis - 1
else:
axis = -1 if self.axis == -1 else self.axis - 1
group_reduction_axes.pop(axis)
mean, variance = tf.nn.moments(reshaped_inputs, group_reduction_axes, keepdims=True)
gamma, beta = self._get_reshaped_weights(input_shape)
normalized_inputs = tf.nn.batch_normalization(
reshaped_inputs,
mean=mean,
variance=variance,
scale=gamma,
offset=beta,
variance_epsilon=self.epsilon,
)
return normalized_inputs
def _get_reshaped_weights(self, input_shape):
broadcast_shape = self._create_broadcast_shape(input_shape)
gamma = None
beta = None
if self.scale:
gamma = tf.reshape(self.gamma, broadcast_shape)
if self.center:
beta = tf.reshape(self.beta, broadcast_shape)
return gamma, beta
def _check_if_input_shape_is_none(self, input_shape):
dim = input_shape[self.axis]
if dim is None:
raise ValueError(
"Axis " + str(self.axis) + " of "
"input tensor should have a defined dimension "
"but the layer received an input with shape " + str(input_shape) + "."
)
def _set_number_of_groups_for_instance_norm(self, input_shape):
dim = input_shape[self.axis]
if self.groups == -1:
self.groups = dim
def _check_size_of_dimensions(self, input_shape):
dim = input_shape[self.axis]
if dim < self.groups:
raise ValueError(
"Number of groups (" + str(self.groups) + ") cannot be "
"more than the number of channels (" + str(dim) + ")."
)
if dim % self.groups != 0:
raise ValueError(
"Number of groups (" + str(self.groups) + ") must be a "
"multiple of the number of channels (" + str(dim) + ")."
)
def _check_axis(self):
if self.axis == 0:
raise ValueError(
"You are trying to normalize your batch axis. Do you want to "
"use tf.layer.batch_normalization instead"
)
def _create_input_spec(self, input_shape):
dim = input_shape[self.axis]
self.input_spec = tf.keras.layers.InputSpec(ndim=len(input_shape), axes={self.axis: dim})
def _add_gamma_weight(self, input_shape):
dim = input_shape[self.axis]
shape = (dim,)
if self.scale:
self.gamma = self.add_weight(
shape=shape,
name="gamma",
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint,
)
else:
self.gamma = None
def _add_beta_weight(self, input_shape):
dim = input_shape[self.axis]
shape = (dim,)
if self.center:
self.beta = self.add_weight(
shape=shape,
name="beta",
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint,
)
else:
self.beta = None
def _create_broadcast_shape(self, input_shape):
broadcast_shape = [1] * len(input_shape)
is_instance_norm = (input_shape[self.axis] // self.groups) == 1
if not is_instance_norm:
broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
broadcast_shape.insert(self.axis, self.groups)
else:
broadcast_shape[self.axis] = self.groups
return broadcast_shape
class TFWav2Vec2WeightNormConv1D(tf.keras.layers.Conv1D):
"""Adapted from https://www.tensorflow.org/probability/api_docs/python/tfp/layers/weight_norm/WeightNorm"""
def __init__(self, filters, kernel_size, groups, explicit_padding, **kwargs):
super().__init__(
filters=filters,
kernel_size=kernel_size,
groups=groups,
padding="valid",
use_bias=True,
bias_initializer="he_normal",
**kwargs,
)
self.explicit_padding = explicit_padding
self.filter_axis = 2
self.initialized = False
self.kernel_norm_axes = tf.constant([0, 1])
def _init_norm(self):
"""Set the norm of the weight vector."""
kernel_norm = tf.sqrt(tf.reduce_sum(tf.square(self.weight_v), axis=self.kernel_norm_axes))
self.weight_g.assign(kernel_norm[:, tf.newaxis, tf.newaxis])
def _normalize_kernel(self):
"""Generate normalized weights."""
kernel = tf.nn.l2_normalize(self.weight_v, axis=self.kernel_norm_axes) * tf.transpose(self.weight_g)
self.kernel = tf.transpose(kernel)
def build(self, input_shape):
if not self.built:
super().build(input_shape)
self.kernel = tf.Variable(tf.transpose(self.kernel), name="weight_v", trainable=True)
self.weight_v = self.kernel
self.weight_g = self.add_weight(
name="weight_g",
shape=(int(self.weight_v.shape[self.filter_axis]), 1, 1),
initializer="ones",
dtype=self.weight_v.dtype,
trainable=True,
)
self.bias = self.add_weight(name="bias", shape=(self.filters,), initializer="zeros", trainable=True)
def call(self, inputs):
if not self.initialized:
self._init_norm()
self.initialized = True
self._normalize_kernel()
padded_inputs = tf.pad(inputs, ((0, 0), (self.explicit_padding, self.explicit_padding), (0, 0)))
output = super().call(padded_inputs)
return output
class TFWav2Vec2NoLayerNormConvLayer(tf.keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, layer_id: int = 0, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = tf.keras.layers.Conv1D(
filters=self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
strides=config.conv_stride[layer_id],
use_bias=config.conv_bias,
name="conv",
)
self.activation = get_tf_activation(config.feat_extract_activation)
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.conv(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class TFWav2Vec2LayerNormConvLayer(tf.keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, layer_id: int = 0, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = tf.keras.layers.Conv1D(
filters=self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
strides=config.conv_stride[layer_id],
use_bias=config.conv_bias,
name="conv",
)
self.layer_norm = tf.keras.layers.LayerNormalization(name="layer_norm", epsilon=config.layer_norm_eps)
self.activation = get_tf_activation(config.feat_extract_activation)
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class TFWav2Vec2GroupNormConvLayer(tf.keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, layer_id: int = 0, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = tf.keras.layers.Conv1D(
filters=self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
strides=config.conv_stride[layer_id],
use_bias=config.conv_bias,
name="conv",
)
self.activation = get_tf_activation(config.feat_extract_activation)
self.layer_norm = TFWav2Vec2GroupNorm(
groups=self.out_conv_dim, epsilon=config.layer_norm_eps, name="layer_norm"
)
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.conv(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class TFWav2Vec2PositionalConvEmbedding(tf.keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.conv = TFWav2Vec2WeightNormConv1D(
filters=config.hidden_size,
kernel_size=config.num_conv_pos_embeddings,
groups=config.num_conv_pos_embedding_groups,
explicit_padding=config.num_conv_pos_embeddings // 2,
name="conv",
)
self.padding = TFWav2Vec2SamePadLayer(config.num_conv_pos_embeddings)
self.activation = get_tf_activation(config.feat_extract_activation)
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.conv(hidden_states)
hidden_states = self.padding(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class TFWav2Vec2SamePadLayer(tf.keras.layers.Layer):
def __init__(self, num_conv_pos_embeddings, **kwargs):
super().__init__(**kwargs)
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
def call(self, hidden_states):
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, : -self.num_pad_remove, :]
return hidden_states
class TFWav2Vec2FeatureExtractor(tf.keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs: Any) -> None:
super().__init__(**kwargs)
if config.feat_extract_norm == "group":
conv_layers = [TFWav2Vec2GroupNormConvLayer(config, layer_id=0, name=f"conv_layers.{0}")] + [
TFWav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1, name=f"conv_layers.{i+1}")
for i in range(config.num_feat_extract_layers - 1)
]
elif config.feat_extract_norm == "layer":
conv_layers = [
TFWav2Vec2LayerNormConvLayer(config, layer_id=i, name=f"conv_layers.{i}")
for i in range(config.num_feat_extract_layers)
]
else:
raise ValueError(
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
)
self.conv_layers = conv_layers
def call(self, input_values):
hidden_states = tf.expand_dims(input_values, -1)
for conv_layer in self.conv_layers:
hidden_states = conv_layer(hidden_states)
return hidden_states
class TFWav2Vec2FeatureProjection(tf.keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs):
super().__init__(**kwargs)
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.projection = tf.keras.layers.Dense(
units=config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
bias_initializer="zeros",
name="projection",
)
self.dropout = tf.keras.layers.Dropout(rate=config.feat_proj_dropout)
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
return hidden_states
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with TFBart->TFWav2Vec2
class TFWav2Vec2Attention(tf.keras.layers.Layer):
"""Multi-headed attention from "Attention Is All You Need"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = tf.keras.layers.Dropout(dropout)
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.is_decoder = is_decoder
self.k_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
self.q_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
self.v_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
self.out_proj = tf.keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
def call(
self,
hidden_states: tf.Tensor,
key_value_states: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
attention_mask: Optional[tf.Tensor] = None,
layer_head_mask: Optional[tf.Tensor] = None,
training=False,
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = shape_list(hidden_states)
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = tf.concat([past_key_value[0], key_states], axis=2)
value_states = tf.concat([past_key_value[1], value_states], axis=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape)
key_states = tf.reshape(key_states, proj_shape)
value_states = tf.reshape(value_states, proj_shape)
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
if attention_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
)
attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim))
attn_output = self.out_proj(attn_output)
attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
return attn_output, attn_weights, past_key_value
class TFWav2Vec2FeedForward(tf.keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs):
super().__init__(**kwargs)
self.intermediate_dropout = tf.keras.layers.Dropout(config.activation_dropout)
self.intermediate_dense = tf.keras.layers.Dense(
units=config.intermediate_size,
kernel_initializer=get_initializer(config.initializer_range),
bias_initializer="zeros",
name="intermediate_dense",
)
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
self.output_dense = tf.keras.layers.Dense(
units=config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
bias_initializer="zeros",
name="output_dense",
)
self.output_dropout = tf.keras.layers.Dropout(config.hidden_dropout)
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.intermediate_dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.intermediate_dropout(hidden_states, training=training)
hidden_states = self.output_dense(hidden_states)
hidden_states = self.output_dropout(hidden_states, training=training)
return hidden_states
class TFWav2Vec2EncoderLayer(tf.keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs):
super().__init__(**kwargs)
self.attention = TFWav2Vec2Attention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
name="attention",
)
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout)
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.feed_forward = TFWav2Vec2FeedForward(config, name="feed_forward")
self.final_layer_norm = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="final_layer_norm"
)
def call(
self,
hidden_states: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = False,
training: bool = False,
) -> Tuple[tf.Tensor]:
attn_residual = hidden_states
hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, training=training
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = attn_residual + hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states + self.feed_forward(hidden_states)
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class TFWav2Vec2EncoderLayerStableLayerNorm(tf.keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs):
super().__init__(**kwargs)
self.attention = TFWav2Vec2Attention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
name="attention",
)
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout)
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.feed_forward = TFWav2Vec2FeedForward(config, name="feed_forward")
self.final_layer_norm = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="final_layer_norm"
)
def call(
self,
hidden_states: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = False,
training: bool = False,
) -> Tuple[tf.Tensor]:
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, training=training
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class TFWav2Vec2Encoder(tf.keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.pos_conv_embed = TFWav2Vec2PositionalConvEmbedding(config, name="pos_conv_embed")
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout)
self.layer = [TFWav2Vec2EncoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)]
def call(
self,
hidden_states: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
training: Optional[bool] = False,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
hidden_states = hidden_states * tf.expand_dims(attention_mask, -1)
attention_mask = _expand_mask(attention_mask)
else:
attention_mask = None
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = np.random.uniform(0, 1)
if training and (dropout_probability < self.config.layerdrop): # skip the layer
continue
layer_outputs = layer_module(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
training=training,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class TFWav2Vec2EncoderStableLayerNorm(tf.keras.layers.Layer):
def __init__(self, config: Wav2Vec2Config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.pos_conv_embed = TFWav2Vec2PositionalConvEmbedding(config, name="pos_conv_embed")
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout)
self.layer = [
TFWav2Vec2EncoderLayerStableLayerNorm(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)
]
def call(
self,
hidden_states: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
training: Optional[bool] = False,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
hidden_states = hidden_states * tf.expand_dims(attention_mask, -1)
attention_mask = _expand_mask(attention_mask)
else:
attention_mask = None
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
hidden_states = self.dropout(hidden_states, training=training)
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = np.random.uniform(0, 1)
if training and (dropout_probability < self.config.layerdrop): # skip the layer
continue
layer_outputs = layer_module(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
training=training,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
@keras_serializable
class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
config_class = Wav2Vec2Config
def __init__(self, config: Wav2Vec2Config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.feature_extractor = TFWav2Vec2FeatureExtractor(config, name="feature_extractor")
self.feature_projection = TFWav2Vec2FeatureProjection(config, name="feature_projection")
if config.do_stable_layer_norm:
self.encoder = TFWav2Vec2EncoderStableLayerNorm(config, name="encoder")
else:
self.encoder = TFWav2Vec2Encoder(config, name="encoder")
def build(self, input_shape: tf.TensorShape):
self.masked_spec_embed = self.add_weight(
shape=(self.config.hidden_size,), initializer="uniform", trainable=True, name="masked_spec_embed"
)
super().build(input_shape)
def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor):
"""
Computes the output length of the convolutional layers
"""
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
return input_lengths
def _mask_hidden_states(
self, hidden_states: tf.Tensor, mask_time_indices: Optional[tf.Tensor] = None, training: bool = False
):
"""
Masks extracted features along time axis and/or along feature axis according to `SpecAugment
<https://arxiv.org/abs/1904.08779>`__ .
"""
# `config.apply_spec_augment` can set masking to False
if not getattr(self.config, "apply_spec_augment", True):
return hidden_states
if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states = tf.tensor_scatter_nd_update(hidden_states, mask_time_indices, self.masked_spec_embed)
elif self.config.mask_time_prob > 0 and training:
# generate indices & apply SpecAugment along time axis
batch_size, sequence_length, hidden_size = hidden_states.shape
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
self.config.mask_time_prob,
self.config.mask_time_length,
min_masks=2,
)
hidden_states = tf.tensor_scatter_nd_update(hidden_states, mask_time_indices, self.masked_spec_embed)
# apply SpecAugment along feature axis
if self.config.mask_feature_prob > 0 and training:
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
)
hidden_states = tf.tensor_scatter_nd_update(hidden_states, mask_feature_indices, self.masked_spec_embed)
return hidden_states
def call(
self,
input_values: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
token_type_ids: Optional[tf.Tensor] = None,
position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
inputs_embeds: Optional[tf.Tensor] = None,
output_attentions: Optional[tf.Tensor] = None,
output_hidden_states: Optional[tf.Tensor] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs: Any,
):
inputs = input_values_processing(
func=self.call,
config=self.config,
input_values=input_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
hidden_states = self.feature_extractor(
tf.cast(inputs["input_values"], tf.float32), training=inputs["training"]
)
if inputs["attention_mask"] is not None:
# compute real output lengths according to convolution formula
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1))
attention_mask = tf.sequence_mask(output_lengths, dtype=hidden_states.dtype)
hidden_states = self.feature_projection(hidden_states, training=inputs["training"])
mask_time_indices = kwargs.get("mask_time_indices", None)
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
hidden_states = tf.tensor_scatter_nd_update(hidden_states, mask_time_indices, self.mask_spec_embed)
hidden_states = self._mask_hidden_states(hidden_states)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
hidden_states = encoder_outputs[0]
if not inputs["return_dict"]:
return (hidden_states,) + encoder_outputs[1:]
return TFBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class TFWav2Vec2PreTrainedModel(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = Wav2Vec2Config
base_model_prefix = "wav2vec2"
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
pad_token = 0.0
input_values = tf.convert_to_tensor(np.random.rand(1, 16000), tf.float32)
dummy_inputs = {
"input_values": input_values,
"attention_mask": tf.cast(tf.not_equal(input_values, pad_token), tf.float32),
}
return dummy_inputs
@tf.function
def serving(self, inputs):
output = self.call(input_values=inputs, training=False)
return self.serving_output(output)
WAV_2_VEC_2_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading or saving, resizing the input
embeddings, pruning heads etc.)
This model is also a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ subclass. Use
it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage
and behavior.
.. note::
TF 2.0 models accepts two formats as inputs:
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional arguments.
This second option is useful when using :meth:`tf.keras.Model.fit` method which currently requires having all
the tensors in the first argument of the model call function: :obj:`model(inputs)`.
If you choose this second option, there are three possibilities you can use to gather all the input Tensors in
the first positional argument :
- a single Tensor with :obj:`input_values` only and nothing else: :obj:`model(inputs_ids)`
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
:obj:`model([input_values, attention_mask])` or :obj:`model([input_values, attention_mask, token_type_ids])`
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
:obj:`model({"input_values": input_values, "token_type_ids": token_type_ids})`
Args:
config (:class:`~transformers.Wav2Vec2Config`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
weights.
"""
WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
Args:
input_values (:obj:`np.ndarray`, :obj:`tf.Tensor`, :obj:`List[tf.Tensor]` :obj:`Dict[str, tf.Tensor]` or :obj:`Dict[str, np.ndarray]` and each example must have the shape :obj:`({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
:func:`transformers.PreTrainedTokenizer.__call__` and :func:`transformers.PreTrainedTokenizer.encode` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
1]``:
- 0 corresponds to a `sentence A` token,
- 1 corresponds to a `sentence B` token.
`What are token type IDs? <../glossary.html#token-type-ids>`__
position_ids (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``.
`What are position IDs? <../glossary.html#position-ids>`__
head_mask (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`({0}, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_values` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_values` indices
into associated vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
config will be used instead.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
used instead.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. This
argument can be used in eager mode, in graph mode the value will always be set to True.
training (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
"""
@add_start_docstrings(
"The bare TFWav2Vec2 Model transformer outputing raw hidden-states without any specific head on top.",
WAV_2_VEC_2_START_DOCSTRING,
)
class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel):
def __init__(self, config: Wav2Vec2Config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.config = config
self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2")
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_values: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
token_type_ids: Optional[tf.Tensor] = None,
position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
inputs_embeds: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs: Any,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
"""
Returns:
Example::
>>> from transformers import Wav2Vec2Processor, TFWav2Vec2Model
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
>>> model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
>>> def map_to_array(batch):
>>> speech, _ = sf.read(batch["file"])
>>> batch["speech"] = speech
>>> return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = processor(ds["speech"][0], return_tensors="tf").input_values # Batch size 1
>>> hidden_states = model(input_values).last_hidden_state
"""
inputs = input_values_processing(
func=self.call,
config=self.config,
input_values=input_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
inputs["output_hidden_states"] = (
inputs["output_hidden_states"] if inputs["output_hidden_states"] else self.config.output_hidden_states
)
inputs["output_attentions"] = (
inputs["output_attentions"] if inputs["output_attentions"] else self.config.output_attentions
)
inputs["return_dict"] = inputs["return_dict"] if inputs["return_dict"] else self.config.return_dict
outputs = self.wav2vec2(
input_values=inputs["input_values"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
def serving_output(self, output):
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFBaseModelOutput(last_hidden_state=output.last_hidden_state, hidden_states=hs, attentions=attns)
@add_start_docstrings(
"""TFWav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """,
WAV_2_VEC_2_START_DOCSTRING,
)
class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
def __init__(self, config: Wav2Vec2Config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2")
self.dropout = tf.keras.layers.Dropout(config.final_dropout)
self.lm_head = tf.keras.layers.Dense(config.vocab_size, name="lm_head")
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature extractor so that its parameter
will not be updated during training.
"""
self.wav2vec2.feature_extractor.trainable = False
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_values: tf.Tensor,
attention_mask: Optional[tf.Tensor] = None,
token_type_ids: Optional[tf.Tensor] = None,
position_ids: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
inputs_embeds: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = None,
labels: Optional[tf.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs: Any,
) -> Union[TFCausalLMOutput, Tuple[tf.Tensor]]:
r"""
labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_values`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
Returns:
Example::
>>> import tensorflow as tf
>>> from transformers import Wav2Vec2Processor, TFWav2Vec2ForCTC
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
>>> model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
>>> def map_to_array(batch):
>>> speech, _ = sf.read(batch["file"])
>>> batch["speech"] = speech
>>> return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = processor(ds["speech"][0], return_tensors="tf").input_values # Batch size 1
>>> logits = model(input_values).logits >>> predicted_ids = tf.argmax(logits, axis=-1)
>>> transcription = processor.decode(predicted_ids[0])
>>> # compute loss
>>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST"
>>> # wrap processor as target processor to encode labels
>>> with processor.as_target_processor():
>>> labels = processor(transcription, return_tensors="tf").input_values
>>> loss = model(input_values, labels=labels).loss
"""
inputs = input_values_processing(
func=self.call,
config=self.config,
input_values=input_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.wav2vec2(
input_values=inputs["input_values"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states, training=inputs["training"])
logits = self.lm_head(hidden_states)
if labels is not None:
attention_mask = (
inputs["attention_mask"]
if inputs["attention_mask"] is not None
else tf.ones_like(inputs["input_values"], dtype=tf.float32)
)
input_lengths = self.wav2vec2._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, axis=-1))
# assuming that padded tokens are filled with -100
# when not being attended to
labels_mask = tf.cast(labels >= 0, tf.int32)
target_lengths = tf.reduce_sum(labels_mask, axis=-1)
flattened_labels = tf.boolean_mask(labels, labels_mask)
flattened_labels = tf.reshape(flattened_labels, [labels.shape[0], -1])
loss = tf.nn.ctc_loss(
logits=logits,
labels=flattened_labels,
logit_length=input_lengths,
label_length=target_lengths,
blank_index=self.config.pad_token_id,
logits_time_major=False,
)
if self.config.ctc_loss_reduction == "sum":
loss = tf.reduce_sum(loss)
if self.config.ctc_loss_reduction == "mean":
loss = tf.reduce_mean(loss)
else:
loss = None
if not inputs["return_dict"]:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TFCausalLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFCausalLMOutput(logits=output.logits, hidden_states=hs, attentions=attns)
...@@ -510,14 +510,6 @@ class Wav2Vec2FeedForward(nn.Module): ...@@ -510,14 +510,6 @@ class Wav2Vec2FeedForward(nn.Module):
return hidden_states return hidden_states
class Wav2Vec2Output(nn.Module):
def __init__(self, config):
super().__init__()
def forward(self, hidden_states, input_tensor):
return hidden_states
class Wav2Vec2EncoderLayer(nn.Module): class Wav2Vec2EncoderLayer(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
......
...@@ -1647,6 +1647,32 @@ class TFTransfoXLPreTrainedModel: ...@@ -1647,6 +1647,32 @@ class TFTransfoXLPreTrainedModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFWav2Vec2ForCTC:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFWav2Vec2Model:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFWav2Vec2PreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -445,6 +445,8 @@ class TFModelTesterMixin: ...@@ -445,6 +445,8 @@ class TFModelTesterMixin:
for name, key in self._prepare_for_class(inputs_dict, model_class).items(): for name, key in self._prepare_for_class(inputs_dict, model_class).items():
if type(key) == bool: if type(key) == bool:
pt_inputs_dict[name] = key pt_inputs_dict[name] = key
elif name == "input_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else: else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
...@@ -455,6 +457,7 @@ class TFModelTesterMixin: ...@@ -455,6 +457,7 @@ class TFModelTesterMixin:
with torch.no_grad(): with torch.no_grad():
pto = pt_model(**pt_inputs_dict) pto = pt_model(**pt_inputs_dict)
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False) tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False)
tf_hidden_states = tfo[0].numpy() tf_hidden_states = tfo[0].numpy()
pt_hidden_states = pto[0].numpy() pt_hidden_states = pto[0].numpy()
...@@ -486,6 +489,8 @@ class TFModelTesterMixin: ...@@ -486,6 +489,8 @@ class TFModelTesterMixin:
if type(key) == bool: if type(key) == bool:
key = np.array(key, dtype=bool) key = np.array(key, dtype=bool)
pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long) pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long)
elif name == "input_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else: else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
# need to rename encoder-decoder "inputs" for PyTorch # need to rename encoder-decoder "inputs" for PyTorch
...@@ -1061,7 +1066,7 @@ class TFModelTesterMixin: ...@@ -1061,7 +1066,7 @@ class TFModelTesterMixin:
def test_lm_head_model_random_no_beam_search_generate(self): def test_lm_head_model_random_no_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] input_ids = inputs_dict.get("input_ids", None)
# iterate over all generative models # iterate over all generative models
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
...@@ -1097,7 +1102,7 @@ class TFModelTesterMixin: ...@@ -1097,7 +1102,7 @@ class TFModelTesterMixin:
def test_lm_head_model_random_beam_search_generate(self): def test_lm_head_model_random_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] input_ids = inputs_dict.get("input_ids", None)
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
model = model_class(config) model = model_class(config)
......
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import math
import unittest
import numpy as np
from transformers import Wav2Vec2Config, is_tf_available
from transformers.testing_utils import require_datasets, require_soundfile, require_tf, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available():
import tensorflow as tf
from transformers import TFWav2Vec2ForCTC, TFWav2Vec2Model, Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_tf_wav2vec2 import _compute_mask_indices
@require_tf
class TFWav2Vec2ModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=1024,
is_training=False,
hidden_size=16,
feat_extract_norm="group",
feat_extract_dropout=0.0,
feat_extract_activation="gelu",
conv_dim=(32, 32, 32),
conv_stride=(4, 4, 4),
conv_kernel=(8, 8, 8),
conv_bias=False,
num_conv_pos_embeddings=16,
num_conv_pos_embedding_groups=2,
num_hidden_layers=4,
num_attention_heads=2,
hidden_dropout_prob=0.1, # this is most likely not correctly set yet
intermediate_size=20,
layer_norm_eps=1e-5,
hidden_act="gelu",
initializer_range=0.02,
vocab_size=32,
do_stable_layer_norm=False,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.hidden_size = hidden_size
self.feat_extract_norm = feat_extract_norm
self.feat_extract_dropout = feat_extract_dropout
self.feat_extract_activation = feat_extract_activation
self.conv_dim = conv_dim
self.conv_stride = conv_stride
self.conv_kernel = conv_kernel
self.conv_bias = conv_bias
self.num_conv_pos_embeddings = num_conv_pos_embeddings
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_dropout_prob = hidden_dropout_prob
self.intermediate_size = intermediate_size
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm
self.scope = scope
output_seq_length = self.seq_length
for kernel, stride in zip(self.conv_kernel, self.conv_stride):
output_seq_length = (output_seq_length - (kernel - 1)) / stride
self.output_seq_length = int(math.ceil(output_seq_length))
self.encoder_seq_length = self.output_seq_length
def prepare_config_and_inputs(self):
input_values = tf.cast(ids_tensor([self.batch_size, self.seq_length], 32768), tf.float32) / 32768.0
attention_mask = tf.ones_like(input_values)
config = Wav2Vec2Config(
hidden_size=self.hidden_size,
feat_extract_norm=self.feat_extract_norm,
feat_extract_dropout=self.feat_extract_dropout,
feat_extract_activation=self.feat_extract_activation,
conv_dim=self.conv_dim,
conv_stride=self.conv_stride,
conv_kernel=self.conv_kernel,
conv_bias=self.conv_bias,
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
hidden_dropout_prob=self.hidden_dropout_prob,
intermediate_size=self.intermediate_size,
layer_norm_eps=self.layer_norm_eps,
hidden_act=self.hidden_act,
initializer_range=self.initializer_range,
vocab_size=self.vocab_size,
do_stable_layer_norm=self.do_stable_layer_norm,
)
return config, input_values, attention_mask
def create_and_check_model(self, config, input_values, attention_mask):
model = TFWav2Vec2Model(config)
result = model(input_values, attention_mask=attention_mask)
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
)
def create_and_check_batch_inference(self, config, input_values, *args):
# test does not pass for models making use of `group_norm`
# check: https://github.com/pytorch/fairseq/issues/3227
config.layerdrop = 0.0
model = TFWav2Vec2Model(config)
input_values = input_values[:3]
attention_mask = tf.ones_like(input_values)
input_lengths = tf.constant([input_values.shape[-1] // i for i in [4, 2, 1]])
length_mask = tf.sequence_mask(input_lengths, dtype=tf.float32)
# convert values that are over input_lengths to padding
input_values = input_values * length_mask
attention_mask = attention_mask * length_mask
batch_outputs = model(input_values, attention_mask=attention_mask, training=False).last_hidden_state
for i in range(input_values.shape[0]):
input_slice = input_values[i : i + 1, : input_lengths[i]]
output = model(input_slice, training=False).last_hidden_state
batch_output = batch_outputs[i : i + 1, : output.shape[1]]
self.parent.assertTrue(np.allclose(output, batch_output, atol=1e-3))
def check_ctc_loss(self, config, input_values, *args):
model = TFWav2Vec2ForCTC(config)
input_values = input_values[:3]
attention_mask = tf.ones_like(input_values)
input_lengths = tf.constant([input_values.shape[-1] // i for i in [4, 2, 1]])
max_length_labels = model.wav2vec2._get_feat_extract_output_lengths(input_lengths)
labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size)
length_mask = tf.sequence_mask(input_lengths, dtype=tf.float32)
# convert values that are over input_lengths to padding
input_values = input_values * length_mask
attention_mask = attention_mask * length_mask
model.config.ctc_loss_reduction = "sum"
sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss
model.config.ctc_loss_reduction = "mean"
mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss
self.parent.assertTrue(abs(labels.shape[0] * mean_loss - sum_loss) < 1e-2)
def check_training(self, config, input_values, *args):
model = TFWav2Vec2ForCTC(config)
# freeze feature encoder
model.freeze_feature_extractor()
input_values = input_values[:3]
input_lengths = tf.constant([input_values.shape[-1] // i for i in [4, 2, 1]])
max_length_labels = model.wav2vec2._get_feat_extract_output_lengths(input_lengths)
labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size)
length_mask = tf.sequence_mask(input_lengths, dtype=tf.float32)
input_values = input_values * length_mask
pad_size = max(max_length_labels) - labels.shape[1]
labels = tf.pad(labels, ((0, 0), (0, pad_size)), constant_values=-100)
loss = model(input_values, labels=labels, training=True).loss
self.parent.assertFalse(tf.math.is_inf(loss))
def prepare_config_and_inputs_for_common(self):
config, input_values, attention_mask = self.prepare_config_and_inputs()
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
return config, inputs_dict
@require_tf
class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFWav2Vec2Model, TFWav2Vec2ForCTC) if is_tf_available() else ()
test_resize_embeddings = False
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFWav2Vec2ModelTester(self)
self.config_tester = ConfigTester(self, config_class=Wav2Vec2Config, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
# overwrite because input_values != input_ids
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["input_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
# overwrite because input_values != input_ids
def test_keyword_and_dict_args(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
inputs = self._prepare_for_class(inputs_dict, model_class)
outputs_dict = model(inputs)
inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
input_values = inputs_keywords.pop("input_values", None)
outputs_keywords = model(input_values, **inputs_keywords)
output_dict = outputs_dict[0].numpy()
output_keywords = outputs_keywords[0].numpy()
self.assertLess(np.sum(np.abs(output_dict - output_keywords)), 1e-6)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_hidden_states_output(config, inputs_dict, model_class):
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
hidden_states = outputs.hidden_states
self.assertEqual(config.output_attentions, False)
self.assertEqual(len(hidden_states), expected_num_layers)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[self.model_tester.output_seq_length, self.model_tester.hidden_size],
)
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(config, inputs_dict, model_class)
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(config, inputs_dict, model_class)
def test_ctc_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs)
def test_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs)
# Wav2Vec2 has no inputs_embeds
def test_inputs_embeds(self):
pass
# Wav2Vec2 cannot resize token embeddings
# since it has no tokens embeddings
def test_resize_tokens_embeddings(self):
pass
# Wav2Vec2 has no inputs_embeds
# and thus the `get_input_embeddings` fn
# is not implemented
def test_model_common_attributes(self):
pass
@slow
def test_model_from_pretrained(self):
model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsNotNone(model)
@require_tf
class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFWav2Vec2Model, TFWav2Vec2ForCTC) if is_tf_available() else ()
test_resize_embeddings = False
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFWav2Vec2ModelTester(
self,
conv_stride=(3, 3, 3),
feat_extract_norm="layer",
do_stable_layer_norm=True,
scope="robust",
)
self.config_tester = ConfigTester(self, config_class=Wav2Vec2Config, hidden_size=37)
# overwrite because input_values != input_ids
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["input_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
# overwrite because input_values != input_ids
def test_keyword_and_dict_args(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
inputs = self._prepare_for_class(inputs_dict, model_class)
outputs_dict = model(inputs)
inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
input_values = inputs_keywords.pop("input_values", None)
outputs_keywords = model(input_values, **inputs_keywords)
output_dict = outputs_dict[0].numpy()
output_keywords = outputs_keywords[0].numpy()
self.assertLess(np.sum(np.abs(output_dict - output_keywords)), 1e-6)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_hidden_states_output(config, inputs_dict, model_class):
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
hidden_states = outputs.hidden_states
self.assertEqual(config.output_attentions, False)
self.assertEqual(len(hidden_states), expected_num_layers)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[self.model_tester.output_seq_length, self.model_tester.hidden_size],
)
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(config, inputs_dict, model_class)
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(config, inputs_dict, model_class)
def test_batched_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_batch_inference(*config_and_inputs)
def test_ctc_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs)
def test_train(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_training(*config_and_inputs)
# Wav2Vec2 has no inputs_embeds
def test_inputs_embeds(self):
pass
# Wav2Vec2 cannot resize token embeddings
# since it has no tokens embeddings
def test_resize_tokens_embeddings(self):
pass
# Wav2Vec2 has no inputs_embeds
# and thus the `get_input_embeddings` fn
# is not implemented
def test_model_common_attributes(self):
pass
@slow
def test_model_from_pretrained(self):
model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsNotNone(model)
@require_tf
class TFWav2Vec2UtilsTest(unittest.TestCase):
def test_compute_mask_indices(self):
batch_size = 4
sequence_length = 60
mask_prob = 0.5
mask_length = 1
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
self.assertListEqual(
tf.reduce_sum(mask, -1).numpy().tolist(), [mask_prob * sequence_length for _ in range(batch_size)]
)
def test_compute_mask_indices_overlap(self):
batch_size = 4
sequence_length = 80
mask_prob = 0.5
mask_length = 4
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
for batch_sum in tf.reduce_sum(mask, -1):
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
@require_tf
@slow
@require_datasets
@require_soundfile
class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):
from datasets import load_dataset
import soundfile as sf
ids = [f"1272-141231-000{i}" for i in range(num_samples)]
# map files to raw
def map_to_array(batch):
speech, _ = sf.read(batch["file"])
batch["speech"] = speech
return batch
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
ds = ds.filter(lambda x: x["id"] in ids).sort("id").map(map_to_array)
return ds["speech"][:num_samples]
def test_inference_ctc_normal(self):
model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
input_speech = self._load_datasamples(1)
input_values = processor(input_speech, return_tensors="tf", sampling_rate=16000).input_values
logits = model(input_values).logits
predicted_ids = tf.argmax(logits, axis=-1)
predicted_trans = processor.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_ctc_normal_batched(self):
model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
input_speech = self._load_datasamples(2)
input_values = processor(
input_speech, return_tensors="tf", padding=True, truncation=True, sampling_rate=16000
).input_values
logits = model(input_values).logits
predicted_ids = tf.argmax(logits, axis=-1)
predicted_trans = processor.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight lowing cloth that was the only garment he wore",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_ctc_robust_batched(self):
model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
input_speech = self._load_datasamples(4)
inputs = processor(input_speech, return_tensors="tf", padding=True, truncation=True)
input_values = inputs.input_values
attention_mask = inputs.attention_mask
logits = model(input_values, attention_mask=attention_mask).logits
predicted_ids = tf.argmax(logits, dim=-1)
predicted_trans = processor.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
"the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
"his instant panic was followed by a small sharp blow high on his chest",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
...@@ -126,6 +126,7 @@ IGNORE_NON_AUTO_CONFIGURED = [ ...@@ -126,6 +126,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
"VisualBertForVisualReasoning", "VisualBertForVisualReasoning",
"VisualBertForQuestionAnswering", "VisualBertForQuestionAnswering",
"VisualBertForMultipleChoice", "VisualBertForMultipleChoice",
"TFWav2Vec2ForCTC",
] ]
# This is to make sure the transformers module imported is the one in the repo. # This is to make sure the transformers module imported is the one in the repo.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment