Unverified Commit cd927a47 authored by Matthijs Hollemans's avatar Matthijs Hollemans Committed by GitHub
Browse files

add word-level timestamps to Whisper (#23205)

* let's go!

* initial implementation of token-level timestamps

* only return a single timestamp per token

* remove token probabilities

* fix return type

* fix doc comment

* strip special tokens

* rename

* revert to not stripping special tokens

* only support models that have alignment_heads

* add integration test

* consistently name it token-level timestamps

* small DTW tweak

* initial support for ASR pipeline

* fix pipeline doc comments

* resolve token timestamps in pipeline with chunking

* change warning when no final timestamp is found

* return word-level timestamps

* fixup

* fix bug that skipped final word in each chunk

* fix failing unit tests

* merge punctuations into the words

* also return word tokens

* also return token indices

* add (failing) unit test for combine_tokens_into_words

* make combine_tokens_into_words private

* restore OpenAI's punctuation rules

* add pipeline tests

* make requested changes

* PR review changes

* fix failing pipeline test

* small stuff from PR

* only return words and their timestamps, not segments

* move alignment_heads into generation config

* forgot to set alignment_heads in pipeline tests

* tiny comment fix

* grr
parent 0f968dda
...@@ -171,7 +171,9 @@ class WhisperConfig(PretrainedConfig): ...@@ -171,7 +171,9 @@ class WhisperConfig(PretrainedConfig):
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
step, irrespectively of `mask_feature_prob`. Only relevant if step, irrespectively of `mask_feature_prob`. Only relevant if
`mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks`. `mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks`.
median_filter_width (`int`, *optional*, defaults to 7):
Width of the median filter used to smoothen to cross-attention outputs when computing token timestamps.
Should be an odd number.
Example: Example:
...@@ -229,6 +231,7 @@ class WhisperConfig(PretrainedConfig): ...@@ -229,6 +231,7 @@ class WhisperConfig(PretrainedConfig):
mask_feature_prob=0.0, mask_feature_prob=0.0,
mask_feature_length=10, mask_feature_length=10,
mask_feature_min_masks=0, mask_feature_min_masks=0,
median_filter_width=7,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
...@@ -265,6 +268,9 @@ class WhisperConfig(PretrainedConfig): ...@@ -265,6 +268,9 @@ class WhisperConfig(PretrainedConfig):
self.mask_feature_prob = mask_feature_prob self.mask_feature_prob = mask_feature_prob
self.mask_feature_length = mask_feature_length self.mask_feature_length = mask_feature_length
self.mask_feature_min_masks = mask_feature_min_masks self.mask_feature_min_masks = mask_feature_min_masks
self.median_filter_width = median_filter_width
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
......
...@@ -227,6 +227,81 @@ def _compute_mask_indices( ...@@ -227,6 +227,81 @@ def _compute_mask_indices(
return spec_aug_mask return spec_aug_mask
def _median_filter(inputs: torch.Tensor, filter_width: int) -> torch.Tensor:
"""
Applies a median filter of width `filter_width` along the last dimension of the input.
The `inputs` tensor is assumed to be 3- or 4-dimensional.
"""
if filter_width <= 0 or filter_width % 2 != 1:
raise ValueError("`filter_width` should be an odd number")
pad_width = filter_width // 2
if inputs.shape[-1] <= pad_width:
return inputs
# Pad the left and right edges.
inputs = nn.functional.pad(inputs, (pad_width, pad_width, 0, 0), mode="reflect")
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
result = inputs.unfold(-1, filter_width, 1).sort()[0][..., pad_width]
return result
def _dynamic_time_warping(matrix: np.ndarray):
"""
Measures similarity between two temporal sequences: the input audio and the output tokens. Used to generate
token-level timestamps.
"""
output_length, input_length = matrix.shape
cost = np.ones((output_length + 1, input_length + 1), dtype=np.float32) * np.inf
trace = -np.ones((output_length + 1, input_length + 1), dtype=np.float32)
cost[0, 0] = 0
for j in range(1, input_length + 1):
for i in range(1, output_length + 1):
c0 = cost[i - 1, j - 1]
c1 = cost[i - 1, j]
c2 = cost[i, j - 1]
if c0 < c1 and c0 < c2:
c, t = c0, 0
elif c1 < c0 and c1 < c2:
c, t = c1, 1
else:
c, t = c2, 2
cost[i, j] = matrix[i - 1, j - 1] + c
trace[i, j] = t
# backtrace
i = trace.shape[0] - 1
j = trace.shape[1] - 1
trace[0, :] = 2
trace[:, 0] = 1
text_indices = []
time_indices = []
while i > 0 or j > 0:
text_indices.append(i - 1)
time_indices.append(j - 1)
if trace[i, j] == 0:
i -= 1
j -= 1
elif trace[i, j] == 1:
i -= 1
elif trace[i, j] == 2:
j -= 1
else:
raise RuntimeError(
f"Internal error in dynamic time warping. Unexpected trace[{i}, {j}]. Please file a bug report."
)
text_indices = np.array(text_indices)[::-1]
time_indices = np.array(time_indices)[::-1]
return text_indices, time_indices
class WhisperPositionalEmbedding(nn.Embedding): class WhisperPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
super().__init__(num_positions, embedding_dim) super().__init__(num_positions, embedding_dim)
...@@ -1472,6 +1547,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1472,6 +1547,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
language=None, language=None,
is_multilingual=None, is_multilingual=None,
prompt_ids: Optional[torch.Tensor] = None, prompt_ids: Optional[torch.Tensor] = None,
return_token_timestamps=None,
**kwargs, **kwargs,
): ):
""" """
...@@ -1534,6 +1610,10 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1534,6 +1610,10 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value. correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
return_token_timestamps (`bool`, *optional*):
Whether to return token-level timestamps with the text. This can be used with or without the
`return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
words.
kwargs: kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
...@@ -1662,7 +1742,19 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1662,7 +1742,19 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
if generation_config.return_timestamps: if generation_config.return_timestamps:
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)] logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
return super().generate( if return_token_timestamps:
kwargs["output_attentions"] = True
kwargs["return_dict_in_generate"] = True
if getattr(generation_config, "task", None) == "translate":
logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
if not hasattr(generation_config, "alignment_heads"):
raise ValueError(
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
)
outputs = super().generate(
inputs, inputs,
generation_config, generation_config,
logits_processor, logits_processor,
...@@ -1672,6 +1764,11 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1672,6 +1764,11 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
**kwargs, **kwargs,
) )
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
outputs["token_timestamps"] = self._extract_token_timestamps(outputs, generation_config.alignment_heads)
return outputs
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
decoder_input_ids, decoder_input_ids,
...@@ -1693,7 +1790,6 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1693,7 +1790,6 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
"decoder_attention_mask": None, "decoder_attention_mask": None,
} }
#
@staticmethod @staticmethod
def _reorder_cache(past_key_values, beam_idx): def _reorder_cache(past_key_values, beam_idx):
reordered_past = () reordered_past = ()
...@@ -1701,6 +1797,44 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1701,6 +1797,44 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past return reordered_past
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02):
"""
Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
map each output token to a position in the input audio.
Returns:
tensor containing the timestamps in seconds for each predicted token
"""
# Create a list with `decoder_layers` elements, each a tensor of shape
# (batch size, attention_heads, output length, input length).
cross_attentions = []
for i in range(self.config.decoder_layers):
cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2))
# Select specific cross-attention layers and heads. This is a tensor
# of shape (batch size, num selected, output length, input length).
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
weights = weights.permute([1, 0, 2, 3])
# Normalize and smoothen the weights.
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = (weights - mean) / std
weights = _median_filter(weights, self.config.median_filter_width)
# Average the different cross-attention heads.
matrix = weights.mean(dim=1)
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)
# Perform dynamic time warping on each element of the batch.
for batch_idx in range(timestamps.shape[0]):
text_indices, time_indices = _dynamic_time_warping(-matrix[batch_idx].double().cpu().numpy())
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
jump_times = time_indices[jumps] * time_precision
timestamps[batch_idx, 1:] = torch.tensor(jump_times)
return timestamps
@add_start_docstrings( @add_start_docstrings(
""" """
......
...@@ -585,7 +585,7 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -585,7 +585,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps. timestamps.
decode_with_timestamps (`bool`, *optional*, defaults to `False`): decode_with_timestamps (`bool`, *optional*, defaults to `False`):
WHether or not to decode with timestamps included in the raw text. Whether or not to decode with timestamps included in the raw text.
Returns: Returns:
`str`: The decoded sentence. `str`: The decoded sentence.
""" """
...@@ -779,6 +779,7 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, ...@@ -779,6 +779,7 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
time_offset = 0.0 time_offset = 0.0
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1 timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
previous_tokens = [] previous_tokens = []
previous_token_timestamps = []
skip = False skip = False
right_stride_start = None right_stride_start = None
...@@ -788,6 +789,8 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, ...@@ -788,6 +789,8 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
# We can drop everything to Python list, it's going to make # We can drop everything to Python list, it's going to make
# our lives easier # our lives easier
token_ids = output["tokens"][0].tolist() token_ids = output["tokens"][0].tolist()
if return_timestamps == "word":
token_timestamps = output["token_timestamps"][0].tolist()
# Those keep track of timestamps within strides # Those keep track of timestamps within strides
# Which need to be skipped and resolve all tokens in a single # Which need to be skipped and resolve all tokens in a single
...@@ -820,6 +823,7 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, ...@@ -820,6 +823,7 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
last_timestamp = token last_timestamp = token
current_tokens = [] current_tokens = []
current_token_timestamps = []
# - all tokens within output # - all tokens within output
for i, token in enumerate(token_ids): for i, token in enumerate(token_ids):
...@@ -883,20 +887,37 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, ...@@ -883,20 +887,37 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
chunk["timestamp"][1] = time chunk["timestamp"][1] = time
# Handling merges. # Handling merges.
previous_tokens.append(current_tokens) previous_tokens.append(current_tokens)
resolved_tokens = _find_longest_common_sequence(previous_tokens) if return_timestamps == "word":
previous_token_timestamps.append(current_token_timestamps)
resolved_tokens, resolved_token_timestamps = _find_longest_common_sequence(
previous_tokens, previous_token_timestamps
)
resolved_text = tokenizer.decode(resolved_tokens) resolved_text = tokenizer.decode(resolved_tokens)
chunk["text"] = resolved_text chunk["text"] = resolved_text
if return_timestamps == "word":
chunk["words"] = _collate_word_timestamps(
tokenizer, resolved_tokens, resolved_token_timestamps, last_language
)
chunks.append(chunk) chunks.append(chunk)
# Flush all our temporary context # Flush all our temporary context
previous_tokens = [] previous_tokens = []
current_tokens = [] current_tokens = []
previous_token_timestamps = []
current_token_timestamps = []
chunk = new_chunk() chunk = new_chunk()
else: else:
# 4/ Regular token # 4/ Regular token
# We just append to the list of all tokens so we can handle # We just append to the list of all tokens so we can handle
# merges later and decode into text. # merges later and decode into text.
current_tokens.append(token) current_tokens.append(token)
if return_timestamps == "word":
start_time = round(token_timestamps[i] + time_offset, 2)
if i + 1 < len(token_timestamps):
end_time = round(token_timestamps[i + 1] + time_offset, 2)
else:
end_time = None # should never happen
current_token_timestamps.append((start_time, end_time))
if "stride" in output: if "stride" in output:
time_offset += chunk_len - stride_right time_offset += chunk_len - stride_right
...@@ -904,21 +925,31 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, ...@@ -904,21 +925,31 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
# Leftover tokens # Leftover tokens
if current_tokens: if current_tokens:
previous_tokens.append(current_tokens) previous_tokens.append(current_tokens)
if return_timestamps == "word":
previous_token_timestamps.append(current_token_timestamps)
elif not (any(p for p in previous_tokens)): elif not (any(p for p in previous_tokens)):
chunk = new_chunk() chunk = new_chunk()
previous_tokens = [] previous_tokens = []
current_tokens = [] current_tokens = []
previous_token_timestamps = []
current_token_timestamps = []
if previous_tokens: if previous_tokens:
if return_timestamps: if return_timestamps:
logger.warning( logger.warning(
"There was an error while processing timestamps, we haven't found a timestamp as last token. Was" "Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. "
" WhisperTimeStampLogitsProcessor used?" "Also make sure WhisperTimeStampLogitsProcessor was used during generation."
) )
# Happens when we don't use timestamps # Happens when we don't use timestamps
resolved_tokens = _find_longest_common_sequence(previous_tokens) resolved_tokens, resolved_token_timestamps = _find_longest_common_sequence(
previous_tokens, previous_token_timestamps
)
resolved_text = tokenizer.decode(resolved_tokens) resolved_text = tokenizer.decode(resolved_tokens)
chunk["text"] = resolved_text chunk["text"] = resolved_text
if return_timestamps == "word":
chunk["words"] = _collate_word_timestamps(
tokenizer, resolved_tokens, resolved_token_timestamps, last_language
)
chunks.append(chunk) chunks.append(chunk)
# Preparing and cleaning up the pipeline output # Preparing and cleaning up the pipeline output
...@@ -931,20 +962,35 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, ...@@ -931,20 +962,35 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
chunk["timestamp"] = tuple(chunk["timestamp"]) chunk["timestamp"] = tuple(chunk["timestamp"])
if not return_language: if not return_language:
chunk.pop("language") chunk.pop("language")
optional = {"chunks": chunks}
if return_timestamps == "word":
new_chunks = []
for chunk in chunks:
new_chunks.extend(chunk["words"])
optional = {"chunks": new_chunks}
else:
optional = {"chunks": chunks}
else: else:
optional = {} optional = {}
return full_text, optional return full_text, optional
def _find_longest_common_sequence(sequences): def _find_longest_common_sequence(sequences, token_timestamp_sequences=None):
# It would be much harder to do O(n) because of fault tolerance. # It would be much harder to do O(n) because of fault tolerance.
# We actually have a really good property which is that the total sequence # We actually have a really good property which is that the total sequence
# MUST be those subsequences in order. # MUST be those subsequences in order.
# If token_timestamp_sequences is provided, will split those sequences in
# exactly the same way.
left_sequence = sequences[0] left_sequence = sequences[0]
left_length = len(left_sequence) left_length = len(left_sequence)
total_sequence = [] total_sequence = []
for right_sequence in sequences[1:]:
if token_timestamp_sequences:
left_token_timestamp_sequence = token_timestamp_sequences[0]
total_token_timestamp_sequence = []
for seq_idx, right_sequence in enumerate(sequences[1:]):
# index = 0 # index = 0
max_ = 0.0 max_ = 0.0
max_indices = (left_length, left_length, 0, 0) max_indices = (left_length, left_length, 0, 0)
...@@ -1018,6 +1064,148 @@ def _find_longest_common_sequence(sequences): ...@@ -1018,6 +1064,148 @@ def _find_longest_common_sequence(sequences):
left_sequence = right_sequence[right_mid:] left_sequence = right_sequence[right_mid:]
left_length = len(left_sequence) left_length = len(left_sequence)
if token_timestamp_sequences:
total_token_timestamp_sequence.extend(left_token_timestamp_sequence[:left_mid])
left_token_timestamp_sequence = token_timestamp_sequences[seq_idx + 1][right_mid:]
total_sequence.extend(left_sequence) total_sequence.extend(left_sequence)
return total_sequence if token_timestamp_sequences is None:
return total_sequence
if len(token_timestamp_sequences) > 0:
total_token_timestamp_sequence.extend(left_token_timestamp_sequence)
return total_sequence, total_token_timestamp_sequence
else:
return total_sequence, []
def _collate_word_timestamps(tokenizer, tokens, token_timestamps, language):
words, _, token_indices = _combine_tokens_into_words(tokenizer, tokens, language)
timings = [
{
"text": word,
"timestamp": (token_timestamps[indices[0]][0], token_timestamps[indices[-1]][1]),
}
for word, indices in zip(words, token_indices)
]
return timings
def _combine_tokens_into_words(
tokenizer,
tokens: List[int],
language: str = None,
prepend_punctuations: str = "\"'“¡¿([{-",
append_punctuations: str = "\"'.。,,!!??::”)]}、",
):
"""
Groups tokens by word. Returns a tuple containing a list of strings with the words, and a list of `token_id`
sequences with the tokens making up each word.
"""
if language is None:
language = tokenizer.language
if language is None:
language = "english"
if language in {"chinese", "japanese", "thai", "lao", "myanmar"}:
# These languages don't typically use spaces.
words, word_tokens, token_indices = _split_tokens_on_unicode(tokenizer, tokens)
else:
words, word_tokens, token_indices = _split_tokens_on_spaces(tokenizer, tokens)
_merge_punctuations(words, word_tokens, token_indices, prepend_punctuations, append_punctuations)
return words, word_tokens, token_indices
def _split_tokens_on_unicode(tokenizer, tokens: List[int]):
"""Combine tokens into words by splitting at any position where the tokens are decoded as valid unicode points."""
decoded_full = tokenizer.decode(tokens, decode_with_timestamps=True)
replacement_char = "\ufffd"
words = []
word_tokens = []
token_indices = []
current_tokens = []
current_indices = []
unicode_offset = 0
for token_idx, token in enumerate(tokens):
current_tokens.append(token)
current_indices.append(token_idx)
decoded = tokenizer.decode(current_tokens, decode_with_timestamps=True)
if (
replacement_char not in decoded
or decoded_full[unicode_offset + decoded.index(replacement_char)] == replacement_char
):
words.append(decoded)
word_tokens.append(current_tokens)
token_indices.append(current_indices)
current_tokens = []
current_indices = []
unicode_offset += len(decoded)
return words, word_tokens, token_indices
def _split_tokens_on_spaces(tokenizer, tokens: List[int]):
"""Combine tokens into words by splitting at whitespace and punctuation tokens."""
subwords, subword_tokens_list, subword_indices_list = _split_tokens_on_unicode(tokenizer, tokens)
words = []
word_tokens = []
token_indices = []
for subword, subword_tokens, subword_indices in zip(subwords, subword_tokens_list, subword_indices_list):
special = subword_tokens[0] >= tokenizer.eos_token_id
with_space = subword.startswith(" ")
punctuation = subword.strip() in "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
if special or with_space or punctuation or len(words) == 0:
words.append(subword)
word_tokens.append(subword_tokens)
token_indices.append(subword_indices)
else:
words[-1] = words[-1] + subword
word_tokens[-1].extend(subword_tokens)
token_indices[-1].extend(subword_indices)
return words, word_tokens, token_indices
def _merge_punctuations(words, tokens, indices, prepended, appended):
"""Merges punctuation tokens with neighboring words."""
# prepend punctuations
i = len(words) - 2
j = len(words) - 1
while i >= 0:
if words[i].startswith(" ") and words[i].strip() in prepended:
words[j] = words[i] + words[j]
tokens[j] = tokens[i] + tokens[j]
indices[j] = indices[i] + indices[j]
words[i] = ""
tokens[i] = []
indices[i] = []
else:
j = i
i -= 1
# append punctuations
i = 0
j = 1
while j < len(words):
if not words[i].endswith(" ") and words[j] in appended:
words[i] += words[j]
tokens[i] += tokens[j]
indices[i] += indices[j]
words[j] = ""
tokens[j] = []
indices[j] = []
else:
i = j
j += 1
# remove elements that are now empty
words[:] = [word for word in words if word]
tokens[:] = [token for token in tokens if token]
indices[:] = [idx for idx in indices if idx]
...@@ -295,7 +295,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -295,7 +295,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps. timestamps.
decode_with_timestamps (`bool`, *optional*, defaults to `False`): decode_with_timestamps (`bool`, *optional*, defaults to `False`):
WHether or not to decode with timestamps included in the raw text. Whether or not to decode with timestamps included in the raw text.
Returns: Returns:
`str`: The decoded sentence. `str`: The decoded sentence.
""" """
......
...@@ -246,12 +246,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -246,12 +246,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
inference to provide more context to the model). Only use `stride` with CTC models. inference to provide more context to the model). Only use `stride` with CTC models.
return_timestamps (*optional*, `str`): return_timestamps (*optional*, `str`):
Only available for pure CTC models. If set to `"char"`, the pipeline will return `timestamps` along the Only available for pure CTC models. If set to `"char"`, the pipeline will return timestamps along the
text for every character in the text. For instance if you get `[{"text": "h", "timestamps": (0.5,0.6), text for every character in the text. For instance if you get `[{"text": "h", "timestamp": (0.5, 0.6)},
{"text": "i", "timestamps": (0.7, .9)}]`, then it means the model predicts that the letter "h" was {"text": "i", "timestamp": (0.7, 0.9)}]`, then it means the model predicts that the letter "h" was
pronounced after `0.5` and before `0.6` seconds. If set to `"word"`, the pipeline will return pronounced after `0.5` and before `0.6` seconds. If set to `"word"`, the pipeline will return
`timestamps` along the text for every word in the text. For instance if you get `[{"text": "hi ", timestamps along the text for every word in the text. For instance if you get `[{"text": "hi ",
"timestamps": (0.5,0.9), {"text": "there", "timestamps": (1.0, .1.5)}]`, then it means the model "timestamp": (0.5, 0.9)}, {"text": "there", "timestamp": (1.0, 1.5)}]`, then it means the model
predicts that the word "hi" was pronounced after `0.5` and before `0.9` seconds. predicts that the word "hi" was pronounced after `0.5` and before `0.9` seconds.
generate_kwargs (`dict`, *optional*): generate_kwargs (`dict`, *optional*):
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
...@@ -265,8 +265,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -265,8 +265,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
- **text** (`str` ) -- The recognized text. - **text** (`str` ) -- The recognized text.
- **chunks** (*optional(, `List[Dict]`) - **chunks** (*optional(, `List[Dict]`)
When using `return_timestamps`, the `chunks` will become a list containing all the various text When using `return_timestamps`, the `chunks` will become a list containing all the various text
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text": chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text":
"there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing "there", "timestamp": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
`"".join(chunk["text"] for chunk in output["chunks"])`. `"".join(chunk["text"] for chunk in output["chunks"])`.
""" """
return super().__call__(inputs, **kwargs) return super().__call__(inputs, **kwargs)
...@@ -421,6 +421,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -421,6 +421,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
generate_kwargs = {} generate_kwargs = {}
if return_timestamps and self.type == "seq2seq_whisper": if return_timestamps and self.type == "seq2seq_whisper":
generate_kwargs["return_timestamps"] = return_timestamps generate_kwargs["return_timestamps"] = return_timestamps
if return_timestamps == "word":
generate_kwargs["return_token_timestamps"] = True
is_last = model_inputs.pop("is_last") is_last = model_inputs.pop("is_last")
if self.type in {"seq2seq", "seq2seq_whisper"}: if self.type in {"seq2seq", "seq2seq_whisper"}:
...@@ -447,7 +449,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -447,7 +449,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
attention_mask=attention_mask, attention_mask=attention_mask,
**generate_kwargs, **generate_kwargs,
) )
out = {"tokens": tokens} if return_timestamps == "word" and self.type == "seq2seq_whisper":
out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]}
else:
out = {"tokens": tokens}
if self.type == "seq2seq_whisper": if self.type == "seq2seq_whisper":
stride = model_inputs.pop("stride", None) stride = model_inputs.pop("stride", None)
if stride is not None: if stride is not None:
...@@ -486,9 +491,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -486,9 +491,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if return_timestamps and self.type == "seq2seq": if return_timestamps and self.type == "seq2seq":
raise ValueError("We cannot return_timestamps yet on non-ctc models apart from Whisper !") raise ValueError("We cannot return_timestamps yet on non-ctc models apart from Whisper !")
if return_timestamps == "char" and self.type == "ctc_with_lm": if return_timestamps == "char" and self.type == "ctc_with_lm":
raise ValueError("CTC with LM cannot return `char` timestamps, only `words`") raise ValueError("CTC with LM cannot return `char` timestamps, only `word`")
if return_timestamps in {"char", "words"} and self.type == "seq2seq_whisper": if return_timestamps == "char" and self.type == "seq2seq_whisper":
raise ValueError("Whisper cannot return `char` nor `words` timestamps, use `True` instead.") raise ValueError("Whisper cannot return `char` timestamps, use `True` or `word` instead.")
if return_language is not None and self.type != "seq2seq_whisper": if return_language is not None and self.type != "seq2seq_whisper":
raise ValueError("Only whisper can return language for now.") raise ValueError("Only whisper can return language for now.")
...@@ -574,6 +579,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -574,6 +579,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
output.pop("logits", None) output.pop("logits", None)
output.pop("is_last", None) output.pop("is_last", None)
output.pop("stride", None) output.pop("stride", None)
output.pop("token_timestamps", None)
for k, v in output.items(): for k, v in output.items():
extra[k].append(v) extra[k].append(v)
return {"text": text, **optional, **extra} return {"text": text, **optional, **extra}
......
...@@ -1436,6 +1436,35 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -1436,6 +1436,35 @@ class WhisperModelIntegrationTests(unittest.TestCase):
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT) self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
@slow
def test_tiny_token_timestamp_generation(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
input_speech = self._load_datasamples(4)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
torch_device
)
generate_outputs = model.generate(
input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
)
self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
# fmt: off
EXPECTED_OUTPUT = torch.tensor([
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.8400, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9400, 26.9400, 26.9400, 26.9400, 29.8400 ],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 28.0000 ],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7600, 1.0000, 1.4200, 1.8000, 1.9400, 2.1800, 2.5200, 3.0200, 3.3200, 3.5400, 3.9400, 4.5600, 4.9200, 5.2800, 5.5600, 5.9000, 6.1600, 6.3000, 6.4800, 6.4800, 6.6400, 7.8200, 7.9600, 8.2200, 8.6000, 8.9200, 9.2200, 9.5200, 9.7200, 10.0600, 10.5400, 10.8800, 11.2600, 11.5400, 11.7400, 12.0800, 15.6800, 15.6800],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 29.6400, 29.6600, 29.6600, 29.6600, 29.6600, 29.7600]
])
# fmt: on
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
@slow @slow
def test_tiny_specaugment_librispeech(self): def test_tiny_specaugment_librispeech(self):
torch_device = "cpu" torch_device = "cpu"
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import _find_longest_common_sequence from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
from transformers.testing_utils import slow from transformers.testing_utils import slow
from ...test_tokenization_common import TokenizerTesterMixin from ...test_tokenization_common import TokenizerTesterMixin
...@@ -255,6 +255,24 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -255,6 +255,24 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokenizer_prompt_ids.tolist(), fast_tokenizer_prompt_ids.tolist()) self.assertListEqual(tokenizer_prompt_ids.tolist(), fast_tokenizer_prompt_ids.tolist())
def test_combine_tokens_into_words(self):
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
# 'whatever "whatever" said someone, clever!?'
encoded_input = [1363, 7969, 503, 1363, 7969, 1, 848, 1580, 11, 13494, 7323]
expected_words = ["whatever", ' "whatever"', " said", " someone,", " clever!?"]
expected_tokens = [[1363, 7969], [503, 1363, 7969, 1], [848], [1580, 11], [13494, 7323]]
expected_indices = [[0, 1], [2, 3, 4, 5], [6], [7, 8], [9, 10]]
output = _combine_tokens_into_words(tokenizer, encoded_input)
self.assertEqual(expected_words, output[0])
self.assertEqual(expected_tokens, output[1])
self.assertEqual(expected_indices, output[2])
output_rust = _combine_tokens_into_words(rust_tokenizer, encoded_input)
self.assertEqual(expected_words, output_rust[0])
self.assertEqual(expected_tokens, output_rust[1])
self.assertEqual(expected_indices, output_rust[2])
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
checkpoint_name = "openai/whisper-small.en" checkpoint_name = "openai/whisper-small.en"
......
...@@ -316,6 +316,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -316,6 +316,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
"chunks": [{"text": " Conquered returned to its place amidst the tents.", "timestamp": (0.0, 3.36)}], "chunks": [{"text": " Conquered returned to its place amidst the tents.", "timestamp": (0.0, 3.36)}],
}, },
) )
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
res = pipe(sample["audio"]["array"], return_timestamps="word")
# fmt: off
# Note that the word-level timestamps predicted here are pretty bad.
self.assertEqual(
res,
{
"text": " Conquered returned to its place amidst the tents.",
"chunks": [
{'text': ' Conquered', 'timestamp': (29.78, 29.9)},
{'text': ' returned', 'timestamp': (29.9, 29.9)},
{'text': ' to', 'timestamp': (29.9, 29.9)},
{'text': ' its', 'timestamp': (29.9, 29.9)},
{'text': ' place', 'timestamp': (29.9, 29.9)},
{'text': ' amidst', 'timestamp': (29.9, 29.9)},
{'text': ' the', 'timestamp': (29.9, 29.9)},
{'text': ' tents.', 'timestamp': (29.9, 29.9)}
]
}
)
# fmt: on
@require_torch @require_torch
@slow @slow
...@@ -699,6 +720,35 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -699,6 +720,35 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
], ],
}, },
) )
speech_recognizer.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
output = speech_recognizer(filename, return_timestamps="word")
# fmt: off
self.assertEqual(
output,
{
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
"chunks": [
{'text': ' Mr.', 'timestamp': (0.0, 1.02)},
{'text': ' Quilter', 'timestamp': (1.02, 1.18)},
{'text': ' is', 'timestamp': (1.18, 1.44)},
{'text': ' the', 'timestamp': (1.44, 1.58)},
{'text': ' apostle', 'timestamp': (1.58, 1.98)},
{'text': ' of', 'timestamp': (1.98, 2.3)},
{'text': ' the', 'timestamp': (2.3, 2.46)},
{'text': ' middle', 'timestamp': (2.46, 2.56)},
{'text': ' classes,', 'timestamp': (2.56, 3.38)},
{'text': ' and', 'timestamp': (3.38, 3.52)},
{'text': ' we', 'timestamp': (3.52, 3.6)},
{'text': ' are', 'timestamp': (3.6, 3.72)},
{'text': ' glad', 'timestamp': (3.72, 4.0)},
{'text': ' to', 'timestamp': (4.0, 4.26)},
{'text': ' welcome', 'timestamp': (4.26, 4.54)},
{'text': ' his', 'timestamp': (4.54, 4.92)},
{'text': ' gospel.', 'timestamp': (4.92, 6.66)},
],
},
)
# fmt: on
@slow @slow
@require_torch @require_torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment