"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c18b4fbe9f8d05c96deff23ee92037912f68a50c"
Unverified Commit a3dbbc34 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Add `decoder_kwargs` to send to LM on asr pipeline. (#15646)


Co-authored-by: default avatarGiuseppe Attanasio <giuseppeattanasio6@gmail.com>
Co-authored-by: default avatarGiuseppe Attanasio <giuseppeattanasio6@gmail.com>
parent cdf19c50
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Dict, Optional, Union
import numpy as np
......@@ -180,7 +180,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if "stride_length_s" in kwargs:
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
return preprocess_params, {}, {}
postprocess_params = {}
if "decoder_kwargs" in kwargs:
postprocess_params["decoder_kwargs"] = kwargs["decoder_kwargs"]
return preprocess_params, {}, postprocess_params
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
if isinstance(inputs, str):
......@@ -319,7 +323,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
extra = model_inputs
return {"is_last": is_last, **out, **extra}
def postprocess(self, model_outputs):
def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None):
if self.type == "ctc_with_lm":
final_logits = []
for outputs in model_outputs:
......@@ -334,9 +338,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
right_n = total_n - right
logits = logits[:, left:right_n]
final_logits.append(logits)
if decoder_kwargs is None:
decoder_kwargs = {}
logits = np.concatenate(final_logits, axis=1)
logits = logits.squeeze(0)
text = self.decoder.decode_beams(logits)[0][0]
text = self.decoder.decode_beams(logits, **decoder_kwargs)[0][0]
else:
skip_special_tokens = self.type != "ctc"
tokens = np.concatenate([outputs["tokens"].numpy() for outputs in model_outputs], axis=-1)
......
......@@ -365,10 +365,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
audio_tiled = np.tile(audio, n_repeats)
output = speech_recognizer([audio_tiled], batch_size=2)
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "<s> <s")
# Making sure the argument are passed to the decoder
# Since no change happens in the result, check the error comes from
# the `decode_beams` function.
with self.assertRaises(TypeError) as e:
output = speech_recognizer([audio_tiled], decoder_kwargs={"num_beams": 2})
self.assertContains(e.msg, "TypeError: decode_beams() got an unexpected keyword argument 'num_beams'")
output = speech_recognizer([audio_tiled], decoder_kwargs={"beam_width": 2})
@require_torch
@require_pyctcdecode
def test_with_local_lm_fast(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment