"vscode:/vscode.git/clone" did not exist on "391177441b133645c02181b57370ab12f71b88c4"
Unverified Commit 463226e2 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Improve error messaging for ASR pipeline. (#19570)

* Improve error messaging for ASR pipeline.

- Raise error early (in `_sanitize`) so users don't waste time trying to
  run queries with invalid params.

- Fix the error was after using `config.inputs_to_logits_ratio` so our
  check was masked by the failing property does not exist.

- Added some manual check on s2t for the error message.
  No non ctc model seems to be used by the default runner (they are all
  skipped).

* Removing pdb.

* Stop the early error it doesn't really work :(.
parent 5ef21866
...@@ -250,6 +250,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -250,6 +250,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline") raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
if chunk_length_s: if chunk_length_s:
if self.type not in {"ctc", "ctc_with_lm"}:
raise ValueError(
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models"
)
if stride_length_s is None: if stride_length_s is None:
stride_length_s = chunk_length_s / 6 stride_length_s = chunk_length_s / 6
...@@ -264,10 +268,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -264,10 +268,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to) stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to) stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
if self.type not in {"ctc", "ctc_with_lm"}:
raise ValueError(
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models"
)
if chunk_len < stride_left + stride_right: if chunk_len < stride_left + stride_right:
raise ValueError("Chunk length must be superior to stride length") raise ValueError("Chunk length must be superior to stride length")
......
...@@ -118,9 +118,15 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -118,9 +118,15 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
}, },
) )
else: else:
# Non CTC models cannot use chunk_length
with self.assertRaises(ValueError) as v:
outputs = speech_recognizer(audio, chunk_length_s=10)
self.assertEqual(v.exception, "")
# Non CTC models cannot use return_timestamps # Non CTC models cannot use return_timestamps
with self.assertRaises(ValueError): with self.assertRaises(ValueError) as v:
outputs = speech_recognizer(audio, return_timestamps="char") outputs = speech_recognizer(audio, return_timestamps="char")
self.assertEqual(v.exception, "")
@require_torch @require_torch
@slow @slow
...@@ -138,6 +144,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -138,6 +144,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
waveform = np.tile(np.arange(1000, dtype=np.float32), 34) waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = speech_recognizer(waveform) output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "(Applaudissements)"}) self.assertEqual(output, {"text": "(Applaudissements)"})
with self.assertRaises(ValueError) as v:
_ = speech_recognizer(waveform, chunk_length_s=10)
self.assertEqual(
str(v.exception),
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models",
)
# Non CTC models cannot use return_timestamps
with self.assertRaises(ValueError) as v:
_ = speech_recognizer(waveform, return_timestamps="char")
self.assertEqual(str(v.exception), "We cannot return_timestamps yet on non-ctc models !")
@require_torch @require_torch
def test_small_model_pt_seq2seq(self): def test_small_model_pt_seq2seq(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment