Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
da3c79b2
Unverified
Commit
da3c79b2
authored
Jan 29, 2024
by
Sanchit Gandhi
Committed by
GitHub
Jan 29, 2024
Browse files
[Whisper] Make tokenizer normalization public (#28136)
* [Whisper] Make tokenizer normalization public * add to docs
parent
e694e985
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
5 deletions
+41
-5
docs/source/en/model_doc/whisper.md
docs/source/en/model_doc/whisper.md
+4
-0
src/transformers/models/whisper/tokenization_whisper.py
src/transformers/models/whisper/tokenization_whisper.py
+18
-3
src/transformers/models/whisper/tokenization_whisper_fast.py
src/transformers/models/whisper/tokenization_whisper_fast.py
+19
-2
No files found.
docs/source/en/model_doc/whisper.md
View file @
da3c79b2
...
@@ -102,6 +102,8 @@ python convert_hf_to_openai.py \
...
@@ -102,6 +102,8 @@ python convert_hf_to_openai.py \
-
save_vocabulary
-
save_vocabulary
-
batch_decode
-
batch_decode
-
decode
-
decode
-
basic_normalize
-
normalize
## WhisperTokenizerFast
## WhisperTokenizerFast
...
@@ -113,6 +115,8 @@ python convert_hf_to_openai.py \
...
@@ -113,6 +115,8 @@ python convert_hf_to_openai.py \
-
save_vocabulary
-
save_vocabulary
-
batch_decode
-
batch_decode
-
decode
-
decode
-
basic_normalize
-
normalize
## WhisperFeatureExtractor
## WhisperFeatureExtractor
...
...
src/transformers/models/whisper/tokenization_whisper.py
View file @
da3c79b2
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
"""Tokenization classes for Whisper."""
"""Tokenization classes for Whisper."""
import
json
import
json
import
os
import
os
import
warnings
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
...
@@ -507,6 +508,20 @@ class WhisperTokenizer(PreTrainedTokenizer):
...
@@ -507,6 +508,20 @@ class WhisperTokenizer(PreTrainedTokenizer):
return
self
.
decoder
.
get
(
index
,
""
)
return
self
.
decoder
.
get
(
index
,
""
)
def
_normalize
(
self
,
text
):
def
_normalize
(
self
,
text
):
warnings
.
warn
(
"The private method `_normalize` is deprecated and will be removed in v5 of Transformers."
"You can normalize an input string using the Whisper English normalizer using the `normalize` method."
)
return
self
.
normalize
(
text
)
def
_basic_normalize
(
self
,
text
,
remove_diacritics
=
False
):
warnings
.
warn
(
"The private method `_basic_normalize` is deprecated and will be removed in v5 of Transformers."
"You can normalize an input string using the Whisper basic normalizer using the `basic_normalize` method."
)
return
self
.
basic_normalize
(
text
,
remove_diacritics
=
remove_diacritics
)
def
normalize
(
self
,
text
):
"""
"""
Normalize a given string using the `EnglishTextNormalizer` class, which preforms commons transformation on
Normalize a given string using the `EnglishTextNormalizer` class, which preforms commons transformation on
english text.
english text.
...
@@ -515,7 +530,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
...
@@ -515,7 +530,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
return
normalizer
(
text
)
return
normalizer
(
text
)
@
staticmethod
@
staticmethod
def
_
basic_normalize
(
text
,
remove_diacritics
=
False
):
def
basic_normalize
(
text
,
remove_diacritics
=
False
):
"""
"""
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on
multilingual text.
multilingual text.
...
@@ -745,10 +760,10 @@ class WhisperTokenizer(PreTrainedTokenizer):
...
@@ -745,10 +760,10 @@ class WhisperTokenizer(PreTrainedTokenizer):
text
=
""
.
join
(
sub_texts
)
text
=
""
.
join
(
sub_texts
)
if
normalize
:
if
normalize
:
clean_text
=
self
.
_
normalize
(
text
)
clean_text
=
self
.
normalize
(
text
)
return
clean_text
return
clean_text
elif
basic_normalize
:
elif
basic_normalize
:
clean_text
=
self
.
_
basic_normalize
(
text
,
remove_diacritics
=
remove_diacritics
)
clean_text
=
self
.
basic_normalize
(
text
,
remove_diacritics
=
remove_diacritics
)
return
clean_text
return
clean_text
else
:
else
:
return
text
return
text
...
...
src/transformers/models/whisper/tokenization_whisper_fast.py
View file @
da3c79b2
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
import
json
import
json
import
os
import
os
import
re
import
re
import
warnings
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
...
@@ -427,6 +428,22 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
...
@@ -427,6 +428,22 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._normalize
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._normalize
def
_normalize
(
self
,
text
):
def
_normalize
(
self
,
text
):
warnings
.
warn
(
"The private method `_normalize` is deprecated and will be removed in v5 of Transformers."
"You can normalize an input string using the Whisper English normalizer using the `normalize` method."
)
return
self
.
normalize
(
text
)
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._basic_normalize
def
_basic_normalize
(
self
,
text
,
remove_diacritics
=
False
):
warnings
.
warn
(
"The private method `_basic_normalize` is deprecated and will be removed in v5 of Transformers."
"You can normalize an input string using the Whisper basic normalizer using the `basic_normalize` method."
)
return
self
.
basic_normalize
(
text
,
remove_diacritics
=
remove_diacritics
)
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.normalize
def
normalize
(
self
,
text
):
"""
"""
Normalize a given string using the `EnglishTextNormalizer` class, which preforms commons transformation on
Normalize a given string using the `EnglishTextNormalizer` class, which preforms commons transformation on
english text.
english text.
...
@@ -435,8 +452,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
...
@@ -435,8 +452,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
return
normalizer
(
text
)
return
normalizer
(
text
)
@
staticmethod
@
staticmethod
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.
_
basic_normalize
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.basic_normalize
def
_
basic_normalize
(
text
,
remove_diacritics
=
False
):
def
basic_normalize
(
text
,
remove_diacritics
=
False
):
"""
"""
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on
multilingual text.
multilingual text.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment