Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
211f93aa
Unverified
Commit
211f93aa
authored
Sep 28, 2023
by
Sanchit Gandhi
Committed by
GitHub
Sep 28, 2023
Browse files
[Whisper Tokenizer] Make decoding faster after adding timestamps (#26299)
make decoding faster
parent
4e931a8e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
34 deletions
+30
-34
src/transformers/models/whisper/tokenization_whisper.py
src/transformers/models/whisper/tokenization_whisper.py
+14
-17
src/transformers/models/whisper/tokenization_whisper_fast.py
src/transformers/models/whisper/tokenization_whisper_fast.py
+16
-17
No files found.
src/transformers/models/whisper/tokenization_whisper.py
View file @
211f93aa
...
...
@@ -314,6 +314,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
self
.
timestamp_pat
=
re
.
compile
(
r
"<\|(\d+\.\d+)\|>"
)
self
.
language
=
language
super
().
__init__
(
...
...
@@ -560,10 +561,12 @@ class WhisperTokenizer(PreTrainedTokenizer):
start_timestamp_position
=
sliced_tokens
[
0
].
item
()
-
timestamp_begin
end_timestamp_position
=
sliced_tokens
[
-
1
].
item
()
-
timestamp_begin
# strip timestamp tokens from the text output
sliced_tokens
=
self
.
_preprocess_token_ids
(
sliced_tokens
,
decode_with_timestamps
=
False
)
sliced_tokens
=
self
.
_preprocess_token_ids
(
sliced_tokens
)
text
=
self
.
_decode
(
sliced_tokens
)
text
=
self
.
_filter_timestamp_ids
(
text
)
offsets
.
append
(
{
"text"
:
self
.
_decode
(
sliced_tokens
)
,
"text"
:
text
,
"timestamp"
:
(
start_timestamp_position
*
time_precision
,
end_timestamp_position
*
time_precision
,
...
...
@@ -585,9 +588,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
"""
return
self
.
convert_tokens_to_ids
([(
"<|%.2f|>"
%
(
i
*
time_precision
))
for
i
in
range
(
1500
+
1
)])
def
_preprocess_token_ids
(
self
,
token_ids
,
skip_special_tokens
:
bool
=
False
,
decode_with_timestamps
:
bool
=
False
,
time_precision
=
0.02
):
def
_preprocess_token_ids
(
self
,
token_ids
,
skip_special_tokens
:
bool
=
False
):
"""
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
...
...
@@ -597,24 +598,17 @@ class WhisperTokenizer(PreTrainedTokenizer):
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
removed.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
filtered out from the token ids.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
if
skip_special_tokens
:
prompt_token_id
=
self
.
convert_tokens_to_ids
(
"<|startofprev|>"
)
decoder_start_token_id
=
self
.
convert_tokens_to_ids
(
"<|startoftranscript|>"
)
token_ids
=
self
.
_strip_prompt
(
token_ids
,
prompt_token_id
,
decoder_start_token_id
)
if
not
decode_with_timestamps
:
# filter timestamp tokens if they are contained in the vocab
timestamp_ids
=
self
.
timestamp_ids
(
time_precision
=
time_precision
)
token_ids
=
[
token
for
token
in
token_ids
if
token
not
in
timestamp_ids
]
return
token_ids
def
_filter_timestamp_ids
(
self
,
token_ids
):
return
re
.
sub
(
self
.
timestamp_pat
,
""
,
token_ids
)
def
decode
(
self
,
token_ids
,
...
...
@@ -644,6 +638,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text.
Returns:
...
...
@@ -652,8 +648,6 @@ class WhisperTokenizer(PreTrainedTokenizer):
filtered_ids
=
self
.
_preprocess_token_ids
(
token_ids
,
skip_special_tokens
=
skip_special_tokens
,
decode_with_timestamps
=
decode_with_timestamps
,
time_precision
=
time_precision
,
)
text
=
super
().
decode
(
...
...
@@ -668,6 +662,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
text
=
self
.
_decode_with_timestamps
(
filtered_ids
,
time_precision
=
time_precision
,
skip_special_tokens
=
skip_special_tokens
)
else
:
text
=
self
.
_filter_timestamp_ids
(
text
)
# retrieve offsets
if
output_offsets
:
offsets
=
self
.
_compute_offsets
(
token_ids
,
time_precision
=
time_precision
)
...
...
src/transformers/models/whisper/tokenization_whisper_fast.py
View file @
211f93aa
...
...
@@ -15,6 +15,7 @@
"""Tokenization classes for Whisper."""
import
json
import
os
import
re
from
functools
import
lru_cache
from
typing
import
List
,
Optional
,
Tuple
...
...
@@ -190,6 +191,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
self
.
english_spelling_normalizer
=
None
self
.
add_prefix_space
=
add_prefix_space
self
.
timestamp_pat
=
re
.
compile
(
r
"<\|(\d+\.\d+)\|>"
)
self
.
language
=
language
self
.
task
=
task
...
...
@@ -269,10 +271,12 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
start_timestamp_position
=
sliced_tokens
[
0
].
item
()
-
timestamp_begin
end_timestamp_position
=
sliced_tokens
[
-
1
].
item
()
-
timestamp_begin
# strip timestamp tokens from the text output
sliced_tokens
=
self
.
_preprocess_token_ids
(
sliced_tokens
,
decode_with_timestamps
=
False
)
sliced_tokens
=
self
.
_preprocess_token_ids
(
sliced_tokens
)
text
=
self
.
_decode
(
sliced_tokens
)
text
=
self
.
_filter_timestamp_ids
(
text
)
offsets
.
append
(
{
"text"
:
self
.
_decode
(
sliced_tokens
)
,
"text"
:
text
,
"timestamp"
:
(
start_timestamp_position
*
time_precision
,
end_timestamp_position
*
time_precision
,
...
...
@@ -296,9 +300,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
return
self
.
convert_tokens_to_ids
([(
"<|%.2f|>"
%
(
i
*
time_precision
))
for
i
in
range
(
1500
+
1
)])
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._preprocess_token_ids
def
_preprocess_token_ids
(
self
,
token_ids
,
skip_special_tokens
:
bool
=
False
,
decode_with_timestamps
:
bool
=
False
,
time_precision
=
0.02
):
def
_preprocess_token_ids
(
self
,
token_ids
,
skip_special_tokens
:
bool
=
False
):
"""
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
...
...
@@ -308,24 +310,18 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
removed.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
filtered out from the token ids.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
if
skip_special_tokens
:
prompt_token_id
=
self
.
convert_tokens_to_ids
(
"<|startofprev|>"
)
decoder_start_token_id
=
self
.
convert_tokens_to_ids
(
"<|startoftranscript|>"
)
token_ids
=
self
.
_strip_prompt
(
token_ids
,
prompt_token_id
,
decoder_start_token_id
)
if
not
decode_with_timestamps
:
# filter timestamp tokens if they are contained in the vocab
timestamp_ids
=
self
.
timestamp_ids
(
time_precision
=
time_precision
)
token_ids
=
[
token
for
token
in
token_ids
if
token
not
in
timestamp_ids
]
return
token_ids
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._filter_timestamp_ids
def
_filter_timestamp_ids
(
self
,
token_ids
):
return
re
.
sub
(
self
.
timestamp_pat
,
""
,
token_ids
)
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.decode
def
decode
(
self
,
...
...
@@ -356,6 +352,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text.
Returns:
...
...
@@ -364,8 +362,6 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
filtered_ids
=
self
.
_preprocess_token_ids
(
token_ids
,
skip_special_tokens
=
skip_special_tokens
,
decode_with_timestamps
=
decode_with_timestamps
,
time_precision
=
time_precision
,
)
text
=
super
().
decode
(
...
...
@@ -380,6 +376,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
text
=
self
.
_decode_with_timestamps
(
filtered_ids
,
time_precision
=
time_precision
,
skip_special_tokens
=
skip_special_tokens
)
else
:
text
=
self
.
_filter_timestamp_ids
(
text
)
# retrieve offsets
if
output_offsets
:
offsets
=
self
.
_compute_offsets
(
token_ids
,
time_precision
=
time_precision
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment