Unverified Commit 211f93aa authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Whisper Tokenizer] Make decoding faster after adding timestamps (#26299)

make decoding faster
parent 4e931a8e
......@@ -314,6 +314,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>")
self.language = language
super().__init__(
......@@ -560,10 +561,12 @@ class WhisperTokenizer(PreTrainedTokenizer):
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
# strip timestamp tokens from the text output
sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False)
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
text = self._decode(sliced_tokens)
text = self._filter_timestamp_ids(text)
offsets.append(
{
"text": self._decode(sliced_tokens),
"text": text,
"timestamp": (
start_timestamp_position * time_precision,
end_timestamp_position * time_precision,
......@@ -585,9 +588,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
"""
return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)])
def _preprocess_token_ids(
self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02
):
def _preprocess_token_ids(self, token_ids, skip_special_tokens: bool = False):
"""
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
......@@ -597,24 +598,17 @@ class WhisperTokenizer(PreTrainedTokenizer):
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
removed.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
filtered out from the token ids.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
if skip_special_tokens:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
if not decode_with_timestamps:
# filter timestamp tokens if they are contained in the vocab
timestamp_ids = self.timestamp_ids(time_precision=time_precision)
token_ids = [token for token in token_ids if token not in timestamp_ids]
return token_ids
def _filter_timestamp_ids(self, token_ids):
return re.sub(self.timestamp_pat, "", token_ids)
def decode(
self,
token_ids,
......@@ -644,6 +638,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text.
Returns:
......@@ -652,8 +648,6 @@ class WhisperTokenizer(PreTrainedTokenizer):
filtered_ids = self._preprocess_token_ids(
token_ids,
skip_special_tokens=skip_special_tokens,
decode_with_timestamps=decode_with_timestamps,
time_precision=time_precision,
)
text = super().decode(
......@@ -668,6 +662,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
text = self._decode_with_timestamps(
filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
)
else:
text = self._filter_timestamp_ids(text)
# retrieve offsets
if output_offsets:
offsets = self._compute_offsets(token_ids, time_precision=time_precision)
......
......@@ -15,6 +15,7 @@
"""Tokenization classes for Whisper."""
import json
import os
import re
from functools import lru_cache
from typing import List, Optional, Tuple
......@@ -190,6 +191,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
self.english_spelling_normalizer = None
self.add_prefix_space = add_prefix_space
self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>")
self.language = language
self.task = task
......@@ -269,10 +271,12 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
# strip timestamp tokens from the text output
sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False)
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
text = self._decode(sliced_tokens)
text = self._filter_timestamp_ids(text)
offsets.append(
{
"text": self._decode(sliced_tokens),
"text": text,
"timestamp": (
start_timestamp_position * time_precision,
end_timestamp_position * time_precision,
......@@ -296,9 +300,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)])
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._preprocess_token_ids
def _preprocess_token_ids(
self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02
):
def _preprocess_token_ids(self, token_ids, skip_special_tokens: bool = False):
"""
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
......@@ -308,24 +310,18 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
removed.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
filtered out from the token ids.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
if skip_special_tokens:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
if not decode_with_timestamps:
# filter timestamp tokens if they are contained in the vocab
timestamp_ids = self.timestamp_ids(time_precision=time_precision)
token_ids = [token for token in token_ids if token not in timestamp_ids]
return token_ids
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._filter_timestamp_ids
def _filter_timestamp_ids(self, token_ids):
return re.sub(self.timestamp_pat, "", token_ids)
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.decode
def decode(
self,
......@@ -356,6 +352,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text.
Returns:
......@@ -364,8 +362,6 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
filtered_ids = self._preprocess_token_ids(
token_ids,
skip_special_tokens=skip_special_tokens,
decode_with_timestamps=decode_with_timestamps,
time_precision=time_precision,
)
text = super().decode(
......@@ -380,6 +376,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
text = self._decode_with_timestamps(
filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
)
else:
text = self._filter_timestamp_ids(text)
# retrieve offsets
if output_offsets:
offsets = self._compute_offsets(token_ids, time_precision=time_precision)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment