Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e7e6d181
Unverified
Commit
e7e6d181
authored
Dec 05, 2022
by
Sanchit Gandhi
Committed by
GitHub
Dec 05, 2022
Browse files
[Whisper] Move decoder id method to tokenizer (#20589)
parent
9ffbed26
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
32 deletions
+10
-32
src/transformers/models/whisper/processing_whisper.py
src/transformers/models/whisper/processing_whisper.py
+1
-31
src/transformers/models/whisper/tokenization_whisper.py
src/transformers/models/whisper/tokenization_whisper.py
+9
-1
No files found.
src/transformers/models/whisper/processing_whisper.py
View file @
e7e6d181
...
...
@@ -42,37 +42,7 @@ class WhisperProcessor(ProcessorMixin):
self
.
_in_target_context_manager
=
False
def
get_decoder_prompt_ids
(
self
,
task
=
None
,
language
=
None
,
no_timestamps
=
True
):
forced_decoder_tokens
=
""
if
language
is
not
None
:
if
f
"<|
{
language
}
|>"
not
in
self
.
tokenizer
.
additional_special_tokens
:
raise
ValueError
(
f
"
{
language
}
is not supported. The language should be one of the following: '<|en|>',"
" '<|zh|>', '<|de|>', '<|es|>', '<|ru|>', '<|ko|>', '<|fr|>', '<|ja|>', '<|pt|>', '<|tr|>',"
" '<|pl|>', '<|ca|>', '<|nl|>', '<|ar|>', '<|sv|>', '<|it|>', '<|id|>', '<|hi|>', '<|fi|>',"
" '<|vi|>', '<|iw|>', '<|uk|>', '<|el|>', '<|ms|>', '<|cs|>', '<|ro|>', '<|da|>', '<|hu|>',"
" '<|ta|>', '<|no|>', '<|th|>', '<|ur|>', '<|hr|>', '<|bg|>', '<|lt|>', '<|la|>', '<|mi|>',"
" '<|ml|>', '<|cy|>', '<|sk|>', '<|te|>', '<|fa|>', '<|lv|>', '<|bn|>', '<|sr|>', '<|az|>',"
" '<|sl|>', '<|kn|>', '<|et|>', '<|mk|>', '<|br|>', '<|eu|>', '<|is|>', '<|hy|>', '<|ne|>',"
" '<|mn|>', '<|bs|>', '<|kk|>', '<|sq|>', '<|sw|>', '<|gl|>', '<|mr|>', '<|pa|>', '<|si|>',"
" '<|km|>', '<|sn|>', '<|yo|>', '<|so|>', '<|af|>', '<|oc|>', '<|ka|>', '<|be|>', '<|tg|>',"
" '<|sd|>', '<|gu|>', '<|am|>', '<|yi|>', '<|lo|>', '<|uz|>', '<|fo|>', '<|ht|>', '<|ps|>',"
" '<|tk|>', '<|nn|>', '<|mt|>', '<|sa|>', '<|lb|>', '<|my|>', '<|bo|>', '<|tl|>', '<|mg|>',"
" '<|as|>', '<|tt|>', '<|haw|>', '<|ln|>', '<|ha|>', '<|ba|>', '<|jw|>', '<|su|>'"
)
forced_decoder_tokens
+=
f
"<|
{
language
}
|>"
if
task
is
not
None
:
if
f
"<|
{
task
}
|>"
not
in
self
.
tokenizer
.
additional_special_tokens
:
raise
ValueError
(
f
"'
{
task
}
' is not supported. The language should be in : {{'transcribe', 'translate'}}"
)
forced_decoder_tokens
+=
f
"<|
{
task
}
|>"
forced_decoder_tokens
+=
"<|notimestamps|>"
if
no_timestamps
else
""
ids
=
self
.
tokenizer
.
encode
(
forced_decoder_tokens
,
add_special_tokens
=
False
)
forced_decoder_ids
=
[(
rank
+
1
,
token
)
for
rank
,
token
in
enumerate
(
ids
)]
return
forced_decoder_ids
return
self
.
tokenizer
.
get_decoder_prompt_ids
(
task
=
task
,
language
=
language
,
no_timestamps
=
no_timestamps
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
"""
...
...
src/transformers/models/whisper/tokenization_whisper.py
View file @
e7e6d181
...
...
@@ -399,9 +399,13 @@ class WhisperTokenizer(PreTrainedTokenizer):
self
.
language
=
self
.
language
.
lower
()
if
self
.
language
in
TO_LANGUAGE_CODE
:
language_id
=
TO_LANGUAGE_CODE
[
self
.
language
]
elif
self
.
language
in
TO_LANGUAGE_CODE
.
values
():
language_id
=
self
.
language
else
:
is_language_code
=
len
(
self
.
language
)
==
2
raise
ValueError
(
f
"Unsupported language:
{
self
.
language
}
. Language should be in:
{
TO_LANGUAGE_CODE
.
keys
()
}
"
f
"Unsupported language:
{
self
.
language
}
. Language should be one of:"
f
"
{
list
(
TO_LANGUAGE_CODE
.
values
())
if
is_language_code
else
list
(
TO_LANGUAGE_CODE
.
keys
())
}
."
)
if
self
.
task
is
not
None
:
...
...
@@ -577,3 +581,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
if
len
(
input_ids
)
>
self
.
model_max_length
:
input_ids
=
input_ids
[
-
self
.
model_max_length
:]
return
input_ids
def
get_decoder_prompt_ids
(
self
,
task
=
None
,
language
=
None
,
no_timestamps
=
True
):
self
.
set_prefix_tokens
(
task
=
task
,
language
=
language
,
predict_timestamps
=
no_timestamps
)
return
self
.
prefix_tokens
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment