Unverified Commit c9cf3377 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Whisper Tokenizer] Skip special tokens when decoding with timestamps (#23945)

parent 8940d315
...@@ -491,7 +491,7 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -491,7 +491,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer) normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
return normalizer(text) return normalizer(text)
def _decode_with_timestamps(self, token_ids, time_precision=0.02) -> str: def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
""" """
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
given tokens with timestamps tokens annotated, e.g. "<|1.08|>". given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
...@@ -505,7 +505,9 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -505,7 +505,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
outputs.append([]) outputs.append([])
else: else:
outputs[-1].append(token) outputs[-1].append(token)
outputs = [s if isinstance(s, str) else self.decode(s) for s in outputs] outputs = [
s if isinstance(s, str) else self.decode(s, skip_special_tokens=skip_special_tokens) for s in outputs
]
return "".join(outputs) return "".join(outputs)
def _compute_offsets(self, token_ids, time_precision=0.02): def _compute_offsets(self, token_ids, time_precision=0.02):
...@@ -593,7 +595,9 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -593,7 +595,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
**kwargs, **kwargs,
) )
if decode_with_timestamps: if decode_with_timestamps:
text = self._decode_with_timestamps(token_ids, time_precision=time_precision) text = self._decode_with_timestamps(
token_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
)
# retrieve offsets # retrieve offsets
if output_offsets: if output_offsets:
offsets = None offsets = None
......
...@@ -199,7 +199,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -199,7 +199,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
return super()._encode_plus(*args, **kwargs) return super()._encode_plus(*args, **kwargs)
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._decode_with_timestamps # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._decode_with_timestamps
def _decode_with_timestamps(self, token_ids, time_precision=0.02) -> str: def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
""" """
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
given tokens with timestamps tokens annotated, e.g. "<|1.08|>". given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
...@@ -213,7 +213,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -213,7 +213,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
outputs.append([]) outputs.append([])
else: else:
outputs[-1].append(token) outputs[-1].append(token)
outputs = [s if isinstance(s, str) else self.decode(s) for s in outputs] outputs = [
s if isinstance(s, str) else self.decode(s, skip_special_tokens=skip_special_tokens) for s in outputs
]
return "".join(outputs) return "".join(outputs)
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._compute_offsets # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._compute_offsets
...@@ -303,7 +305,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -303,7 +305,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
**kwargs, **kwargs,
) )
if decode_with_timestamps: if decode_with_timestamps:
text = self._decode_with_timestamps(token_ids, time_precision=time_precision) text = self._decode_with_timestamps(
token_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
)
# retrieve offsets # retrieve offsets
if output_offsets: if output_offsets:
offsets = None offsets = None
......
...@@ -213,6 +213,38 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -213,6 +213,38 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens
) )
def test_skip_special_tokens_with_timestamps(self):
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
# fmt: off
encoded_input = [
50258, 50363, 50364, 634, 575, 12525, 22618, 1968, 6144,
35617, 20084, 1756, 311, 589, 307, 534, 10281, 934,
439, 293, 50676, 50676, 393, 4411, 294, 309, 457,
707, 295, 33301, 286, 392, 6628, 13, 50836, 50257,
]
# fmt: on
expected_with_special_tokens = "<|startoftranscript|><|notimestamps|><|0.00|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and<|6.24|><|6.24|> can discover in it but little of rocky Ithaca.<|9.44|><|endoftext|>"
expected_without_special_tokens = "<|0.00|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and<|6.24|><|6.24|> can discover in it but little of rocky Ithaca.<|9.44|>"
self.assertEqual(
tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=False),
expected_with_special_tokens,
)
self.assertEqual(
tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=True),
expected_without_special_tokens,
)
self.assertEqual(
rust_tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=False),
expected_with_special_tokens,
)
self.assertEqual(
rust_tokenizer.decode(encoded_input, decode_with_timestamps=True, skip_special_tokens=True),
expected_without_special_tokens,
)
def test_fast_tokenizer_get_prompt_ids(self): def test_fast_tokenizer_get_prompt_ids(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer() rust_tokenizer = self.get_rust_tokenizer()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment