Unverified Commit 13254591 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Refactor whisper asr pipeline to include language too. (#21427)



* [WIP] whisper refacto to support language output.

* Handling merges.

* A bit more cleanup and comments.

* Many improvements.

Lots of details everywhere.

* Cleanup old code and tests.

* Handle lone timestamp tokens (just recover when something bad happens).

* Adding return_language example.

* No ffmpeg.

* Hmm.

* Some corrections.

* Both fast and slow.

* New black.

* Update src/transformers/models/whisper/tokenization_whisper.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/whisper/tokenization_whisper.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Remove print.

* Undoing tests modifications.

* Smaller test modifications.

* Rename.

* Remove maxDiff.

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 8e5a1b2a
......@@ -703,3 +703,289 @@ class WhisperTokenizer(PreTrainedTokenizer):
forced_tokens = self.prefix_tokens[1:]
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_tokens)]
return forced_decoder_ids
def _decode_asr(self, model_outputs, *, return_timestamps, return_language, time_precision):
return _decode_asr(
self,
model_outputs,
return_timestamps=return_timestamps,
return_language=return_language,
time_precision=time_precision,
)
def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision):
"""
Internal method meant to only be used by asr pipeline. Handles all the little quirks specific to whisper to handle
the various options not allowed in other seq2seq models
"""
# =========== Overview ============
# - iterate over all outputs
# - all tokens within output
# - Each token can be
# - language token
# - special token
# - timestamp token
# - text token
# - We accumulate the text tokens.
# - We split on end timestamps
# - Lots of complexity comes from stride and timestamps
last_language = None
def new_chunk():
return {"language": last_language, "timestamp": [None, None], "text": ""}
# Welcome to the state machine !
chunks = []
chunk = new_chunk()
time_offset = 0.0
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
previous_tokens = []
skip = False
right_stride_start = None
all_special_ids = set(tokenizer.all_special_ids)
# - iterate over all outputs
for chunk_id, output in enumerate(model_outputs):
# We can drop everything to Python list, it's going to make
# our lives easier
token_ids = output["tokens"][0].tolist()
# Those keep track of timestamps within strides
# Which need to be skipped and resolve all tokens in a single
# chunk.
last_timestamp = None
first_timestamp = timestamp_begin
if "stride" in output:
chunk_len, stride_left, stride_right = output["stride"]
# Offset the timings to account for the other `model_outputs`.
time_offset -= stride_left
right_stride_start = chunk_len - stride_right
# Keeping track of timestamps within strides
# We're going to NOT split on those, and delay until we're
# out of BOTH stride. Otherwise lots of issues occur and
# corner cases
if stride_left:
first_timestamp = stride_left / time_precision + timestamp_begin
if stride_right:
for token in reversed(token_ids):
if token >= timestamp_begin:
# There can be several token in the right stride
# But the last one is ALWAYS going to be skipped
if (
last_timestamp is not None
and (token - timestamp_begin) * time_precision < right_stride_start
):
break
last_timestamp = token
current_tokens = []
# - all tokens within output
for i, token in enumerate(token_ids):
# 4 possible states for each token
# - 1/ Language code
# - 2/ all other special tokens (which we ignore)
# - 3/ Timestamp
# - 4/ Regular text
if token in all_special_ids:
# Either language code or other
text = tokenizer.decode([token])
# Removing outer shell <|XX|>
text = text[2:-2]
language = LANGUAGES.get(text, None)
if language is not None:
# 1/ Indeed some language
# TODO Handle when language is different from the previous
# one, and we cannot use timestamped tokens to create chunks
if last_language and language != last_language and not return_timestamps:
previous_tokens.append(current_tokens)
resolved_tokens = _find_longest_common_sequence(previous_tokens)
resolved_text = tokenizer.decode(resolved_tokens)
chunk["text"] = resolved_text
chunks.append(chunk)
# Flush all our temporary context
previous_tokens = []
current_tokens = []
chunk = new_chunk()
chunk["language"] = language
last_language = language
else:
# 2/ This is a regular special token, ignoring it
pass
elif token >= timestamp_begin:
# 3/ Timestamp token
time = (token - timestamp_begin) * time_precision + time_offset
time = round(time, 2)
if last_timestamp and token >= last_timestamp:
# Whisper outputted a timestamp token, but it falls within
# our stride, so we're going to skip it for the time being
# and resolve this later
# Skip is necessary because timestamp tokens always come
# by pair, so we need to skip the next one too (which would mark the start of another chunk).
skip = True
elif skip or (previous_tokens and token < first_timestamp):
skip = False
elif chunk["timestamp"][0] is None:
chunk["timestamp"][0] = time
else:
# This is the end of the timestamp chunk
if time == chunk["timestamp"][0]:
# This is a bug in timestamp token output
# where we're taking the duplicate token
# as a stop where it should be a start.
# This is an issue in the underlying model output
# Let's just skip it so it becomes de-factor
# a start agin
pass
else:
chunk["timestamp"][1] = time
# Handling merges.
previous_tokens.append(current_tokens)
resolved_tokens = _find_longest_common_sequence(previous_tokens)
resolved_text = tokenizer.decode(resolved_tokens)
chunk["text"] = resolved_text
chunks.append(chunk)
# Flush all our temporary context
previous_tokens = []
current_tokens = []
chunk = new_chunk()
else:
# 4/ Regular token
# We just append to the list of all tokens so we can handle
# merges later and decode into text.
current_tokens.append(token)
if "stride" in output:
time_offset += chunk_len - stride_right
# Leftover tokens
if current_tokens:
previous_tokens.append(current_tokens)
elif not (any(p for p in previous_tokens)):
# print("Flushing previous tokens (END)")
chunk = new_chunk()
previous_tokens = []
current_tokens = []
if previous_tokens:
if return_timestamps:
# Last token should always be timestamps, so there shouldn't be
# leftover
raise ValueError(
"There was an error while processing timestamps, we haven't found a timestamp as last token. Was"
" WhisperTimeStampLogitsProcessor used?"
)
# Happens when we don't use timestamps
resolved_tokens = _find_longest_common_sequence(previous_tokens)
# print("Flushing previous tokens (FINAL)")
resolved_text = tokenizer.decode(resolved_tokens)
chunk["text"] = resolved_text
chunks.append(chunk)
# Preparing and cleaning up the pipeline output
full_text = "".join(chunk["text"] for chunk in chunks)
if return_timestamps or return_language:
for chunk in chunks:
if not return_timestamps:
chunk.pop("timestamp")
else:
chunk["timestamp"] = tuple(chunk["timestamp"])
if not return_language:
chunk.pop("language")
optional = {"chunks": chunks}
else:
optional = {}
return full_text, optional
def _find_longest_common_sequence(sequences):
# It would be much harder to do O(n) because of fault tolerance.
# We actually have a really good property which is that the total sequence
# MUST be those subsequences in order.
left_sequence = sequences[0]
left_length = len(left_sequence)
total_sequence = []
for right_sequence in sequences[1:]:
# index = 0
max_ = 0.0
max_indices = (left_length, left_length, 0, 0)
# Here we're sliding matches
# [a, b, c, d]
# [c, d, f]
# = [c] == [d]
#
# [a, b, c, d]
# [c, d, f]
# = [c, d] == [c, d]
#
#
# [a, b, c, d]
# [c, d, f]
#
# = [b, c, d] == [c, d, f]
#
# [a, b, c, d]
# [c, d, f]
#
# [a, b, c] == [c, d, f]
#
# [a, b, c, d]
# [d, f]
#
# [a, b] == [d, f]
#
# [a, b, c, d]
# [f]
#
# [a] == [f]
right_length = len(right_sequence)
for i in range(1, left_length + right_length):
# epsilon to favor long perfect matches
eps = i / 10000.0
# Slightly convoluted because we don't want out of bound indices
# This will be necessary for a small conflict resolution optimization
# later
left_start = max(0, left_length - i)
left_stop = min(left_length, left_length + right_length - i)
left = np.array(left_sequence[left_start:left_stop])
right_start = max(0, i - left_length)
right_stop = min(right_length, i)
right = np.array(right_sequence[right_start:right_stop])
# We can only match subsequences of the same size.
if len(left) != len(right):
raise RuntimeError(
"There is a bug within whisper `decode_asr` function, please report it. Dropping to prevent bad inference."
)
matches = np.sum(left == right)
matching = matches / i + eps
if matches > 1 and matching > max_:
max_ = matching
max_indices = (left_start, left_stop, right_start, right_stop)
(left_start, left_stop, right_start, right_stop) = max_indices
# This is a small conflict optimization since those sequences overlap
# in audio.
# We're going to give more confidence to the left sequence
# for the left of the overlap,
# and to the right of the sequence, for the right of the overlap
left_mid = (left_stop + left_start) // 2
right_mid = (right_stop + right_start) // 2
total_sequence.extend(left_sequence[:left_mid])
left_sequence = right_sequence[right_mid:]
left_length = len(left_sequence)
total_sequence.extend(left_sequence)
return total_sequence
......@@ -24,7 +24,7 @@ from ...tokenization_utils_base import BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
from .english_normalizer import EnglishTextNormalizer
from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer
from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer, _decode_asr
logger = logging.get_logger(__name__)
......@@ -475,3 +475,12 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
forced_tokens = self.prefix_tokens[1:]
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_tokens)]
return forced_decoder_ids
def _decode_asr(self, model_outputs, *, return_timestamps, return_language, time_precision):
return _decode_asr(
self,
model_outputs,
return_timestamps=return_timestamps,
return_language=return_language,
time_precision=time_precision,
)
......@@ -82,112 +82,6 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
break
def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions):
"""
Computes the final sequences by merging the end of the nth sequence with the beginning of the n+1th sequence. Since
`WhisperForConditionalGeneration` produces the timestamps pairwise, we filter the consecutive timestamps and only
iterate over them. We keep track of the `time` which indicates the actual starting time of the chunk that is
processed. We need to make sure to offset the timestamps tokens by the `time` in order for the tokenizer to
properly compute the final `offset`.
"""
# index of the first timestamp token
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
items = []
# approximation of the token to time ratio : ~0.2seconds
time_precision = feature_extractor.chunk_length / max_source_positions
time = 0
for seq_idx, item in enumerate(sequences):
sequence, stride = item
if isinstance(sequence, list):
sequence = np.array(sequence)
chunk_len, stride_left, stride_right = stride
sequence = sequence.squeeze(0)
# get rid of the `forced_decoder_idx` that are use to parametrize the generation
begin_idx = np.where(sequence == timestamp_begin)[0][0] if timestamp_begin in sequence else 0
sequence = sequence[begin_idx:]
timestamp_tokens = sequence >= timestamp_begin
if seq_idx != 0 and sum(timestamp_tokens) > 0:
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
last_timestamp = np.where(timestamp_tokens)[0][-1]
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
time -= stride_left + stride_right
offset = int((time / feature_extractor.sampling_rate) / time_precision)
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision)
# relevant timestamps are in the overlapping part
relevant_timestamp = np.where(sequence[consecutive] >= timestamp_begin + overlap_time)[0]
if relevant_timestamp.shape[0] > 0:
relevant_timestamp = (
consecutive[relevant_timestamp[0] - 1] if relevant_timestamp[0] > 0 else consecutive[0]
)
# if a big stride is used, we need to check some of the previous items for the best overlap
best_match = 0
sliced_sequence = []
for idx, previous_sequence in enumerate(reversed(items)):
previous_tokens = previous_sequence[1:-1]
if previous_sequence[0] < (timestamp_begin + offset - overlap_time) and idx != 0:
break # the previous sequence is too far in the past
if len(previous_tokens) > 0:
# find the longest common sequence between the overlapping parts
index_left, index_right, match_length = _fast_find_longest_common_sequence(
sequence[1:relevant_timestamp], previous_tokens
)
# don't do anything if only 1 token was matched
if match_length > 1 and match_length > best_match:
best_match = match_length
best_idx = idx
end_of_curr_sequence_idx = (
np.where(sequence[index_left + 1 :] >= timestamp_begin)[0][0] + 1
)
end_of_curr_sequence_idx = end_of_curr_sequence_idx + 1 + index_left
# if all the tokens are matched, suffix
if index_left == 0 and match_length == len(previous_tokens):
sliced_sequence = np.insert(
sequence[index_left + 1 : end_of_curr_sequence_idx], 0, previous_sequence[0]
)
sliced_sequence[-1] = previous_sequence[-1]
# if part of the previous sequence is not taken
elif index_left >= 0:
sliced_sequence = sequence[index_left + 1 : end_of_curr_sequence_idx]
# let's insert the missing part of the previous sequence
previous_slice = (
previous_sequence[: index_right + 1] if index_right > 0 else [previous_sequence[0]]
)
sliced_sequence = np.insert(sliced_sequence, 0, previous_slice)
sliced_sequence[-1] += offset
if len(sliced_sequence) > 0:
items[len(items) - best_idx - 1] = sliced_sequence
items = items[: len(items) - best_idx]
sequence = sequence[end_of_curr_sequence_idx:]
# sequence might have changed
timestamp_tokens = sequence >= timestamp_begin
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
if sum(timestamp_tokens) > 0:
last_timestamp = np.where(timestamp_tokens)[0][-1]
consecutive = (
np.append(consecutive, last_timestamp + 1) if last_timestamp not in consecutive else consecutive
)
if len(consecutive) > 0:
last_slice = 0
for current_slice in consecutive:
actual_offset = items[-1][-1] if seq_idx != 0 or last_slice != 0 else sequence[0]
sliced_tokens = sequence[last_slice:current_slice]
duration = sliced_tokens[-1] - sliced_tokens[0]
sliced_tokens[0] = actual_offset
sliced_tokens[-1] = actual_offset + duration
items.append(sliced_tokens)
last_slice = current_slice
time += chunk_len
result = []
for i in range(len(items)):
result += items[i].tolist()
return result
def _fast_find_longest_common_sequence(sequence_left, sequence_right):
seq_len_left = len(sequence_left)
seq_len_right = len(sequence_right)
......@@ -384,6 +278,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
ignore_warning=None,
decoder_kwargs=None,
return_timestamps=None,
return_language=None,
generate_kwargs=None,
max_new_tokens=None,
):
......@@ -413,6 +308,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if return_timestamps is not None:
forward_params["return_timestamps"] = return_timestamps
postprocess_params["return_timestamps"] = return_timestamps
if return_language is not None:
postprocess_params["return_language"] = return_language
return preprocess_params, forward_params, postprocess_params
......@@ -580,7 +477,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
extra = model_inputs
return {"is_last": is_last, **out, **extra}
def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None, return_timestamps=None):
def postprocess(
self, model_outputs, decoder_kwargs: Optional[Dict] = None, return_timestamps=None, return_language=None
):
# Optional return types
optional = {}
......@@ -591,12 +490,15 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if return_timestamps in {"char", "words"} and self.type == "seq2seq_whisper":
raise ValueError("Whisper cannot return `char` nor `words` timestamps, use `True` instead.")
if return_language is not None and self.type != "seq2seq_whisper":
raise ValueError("Only whisper can return language for now.")
final_items = []
key = "logits" if self.type == "ctc_with_lm" else "tokens"
stride = None
for outputs in model_outputs:
items = outputs[key].numpy()
stride = outputs.pop("stride", None)
stride = outputs.get("stride", None)
if stride is not None and self.type in {"ctc", "ctc_with_lm"}:
total_n, left, right = stride
# Total_n might be < logits.shape[1]
......@@ -605,15 +507,28 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# This won't work with left padding (which doesn't exist right now)
right_n = total_n - right
items = items[:, left:right_n]
if self.type == "seq2seq_whisper" and return_timestamps and stride is not None:
# Whisper needs the stride data
items = [items, stride]
final_items.append(items)
if stride and self.type in {"seq2seq", "seq2seq_whisper"} and not return_timestamps:
if stride and self.type == "seq2seq":
items = _find_longest_common_sequence(final_items, self.tokenizer)
elif stride and self.type == "seq2seq_whisper" and return_timestamps:
items = _find_timestamp_sequence(
final_items, self.tokenizer, self.feature_extractor, self.model.config.max_source_positions
elif self.type == "seq2seq_whisper":
time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions
# Send the chunking back to seconds, it's easier to handle in whisper
sampling_rate = self.feature_extractor.sampling_rate
for output in model_outputs:
if "stride" in output:
chunk_len, stride_left, stride_right = output["stride"]
# Go back in seconds
chunk_len /= sampling_rate
stride_left /= sampling_rate
stride_right /= sampling_rate
output["stride"] = chunk_len, stride_left, stride_right
text, optional = self.tokenizer._decode_asr(
model_outputs,
return_timestamps=return_timestamps,
return_language=return_language,
time_precision=time_precision,
)
else:
items = np.concatenate(final_items, axis=1)
......@@ -631,14 +546,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
offsets = []
for word, (start_offset, end_offset) in chunk_offset:
offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
else:
elif self.type != "seq2seq_whisper":
skip_special_tokens = self.type != "ctc"
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
if return_timestamps and self.type == "seq2seq_whisper":
offsets = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens, output_offsets=True)[
"offsets"
]
elif return_timestamps:
if return_timestamps:
offsets = self.tokenizer.decode(
items, skip_special_tokens=skip_special_tokens, output_char_offsets=True
)["char_offsets"]
......@@ -656,14 +567,119 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
optional["chunks"] = chunks
elif return_timestamps and self.type == "seq2seq_whisper":
optional["chunks"] = offsets
extra = defaultdict(list)
for output in model_outputs:
output.pop("tokens", None)
output.pop("logits", None)
output.pop("is_last", None)
output.pop("stride", None)
for k, v in output.items():
extra[k].append(v)
return {"text": text, **optional, **extra}
def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions):
"""
Computes the final sequences by merging the end of the nth sequence with the beginning of the n+1th sequence. Since
`WhisperForConditionalGeneration` produces the timestamps pairwise, we filter the consecutive timestamps and only
iterate over them. We keep track of the `time` which indicates the actual starting time of the chunk that is
processed. We need to make sure to offset the timestamps tokens by the `time` in order for the tokenizer to
properly compute the final `offset`.
"""
# index of the first timestamp token
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
items = []
# approximation of the token to time ratio : ~0.2seconds
time_precision = feature_extractor.chunk_length / max_source_positions
time = 0
for seq_idx, item in enumerate(sequences):
sequence, stride = item
if isinstance(sequence, list):
sequence = np.array(sequence)
chunk_len, stride_left, stride_right = stride
sequence = sequence.squeeze(0)
# get rid of the `forced_decoder_idx` that are use to parametrize the generation
begin_idx = np.where(sequence == timestamp_begin)[0][0] if timestamp_begin in sequence else 0
sequence = sequence[begin_idx:]
timestamp_tokens = sequence >= timestamp_begin
if seq_idx != 0 and sum(timestamp_tokens) > 0:
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
last_timestamp = np.where(timestamp_tokens)[0][-1]
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
time -= stride_left + stride_right
offset = int((time / feature_extractor.sampling_rate) / time_precision)
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision)
# relevant timestamps are in the overlapping part
relevant_timestamp = np.where(sequence[consecutive] >= timestamp_begin + overlap_time)[0]
if relevant_timestamp.shape[0] > 0:
relevant_timestamp = (
consecutive[relevant_timestamp[0] - 1] if relevant_timestamp[0] > 0 else consecutive[0]
)
# if a big stride is used, we need to check some of the previous items for the best overlap
best_match = 0
sliced_sequence = []
for idx, previous_sequence in enumerate(reversed(items)):
previous_tokens = previous_sequence[1:-1]
if previous_sequence[0] < (timestamp_begin + offset - overlap_time) and idx != 0:
break # the previous sequence is too far in the past
if len(previous_tokens) > 0:
# find the longest common sequence between the overlapping parts
index_left, index_right, match_length = _fast_find_longest_common_sequence(
sequence[1:relevant_timestamp], previous_tokens
)
# don't do anything if only 1 token was matched
if match_length > 1 and match_length > best_match:
best_match = match_length
best_idx = idx
end_of_curr_sequence_idx = (
np.where(sequence[index_left + 1 :] >= timestamp_begin)[0][0] + 1
)
end_of_curr_sequence_idx = end_of_curr_sequence_idx + 1 + index_left
# if all the tokens are matched, suffix
if index_left == 0 and match_length == len(previous_tokens):
sliced_sequence = np.insert(
sequence[index_left + 1 : end_of_curr_sequence_idx], 0, previous_sequence[0]
)
sliced_sequence[-1] = previous_sequence[-1]
# if part of the previous sequence is not taken
elif index_left >= 0:
sliced_sequence = sequence[index_left + 1 : end_of_curr_sequence_idx]
# let's insert the missing part of the previous sequence
previous_slice = (
previous_sequence[: index_right + 1] if index_right > 0 else [previous_sequence[0]]
)
sliced_sequence = np.insert(sliced_sequence, 0, previous_slice)
sliced_sequence[-1] += offset
if len(sliced_sequence) > 0:
items[len(items) - best_idx - 1] = sliced_sequence
items = items[: len(items) - best_idx]
sequence = sequence[end_of_curr_sequence_idx:]
# sequence might have changed
timestamp_tokens = sequence >= timestamp_begin
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
if sum(timestamp_tokens) > 0:
last_timestamp = np.where(timestamp_tokens)[0][-1]
consecutive = (
np.append(consecutive, last_timestamp + 1) if last_timestamp not in consecutive else consecutive
)
if len(consecutive) > 0:
last_slice = 0
for current_slice in consecutive:
actual_offset = items[-1][-1] if seq_idx != 0 or last_slice != 0 else sequence[0]
sliced_tokens = sequence[last_slice:current_slice]
duration = sliced_tokens[-1] - sliced_tokens[0]
sliced_tokens[0] = actual_offset
sliced_tokens[-1] = actual_offset + duration
items.append(sliced_tokens)
last_slice = current_slice
time += chunk_len
result = []
for i in range(len(items)):
result += items[i].tolist()
return result
......@@ -15,6 +15,7 @@
import unittest
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import _find_longest_common_sequence
from transformers.testing_utils import slow
from ...test_tokenization_common import TokenizerTesterMixin
......@@ -115,6 +116,84 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
expected_encoding=expected_encoding, model_name="openai/whisper-tiny.en", padding=False
)
def test_output_offsets(self):
tokenizer = self.get_tokenizer()
previous_sequence = [51492, 406, 3163, 1953, 466, 13, 51612, 51612]
self.assertEqual(
tokenizer.decode(previous_sequence, output_offsets=True),
{
"text": " not worth thinking about.",
"offsets": [{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)}],
},
)
# Merge when the previous sequence is a suffix of the next sequence
# fmt: off
next_sequences_1 = [50364, 295, 6177, 3391, 11, 19817, 3337, 507, 307, 406, 3163, 1953, 466, 13, 50614, 50614, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50834, 50257]
# fmt: on
self.assertEqual(
tokenizer.decode(next_sequences_1, output_offsets=True),
{
"text": (
" of spectators, retrievality is not worth thinking about. His instant panic was followed by a"
" small, sharp blow high on his chest.<|endoftext|>"
),
"offsets": [
{"text": " of spectators, retrievality is not worth thinking about.", "timestamp": (0.0, 5.0)},
{
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
"timestamp": (5.0, 9.4),
},
],
},
)
def test_find_longest_common_subsequence(self):
previous_sequence = [1, 2, 3]
next_sequence = [2, 3, 4, 5]
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
self.assertEqual(merge, [1, 2, 3, 4, 5])
# Now previous is larger than next.
# We merge what we can and remove the extra right side of the left sequence
previous_sequence = [1, 2, 3, 4, 5, 6, 7]
next_sequence = [2, 3, 4, 5]
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
self.assertEqual(merge, [1, 2, 3, 4, 5])
# Nothing in common
previous_sequence = [1, 2, 3]
next_sequence = [4, 5, 6]
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
self.assertEqual(merge, [1, 2, 3, 4, 5, 6])
# Some errors in the overlap.
# We take from previous on the left, from the next on the right of the overlap
previous_sequence = [1, 2, 3, 4, 99]
next_sequence = [2, 98, 4, 5, 6]
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
self.assertEqual(merge, [1, 2, 3, 4, 5, 6])
# We take from previous on the left, from the next on the right of the overlap
previous_sequence = [1, 2, 99, 4, 5]
next_sequence = [2, 3, 4, 98, 6]
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
self.assertEqual(merge, [1, 2, 99, 4, 98, 6])
# This works on 3 sequences
seq1 = [1, 2, 3]
seq2 = [2, 3, 4]
seq3 = [3, 4, 5]
merge = _find_longest_common_sequence([seq1, seq2, seq3])
self.assertEqual(merge, [1, 2, 3, 4, 5])
# This works on 3 sequences with errors
seq1 = [1, 2, 3, 98, 5]
seq2 = [2, 99, 4, 5, 6, 7]
seq3 = [4, 97, 6, 7, 8]
merge = _find_longest_common_sequence([seq1, seq2, seq3])
self.assertEqual(merge, [1, 2, 3, 4, 5, 6, 7, 8])
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
checkpoint_name = "openai/whisper-small.en"
......
......@@ -538,7 +538,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
"tight-loan cloth that was the only garment he wore, the "
"cut"
),
"timestamp": (5.5, 11.94),
"timestamp": (5.5, 11.95),
},
{
"text": (
......@@ -546,15 +546,15 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
"overstrained eyes, even the soaring arena around him "
"with"
),
"timestamp": (11.94, 19.6),
"timestamp": (11.95, 19.61),
},
{
"text": " the thousands of spectators, retrievality is not worth thinking about.",
"timestamp": (19.6, 26.66),
"timestamp": (19.61, 25.0),
},
{
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
"timestamp": (26.66, 31.06),
"timestamp": (25.0, 29.4),
},
],
"text": (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment