Unverified Commit dedd1116 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[ASR Pipeline] Clarify return timestamps (#25344)

* [ASR Pipeline] Clarify return timestamps

* fix indentation

* fix ctc check

* fix ctc error message!

* fix test

* fix other test

* add new tests

* final comment
parent 5ea2595e
......@@ -156,8 +156,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
feature_extractor ([`SequenceFeatureExtractor`]):
The feature extractor that will be used by the pipeline to encode waveform for the model.
chunk_length_s (`float`, *optional*, defaults to 0):
The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default). Only
available for CTC models, e.g. [`Wav2Vec2ForCTC`].
The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default).
<Tip>
......@@ -247,14 +246,29 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
inference to provide more context to the model). Only use `stride` with CTC models.
return_timestamps (*optional*, `str`):
Only available for pure CTC models. If set to `"char"`, the pipeline will return timestamps along the
text for every character in the text. For instance if you get `[{"text": "h", "timestamp": (0.5, 0.6)},
{"text": "i", "timestamp": (0.7, 0.9)}]`, then it means the model predicts that the letter "h" was
pronounced after `0.5` and before `0.6` seconds. If set to `"word"`, the pipeline will return
timestamps along the text for every word in the text. For instance if you get `[{"text": "hi ",
"timestamp": (0.5, 0.9)}, {"text": "there", "timestamp": (1.0, 1.5)}]`, then it means the model
predicts that the word "hi" was pronounced after `0.5` and before `0.9` seconds.
return_timestamps (*optional*, `str` or `bool`):
Only available for pure CTC models (Wav2Vec2, HuBERT, etc) and the Whisper model. Not available for
other sequence-to-sequence models.
For CTC models, timestamps can take one of two formats:
- `"char"`: the pipeline will return timestamps along the text for every character in the text. For
instance, if you get `[{"text": "h", "timestamp": (0.5, 0.6)}, {"text": "i", "timestamp": (0.7,
0.9)}]`, then it means the model predicts that the letter "h" was spoken after `0.5` and before
`0.6` seconds.
- `"word"`: the pipeline will return timestamps along the text for every word in the text. For
instance, if you get `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": "there", "timestamp":
(1.0, 1.5)}]`, then it means the model predicts that the word "hi" was spoken after `0.5` and
before `0.9` seconds.
For the Whisper model, timestamps can take one of two formats:
- `"word"`: same as above for word-level CTC timestamps. Word-level timestamps are predicted
through the *dynamic-time warping (DTW)* algorithm, an approximation to word-level timestamps
by inspecting the cross-attention weights.
- `True`: the pipeline will return timestamps along the text for *segments* of words in the text.
For instance, if you get `[{"text": " Hi there!", "timestamp": (0.5, 1.5)}]`, then it means the
model predicts that the segment "Hi there!" was spoken after `0.5` and before `1.5` seconds.
Note that a segment of text refers to a sequence of one or more words, rather than individual
words as with word-level timestamps.
generate_kwargs (`dict`, *optional*):
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
complete overview of generate, check the [following
......@@ -264,7 +278,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
Return:
`Dict`: A dictionary with the following keys:
- **text** (`str` ) -- The recognized text.
- **text** (`str`): The recognized text.
- **chunks** (*optional(, `List[Dict]`)
When using `return_timestamps`, the `chunks` will become a list containing all the various text
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text":
......@@ -308,6 +322,20 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if decoder_kwargs is not None:
postprocess_params["decoder_kwargs"] = decoder_kwargs
if return_timestamps is not None:
if self.type == "seq2seq" and return_timestamps:
raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!")
if self.type == "ctc_with_lm" and return_timestamps != "word":
raise ValueError("CTC with LM can only predict word level timestamps, set `return_timestamps='word'`")
if self.type == "ctc" and return_timestamps not in ["char", "word"]:
raise ValueError(
"CTC can either predict character (char) level timestamps, or word level timestamps."
"Set `return_timestamps='char'` or `return_timestamps='word'` as required."
)
if self.type == "seq2seq_whisper" and return_timestamps == "char":
raise ValueError(
"Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
"Use `return_timestamps='word'` or `return_timestamps=True` respectively."
)
forward_params["return_timestamps"] = return_timestamps
postprocess_params["return_timestamps"] = return_timestamps
if return_language is not None:
......@@ -497,13 +525,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# Optional return types
optional = {}
if return_timestamps and self.type == "seq2seq":
raise ValueError("We cannot return_timestamps yet on non-ctc models apart from Whisper !")
if return_timestamps == "char" and self.type == "ctc_with_lm":
raise ValueError("CTC with LM cannot return `char` timestamps, only `word`")
if return_timestamps == "char" and self.type == "seq2seq_whisper":
raise ValueError("Whisper cannot return `char` timestamps, use `True` or `word` instead.")
if return_language is not None and self.type != "seq2seq_whisper":
raise ValueError("Only whisper can return language for now.")
......
......@@ -136,7 +136,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
else:
# Non CTC models cannot use return_timestamps
with self.assertRaisesRegex(
ValueError, "^We cannot return_timestamps yet on non-ctc models apart from Whisper !$"
ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
):
outputs = speech_recognizer(audio, return_timestamps="char")
......@@ -161,7 +161,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
# Non CTC models cannot use return_timestamps
with self.assertRaisesRegex(
ValueError, "^We cannot return_timestamps yet on non-ctc models apart from Whisper !$"
ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
):
_ = speech_recognizer(waveform, return_timestamps="char")
......@@ -261,6 +261,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
],
},
)
# CTC + LM models cannot use return_timestamps="char"
with self.assertRaisesRegex(
ValueError, "^CTC with LM can only predict word level timestamps, set `return_timestamps='word'`$"
):
_ = speech_recognizer(filename, return_timestamps="char")
@require_tf
def test_small_model_tf(self):
......@@ -750,6 +755,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
)
# fmt: on
# Whisper can only predict segment level timestamps or word level, not character level
with self.assertRaisesRegex(
ValueError,
"^Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
"Use `return_timestamps='word'` or `return_timestamps=True` respectively.$",
):
_ = speech_recognizer(filename, return_timestamps="char")
@slow
@require_torch
@require_torchaudio
......@@ -1082,6 +1095,13 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
],
},
)
# CTC models must specify return_timestamps type - cannot set `return_timestamps=True` blindly
with self.assertRaisesRegex(
ValueError,
"^CTC can either predict character (char) level timestamps, or word level timestamps."
"Set `return_timestamps='char'` or `return_timestamps='word'` as required.$",
):
_ = speech_recognizer(audio, return_timestamps=True)
@require_torch
@slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment