"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "f861504466bf19c607e3c8407be6194a565afc00"
Unverified Commit ad1d3c4d authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Make TF Wav2Vec2 outputs the same as PT's version (#15530)



* fix outputs

* fix for CTC

* fix doc

* make style
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 131e2584
......@@ -16,6 +16,7 @@
import inspect
import warnings
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
......@@ -43,6 +44,9 @@ from .configuration_wav2vec2 import Wav2Vec2Config
logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 2
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
_CONFIG_FOR_DOC = "Wav2Vec2Config"
_TOKENIZER_FOR_DOC = "Wav2Vec2Tokenizer"
......@@ -58,6 +62,35 @@ TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
LARGE_NEGATIVE = -1e8
@dataclass
class TFWav2Vec2BaseModelOutput(ModelOutput):
"""
Output type of [`TFWav2Vec2BaseModelOutput`], with potential hidden states and attentions.
Args:
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
extract_features (`tf.Tensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
Sequence of extracted feature vectors of the last convolutional layer of the model.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: tf.Tensor = None
extract_features: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
def input_values_processing(func, config, input_values, **kwargs):
"""
Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
......@@ -707,10 +740,10 @@ class TFWav2Vec2FeatureProjection(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(rate=config.feat_proj_dropout)
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(hidden_states)
norm_hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(norm_hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
return hidden_states
return hidden_states, norm_hidden_states
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with TFBart->TFWav2Vec2
......@@ -1222,19 +1255,20 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
kwargs_call=kwargs,
)
hidden_states = self.feature_extractor(
extract_features = self.feature_extractor(
tf.cast(inputs["input_values"], tf.float32), training=inputs["training"]
)
# extract_features = tf.transpose(extract_features, perm=(0, 2, 1))
if inputs["attention_mask"] is not None:
# compute real output lengths according to convolution formula
output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1))
attention_mask = tf.sequence_mask(
output_lengths, maxlen=shape_list(hidden_states)[1], dtype=hidden_states.dtype
output_lengths, maxlen=shape_list(extract_features)[1], dtype=extract_features.dtype
)
hidden_states = self.feature_projection(hidden_states, training=inputs["training"])
hidden_states, extract_features = self.feature_projection(extract_features, training=inputs["training"])
mask_time_indices = kwargs.get("mask_time_indices", None)
if inputs["training"]:
......@@ -1251,10 +1285,11 @@ class TFWav2Vec2MainLayer(tf.keras.layers.Layer):
hidden_states = encoder_outputs[0]
if not inputs["return_dict"]:
return (hidden_states,) + encoder_outputs[1:]
return (hidden_states, extract_features) + encoder_outputs[1:]
return TFBaseModelOutput(
return TFWav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
......@@ -1635,7 +1670,7 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
loss = None
if not inputs["return_dict"]:
output = (logits,) + outputs[1:]
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return TFCausalLMOutput(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment