Unverified Commit 57f44dc4 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Whisper] Allow basic text normalization (#26149)

* [Whisper] Allow basic text normalization

* up

* style copies
parent bd620591
...@@ -23,7 +23,7 @@ import regex as re ...@@ -23,7 +23,7 @@ import regex as re
from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging from ...utils import logging
from .english_normalizer import EnglishTextNormalizer from .english_normalizer import BasicTextNormalizer, EnglishTextNormalizer
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
...@@ -510,6 +510,15 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -510,6 +510,15 @@ class WhisperTokenizer(PreTrainedTokenizer):
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer) normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
return normalizer(text) return normalizer(text)
@staticmethod
def _basic_normalize(text, remove_diacritics=False):
"""
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on
multilingual text.
"""
normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics)
return normalizer(text)
def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str: def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
""" """
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
...@@ -617,6 +626,9 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -617,6 +626,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
output_offsets: bool = False, output_offsets: bool = False,
time_precision=0.02, time_precision=0.02,
decode_with_timestamps: bool = False, decode_with_timestamps: bool = False,
normalize: bool = False,
basic_normalize: bool = False,
remove_diacritics: bool = False,
**kwargs, **kwargs,
) -> str: ) -> str:
""" """
...@@ -633,8 +645,6 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -633,8 +645,6 @@ class WhisperTokenizer(PreTrainedTokenizer):
clean_up_tokenization_spaces (`bool`, *optional*): clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces. If `None`, will default to Whether or not to clean up the tokenization spaces. If `None`, will default to
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`). `self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
output_offsets (`bool`, *optional*, defaults to `False`): output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps. timestamps.
...@@ -642,6 +652,17 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -642,6 +652,17 @@ class WhisperTokenizer(PreTrainedTokenizer):
The time ratio to convert from token to time. The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`): decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. Whether or not to decode with timestamps included in the raw text.
normalize (`bool`, *optional*, defaults to `False`):
Whether or not to apply the English text normalizer to the decoded text. Only applicable when the
target text is in English. Otherwise, the basic text normalizer should be applied.
basic_normalize (`bool`, *optional*, defaults to `False`):
Whether or not to apply the Basic text normalizer to the decoded text. Applicable to multilingual
target text.
remove_diacritics (`bool`, *optional*, defaults to `False`):
Whether or not to remove diacritics when applying the Basic text normalizer. Removing diacritics may
destroy information in the decoded text, hence it should be used with caution.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
Returns: Returns:
`str`: The decoded sentence. `str`: The decoded sentence.
""" """
...@@ -654,7 +675,9 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -654,7 +675,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
filtered_ids, filtered_ids,
skip_special_tokens=skip_special_tokens, skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
decode_with_timestamps=decode_with_timestamps, normalize=normalize,
basic_normalize=basic_normalize,
remove_diacritics=remove_diacritics,
**kwargs, **kwargs,
) )
if decode_with_timestamps: if decode_with_timestamps:
...@@ -676,7 +699,8 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -676,7 +699,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
token_ids: Union[int, List[int]], token_ids: Union[int, List[int]],
skip_special_tokens: bool = False, skip_special_tokens: bool = False,
normalize: bool = False, normalize: bool = False,
decode_with_timestamps: bool = False, basic_normalize: bool = False,
remove_diacritics: bool = False,
**kwargs, **kwargs,
) -> str: ) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
...@@ -705,6 +729,9 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -705,6 +729,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
if normalize: if normalize:
clean_text = self._normalize(text) clean_text = self._normalize(text)
return clean_text return clean_text
elif basic_normalize:
clean_text = self._basic_normalize(text, remove_diacritics=remove_diacritics)
return clean_text
else: else:
return text return text
......
...@@ -25,7 +25,7 @@ from tokenizers import AddedToken, pre_tokenizers, processors ...@@ -25,7 +25,7 @@ from tokenizers import AddedToken, pre_tokenizers, processors
from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_base import BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging from ...utils import logging
from .english_normalizer import EnglishTextNormalizer from .english_normalizer import BasicTextNormalizer, EnglishTextNormalizer
from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer, _decode_asr from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer, _decode_asr
...@@ -331,6 +331,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -331,6 +331,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
output_offsets: bool = False, output_offsets: bool = False,
time_precision=0.02, time_precision=0.02,
decode_with_timestamps: bool = False, decode_with_timestamps: bool = False,
normalize: bool = False,
basic_normalize: bool = False,
remove_diacritics: bool = False,
**kwargs, **kwargs,
) -> str: ) -> str:
""" """
...@@ -347,8 +350,6 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -347,8 +350,6 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
clean_up_tokenization_spaces (`bool`, *optional*): clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces. If `None`, will default to Whether or not to clean up the tokenization spaces. If `None`, will default to
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`). `self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
output_offsets (`bool`, *optional*, defaults to `False`): output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps. timestamps.
...@@ -356,6 +357,17 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -356,6 +357,17 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
The time ratio to convert from token to time. The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`): decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. Whether or not to decode with timestamps included in the raw text.
normalize (`bool`, *optional*, defaults to `False`):
Whether or not to apply the English text normalizer to the decoded text. Only applicable when the
target text is in English. Otherwise, the basic text normalizer should be applied.
basic_normalize (`bool`, *optional*, defaults to `False`):
Whether or not to apply the Basic text normalizer to the decoded text. Applicable to multilingual
target text.
remove_diacritics (`bool`, *optional*, defaults to `False`):
Whether or not to remove diacritics when applying the Basic text normalizer. Removing diacritics may
destroy information in the decoded text, hence it should be used with caution.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
Returns: Returns:
`str`: The decoded sentence. `str`: The decoded sentence.
""" """
...@@ -368,7 +380,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -368,7 +380,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
filtered_ids, filtered_ids,
skip_special_tokens=skip_special_tokens, skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
decode_with_timestamps=decode_with_timestamps, normalize=normalize,
basic_normalize=basic_normalize,
remove_diacritics=remove_diacritics,
**kwargs, **kwargs,
) )
if decode_with_timestamps: if decode_with_timestamps:
...@@ -385,12 +399,17 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -385,12 +399,17 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
return {"text": text, "offsets": offsets} return {"text": text, "offsets": offsets}
return text return text
def _decode(self, *args, normalize: bool = False, **kwargs) -> str: def _decode(
self, *args, normalize: bool = False, basic_normalize: bool = False, remove_diacritics: bool = False, **kwargs
) -> str:
text = super()._decode(*args, **kwargs) text = super()._decode(*args, **kwargs)
if normalize: if normalize:
clean_text = self._normalize(text) clean_text = self._normalize(text)
return clean_text return clean_text
elif basic_normalize:
clean_text = self._basic_normalize(text, remove_diacritics=remove_diacritics)
return clean_text
else: else:
return text return text
...@@ -403,6 +422,16 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -403,6 +422,16 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer) normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
return normalizer(text) return normalizer(text)
@staticmethod
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._basic_normalize
def _basic_normalize(text, remove_diacritics=False):
"""
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on
multilingual text.
"""
normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics)
return normalizer(text)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
files = self._tokenizer.model.save(save_directory, name=filename_prefix) files = self._tokenizer.model.save(save_directory, name=filename_prefix)
......
...@@ -273,6 +273,40 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -273,6 +273,40 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(expected_tokens, output_rust[1]) self.assertEqual(expected_tokens, output_rust[1])
self.assertEqual(expected_indices, output_rust[2]) self.assertEqual(expected_indices, output_rust[2])
def test_basic_normalizer(self):
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
input_str = "Hola güey!"
expected_output_normalize = "hola güey "
expected_output_diacritics = "hola guey "
# tokenizer tests
encoded_input = tokenizer(input_str).input_ids
decoded_output = tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=False)
self.assertEqual(decoded_output, input_str)
decoded_output_normalize = tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=True)
self.assertEqual(decoded_output_normalize, expected_output_normalize)
decoded_output_diacritics = tokenizer.decode(
encoded_input, skip_special_tokens=True, basic_normalize=True, remove_diacritics=True
)
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)
# fast tokenizer tests
encoded_input = rust_tokenizer(input_str).input_ids
decoded_output = rust_tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=False)
self.assertEqual(decoded_output, input_str)
decoded_output_normalize = rust_tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=True)
self.assertEqual(decoded_output_normalize, expected_output_normalize)
decoded_output_diacritics = rust_tokenizer.decode(
encoded_input, skip_special_tokens=True, basic_normalize=True, remove_diacritics=True
)
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
checkpoint_name = "openai/whisper-small.en" checkpoint_name = "openai/whisper-small.en"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment