Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
211f93aa
Unverified
Commit
211f93aa
authored
Sep 28, 2023
by
Sanchit Gandhi
Committed by
GitHub
Sep 28, 2023
Browse files
[Whisper Tokenizer] Make decoding faster after adding timestamps (#26299)
make decoding faster
parent
4e931a8e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
34 deletions
+30
-34
src/transformers/models/whisper/tokenization_whisper.py
src/transformers/models/whisper/tokenization_whisper.py
+14
-17
src/transformers/models/whisper/tokenization_whisper_fast.py
src/transformers/models/whisper/tokenization_whisper_fast.py
+16
-17
No files found.
src/transformers/models/whisper/tokenization_whisper.py
View file @
211f93aa
...
@@ -314,6 +314,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
...
@@ -314,6 +314,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
self
.
timestamp_pat
=
re
.
compile
(
r
"<\|(\d+\.\d+)\|>"
)
self
.
language
=
language
self
.
language
=
language
super
().
__init__
(
super
().
__init__
(
...
@@ -560,10 +561,12 @@ class WhisperTokenizer(PreTrainedTokenizer):
...
@@ -560,10 +561,12 @@ class WhisperTokenizer(PreTrainedTokenizer):
start_timestamp_position
=
sliced_tokens
[
0
].
item
()
-
timestamp_begin
start_timestamp_position
=
sliced_tokens
[
0
].
item
()
-
timestamp_begin
end_timestamp_position
=
sliced_tokens
[
-
1
].
item
()
-
timestamp_begin
end_timestamp_position
=
sliced_tokens
[
-
1
].
item
()
-
timestamp_begin
# strip timestamp tokens from the text output
# strip timestamp tokens from the text output
sliced_tokens
=
self
.
_preprocess_token_ids
(
sliced_tokens
,
decode_with_timestamps
=
False
)
sliced_tokens
=
self
.
_preprocess_token_ids
(
sliced_tokens
)
text
=
self
.
_decode
(
sliced_tokens
)
text
=
self
.
_filter_timestamp_ids
(
text
)
offsets
.
append
(
offsets
.
append
(
{
{
"text"
:
self
.
_decode
(
sliced_tokens
)
,
"text"
:
text
,
"timestamp"
:
(
"timestamp"
:
(
start_timestamp_position
*
time_precision
,
start_timestamp_position
*
time_precision
,
end_timestamp_position
*
time_precision
,
end_timestamp_position
*
time_precision
,
...
@@ -585,9 +588,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
...
@@ -585,9 +588,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
"""
"""
return
self
.
convert_tokens_to_ids
([(
"<|%.2f|>"
%
(
i
*
time_precision
))
for
i
in
range
(
1500
+
1
)])
return
self
.
convert_tokens_to_ids
([(
"<|%.2f|>"
%
(
i
*
time_precision
))
for
i
in
range
(
1500
+
1
)])
def
_preprocess_token_ids
(
def
_preprocess_token_ids
(
self
,
token_ids
,
skip_special_tokens
:
bool
=
False
):
self
,
token_ids
,
skip_special_tokens
:
bool
=
False
,
decode_with_timestamps
:
bool
=
False
,
time_precision
=
0.02
):
"""
"""
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
...
@@ -597,24 +598,17 @@ class WhisperTokenizer(PreTrainedTokenizer):
...
@@ -597,24 +598,17 @@ class WhisperTokenizer(PreTrainedTokenizer):
skip_special_tokens (`bool`, *optional*, defaults to `False`):
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
removed.
removed.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
filtered out from the token ids.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
"""
if
skip_special_tokens
:
if
skip_special_tokens
:
prompt_token_id
=
self
.
convert_tokens_to_ids
(
"<|startofprev|>"
)
prompt_token_id
=
self
.
convert_tokens_to_ids
(
"<|startofprev|>"
)
decoder_start_token_id
=
self
.
convert_tokens_to_ids
(
"<|startoftranscript|>"
)
decoder_start_token_id
=
self
.
convert_tokens_to_ids
(
"<|startoftranscript|>"
)
token_ids
=
self
.
_strip_prompt
(
token_ids
,
prompt_token_id
,
decoder_start_token_id
)
token_ids
=
self
.
_strip_prompt
(
token_ids
,
prompt_token_id
,
decoder_start_token_id
)
if
not
decode_with_timestamps
:
# filter timestamp tokens if they are contained in the vocab
timestamp_ids
=
self
.
timestamp_ids
(
time_precision
=
time_precision
)
token_ids
=
[
token
for
token
in
token_ids
if
token
not
in
timestamp_ids
]
return
token_ids
return
token_ids
def
_filter_timestamp_ids
(
self
,
token_ids
):
return
re
.
sub
(
self
.
timestamp_pat
,
""
,
token_ids
)
def
decode
(
def
decode
(
self
,
self
,
token_ids
,
token_ids
,
...
@@ -644,6 +638,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
...
@@ -644,6 +638,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
output_offsets (`bool`, *optional*, defaults to `False`):
output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
timestamps.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text.
Whether or not to decode with timestamps included in the raw text.
Returns:
Returns:
...
@@ -652,8 +648,6 @@ class WhisperTokenizer(PreTrainedTokenizer):
...
@@ -652,8 +648,6 @@ class WhisperTokenizer(PreTrainedTokenizer):
filtered_ids
=
self
.
_preprocess_token_ids
(
filtered_ids
=
self
.
_preprocess_token_ids
(
token_ids
,
token_ids
,
skip_special_tokens
=
skip_special_tokens
,
skip_special_tokens
=
skip_special_tokens
,
decode_with_timestamps
=
decode_with_timestamps
,
time_precision
=
time_precision
,
)
)
text
=
super
().
decode
(
text
=
super
().
decode
(
...
@@ -668,6 +662,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
...
@@ -668,6 +662,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
text
=
self
.
_decode_with_timestamps
(
text
=
self
.
_decode_with_timestamps
(
filtered_ids
,
time_precision
=
time_precision
,
skip_special_tokens
=
skip_special_tokens
filtered_ids
,
time_precision
=
time_precision
,
skip_special_tokens
=
skip_special_tokens
)
)
else
:
text
=
self
.
_filter_timestamp_ids
(
text
)
# retrieve offsets
# retrieve offsets
if
output_offsets
:
if
output_offsets
:
offsets
=
self
.
_compute_offsets
(
token_ids
,
time_precision
=
time_precision
)
offsets
=
self
.
_compute_offsets
(
token_ids
,
time_precision
=
time_precision
)
...
...
src/transformers/models/whisper/tokenization_whisper_fast.py
View file @
211f93aa
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
"""Tokenization classes for Whisper."""
"""Tokenization classes for Whisper."""
import
json
import
json
import
os
import
os
import
re
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
...
@@ -190,6 +191,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
...
@@ -190,6 +191,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
self
.
english_spelling_normalizer
=
None
self
.
english_spelling_normalizer
=
None
self
.
add_prefix_space
=
add_prefix_space
self
.
add_prefix_space
=
add_prefix_space
self
.
timestamp_pat
=
re
.
compile
(
r
"<\|(\d+\.\d+)\|>"
)
self
.
language
=
language
self
.
language
=
language
self
.
task
=
task
self
.
task
=
task
...
@@ -269,10 +271,12 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
...
@@ -269,10 +271,12 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
start_timestamp_position
=
sliced_tokens
[
0
].
item
()
-
timestamp_begin
start_timestamp_position
=
sliced_tokens
[
0
].
item
()
-
timestamp_begin
end_timestamp_position
=
sliced_tokens
[
-
1
].
item
()
-
timestamp_begin
end_timestamp_position
=
sliced_tokens
[
-
1
].
item
()
-
timestamp_begin
# strip timestamp tokens from the text output
# strip timestamp tokens from the text output
sliced_tokens
=
self
.
_preprocess_token_ids
(
sliced_tokens
,
decode_with_timestamps
=
False
)
sliced_tokens
=
self
.
_preprocess_token_ids
(
sliced_tokens
)
text
=
self
.
_decode
(
sliced_tokens
)
text
=
self
.
_filter_timestamp_ids
(
text
)
offsets
.
append
(
offsets
.
append
(
{
{
"text"
:
self
.
_decode
(
sliced_tokens
)
,
"text"
:
text
,
"timestamp"
:
(
"timestamp"
:
(
start_timestamp_position
*
time_precision
,
start_timestamp_position
*
time_precision
,
end_timestamp_position
*
time_precision
,
end_timestamp_position
*
time_precision
,
...
@@ -296,9 +300,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
...
@@ -296,9 +300,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
return
self
.
convert_tokens_to_ids
([(
"<|%.2f|>"
%
(
i
*
time_precision
))
for
i
in
range
(
1500
+
1
)])
return
self
.
convert_tokens_to_ids
([(
"<|%.2f|>"
%
(
i
*
time_precision
))
for
i
in
range
(
1500
+
1
)])
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._preprocess_token_ids
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._preprocess_token_ids
def
_preprocess_token_ids
(
def
_preprocess_token_ids
(
self
,
token_ids
,
skip_special_tokens
:
bool
=
False
):
self
,
token_ids
,
skip_special_tokens
:
bool
=
False
,
decode_with_timestamps
:
bool
=
False
,
time_precision
=
0.02
):
"""
"""
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
...
@@ -308,24 +310,18 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
...
@@ -308,24 +310,18 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
skip_special_tokens (`bool`, *optional*, defaults to `False`):
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
removed.
removed.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text. If `False`, timestamps will be
filtered out from the token ids.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
"""
if
skip_special_tokens
:
if
skip_special_tokens
:
prompt_token_id
=
self
.
convert_tokens_to_ids
(
"<|startofprev|>"
)
prompt_token_id
=
self
.
convert_tokens_to_ids
(
"<|startofprev|>"
)
decoder_start_token_id
=
self
.
convert_tokens_to_ids
(
"<|startoftranscript|>"
)
decoder_start_token_id
=
self
.
convert_tokens_to_ids
(
"<|startoftranscript|>"
)
token_ids
=
self
.
_strip_prompt
(
token_ids
,
prompt_token_id
,
decoder_start_token_id
)
token_ids
=
self
.
_strip_prompt
(
token_ids
,
prompt_token_id
,
decoder_start_token_id
)
if
not
decode_with_timestamps
:
# filter timestamp tokens if they are contained in the vocab
timestamp_ids
=
self
.
timestamp_ids
(
time_precision
=
time_precision
)
token_ids
=
[
token
for
token
in
token_ids
if
token
not
in
timestamp_ids
]
return
token_ids
return
token_ids
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._filter_timestamp_ids
def
_filter_timestamp_ids
(
self
,
token_ids
):
return
re
.
sub
(
self
.
timestamp_pat
,
""
,
token_ids
)
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.decode
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.decode
def
decode
(
def
decode
(
self
,
self
,
...
@@ -356,6 +352,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
...
@@ -356,6 +352,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
output_offsets (`bool`, *optional*, defaults to `False`):
output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
timestamps.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text.
Whether or not to decode with timestamps included in the raw text.
Returns:
Returns:
...
@@ -364,8 +362,6 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
...
@@ -364,8 +362,6 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
filtered_ids
=
self
.
_preprocess_token_ids
(
filtered_ids
=
self
.
_preprocess_token_ids
(
token_ids
,
token_ids
,
skip_special_tokens
=
skip_special_tokens
,
skip_special_tokens
=
skip_special_tokens
,
decode_with_timestamps
=
decode_with_timestamps
,
time_precision
=
time_precision
,
)
)
text
=
super
().
decode
(
text
=
super
().
decode
(
...
@@ -380,6 +376,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
...
@@ -380,6 +376,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
text
=
self
.
_decode_with_timestamps
(
text
=
self
.
_decode_with_timestamps
(
filtered_ids
,
time_precision
=
time_precision
,
skip_special_tokens
=
skip_special_tokens
filtered_ids
,
time_precision
=
time_precision
,
skip_special_tokens
=
skip_special_tokens
)
)
else
:
text
=
self
.
_filter_timestamp_ids
(
text
)
# retrieve offsets
# retrieve offsets
if
output_offsets
:
if
output_offsets
:
offsets
=
self
.
_compute_offsets
(
token_ids
,
time_precision
=
time_precision
)
offsets
=
self
.
_compute_offsets
(
token_ids
,
time_precision
=
time_precision
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment