Unverified Commit dedd1116 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[ASR Pipeline] Clarify return timestamps (#25344)

* [ASR Pipeline] Clarify return timestamps

* fix indentation

* fix ctc check

* fix ctc error message!

* fix test

* fix other test

* add new tests

* final comment
parent 5ea2595e
...@@ -156,8 +156,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -156,8 +156,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
feature_extractor ([`SequenceFeatureExtractor`]): feature_extractor ([`SequenceFeatureExtractor`]):
The feature extractor that will be used by the pipeline to encode waveform for the model. The feature extractor that will be used by the pipeline to encode waveform for the model.
chunk_length_s (`float`, *optional*, defaults to 0): chunk_length_s (`float`, *optional*, defaults to 0):
The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default). Only The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default).
available for CTC models, e.g. [`Wav2Vec2ForCTC`].
<Tip> <Tip>
...@@ -247,14 +246,29 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -247,14 +246,29 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
inference to provide more context to the model). Only use `stride` with CTC models. inference to provide more context to the model). Only use `stride` with CTC models.
return_timestamps (*optional*, `str`): return_timestamps (*optional*, `str` or `bool`):
Only available for pure CTC models. If set to `"char"`, the pipeline will return timestamps along the Only available for pure CTC models (Wav2Vec2, HuBERT, etc) and the Whisper model. Not available for
text for every character in the text. For instance if you get `[{"text": "h", "timestamp": (0.5, 0.6)}, other sequence-to-sequence models.
{"text": "i", "timestamp": (0.7, 0.9)}]`, then it means the model predicts that the letter "h" was
pronounced after `0.5` and before `0.6` seconds. If set to `"word"`, the pipeline will return For CTC models, timestamps can take one of two formats:
timestamps along the text for every word in the text. For instance if you get `[{"text": "hi ", - `"char"`: the pipeline will return timestamps along the text for every character in the text. For
"timestamp": (0.5, 0.9)}, {"text": "there", "timestamp": (1.0, 1.5)}]`, then it means the model instance, if you get `[{"text": "h", "timestamp": (0.5, 0.6)}, {"text": "i", "timestamp": (0.7,
predicts that the word "hi" was pronounced after `0.5` and before `0.9` seconds. 0.9)}]`, then it means the model predicts that the letter "h" was spoken after `0.5` and before
`0.6` seconds.
- `"word"`: the pipeline will return timestamps along the text for every word in the text. For
instance, if you get `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": "there", "timestamp":
(1.0, 1.5)}]`, then it means the model predicts that the word "hi" was spoken after `0.5` and
before `0.9` seconds.
For the Whisper model, timestamps can take one of two formats:
- `"word"`: same as above for word-level CTC timestamps. Word-level timestamps are predicted
through the *dynamic-time warping (DTW)* algorithm, an approximation to word-level timestamps
by inspecting the cross-attention weights.
- `True`: the pipeline will return timestamps along the text for *segments* of words in the text.
For instance, if you get `[{"text": " Hi there!", "timestamp": (0.5, 1.5)}]`, then it means the
model predicts that the segment "Hi there!" was spoken after `0.5` and before `1.5` seconds.
Note that a segment of text refers to a sequence of one or more words, rather than individual
words as with word-level timestamps.
generate_kwargs (`dict`, *optional*): generate_kwargs (`dict`, *optional*):
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
complete overview of generate, check the [following complete overview of generate, check the [following
...@@ -264,12 +278,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -264,12 +278,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
Return: Return:
`Dict`: A dictionary with the following keys: `Dict`: A dictionary with the following keys:
- **text** (`str` ) -- The recognized text. - **text** (`str`): The recognized text.
- **chunks** (*optional(, `List[Dict]`) - **chunks** (*optional(, `List[Dict]`)
When using `return_timestamps`, the `chunks` will become a list containing all the various text When using `return_timestamps`, the `chunks` will become a list containing all the various text
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text":
"there", "timestamp": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing "there", "timestamp": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
`"".join(chunk["text"] for chunk in output["chunks"])`. `"".join(chunk["text"] for chunk in output["chunks"])`.
""" """
return super().__call__(inputs, **kwargs) return super().__call__(inputs, **kwargs)
...@@ -308,6 +322,20 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -308,6 +322,20 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if decoder_kwargs is not None: if decoder_kwargs is not None:
postprocess_params["decoder_kwargs"] = decoder_kwargs postprocess_params["decoder_kwargs"] = decoder_kwargs
if return_timestamps is not None: if return_timestamps is not None:
if self.type == "seq2seq" and return_timestamps:
raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!")
if self.type == "ctc_with_lm" and return_timestamps != "word":
raise ValueError("CTC with LM can only predict word level timestamps, set `return_timestamps='word'`")
if self.type == "ctc" and return_timestamps not in ["char", "word"]:
raise ValueError(
"CTC can either predict character (char) level timestamps, or word level timestamps."
"Set `return_timestamps='char'` or `return_timestamps='word'` as required."
)
if self.type == "seq2seq_whisper" and return_timestamps == "char":
raise ValueError(
"Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
"Use `return_timestamps='word'` or `return_timestamps=True` respectively."
)
forward_params["return_timestamps"] = return_timestamps forward_params["return_timestamps"] = return_timestamps
postprocess_params["return_timestamps"] = return_timestamps postprocess_params["return_timestamps"] = return_timestamps
if return_language is not None: if return_language is not None:
...@@ -497,13 +525,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -497,13 +525,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# Optional return types # Optional return types
optional = {} optional = {}
if return_timestamps and self.type == "seq2seq":
raise ValueError("We cannot return_timestamps yet on non-ctc models apart from Whisper !")
if return_timestamps == "char" and self.type == "ctc_with_lm":
raise ValueError("CTC with LM cannot return `char` timestamps, only `word`")
if return_timestamps == "char" and self.type == "seq2seq_whisper":
raise ValueError("Whisper cannot return `char` timestamps, use `True` or `word` instead.")
if return_language is not None and self.type != "seq2seq_whisper": if return_language is not None and self.type != "seq2seq_whisper":
raise ValueError("Only whisper can return language for now.") raise ValueError("Only whisper can return language for now.")
......
...@@ -136,7 +136,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -136,7 +136,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
else: else:
# Non CTC models cannot use return_timestamps # Non CTC models cannot use return_timestamps
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "^We cannot return_timestamps yet on non-ctc models apart from Whisper !$" ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
): ):
outputs = speech_recognizer(audio, return_timestamps="char") outputs = speech_recognizer(audio, return_timestamps="char")
...@@ -161,7 +161,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -161,7 +161,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
# Non CTC models cannot use return_timestamps # Non CTC models cannot use return_timestamps
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, "^We cannot return_timestamps yet on non-ctc models apart from Whisper !$" ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
): ):
_ = speech_recognizer(waveform, return_timestamps="char") _ = speech_recognizer(waveform, return_timestamps="char")
...@@ -261,6 +261,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -261,6 +261,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
], ],
}, },
) )
# CTC + LM models cannot use return_timestamps="char"
with self.assertRaisesRegex(
ValueError, "^CTC with LM can only predict word level timestamps, set `return_timestamps='word'`$"
):
_ = speech_recognizer(filename, return_timestamps="char")
@require_tf @require_tf
def test_small_model_tf(self): def test_small_model_tf(self):
...@@ -750,6 +755,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -750,6 +755,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
) )
# fmt: on # fmt: on
# Whisper can only predict segment level timestamps or word level, not character level
with self.assertRaisesRegex(
ValueError,
"^Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
"Use `return_timestamps='word'` or `return_timestamps=True` respectively.$",
):
_ = speech_recognizer(filename, return_timestamps="char")
@slow @slow
@require_torch @require_torch
@require_torchaudio @require_torchaudio
...@@ -1082,6 +1095,13 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -1082,6 +1095,13 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
], ],
}, },
) )
# CTC models must specify return_timestamps type - cannot set `return_timestamps=True` blindly
with self.assertRaisesRegex(
ValueError,
"^CTC can either predict character (char) level timestamps, or word level timestamps."
"Set `return_timestamps='char'` or `return_timestamps='word'` as required.$",
):
_ = speech_recognizer(audio, return_timestamps=True)
@require_torch @require_torch
@slow @slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment