"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "22b41b3f8a5cdb37e686d18d8d9a24eb98a331ec"
Unverified Commit ad1d3c4d authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Make TF Wav2Vec2 outputs the same as PT's version (#15530)



* fix outputs

* fix for CTC

* fix doc

* make style
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 131e2584
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import inspect import inspect
import warnings import warnings
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -43,6 +44,9 @@ from .configuration_wav2vec2 import Wav2Vec2Config ...@@ -43,6 +44,9 @@ from .configuration_wav2vec2 import Wav2Vec2Config
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 2
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h" _CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
_CONFIG_FOR_DOC = "Wav2Vec2Config" _CONFIG_FOR_DOC = "Wav2Vec2Config"
_TOKENIZER_FOR_DOC = "Wav2Vec2Tokenizer" _TOKENIZER_FOR_DOC = "Wav2Vec2Tokenizer"
...@@ -58,6 +62,35 @@ TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -58,6 +62,35 @@ TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
LARGE_NEGATIVE = -1e8 LARGE_NEGATIVE = -1e8
@dataclass
class TFWav2Vec2BaseModelOutput(ModelOutput):
"""
Output type of [`TFWav2Vec2BaseModelOutput`], with potential hidden states and attentions.
Args:
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
extract_features (`tf.Tensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
Sequence of extracted feature vectors of the last convolutional layer of the model.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: tf.Tensor = None
extract_features: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
def input_values_processing(func, config, input_values, **kwargs): def input_values_processing(func, config, input_values, **kwargs):
""" """
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
...@@ -707,10 +740,10 @@ class TFWav2Vec2FeatureProjection(tf.keras.layers.Layer): ...@@ -707,10 +740,10 @@ class TFWav2Vec2FeatureProjection(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(rate=config.feat_proj_dropout) self.dropout = tf.keras.layers.Dropout(rate=config.feat_proj_dropout)
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.layer_norm(hidden_states) norm_hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(hidden_states) hidden_states = self.projection(norm_hidden_states)
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)
return hidden_states return hidden_states, norm_hidden_states
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with TFBart->TFWav2Vec2 # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with TFBart->TFWav2Vec2
...@@ -1222,19 +1255,20 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer): ...@@ -1222,19 +1255,20 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
kwargs_call=kwargs, kwargs_call=kwargs,
) )
hidden_states = self.feature_extractor( extract_features = self.feature_extractor(
tf.cast(inputs["input_values"], tf.float32), training=inputs["training"] tf.cast(inputs["input_values"], tf.float32), training=inputs["training"]
) )
# extract_features = tf.transpose(extract_features, perm=(0, 2, 1))
if inputs["attention_mask"] is not None: if inputs["attention_mask"] is not None:
# compute real output lengths according to convolution formula # compute real output lengths according to convolution formula
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1)) output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1))
attention_mask = tf.sequence_mask( attention_mask = tf.sequence_mask(
output_lengths, maxlen=shape_list(hidden_states)[1], dtype=hidden_states.dtype output_lengths, maxlen=shape_list(extract_features)[1], dtype=extract_features.dtype
) )
hidden_states = self.feature_projection(hidden_states, training=inputs["training"]) hidden_states, extract_features = self.feature_projection(extract_features, training=inputs["training"])
mask_time_indices = kwargs.get("mask_time_indices", None) mask_time_indices = kwargs.get("mask_time_indices", None)
if inputs["training"]: if inputs["training"]:
...@@ -1251,10 +1285,11 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer): ...@@ -1251,10 +1285,11 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
hidden_states = encoder_outputs[0] hidden_states = encoder_outputs[0]
if not inputs["return_dict"]: if not inputs["return_dict"]:
return (hidden_states,) + encoder_outputs[1:] return (hidden_states, extract_features) + encoder_outputs[1:]
return TFBaseModelOutput( return TFWav2Vec2BaseModelOutput(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
) )
...@@ -1635,7 +1670,7 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel): ...@@ -1635,7 +1670,7 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
loss = None loss = None
if not inputs["return_dict"]: if not inputs["return_dict"]:
output = (logits,) + outputs[1:] output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TFCausalLMOutput( return TFCausalLMOutput(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment