Unverified Commit ddf7ac42 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Token level timestamps for long-form generation in Whisper (#29148)

parent 8a1faf28
...@@ -720,6 +720,7 @@ class WhisperGenerationMixin: ...@@ -720,6 +720,7 @@ class WhisperGenerationMixin:
input_stride=input_stride, input_stride=input_stride,
prev_idx=prev_i, prev_idx=prev_i,
idx=i, idx=i,
return_token_timestamps=return_token_timestamps,
) )
current_segments[prev_i] += segments current_segments[prev_i] += segments
...@@ -809,11 +810,15 @@ class WhisperGenerationMixin: ...@@ -809,11 +810,15 @@ class WhisperGenerationMixin:
# remove eos token id # remove eos token id
if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: if is_not_final and seek_sequence[-1] == generation_config.eos_token_id:
seek_sequence = seek_sequence[:-1] seek_sequence = seek_sequence[:-1]
if return_token_timestamps:
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-1]
# remove all padding tokens # remove all padding tokens
if seek_sequence[-1] == generation_config.pad_token_id: if seek_sequence[-1] == generation_config.pad_token_id:
num_paddings = (seek_sequence == generation_config.pad_token_id).sum() num_paddings = (seek_sequence == generation_config.pad_token_id).sum()
seek_sequence = seek_sequence[:-num_paddings] seek_sequence = seek_sequence[:-num_paddings]
if return_token_timestamps:
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings]
# check which sequences in batch need fallback & which should be skipped # check which sequences in batch need fallback & which should be skipped
needs_fallback[i], should_skip[i] = self._need_fallback( needs_fallback[i], should_skip[i] = self._need_fallback(
...@@ -878,15 +883,18 @@ class WhisperGenerationMixin: ...@@ -878,15 +883,18 @@ class WhisperGenerationMixin:
seek_outputs["token_timestamps"] = self._extract_token_timestamps( seek_outputs["token_timestamps"] = self._extract_token_timestamps(
seek_outputs, generation_config.alignment_heads, num_frames=num_frames seek_outputs, generation_config.alignment_heads, num_frames=num_frames
) )
seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, decoder_input_ids.shape[-1] :]
seek_outputs["sequences"] = seek_outputs["sequences"][:, decoder_input_ids.shape[-1] :] seek_outputs["sequences"] = seek_outputs["sequences"][:, decoder_input_ids.shape[-1] :]
def split_by_batch_index(values, key, batch_idx): def split_by_batch_index(values, key, batch_idx):
if key == "scores": if key == "scores":
return [v[batch_idx].cpu() for v in values] return [v[batch_idx].cpu() for v in values]
if key == "past_key_values": elif key == "past_key_values":
# we don't save `past_key_values` as this is too costly # we don't save `past_key_values` as this is too costly
return None return None
elif isinstance(values[batch_idx], tuple) and torch.is_tensor(values[batch_idx][0]):
return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
return values[batch_idx].cpu() return values[batch_idx].cpu()
sequence_tokens = seek_outputs["sequences"] sequence_tokens = seek_outputs["sequences"]
...@@ -1611,6 +1619,7 @@ class WhisperGenerationMixin: ...@@ -1611,6 +1619,7 @@ class WhisperGenerationMixin:
input_stride, input_stride,
prev_idx, prev_idx,
idx, idx,
return_token_timestamps,
): ):
# find the predicted "end of segment" predictions of Whisper # find the predicted "end of segment" predictions of Whisper
# "end of segment" predictions occur whenever Whisper predicts a timestamp token # "end of segment" predictions occur whenever Whisper predicts a timestamp token
...@@ -1618,6 +1627,7 @@ class WhisperGenerationMixin: ...@@ -1618,6 +1627,7 @@ class WhisperGenerationMixin:
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
timestamp_segment_indices.add_(1) timestamp_segment_indices.add_(1)
token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
# If whisper predicted a "end of segment" via a timestep token, let's go ever each # If whisper predicted a "end of segment" via a timestep token, let's go ever each
# "end of segment" prediction and slice the decoding into segments accordingly # "end of segment" prediction and slice the decoding into segments accordingly
...@@ -1642,6 +1652,10 @@ class WhisperGenerationMixin: ...@@ -1642,6 +1652,10 @@ class WhisperGenerationMixin:
"result": seek_outputs[idx], "result": seek_outputs[idx],
} }
) )
if return_token_timestamps:
segments[-1]["token_timestamps"] = (
token_timestamps[last_slice:current_slice] + time_offset[prev_idx]
)
last_slice = current_slice last_slice = current_slice
if single_timestamp_ending: if single_timestamp_ending:
...@@ -1661,7 +1675,6 @@ class WhisperGenerationMixin: ...@@ -1661,7 +1675,6 @@ class WhisperGenerationMixin:
if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin: if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one. # no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = timestamps[-1].item() - timestamp_begin last_timestamp_pos = timestamps[-1].item() - timestamp_begin
segments = [ segments = [
{ {
"start": time_offset[prev_idx], "start": time_offset[prev_idx],
...@@ -1670,6 +1683,8 @@ class WhisperGenerationMixin: ...@@ -1670,6 +1683,8 @@ class WhisperGenerationMixin:
"result": seek_outputs[idx], "result": seek_outputs[idx],
} }
] ]
if return_token_timestamps:
segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx]
segment_offset = seek_num_frames[prev_idx] segment_offset = seek_num_frames[prev_idx]
return segments, segment_offset return segments, segment_offset
...@@ -483,6 +483,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -483,6 +483,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
generate_kwargs["return_timestamps"] = return_timestamps generate_kwargs["return_timestamps"] = return_timestamps
if return_timestamps == "word": if return_timestamps == "word":
generate_kwargs["return_token_timestamps"] = True generate_kwargs["return_token_timestamps"] = True
generate_kwargs["return_segments"] = True
if stride is not None: if stride is not None:
if isinstance(stride, tuple): if isinstance(stride, tuple):
...@@ -499,8 +500,16 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -499,8 +500,16 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
attention_mask=attention_mask, attention_mask=attention_mask,
**generate_kwargs, **generate_kwargs,
) )
# whisper longform generation stores timestamps in "segments"
if return_timestamps == "word" and self.type == "seq2seq_whisper": if return_timestamps == "word" and self.type == "seq2seq_whisper":
out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]} if "segments" not in tokens:
out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]}
else:
token_timestamps = [
torch.cat([segment["token_timestamps"] for segment in segment_list])
for segment_list in tokens["segments"]
]
out = {"tokens": tokens["sequences"], "token_timestamps": token_timestamps}
else: else:
out = {"tokens": tokens} out = {"tokens": tokens}
if self.type == "seq2seq_whisper": if self.type == "seq2seq_whisper":
......
...@@ -1969,6 +1969,56 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -1969,6 +1969,56 @@ class WhisperModelIntegrationTests(unittest.TestCase):
self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples) self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples)
@slow
def test_tiny_token_timestamp_generation_longform(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
input_speech = self._load_datasamples(5)
long_input_speech = np.concatenate(input_speech, dtype=np.float32)
inputs = processor.feature_extractor(
raw_speech=long_input_speech,
return_tensors="pt",
truncation=False, # False so the audio isn't truncated and whole audio is sent to the model
return_attention_mask=True,
padding=True,
)
inputs = inputs.to(torch_device)
generate_outputs = model.generate(**inputs, return_segments=True, return_token_timestamps=True)
token_timestamps_shape = [
[segment["token_timestamps"].shape for segment in segment_list]
for segment_list in generate_outputs["segments"]
]
tokens_shape = [
[segment["tokens"].shape for segment in segment_list] for segment_list in generate_outputs["segments"]
]
self.assertListEqual(tokens_shape, token_timestamps_shape)
# fmt: off
EXPECTED_OUTPUT = [
torch.tensor([0.0000, 0.4200, 0.8200, 0.9400, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0400, 2.3400, 2.5200, 2.6600, 3.2000, 3.4400, 3.5600, 3.6800, 3.8200, 4.1000, 4.3000, 4.5800, 4.9400, 5.4000, 6.3600]),
torch.tensor([ 6.5400, 6.5400, 6.7400, 6.9600, 7.2600, 7.3400, 7.5800, 7.5800, 7.6400, 7.8400, 8.1000, 8.5000, 9.0000, 9.4800, 9.7200, 10.2600, 11.1000]),
torch.tensor([11.2200, 11.2200, 11.4200, 11.6600, 12.0800, 12.4400, 12.5800, 12.8400, 13.1800, 13.6800, 14.0000, 14.2200, 14.6200, 14.9800, 15.2200, 15.6000, 15.9400, 16.2000, 16.5600, 16.8400, 16.9800]),
torch.tensor([16.9800, 16.9800, 17.3200, 18.1600, 18.6400, 18.8600, 19.2800, 19.5600, 19.8800, 20.1800, 20.3800, 20.7200, 21.1600, 21.5400, 21.9000, 22.2000, 22.4200, 22.8600, 23.7000]),
torch.tensor([23.7000, 23.7000, 23.9400, 24.1800, 24.3800, 24.8400, 25.2800, 25.6600, 25.9200, 26.2600, 26.4000, 26.5800, 26.7600, 27.1400, 27.3800, 28.0400, 28.3800, 28.8200, 29.3400, 29.5200]),
torch.tensor([29.4400, 29.4400, 29.7000, 30.0800, 30.3800, 30.5400, 30.8200, 31.0600, 31.6600, 31.9200, 32.3000, 32.4800, 32.6200, 33.6800]),
torch.tensor([33.8000, 33.8000, 33.9800, 33.9800, 34.1800, 34.4400, 34.6200, 35.0000, 35.2200, 35.3200, 35.5600, 35.9200, 36.3800, 36.6200, 36.6600, 36.9600, 37.3400, 37.9800, 38.5800, 38.7200, 38.9800, 39.4400, 39.5800, 39.8000, 40.1200, 40.2600]),
torch.tensor([40.5200, 40.5200, 40.6200, 41.1000, 41.5400, 41.9200, 42.1000, 42.3200, 42.3200, 43.0600, 44.6000]),
torch.tensor([44.7000, 44.7000, 44.8600, 44.9400, 45.1400, 45.1400, 45.2800, 45.6200, 45.9000, 46.2600, 47.1600, 47.4800, 47.7400, 48.1000, 48.2800, 48.4000, 48.6200, 48.8400, 49.0400, 49.2800, 49.4800, 49.6600, 49.9400, 50.5400]),
torch.tensor([50.5400, 50.5400, 50.6600, 50.8800, 51.2400, 51.7200, 52.8400]),
torch.tensor([52.9600, 52.9600, 53.0400, 53.2600, 53.4200, 53.5800, 53.9200, 54.1200, 54.7200, 54.9400, 55.2600, 55.6200, 55.9800, 56.5600, 56.8000, 56.9200, 57.3600, 57.9200, 58.1800, 58.5000, 58.6400, 58.8200]),
torch.tensor([58.6800, 58.6800, 59.1400, 59.5400, 59.9200, 60.1600, 60.3800, 60.8200, 61.6200, 62.2600, 75.2000]),
]
# fmt: on
for segment, exp_segment in zip(generate_outputs["segments"][0], EXPECTED_OUTPUT):
self.assertTrue(torch.allclose(segment["token_timestamps"], exp_segment))
@slow @slow
def test_tiny_specaugment_librispeech(self): def test_tiny_specaugment_librispeech(self):
torch_device = "cpu" torch_device = "cpu"
......
...@@ -361,6 +361,70 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -361,6 +361,70 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
) )
# fmt: on # fmt: on
@slow
@require_torch
def test_return_timestamps_in_preprocess_longform(self):
pipe = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny.en",
)
data = load_dataset("librispeech_asr", "clean", split="test", streaming=True)
samples = [next(iter(data)) for _ in range(8)]
audio = np.concatenate([sample["audio"]["array"] for sample in samples])
res = pipe(audio)
expected_output = {
"text": " Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
"the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
"the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
"the tents. Concord returned to its place amidst the tents."
}
self.assertEqual(res, expected_output)
res = pipe(audio, return_timestamps=True)
self.assertEqual(
res,
{
"text": " Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents.",
"chunks": [
{"timestamp": (0.0, 3.22), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (3.22, 6.74), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (6.74, 10.26), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (10.26, 13.78), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (13.78, 17.3), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (17.3, 20.82), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (20.82, 24.34), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (24.34, 27.86), "text": " Concord returned to its place amidst the tents."},
],
},
)
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
res = pipe(audio, return_timestamps="word")
# fmt: off
self.assertEqual(
res["chunks"][:15],
[
{"text": " Concord", "timestamp": (0.5, 0.94)},
{"text": " returned", "timestamp": (0.94, 1.52)},
{"text": " to", "timestamp": (1.52, 1.78)},
{"text": " its", "timestamp": (1.78, 1.98)},
{"text": " place", "timestamp": (1.98, 2.16)},
{"text": " amidst", "timestamp": (2.16, 2.5)},
{"text": " the", "timestamp": (2.5, 2.9)},
{"text": " tents.", "timestamp": (2.9, 4.2)},
{"text": " Concord", "timestamp": (4.2, 4.5)},
{"text": " returned", "timestamp": (4.5, 5.0)},
{"text": " to", "timestamp": (5.0, 5.28)},
{"text": " its", "timestamp": (5.28, 5.48)},
{"text": " place", "timestamp": (5.48, 5.7)},
{"text": " amidst", "timestamp": (5.7, 6.02)},
{"text": " the", "timestamp": (6.02, 6.4)}
],
)
# fmt: on
@require_torch @require_torch
def test_return_timestamps_in_init(self): def test_return_timestamps_in_init(self):
# segment-level timestamps are accepted # segment-level timestamps are accepted
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment