".github/vscode:/vscode.git/clone" did not exist on "d4564df1d4a6b355779f1a8ac250cb47cb4c38d8"
Unverified Commit ddf7ac42 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Token level timestamps for long-form generation in Whisper (#29148)

parent 8a1faf28
......@@ -720,6 +720,7 @@ class WhisperGenerationMixin:
input_stride=input_stride,
prev_idx=prev_i,
idx=i,
return_token_timestamps=return_token_timestamps,
)
current_segments[prev_i] += segments
......@@ -809,11 +810,15 @@ class WhisperGenerationMixin:
# remove eos token id
if is_not_final and seek_sequence[-1] == generation_config.eos_token_id:
seek_sequence = seek_sequence[:-1]
if return_token_timestamps:
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-1]
# remove all padding tokens
if seek_sequence[-1] == generation_config.pad_token_id:
num_paddings = (seek_sequence == generation_config.pad_token_id).sum()
seek_sequence = seek_sequence[:-num_paddings]
if return_token_timestamps:
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings]
# check which sequences in batch need fallback & which should be skipped
needs_fallback[i], should_skip[i] = self._need_fallback(
......@@ -878,15 +883,18 @@ class WhisperGenerationMixin:
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
)
seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, decoder_input_ids.shape[-1] :]
seek_outputs["sequences"] = seek_outputs["sequences"][:, decoder_input_ids.shape[-1] :]
def split_by_batch_index(values, key, batch_idx):
if key == "scores":
return [v[batch_idx].cpu() for v in values]
if key == "past_key_values":
elif key == "past_key_values":
# we don't save `past_key_values` as this is too costly
return None
elif isinstance(values[batch_idx], tuple) and torch.is_tensor(values[batch_idx][0]):
return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
return values[batch_idx].cpu()
sequence_tokens = seek_outputs["sequences"]
......@@ -1611,6 +1619,7 @@ class WhisperGenerationMixin:
input_stride,
prev_idx,
idx,
return_token_timestamps,
):
# find the predicted "end of segment" predictions of Whisper
# "end of segment" predictions occur whenever Whisper predicts a timestamp token
......@@ -1618,6 +1627,7 @@ class WhisperGenerationMixin:
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
timestamp_segment_indices.add_(1)
token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
# If whisper predicted a "end of segment" via a timestep token, let's go ever each
# "end of segment" prediction and slice the decoding into segments accordingly
......@@ -1642,6 +1652,10 @@ class WhisperGenerationMixin:
"result": seek_outputs[idx],
}
)
if return_token_timestamps:
segments[-1]["token_timestamps"] = (
token_timestamps[last_slice:current_slice] + time_offset[prev_idx]
)
last_slice = current_slice
if single_timestamp_ending:
......@@ -1661,7 +1675,6 @@ class WhisperGenerationMixin:
if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = timestamps[-1].item() - timestamp_begin
segments = [
{
"start": time_offset[prev_idx],
......@@ -1670,6 +1683,8 @@ class WhisperGenerationMixin:
"result": seek_outputs[idx],
}
]
if return_token_timestamps:
segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx]
segment_offset = seek_num_frames[prev_idx]
return segments, segment_offset
......@@ -483,6 +483,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
generate_kwargs["return_timestamps"] = return_timestamps
if return_timestamps == "word":
generate_kwargs["return_token_timestamps"] = True
generate_kwargs["return_segments"] = True
if stride is not None:
if isinstance(stride, tuple):
......@@ -499,8 +500,16 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
attention_mask=attention_mask,
**generate_kwargs,
)
# whisper longform generation stores timestamps in "segments"
if return_timestamps == "word" and self.type == "seq2seq_whisper":
if "segments" not in tokens:
out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]}
else:
token_timestamps = [
torch.cat([segment["token_timestamps"] for segment in segment_list])
for segment_list in tokens["segments"]
]
out = {"tokens": tokens["sequences"], "token_timestamps": token_timestamps}
else:
out = {"tokens": tokens}
if self.type == "seq2seq_whisper":
......
......@@ -1969,6 +1969,56 @@ class WhisperModelIntegrationTests(unittest.TestCase):
self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples)
@slow
def test_tiny_token_timestamp_generation_longform(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
input_speech = self._load_datasamples(5)
long_input_speech = np.concatenate(input_speech, dtype=np.float32)
inputs = processor.feature_extractor(
raw_speech=long_input_speech,
return_tensors="pt",
truncation=False, # False so the audio isn't truncated and whole audio is sent to the model
return_attention_mask=True,
padding=True,
)
inputs = inputs.to(torch_device)
generate_outputs = model.generate(**inputs, return_segments=True, return_token_timestamps=True)
token_timestamps_shape = [
[segment["token_timestamps"].shape for segment in segment_list]
for segment_list in generate_outputs["segments"]
]
tokens_shape = [
[segment["tokens"].shape for segment in segment_list] for segment_list in generate_outputs["segments"]
]
self.assertListEqual(tokens_shape, token_timestamps_shape)
# fmt: off
EXPECTED_OUTPUT = [
torch.tensor([0.0000, 0.4200, 0.8200, 0.9400, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0400, 2.3400, 2.5200, 2.6600, 3.2000, 3.4400, 3.5600, 3.6800, 3.8200, 4.1000, 4.3000, 4.5800, 4.9400, 5.4000, 6.3600]),
torch.tensor([ 6.5400, 6.5400, 6.7400, 6.9600, 7.2600, 7.3400, 7.5800, 7.5800, 7.6400, 7.8400, 8.1000, 8.5000, 9.0000, 9.4800, 9.7200, 10.2600, 11.1000]),
torch.tensor([11.2200, 11.2200, 11.4200, 11.6600, 12.0800, 12.4400, 12.5800, 12.8400, 13.1800, 13.6800, 14.0000, 14.2200, 14.6200, 14.9800, 15.2200, 15.6000, 15.9400, 16.2000, 16.5600, 16.8400, 16.9800]),
torch.tensor([16.9800, 16.9800, 17.3200, 18.1600, 18.6400, 18.8600, 19.2800, 19.5600, 19.8800, 20.1800, 20.3800, 20.7200, 21.1600, 21.5400, 21.9000, 22.2000, 22.4200, 22.8600, 23.7000]),
torch.tensor([23.7000, 23.7000, 23.9400, 24.1800, 24.3800, 24.8400, 25.2800, 25.6600, 25.9200, 26.2600, 26.4000, 26.5800, 26.7600, 27.1400, 27.3800, 28.0400, 28.3800, 28.8200, 29.3400, 29.5200]),
torch.tensor([29.4400, 29.4400, 29.7000, 30.0800, 30.3800, 30.5400, 30.8200, 31.0600, 31.6600, 31.9200, 32.3000, 32.4800, 32.6200, 33.6800]),
torch.tensor([33.8000, 33.8000, 33.9800, 33.9800, 34.1800, 34.4400, 34.6200, 35.0000, 35.2200, 35.3200, 35.5600, 35.9200, 36.3800, 36.6200, 36.6600, 36.9600, 37.3400, 37.9800, 38.5800, 38.7200, 38.9800, 39.4400, 39.5800, 39.8000, 40.1200, 40.2600]),
torch.tensor([40.5200, 40.5200, 40.6200, 41.1000, 41.5400, 41.9200, 42.1000, 42.3200, 42.3200, 43.0600, 44.6000]),
torch.tensor([44.7000, 44.7000, 44.8600, 44.9400, 45.1400, 45.1400, 45.2800, 45.6200, 45.9000, 46.2600, 47.1600, 47.4800, 47.7400, 48.1000, 48.2800, 48.4000, 48.6200, 48.8400, 49.0400, 49.2800, 49.4800, 49.6600, 49.9400, 50.5400]),
torch.tensor([50.5400, 50.5400, 50.6600, 50.8800, 51.2400, 51.7200, 52.8400]),
torch.tensor([52.9600, 52.9600, 53.0400, 53.2600, 53.4200, 53.5800, 53.9200, 54.1200, 54.7200, 54.9400, 55.2600, 55.6200, 55.9800, 56.5600, 56.8000, 56.9200, 57.3600, 57.9200, 58.1800, 58.5000, 58.6400, 58.8200]),
torch.tensor([58.6800, 58.6800, 59.1400, 59.5400, 59.9200, 60.1600, 60.3800, 60.8200, 61.6200, 62.2600, 75.2000]),
]
# fmt: on
for segment, exp_segment in zip(generate_outputs["segments"][0], EXPECTED_OUTPUT):
self.assertTrue(torch.allclose(segment["token_timestamps"], exp_segment))
@slow
def test_tiny_specaugment_librispeech(self):
torch_device = "cpu"
......
......@@ -361,6 +361,70 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
)
# fmt: on
@slow
@require_torch
def test_return_timestamps_in_preprocess_longform(self):
pipe = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny.en",
)
data = load_dataset("librispeech_asr", "clean", split="test", streaming=True)
samples = [next(iter(data)) for _ in range(8)]
audio = np.concatenate([sample["audio"]["array"] for sample in samples])
res = pipe(audio)
expected_output = {
"text": " Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
"the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
"the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst "
"the tents. Concord returned to its place amidst the tents."
}
self.assertEqual(res, expected_output)
res = pipe(audio, return_timestamps=True)
self.assertEqual(
res,
{
"text": " Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents. Concord returned to its place amidst the tents.",
"chunks": [
{"timestamp": (0.0, 3.22), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (3.22, 6.74), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (6.74, 10.26), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (10.26, 13.78), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (13.78, 17.3), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (17.3, 20.82), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (20.82, 24.34), "text": " Concord returned to its place amidst the tents."},
{"timestamp": (24.34, 27.86), "text": " Concord returned to its place amidst the tents."},
],
},
)
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
res = pipe(audio, return_timestamps="word")
# fmt: off
self.assertEqual(
res["chunks"][:15],
[
{"text": " Concord", "timestamp": (0.5, 0.94)},
{"text": " returned", "timestamp": (0.94, 1.52)},
{"text": " to", "timestamp": (1.52, 1.78)},
{"text": " its", "timestamp": (1.78, 1.98)},
{"text": " place", "timestamp": (1.98, 2.16)},
{"text": " amidst", "timestamp": (2.16, 2.5)},
{"text": " the", "timestamp": (2.5, 2.9)},
{"text": " tents.", "timestamp": (2.9, 4.2)},
{"text": " Concord", "timestamp": (4.2, 4.5)},
{"text": " returned", "timestamp": (4.5, 5.0)},
{"text": " to", "timestamp": (5.0, 5.28)},
{"text": " its", "timestamp": (5.28, 5.48)},
{"text": " place", "timestamp": (5.48, 5.7)},
{"text": " amidst", "timestamp": (5.7, 6.02)},
{"text": " the", "timestamp": (6.02, 6.4)}
],
)
# fmt: on
@require_torch
def test_return_timestamps_in_init(self):
# segment-level timestamps are accepted
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment