Unverified Commit 57f44dc4 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Whisper] Allow basic text normalization (#26149)

* [Whisper] Allow basic text normalization

* up

* style copies
parent bd620591
......@@ -23,7 +23,7 @@ import regex as re
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging
from .english_normalizer import EnglishTextNormalizer
from .english_normalizer import BasicTextNormalizer, EnglishTextNormalizer
VOCAB_FILES_NAMES = {
......@@ -510,6 +510,15 @@ class WhisperTokenizer(PreTrainedTokenizer):
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
return normalizer(text)
@staticmethod
def _basic_normalize(text, remove_diacritics=False):
"""
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on
multilingual text.
"""
normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics)
return normalizer(text)
def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
......@@ -617,6 +626,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
output_offsets: bool = False,
time_precision=0.02,
decode_with_timestamps: bool = False,
normalize: bool = False,
basic_normalize: bool = False,
remove_diacritics: bool = False,
**kwargs,
) -> str:
"""
......@@ -633,8 +645,6 @@ class WhisperTokenizer(PreTrainedTokenizer):
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces. If `None`, will default to
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
......@@ -642,6 +652,17 @@ class WhisperTokenizer(PreTrainedTokenizer):
The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text.
normalize (`bool`, *optional*, defaults to `False`):
Whether or not to apply the English text normalizer to the decoded text. Only applicable when the
target text is in English. Otherwise, the basic text normalizer should be applied.
basic_normalize (`bool`, *optional*, defaults to `False`):
Whether or not to apply the Basic text normalizer to the decoded text. Applicable to multilingual
target text.
remove_diacritics (`bool`, *optional*, defaults to `False`):
Whether or not to remove diacritics when applying the Basic text normalizer. Removing diacritics may
destroy information in the decoded text, hence it should be used with caution.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
Returns:
`str`: The decoded sentence.
"""
......@@ -654,7 +675,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
filtered_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
decode_with_timestamps=decode_with_timestamps,
normalize=normalize,
basic_normalize=basic_normalize,
remove_diacritics=remove_diacritics,
**kwargs,
)
if decode_with_timestamps:
......@@ -676,7 +699,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
normalize: bool = False,
decode_with_timestamps: bool = False,
basic_normalize: bool = False,
remove_diacritics: bool = False,
**kwargs,
) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
......@@ -705,6 +729,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
if normalize:
clean_text = self._normalize(text)
return clean_text
elif basic_normalize:
clean_text = self._basic_normalize(text, remove_diacritics=remove_diacritics)
return clean_text
else:
return text
......
......@@ -25,7 +25,7 @@ from tokenizers import AddedToken, pre_tokenizers, processors
from ...tokenization_utils_base import BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
from .english_normalizer import EnglishTextNormalizer
from .english_normalizer import BasicTextNormalizer, EnglishTextNormalizer
from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer, _decode_asr
......@@ -331,6 +331,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
output_offsets: bool = False,
time_precision=0.02,
decode_with_timestamps: bool = False,
normalize: bool = False,
basic_normalize: bool = False,
remove_diacritics: bool = False,
**kwargs,
) -> str:
"""
......@@ -347,8 +350,6 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces. If `None`, will default to
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
......@@ -356,6 +357,17 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text.
normalize (`bool`, *optional*, defaults to `False`):
Whether or not to apply the English text normalizer to the decoded text. Only applicable when the
target text is in English. Otherwise, the basic text normalizer should be applied.
basic_normalize (`bool`, *optional*, defaults to `False`):
Whether or not to apply the Basic text normalizer to the decoded text. Applicable to multilingual
target text.
remove_diacritics (`bool`, *optional*, defaults to `False`):
Whether or not to remove diacritics when applying the Basic text normalizer. Removing diacritics may
destroy information in the decoded text, hence it should be used with caution.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
Returns:
`str`: The decoded sentence.
"""
......@@ -368,7 +380,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
filtered_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
decode_with_timestamps=decode_with_timestamps,
normalize=normalize,
basic_normalize=basic_normalize,
remove_diacritics=remove_diacritics,
**kwargs,
)
if decode_with_timestamps:
......@@ -385,12 +399,17 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
return {"text": text, "offsets": offsets}
return text
def _decode(self, *args, normalize: bool = False, **kwargs) -> str:
def _decode(
self, *args, normalize: bool = False, basic_normalize: bool = False, remove_diacritics: bool = False, **kwargs
) -> str:
text = super()._decode(*args, **kwargs)
if normalize:
clean_text = self._normalize(text)
return clean_text
elif basic_normalize:
clean_text = self._basic_normalize(text, remove_diacritics=remove_diacritics)
return clean_text
else:
return text
......@@ -403,6 +422,16 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
return normalizer(text)
@staticmethod
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._basic_normalize
def _basic_normalize(text, remove_diacritics=False):
"""
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on
multilingual text.
"""
normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics)
return normalizer(text)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
......
......@@ -273,6 +273,40 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(expected_tokens, output_rust[1])
self.assertEqual(expected_indices, output_rust[2])
def test_basic_normalizer(self):
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
input_str = "Hola güey!"
expected_output_normalize = "hola güey "
expected_output_diacritics = "hola guey "
# tokenizer tests
encoded_input = tokenizer(input_str).input_ids
decoded_output = tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=False)
self.assertEqual(decoded_output, input_str)
decoded_output_normalize = tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=True)
self.assertEqual(decoded_output_normalize, expected_output_normalize)
decoded_output_diacritics = tokenizer.decode(
encoded_input, skip_special_tokens=True, basic_normalize=True, remove_diacritics=True
)
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)
# fast tokenizer tests
encoded_input = rust_tokenizer(input_str).input_ids
decoded_output = rust_tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=False)
self.assertEqual(decoded_output, input_str)
decoded_output_normalize = rust_tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=True)
self.assertEqual(decoded_output_normalize, expected_output_normalize)
decoded_output_diacritics = rust_tokenizer.decode(
encoded_input, skip_special_tokens=True, basic_normalize=True, remove_diacritics=True
)
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
checkpoint_name = "openai/whisper-small.en"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment