"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "24588c6731474030d9f91285f2c4472f1394df04"
Unverified Commit 211f93aa authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Whisper Tokenizer] Make decoding faster after adding timestamps (#26299)

make decoding faster
parent 4e931a8e
...@@ -314,6 +314,7 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -314,6 +314,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>")
self.language = language self.language = language
super().__init__( super().__init__(
...@@ -560,10 +561,12 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -560,10 +561,12 @@ class WhisperTokenizer(PreTrainedTokenizer):
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
# strip timestamp tokens from the text output # strip timestamp tokens from the text output
sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False) sliced_tokens = self._preprocess_token_ids(sliced_tokens)
text = self._decode(sliced_tokens)
text = self._filter_timestamp_ids(text)
offsets.append( offsets.append(
{ {
"text": self._decode(sliced_tokens), "text": text,
"timestamp": ( "timestamp": (
start_timestamp_position * time_precision, start_timestamp_position * time_precision,
end_timestamp_position * time_precision, end_timestamp_position * time_precision,
...@@ -585,9 +588,7 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -585,9 +588,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
""" """
return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)]) return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)])
def _preprocess_token_ids( def _preprocess_token_ids(self, token_ids, skip_special_tokens: bool = False):
self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02
):
""" """
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids. Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
...@@ -597,24 +598,17 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -597,24 +598,17 @@ class WhisperTokenizer(PreTrainedTokenizer):
skip_special_tokens (`bool`, *optional*, defaults to `False`): skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
removed. removed.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
filtered out from the token ids.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
""" """
if skip_special_tokens: if skip_special_tokens:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>") prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>") decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id) token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
if not decode_with_timestamps:
# filter timestamp tokens if they are contained in the vocab
timestamp_ids = self.timestamp_ids(time_precision=time_precision)
token_ids = [token for token in token_ids if token not in timestamp_ids]
return token_ids return token_ids
def _filter_timestamp_ids(self, token_ids):
return re.sub(self.timestamp_pat, "", token_ids)
def decode( def decode(
self, self,
token_ids, token_ids,
...@@ -644,6 +638,8 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -644,6 +638,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
output_offsets (`bool`, *optional*, defaults to `False`): output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps. timestamps.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`): decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. Whether or not to decode with timestamps included in the raw text.
Returns: Returns:
...@@ -652,8 +648,6 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -652,8 +648,6 @@ class WhisperTokenizer(PreTrainedTokenizer):
filtered_ids = self._preprocess_token_ids( filtered_ids = self._preprocess_token_ids(
token_ids, token_ids,
skip_special_tokens=skip_special_tokens, skip_special_tokens=skip_special_tokens,
decode_with_timestamps=decode_with_timestamps,
time_precision=time_precision,
) )
text = super().decode( text = super().decode(
...@@ -668,6 +662,9 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -668,6 +662,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
text = self._decode_with_timestamps( text = self._decode_with_timestamps(
filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
) )
else:
text = self._filter_timestamp_ids(text)
# retrieve offsets # retrieve offsets
if output_offsets: if output_offsets:
offsets = self._compute_offsets(token_ids, time_precision=time_precision) offsets = self._compute_offsets(token_ids, time_precision=time_precision)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Tokenization classes for Whisper.""" """Tokenization classes for Whisper."""
import json import json
import os import os
import re
from functools import lru_cache from functools import lru_cache
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
...@@ -190,6 +191,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -190,6 +191,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
self.english_spelling_normalizer = None self.english_spelling_normalizer = None
self.add_prefix_space = add_prefix_space self.add_prefix_space = add_prefix_space
self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>")
self.language = language self.language = language
self.task = task self.task = task
...@@ -269,10 +271,12 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -269,10 +271,12 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
# strip timestamp tokens from the text output # strip timestamp tokens from the text output
sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False) sliced_tokens = self._preprocess_token_ids(sliced_tokens)
text = self._decode(sliced_tokens)
text = self._filter_timestamp_ids(text)
offsets.append( offsets.append(
{ {
"text": self._decode(sliced_tokens), "text": text,
"timestamp": ( "timestamp": (
start_timestamp_position * time_precision, start_timestamp_position * time_precision,
end_timestamp_position * time_precision, end_timestamp_position * time_precision,
...@@ -296,9 +300,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -296,9 +300,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)]) return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)])
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._preprocess_token_ids # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._preprocess_token_ids
def _preprocess_token_ids( def _preprocess_token_ids(self, token_ids, skip_special_tokens: bool = False):
self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02
):
""" """
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids. Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
...@@ -308,24 +310,18 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -308,24 +310,18 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
skip_special_tokens (`bool`, *optional*, defaults to `False`): skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
removed. removed.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
filtered out from the token ids.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
""" """
if skip_special_tokens: if skip_special_tokens:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>") prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>") decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id) token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
if not decode_with_timestamps:
# filter timestamp tokens if they are contained in the vocab
timestamp_ids = self.timestamp_ids(time_precision=time_precision)
token_ids = [token for token in token_ids if token not in timestamp_ids]
return token_ids return token_ids
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._filter_timestamp_ids
def _filter_timestamp_ids(self, token_ids):
return re.sub(self.timestamp_pat, "", token_ids)
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.decode # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.decode
def decode( def decode(
self, self,
...@@ -356,6 +352,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -356,6 +352,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
output_offsets (`bool`, *optional*, defaults to `False`): output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps. timestamps.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`): decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. Whether or not to decode with timestamps included in the raw text.
Returns: Returns:
...@@ -364,8 +362,6 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -364,8 +362,6 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
filtered_ids = self._preprocess_token_ids( filtered_ids = self._preprocess_token_ids(
token_ids, token_ids,
skip_special_tokens=skip_special_tokens, skip_special_tokens=skip_special_tokens,
decode_with_timestamps=decode_with_timestamps,
time_precision=time_precision,
) )
text = super().decode( text = super().decode(
...@@ -380,6 +376,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -380,6 +376,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
text = self._decode_with_timestamps( text = self._decode_with_timestamps(
filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
) )
else:
text = self._filter_timestamp_ids(text)
# retrieve offsets # retrieve offsets
if output_offsets: if output_offsets:
offsets = self._compute_offsets(token_ids, time_precision=time_precision) offsets = self._compute_offsets(token_ids, time_precision=time_precision)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment