Unverified Commit 9c8979e3 authored by Kamil Akesbi's avatar Kamil Akesbi Committed by GitHub
Browse files

Word-level timestamps broken for short-form audio (#30325)



* force chunk_length_s in AutomaticSpeechRecognitionPipeline

* compute num_frames even when stride is None

* add slow tests

* fix test

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/pipelines/test_pipelines_automatic_speech_recognition.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add input validation

* fixup

* small fix

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 4fda78c3
......@@ -446,6 +446,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
if stride is None:
extra["segment_size"] = len(inputs)
if self.torch_dtype is not None:
processed = processed.to(dtype=self.torch_dtype)
......@@ -459,8 +461,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
attention_mask = model_inputs.pop("attention_mask", None)
stride = model_inputs.pop("stride", None)
segment_size = model_inputs.pop("segment_size", None)
is_last = model_inputs.pop("is_last")
if stride is not None and segment_size is not None:
raise ValueError("segment_size must be used only when stride is None")
if self.type in {"seq2seq", "seq2seq_whisper"}:
encoder = self.model.get_encoder()
# Consume values so we can let extra information flow freely through
......@@ -488,6 +494,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
else:
generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride]
else:
if isinstance(segment_size, int):
generate_kwargs["num_frames"] = segment_size // self.feature_extractor.hop_length
else:
generate_kwargs["num_frames"] = segment_size[0] // self.feature_extractor.hop_length
if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames:
generate_kwargs["input_features"] = inputs
else:
......
......@@ -755,6 +755,94 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
},
)
@slow
@require_torch
def test_whisper_large_timestamp_prediction(self):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
array = np.concatenate(
[ds[40]["audio"]["array"], ds[41]["audio"]["array"], ds[42]["audio"]["array"], ds[43]["audio"]["array"]]
)
pipe = pipeline(model="openai/whisper-large-v3", return_timestamps=True)
output = pipe(ds[40]["audio"])
self.assertDictEqual(
output,
{
"text": " A man said to the universe, Sir, I exist.",
"chunks": [{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 4.08)}],
},
)
output = pipe(array, chunk_length_s=10)
self.assertDictEqual(
nested_simplify(output),
{
"chunks": [
{"timestamp": (0.0, 2.0), "text": (" A man said to the universe,")},
{"timestamp": (2.0, 4.1), "text": (" Sir, I exist.")},
{"timestamp": (5.14, 5.96), "text": (" Sweat covered")},
{"timestamp": (5.96, 8.02), "text": (" Breon's body, trickling into")},
{"timestamp": (8.02, 10.67), "text": (" the tight loincloth that was the only garment he wore,")},
{"timestamp": (10.67, 13.67), "text": (" the cut on his chest still dripping blood,")},
{"timestamp": (13.67, 17.61), "text": (" the ache of his overstrained eyes.")},
{
"timestamp": (17.61, 24.0),
"text": (
" Even the soaring arena around him with thousands of spectators were trivialities not worth thinking about."
),
},
{
"timestamp": (24.0, 29.94),
"text": (" His instant of panic was followed by a small, sharp blow high on his chest."),
},
],
"text": (
" A man said to the universe, Sir, I exist. Sweat covered Breon's"
" body, trickling into the tight loincloth that was the only garment"
" he wore, the cut on his chest still dripping blood, the ache of his"
" overstrained eyes. Even the soaring arena around him with thousands"
" of spectators were trivialities not worth thinking about. His "
"instant of panic was followed by a small, sharp blow high on his chest."
),
},
)
output = pipe(array)
self.assertDictEqual(
output,
{
"chunks": [
{"timestamp": (0.0, 1.96), "text": " A man said to the universe,"},
{"timestamp": (2.7, 4.1), "text": " Sir, I exist."},
{"timestamp": (5.14, 6.84), "text": " Sweat covered Brion's body,"},
{
"timestamp": (7.4, 10.68),
"text": " trickling into the tight loincloth that was the only garment he wore,",
},
{"timestamp": (11.6, 13.94), "text": " the cut on his chest still dripping blood,"},
{"timestamp": (14.78, 16.72), "text": " the ache of his overstrained eyes,"},
{
"timestamp": (17.32, 21.16),
"text": " even the soaring arena around him with the thousands of spectators",
},
{"timestamp": (21.16, 23.94), "text": " were trivialities not worth thinking about."},
{
"timestamp": (24.42, 29.94),
"text": " His instant panic was followed by a small sharp blow high on his chest.",
},
],
"text": (
" A man said to the universe, Sir, I exist. Sweat covered Brion's body,"
" trickling into the tight loincloth that was the only garment he wore, "
"the cut on his chest still dripping blood, the ache of his overstrained "
"eyes, even the soaring arena around him with the thousands of spectators "
"were trivialities not worth thinking about. His instant panic was followed "
"by a small sharp blow high on his chest."
),
},
)
@slow
@require_torch
def test_whisper_word_timestamps_batched(self):
......@@ -799,6 +887,49 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
output = pipe(sample, batch_size=2)
self.assertDictEqual(output, EXPECTED_OUTPUT)
@slow
@require_torch
def test_whisper_large_word_timestamps_batched(self):
pipe = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-large-v3",
return_timestamps="word",
)
data = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = data[0]["audio"]
# not the same output as test_simple_whisper_asr because of chunking
EXPECTED_OUTPUT = {
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
"chunks": [
{"text": " Mr.", "timestamp": (0.0, 0.74)},
{"text": " Quilter", "timestamp": (0.74, 1.04)},
{"text": " is", "timestamp": (1.04, 1.3)},
{"text": " the", "timestamp": (1.3, 1.44)},
{"text": " apostle", "timestamp": (1.44, 1.74)},
{"text": " of", "timestamp": (1.74, 2.18)},
{"text": " the", "timestamp": (2.18, 2.28)},
{"text": " middle", "timestamp": (2.28, 2.5)},
{"text": " classes,", "timestamp": (2.5, 3.0)},
{"text": " and", "timestamp": (3.0, 3.4)},
{"text": " we", "timestamp": (3.4, 3.5)},
{"text": " are", "timestamp": (3.5, 3.6)},
{"text": " glad", "timestamp": (3.6, 3.84)},
{"text": " to", "timestamp": (3.84, 4.1)},
{"text": " welcome", "timestamp": (4.1, 4.4)},
{"text": " his", "timestamp": (4.4, 4.7)},
{"text": " gospel.", "timestamp": (4.7, 5.34)},
],
}
# batch size 1: copy the audio sample since pipeline consumes it
output = pipe(sample.copy(), batch_size=1)
self.assertDictEqual(output, EXPECTED_OUTPUT)
# batch size 2: input audio is chunked into smaller pieces so it's testing batching
output = pipe(sample, batch_size=2)
self.assertDictEqual(output, EXPECTED_OUTPUT)
@require_torch
@slow
def test_torch_speech_encoder_decoder(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment