Unverified Commit e3139ad3 authored by Sreyan Ghosh's avatar Sreyan Ghosh Committed by GitHub
Browse files

fixed calculation of ctc loss in TFWav2Vec2ForCTC (#18014)


Co-authored-by: default avatarSreyan-G@NVIDIA <sreyang@nvidia.com>
parent 96d833b2
......@@ -25,7 +25,13 @@ import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
from ...modeling_tf_utils import (
TFPreTrainedModel,
booleans_processing,
get_initializer,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax
from ...utils import (
ModelOutput,
......@@ -1580,6 +1586,7 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
"""
self.wav2vec2.feature_extractor.trainable = False
@unpack_inputs
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
......@@ -1702,6 +1709,8 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
loss = tf.reduce_sum(loss)
if self.config.ctc_loss_reduction == "mean":
loss = tf.reduce_mean(loss)
loss = tf.reshape(loss, (1,))
else:
loss = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment