"...centerface_pytorch.git" did not exist on "718c73f033aad4a2a915eb94904389b9e0622f61"
Unverified Commit e3139ad3 authored by Sreyan Ghosh's avatar Sreyan Ghosh Committed by GitHub
Browse files

fixed calculation of ctc loss in TFWav2Vec2ForCTC (#18014)


Co-authored-by: default avatarSreyan-G@NVIDIA <sreyang@nvidia.com>
parent 96d833b2
...@@ -25,7 +25,13 @@ import tensorflow as tf ...@@ -25,7 +25,13 @@ import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable from ...modeling_tf_utils import (
TFPreTrainedModel,
booleans_processing,
get_initializer,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list, stable_softmax from ...tf_utils import shape_list, stable_softmax
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
...@@ -1580,6 +1586,7 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel): ...@@ -1580,6 +1586,7 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
""" """
self.wav2vec2.feature_extractor.trainable = False self.wav2vec2.feature_extractor.trainable = False
@unpack_inputs
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -1702,6 +1709,8 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel): ...@@ -1702,6 +1709,8 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
loss = tf.reduce_sum(loss) loss = tf.reduce_sum(loss)
if self.config.ctc_loss_reduction == "mean": if self.config.ctc_loss_reduction == "mean":
loss = tf.reduce_mean(loss) loss = tf.reduce_mean(loss)
loss = tf.reshape(loss, (1,))
else: else:
loss = None loss = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment