Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
57f44dc4
Unverified
Commit
57f44dc4
authored
Oct 03, 2023
by
Sanchit Gandhi
Committed by
GitHub
Oct 03, 2023
Browse files
[Whisper] Allow basic text normalization (#26149)
* [Whisper] Allow basic text normalization * up * style copies
parent
bd620591
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
100 additions
and
10 deletions
+100
-10
src/transformers/models/whisper/tokenization_whisper.py
src/transformers/models/whisper/tokenization_whisper.py
+32
-5
src/transformers/models/whisper/tokenization_whisper_fast.py
src/transformers/models/whisper/tokenization_whisper_fast.py
+34
-5
tests/models/whisper/test_tokenization_whisper.py
tests/models/whisper/test_tokenization_whisper.py
+34
-0
No files found.
src/transformers/models/whisper/tokenization_whisper.py
View file @
57f44dc4
...
@@ -23,7 +23,7 @@ import regex as re
...
@@ -23,7 +23,7 @@ import regex as re
from
...tokenization_utils
import
AddedToken
,
PreTrainedTokenizer
from
...tokenization_utils
import
AddedToken
,
PreTrainedTokenizer
from
...utils
import
logging
from
...utils
import
logging
from
.english_normalizer
import
EnglishTextNormalizer
from
.english_normalizer
import
BasicTextNormalizer
,
EnglishTextNormalizer
VOCAB_FILES_NAMES
=
{
VOCAB_FILES_NAMES
=
{
...
@@ -510,6 +510,15 @@ class WhisperTokenizer(PreTrainedTokenizer):
...
@@ -510,6 +510,15 @@ class WhisperTokenizer(PreTrainedTokenizer):
normalizer
=
EnglishTextNormalizer
(
self
.
english_spelling_normalizer
)
normalizer
=
EnglishTextNormalizer
(
self
.
english_spelling_normalizer
)
return
normalizer
(
text
)
return
normalizer
(
text
)
@
staticmethod
def
_basic_normalize
(
text
,
remove_diacritics
=
False
):
"""
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on
multilingual text.
"""
normalizer
=
BasicTextNormalizer
(
remove_diacritics
=
remove_diacritics
)
return
normalizer
(
text
)
def
_decode_with_timestamps
(
self
,
token_ids
,
skip_special_tokens
=
False
,
time_precision
=
0.02
)
->
str
:
def
_decode_with_timestamps
(
self
,
token_ids
,
skip_special_tokens
=
False
,
time_precision
=
0.02
)
->
str
:
"""
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
...
@@ -617,6 +626,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
...
@@ -617,6 +626,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
output_offsets
:
bool
=
False
,
output_offsets
:
bool
=
False
,
time_precision
=
0.02
,
time_precision
=
0.02
,
decode_with_timestamps
:
bool
=
False
,
decode_with_timestamps
:
bool
=
False
,
normalize
:
bool
=
False
,
basic_normalize
:
bool
=
False
,
remove_diacritics
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
)
->
str
:
)
->
str
:
"""
"""
...
@@ -633,8 +645,6 @@ class WhisperTokenizer(PreTrainedTokenizer):
...
@@ -633,8 +645,6 @@ class WhisperTokenizer(PreTrainedTokenizer):
clean_up_tokenization_spaces (`bool`, *optional*):
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces. If `None`, will default to
Whether or not to clean up the tokenization spaces. If `None`, will default to
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
output_offsets (`bool`, *optional*, defaults to `False`):
output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
timestamps.
...
@@ -642,6 +652,17 @@ class WhisperTokenizer(PreTrainedTokenizer):
...
@@ -642,6 +652,17 @@ class WhisperTokenizer(PreTrainedTokenizer):
The time ratio to convert from token to time.
The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text.
Whether or not to decode with timestamps included in the raw text.
normalize (`bool`, *optional*, defaults to `False`):
Whether or not to apply the English text normalizer to the decoded text. Only applicable when the
target text is in English. Otherwise, the basic text normalizer should be applied.
basic_normalize (`bool`, *optional*, defaults to `False`):
Whether or not to apply the Basic text normalizer to the decoded text. Applicable to multilingual
target text.
remove_diacritics (`bool`, *optional*, defaults to `False`):
Whether or not to remove diacritics when applying the Basic text normalizer. Removing diacritics may
destroy information in the decoded text, hence it should be used with caution.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
Returns:
Returns:
`str`: The decoded sentence.
`str`: The decoded sentence.
"""
"""
...
@@ -654,7 +675,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
...
@@ -654,7 +675,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
filtered_ids
,
filtered_ids
,
skip_special_tokens
=
skip_special_tokens
,
skip_special_tokens
=
skip_special_tokens
,
clean_up_tokenization_spaces
=
clean_up_tokenization_spaces
,
clean_up_tokenization_spaces
=
clean_up_tokenization_spaces
,
decode_with_timestamps
=
decode_with_timestamps
,
normalize
=
normalize
,
basic_normalize
=
basic_normalize
,
remove_diacritics
=
remove_diacritics
,
**
kwargs
,
**
kwargs
,
)
)
if
decode_with_timestamps
:
if
decode_with_timestamps
:
...
@@ -676,7 +699,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
...
@@ -676,7 +699,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
token_ids
:
Union
[
int
,
List
[
int
]],
token_ids
:
Union
[
int
,
List
[
int
]],
skip_special_tokens
:
bool
=
False
,
skip_special_tokens
:
bool
=
False
,
normalize
:
bool
=
False
,
normalize
:
bool
=
False
,
decode_with_timestamps
:
bool
=
False
,
basic_normalize
:
bool
=
False
,
remove_diacritics
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
)
->
str
:
)
->
str
:
self
.
_decode_use_source_tokenizer
=
kwargs
.
pop
(
"use_source_tokenizer"
,
False
)
self
.
_decode_use_source_tokenizer
=
kwargs
.
pop
(
"use_source_tokenizer"
,
False
)
...
@@ -705,6 +729,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
...
@@ -705,6 +729,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
if
normalize
:
if
normalize
:
clean_text
=
self
.
_normalize
(
text
)
clean_text
=
self
.
_normalize
(
text
)
return
clean_text
return
clean_text
elif
basic_normalize
:
clean_text
=
self
.
_basic_normalize
(
text
,
remove_diacritics
=
remove_diacritics
)
return
clean_text
else
:
else
:
return
text
return
text
...
...
src/transformers/models/whisper/tokenization_whisper_fast.py
View file @
57f44dc4
...
@@ -25,7 +25,7 @@ from tokenizers import AddedToken, pre_tokenizers, processors
...
@@ -25,7 +25,7 @@ from tokenizers import AddedToken, pre_tokenizers, processors
from
...tokenization_utils_base
import
BatchEncoding
from
...tokenization_utils_base
import
BatchEncoding
from
...tokenization_utils_fast
import
PreTrainedTokenizerFast
from
...tokenization_utils_fast
import
PreTrainedTokenizerFast
from
...utils
import
logging
from
...utils
import
logging
from
.english_normalizer
import
EnglishTextNormalizer
from
.english_normalizer
import
BasicTextNormalizer
,
EnglishTextNormalizer
from
.tokenization_whisper
import
LANGUAGES
,
TASK_IDS
,
TO_LANGUAGE_CODE
,
WhisperTokenizer
,
_decode_asr
from
.tokenization_whisper
import
LANGUAGES
,
TASK_IDS
,
TO_LANGUAGE_CODE
,
WhisperTokenizer
,
_decode_asr
...
@@ -331,6 +331,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
...
@@ -331,6 +331,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
output_offsets
:
bool
=
False
,
output_offsets
:
bool
=
False
,
time_precision
=
0.02
,
time_precision
=
0.02
,
decode_with_timestamps
:
bool
=
False
,
decode_with_timestamps
:
bool
=
False
,
normalize
:
bool
=
False
,
basic_normalize
:
bool
=
False
,
remove_diacritics
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
)
->
str
:
)
->
str
:
"""
"""
...
@@ -347,8 +350,6 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
...
@@ -347,8 +350,6 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
clean_up_tokenization_spaces (`bool`, *optional*):
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces. If `None`, will default to
Whether or not to clean up the tokenization spaces. If `None`, will default to
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
output_offsets (`bool`, *optional*, defaults to `False`):
output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
timestamps.
...
@@ -356,6 +357,17 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
...
@@ -356,6 +357,17 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
The time ratio to convert from token to time.
The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text.
Whether or not to decode with timestamps included in the raw text.
normalize (`bool`, *optional*, defaults to `False`):
Whether or not to apply the English text normalizer to the decoded text. Only applicable when the
target text is in English. Otherwise, the basic text normalizer should be applied.
basic_normalize (`bool`, *optional*, defaults to `False`):
Whether or not to apply the Basic text normalizer to the decoded text. Applicable to multilingual
target text.
remove_diacritics (`bool`, *optional*, defaults to `False`):
Whether or not to remove diacritics when applying the Basic text normalizer. Removing diacritics may
destroy information in the decoded text, hence it should be used with caution.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
Returns:
Returns:
`str`: The decoded sentence.
`str`: The decoded sentence.
"""
"""
...
@@ -368,7 +380,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
...
@@ -368,7 +380,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
filtered_ids
,
filtered_ids
,
skip_special_tokens
=
skip_special_tokens
,
skip_special_tokens
=
skip_special_tokens
,
clean_up_tokenization_spaces
=
clean_up_tokenization_spaces
,
clean_up_tokenization_spaces
=
clean_up_tokenization_spaces
,
decode_with_timestamps
=
decode_with_timestamps
,
normalize
=
normalize
,
basic_normalize
=
basic_normalize
,
remove_diacritics
=
remove_diacritics
,
**
kwargs
,
**
kwargs
,
)
)
if
decode_with_timestamps
:
if
decode_with_timestamps
:
...
@@ -385,12 +399,17 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
...
@@ -385,12 +399,17 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
return
{
"text"
:
text
,
"offsets"
:
offsets
}
return
{
"text"
:
text
,
"offsets"
:
offsets
}
return
text
return
text
def
_decode
(
self
,
*
args
,
normalize
:
bool
=
False
,
**
kwargs
)
->
str
:
def
_decode
(
self
,
*
args
,
normalize
:
bool
=
False
,
basic_normalize
:
bool
=
False
,
remove_diacritics
:
bool
=
False
,
**
kwargs
)
->
str
:
text
=
super
().
_decode
(
*
args
,
**
kwargs
)
text
=
super
().
_decode
(
*
args
,
**
kwargs
)
if
normalize
:
if
normalize
:
clean_text
=
self
.
_normalize
(
text
)
clean_text
=
self
.
_normalize
(
text
)
return
clean_text
return
clean_text
elif
basic_normalize
:
clean_text
=
self
.
_basic_normalize
(
text
,
remove_diacritics
=
remove_diacritics
)
return
clean_text
else
:
else
:
return
text
return
text
...
@@ -403,6 +422,16 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
...
@@ -403,6 +422,16 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
normalizer
=
EnglishTextNormalizer
(
self
.
english_spelling_normalizer
)
normalizer
=
EnglishTextNormalizer
(
self
.
english_spelling_normalizer
)
return
normalizer
(
text
)
return
normalizer
(
text
)
@
staticmethod
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._basic_normalize
def
_basic_normalize
(
text
,
remove_diacritics
=
False
):
"""
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on
multilingual text.
"""
normalizer
=
BasicTextNormalizer
(
remove_diacritics
=
remove_diacritics
)
return
normalizer
(
text
)
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
:
Optional
[
str
]
=
None
)
->
Tuple
[
str
]:
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
:
Optional
[
str
]
=
None
)
->
Tuple
[
str
]:
files
=
self
.
_tokenizer
.
model
.
save
(
save_directory
,
name
=
filename_prefix
)
files
=
self
.
_tokenizer
.
model
.
save
(
save_directory
,
name
=
filename_prefix
)
...
...
tests/models/whisper/test_tokenization_whisper.py
View file @
57f44dc4
...
@@ -273,6 +273,40 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
...
@@ -273,6 +273,40 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
self
.
assertEqual
(
expected_tokens
,
output_rust
[
1
])
self
.
assertEqual
(
expected_tokens
,
output_rust
[
1
])
self
.
assertEqual
(
expected_indices
,
output_rust
[
2
])
self
.
assertEqual
(
expected_indices
,
output_rust
[
2
])
def
test_basic_normalizer
(
self
):
tokenizer
=
self
.
get_tokenizer
()
rust_tokenizer
=
self
.
get_rust_tokenizer
()
input_str
=
"Hola güey!"
expected_output_normalize
=
"hola güey "
expected_output_diacritics
=
"hola guey "
# tokenizer tests
encoded_input
=
tokenizer
(
input_str
).
input_ids
decoded_output
=
tokenizer
.
decode
(
encoded_input
,
skip_special_tokens
=
True
,
basic_normalize
=
False
)
self
.
assertEqual
(
decoded_output
,
input_str
)
decoded_output_normalize
=
tokenizer
.
decode
(
encoded_input
,
skip_special_tokens
=
True
,
basic_normalize
=
True
)
self
.
assertEqual
(
decoded_output_normalize
,
expected_output_normalize
)
decoded_output_diacritics
=
tokenizer
.
decode
(
encoded_input
,
skip_special_tokens
=
True
,
basic_normalize
=
True
,
remove_diacritics
=
True
)
self
.
assertEqual
(
decoded_output_diacritics
,
expected_output_diacritics
)
# fast tokenizer tests
encoded_input
=
rust_tokenizer
(
input_str
).
input_ids
decoded_output
=
rust_tokenizer
.
decode
(
encoded_input
,
skip_special_tokens
=
True
,
basic_normalize
=
False
)
self
.
assertEqual
(
decoded_output
,
input_str
)
decoded_output_normalize
=
rust_tokenizer
.
decode
(
encoded_input
,
skip_special_tokens
=
True
,
basic_normalize
=
True
)
self
.
assertEqual
(
decoded_output_normalize
,
expected_output_normalize
)
decoded_output_diacritics
=
rust_tokenizer
.
decode
(
encoded_input
,
skip_special_tokens
=
True
,
basic_normalize
=
True
,
remove_diacritics
=
True
)
self
.
assertEqual
(
decoded_output_diacritics
,
expected_output_diacritics
)
class
SpeechToTextTokenizerMultilinguialTest
(
unittest
.
TestCase
):
class
SpeechToTextTokenizerMultilinguialTest
(
unittest
.
TestCase
):
checkpoint_name
=
"openai/whisper-small.en"
checkpoint_name
=
"openai/whisper-small.en"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment