Unverified Commit 2fbbcf50 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Fix M4T for ASR pipeline (#32296)

* tentative fix

* do the same for M4T
parent 084b5094
......@@ -3154,6 +3154,7 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel):
"""
text_decoder_input_ids = kwargs.pop("decoder_input_ids", None)
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
input_features = input_features if input_features is not None else kwargs.pop("inputs")
if tgt_lang is not None:
inputs = kwargs.get("input_embeds") if input_features is None else input_features
inputs = (
......
......@@ -3422,6 +3422,7 @@ class SeamlessM4Tv2ForSpeechToText(SeamlessM4Tv2PreTrainedModel):
"""
text_decoder_input_ids = kwargs.pop("decoder_input_ids", None)
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
input_features = input_features if input_features is not None else kwargs.pop("inputs")
if tgt_lang is not None:
inputs = kwargs.get("input_embeds") if input_features is None else input_features
inputs = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment