Unverified Commit ac957f69 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Whisper Tokenizer] Encode timestamps (#26054)

* [Whisper Tokenizer] Fix tests after adding timestamps

* fix s2t tokenizer tests

* fix vocab test

* backwards comp

* fix tests

* comment

* style

* fix last test

* fix fast

* make faster

* move logic to decode

* remove skip test

* fix decode with offsets

* fix special tokens

* empty commit to re-trigger ci

* use lru cache
parent 6d49b9dc
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Tokenization classes for Whisper.""" """Tokenization classes for Whisper."""
import json import json
import os import os
from functools import lru_cache
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -546,6 +547,8 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -546,6 +547,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
if len(sliced_tokens) > 1: if len(sliced_tokens) > 1:
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
# strip timestamp tokens from the text output
sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False)
offsets.append( offsets.append(
{ {
"text": self._decode(sliced_tokens), "text": self._decode(sliced_tokens),
...@@ -559,6 +562,47 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -559,6 +562,47 @@ class WhisperTokenizer(PreTrainedTokenizer):
return offsets return offsets
@lru_cache
def timestamp_ids(self, time_precision=0.02):
"""
Compute the timestamp token ids for a given precision and save to least-recently used (LRU) cache.
Args:
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)])
def _preprocess_token_ids(
self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02
):
"""
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
Args:
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids. Typically, obtained using the `__call__` method of the tokenizer.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
removed.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
filtered out from the token ids.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
if skip_special_tokens:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
if not decode_with_timestamps:
# filter timestamp tokens if they are contained in the vocab
timestamp_ids = self.timestamp_ids(time_precision=time_precision)
token_ids = [token for token in token_ids if token not in timestamp_ids]
return token_ids
def decode( def decode(
self, self,
token_ids, token_ids,
...@@ -593,33 +637,40 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -593,33 +637,40 @@ class WhisperTokenizer(PreTrainedTokenizer):
Returns: Returns:
`str`: The decoded sentence. `str`: The decoded sentence.
""" """
text = super().decode( filtered_ids = self._preprocess_token_ids(
token_ids, token_ids,
skip_special_tokens=skip_special_tokens, skip_special_tokens=skip_special_tokens,
decode_with_timestamps=decode_with_timestamps,
time_precision=time_precision,
)
text = super().decode(
filtered_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
decode_with_timestamps=decode_with_timestamps,
**kwargs, **kwargs,
) )
if decode_with_timestamps: if decode_with_timestamps:
# legacy method to decode timestamps when not included in the tokenizer vocabulary
text = self._decode_with_timestamps( text = self._decode_with_timestamps(
token_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
) )
# retrieve offsets # retrieve offsets
if output_offsets: if output_offsets:
offsets = None
offsets = self._compute_offsets(token_ids, time_precision=time_precision) offsets = self._compute_offsets(token_ids, time_precision=time_precision)
return {"text": text, "offsets": offsets} return {"text": text, "offsets": offsets}
return text return text
def _decode( def _decode(
self, token_ids: Union[int, List[int]], skip_special_tokens: bool = False, normalize: bool = False, **kwargs self,
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
normalize: bool = False,
decode_with_timestamps: bool = False,
**kwargs,
) -> str: ) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
if skip_special_tokens:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
# To avoid mixing byte-level and unicode for byte-level BPT # To avoid mixing byte-level and unicode for byte-level BPT
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Tokenization classes for Whisper.""" """Tokenization classes for Whisper."""
import json import json
import os import os
from functools import lru_cache
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
import numpy as np import numpy as np
...@@ -255,6 +256,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -255,6 +256,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
if len(sliced_tokens) > 1: if len(sliced_tokens) > 1:
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
# strip timestamp tokens from the text output
sliced_tokens = self._preprocess_token_ids(sliced_tokens, decode_with_timestamps=False)
offsets.append( offsets.append(
{ {
"text": self._decode(sliced_tokens), "text": self._decode(sliced_tokens),
...@@ -268,6 +271,49 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -268,6 +271,49 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
return offsets return offsets
@lru_cache
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.timestamp_ids
def timestamp_ids(self, time_precision=0.02):
"""
Compute the timestamp token ids for a given precision and save to least-recently used (LRU) cache.
Args:
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)])
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._preprocess_token_ids
def _preprocess_token_ids(
self, token_ids, skip_special_tokens: bool = False, decode_with_timestamps: bool = False, time_precision=0.02
):
"""
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
Args:
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids. Typically, obtained using the `__call__` method of the tokenizer.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
removed.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
filtered out from the token ids.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
if skip_special_tokens:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
if not decode_with_timestamps:
# filter timestamp tokens if they are contained in the vocab
timestamp_ids = self.timestamp_ids(time_precision=time_precision)
token_ids = [token for token in token_ids if token not in timestamp_ids]
return token_ids
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.decode # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.decode
def decode( def decode(
self, self,
...@@ -303,29 +349,32 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -303,29 +349,32 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
Returns: Returns:
`str`: The decoded sentence. `str`: The decoded sentence.
""" """
text = super().decode( filtered_ids = self._preprocess_token_ids(
token_ids, token_ids,
skip_special_tokens=skip_special_tokens, skip_special_tokens=skip_special_tokens,
decode_with_timestamps=decode_with_timestamps,
time_precision=time_precision,
)
text = super().decode(
filtered_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
decode_with_timestamps=decode_with_timestamps,
**kwargs, **kwargs,
) )
if decode_with_timestamps: if decode_with_timestamps:
# legacy method to decode timestamps when not included in the tokenizer vocabulary
text = self._decode_with_timestamps( text = self._decode_with_timestamps(
token_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
) )
# retrieve offsets # retrieve offsets
if output_offsets: if output_offsets:
offsets = None
offsets = self._compute_offsets(token_ids, time_precision=time_precision) offsets = self._compute_offsets(token_ids, time_precision=time_precision)
return {"text": text, "offsets": offsets} return {"text": text, "offsets": offsets}
return text return text
def _decode(self, *args, normalize: bool = False, **kwargs) -> str: def _decode(self, *args, normalize: bool = False, **kwargs) -> str:
if kwargs["skip_special_tokens"]:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
kwargs["token_ids"] = self._strip_prompt(kwargs["token_ids"], prompt_token_id, decoder_start_token_id)
text = super()._decode(*args, **kwargs) text = super()._decode(*args, **kwargs)
if normalize: if normalize:
......
...@@ -52,14 +52,13 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -52,14 +52,13 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id) self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token) self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
@unittest.skip("TODO @Sanchit. Let's make the CI green in the mean time")
def test_get_vocab(self): def test_get_vocab(self):
vocab_keys = list(self.get_tokenizer().get_vocab().keys()) vocab_keys = list(self.get_tokenizer().get_vocab().keys())
self.assertEqual(vocab_keys[0], "!") self.assertEqual(vocab_keys[0], "!")
self.assertEqual(vocab_keys[1], '"') self.assertEqual(vocab_keys[1], '"')
self.assertEqual(vocab_keys[-1], "<|notimestamps|>") self.assertEqual(vocab_keys[-1], "<|30.00|>")
self.assertEqual(len(vocab_keys), 50364) self.assertEqual(len(vocab_keys), 51865)
def test_vocab_size(self): def test_vocab_size(self):
self.assertEqual(self.get_tokenizer().vocab_size, 50258) self.assertEqual(self.get_tokenizer().vocab_size, 50258)
...@@ -117,7 +116,6 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -117,7 +116,6 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
expected_encoding=expected_encoding, model_name="openai/whisper-tiny.en", padding=False expected_encoding=expected_encoding, model_name="openai/whisper-tiny.en", padding=False
) )
@unittest.skip("TODO @Sanchit. Let's make the CI green in the mean time")
def test_output_offsets(self): def test_output_offsets(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
previous_sequence = [51492, 406, 3163, 1953, 466, 13, 51612, 51612] previous_sequence = [51492, 406, 3163, 1953, 466, 13, 51612, 51612]
...@@ -400,7 +398,6 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): ...@@ -400,7 +398,6 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
transcription = multilingual_tokenizer.batch_decode(batch_encoding, skip_special_tokens=True) transcription = multilingual_tokenizer.batch_decode(batch_encoding, skip_special_tokens=True)
self.assertListEqual(batch, transcription) self.assertListEqual(batch, transcription)
@unittest.skip("TODO @Sanchit. Let's make the CI green in the mean time")
def test_offset_decoding(self): def test_offset_decoding(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny") multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
# fmt: off # fmt: off
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment