Unverified Commit a3dbbc34 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Add `decoder_kwargs` to send to LM on asr pipeline. (#15646)


Co-authored-by: default avatarGiuseppe Attanasio <giuseppeattanasio6@gmail.com>
Co-authored-by: default avatarGiuseppe Attanasio <giuseppeattanasio6@gmail.com>
parent cdf19c50
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Dict, Optional, Union
import numpy as np import numpy as np
...@@ -180,7 +180,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -180,7 +180,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if "stride_length_s" in kwargs: if "stride_length_s" in kwargs:
preprocess_params["stride_length_s"] = kwargs["stride_length_s"] preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
return preprocess_params, {}, {} postprocess_params = {}
if "decoder_kwargs" in kwargs:
postprocess_params["decoder_kwargs"] = kwargs["decoder_kwargs"]
return preprocess_params, {}, postprocess_params
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None): def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
if isinstance(inputs, str): if isinstance(inputs, str):
...@@ -319,7 +323,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -319,7 +323,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
extra = model_inputs extra = model_inputs
return {"is_last": is_last, **out, **extra} return {"is_last": is_last, **out, **extra}
def postprocess(self, model_outputs): def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None):
if self.type == "ctc_with_lm": if self.type == "ctc_with_lm":
final_logits = [] final_logits = []
for outputs in model_outputs: for outputs in model_outputs:
...@@ -334,9 +338,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -334,9 +338,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
right_n = total_n - right right_n = total_n - right
logits = logits[:, left:right_n] logits = logits[:, left:right_n]
final_logits.append(logits) final_logits.append(logits)
if decoder_kwargs is None:
decoder_kwargs = {}
logits = np.concatenate(final_logits, axis=1) logits = np.concatenate(final_logits, axis=1)
logits = logits.squeeze(0) logits = logits.squeeze(0)
text = self.decoder.decode_beams(logits)[0][0] text = self.decoder.decode_beams(logits, **decoder_kwargs)[0][0]
else: else:
skip_special_tokens = self.type != "ctc" skip_special_tokens = self.type != "ctc"
tokens = np.concatenate([outputs["tokens"].numpy() for outputs in model_outputs], axis=-1) tokens = np.concatenate([outputs["tokens"].numpy() for outputs in model_outputs], axis=-1)
......
...@@ -365,10 +365,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -365,10 +365,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
audio_tiled = np.tile(audio, n_repeats) audio_tiled = np.tile(audio, n_repeats)
output = speech_recognizer([audio_tiled], batch_size=2) output = speech_recognizer([audio_tiled], batch_size=2)
self.assertEqual(output, [{"text": ANY(str)}]) self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "<s> <s") self.assertEqual(output[0]["text"][:6], "<s> <s")
# Making sure the argument are passed to the decoder
# Since no change happens in the result, check the error comes from
# the `decode_beams` function.
with self.assertRaises(TypeError) as e:
output = speech_recognizer([audio_tiled], decoder_kwargs={"num_beams": 2})
self.assertContains(e.msg, "TypeError: decode_beams() got an unexpected keyword argument 'num_beams'")
output = speech_recognizer([audio_tiled], decoder_kwargs={"beam_width": 2})
@require_torch @require_torch
@require_pyctcdecode @require_pyctcdecode
def test_with_local_lm_fast(self): def test_with_local_lm_fast(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment