Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a5c642fe
Unverified
Commit
a5c642fe
authored
Jul 14, 2024
by
Joao Gante
Committed by
GitHub
Jul 14, 2024
Browse files
Whisper: move to tensor cpu before converting to np array at decode time (#31954)
parent
df1c248a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
4 deletions
+10
-4
src/transformers/models/whisper/tokenization_whisper.py
src/transformers/models/whisper/tokenization_whisper.py
+5
-2
src/transformers/models/whisper/tokenization_whisper_fast.py
src/transformers/models/whisper/tokenization_whisper_fast.py
+5
-2
No files found.
src/transformers/models/whisper/tokenization_whisper.py
View file @
a5c642fe
...
...
@@ -872,8 +872,11 @@ class WhisperTokenizer(PreTrainedTokenizer):
@
staticmethod
def
_convert_to_list
(
token_ids
):
# convert type to ndarray if necessary
if
"torch"
in
str
(
type
(
token_ids
))
or
"tensorflow"
in
str
(
type
(
token_ids
))
and
hasattr
(
token_ids
,
"numpy"
):
token_ids
=
token_ids
.
numpy
()
if
hasattr
(
token_ids
,
"numpy"
):
if
"torch"
in
str
(
type
(
token_ids
)):
token_ids
=
token_ids
.
cpu
().
numpy
()
elif
"tensorflow"
in
str
(
type
(
token_ids
)):
token_ids
=
token_ids
.
numpy
()
# now the token ids are either a numpy array, or a list of lists
if
isinstance
(
token_ids
,
np
.
ndarray
):
token_ids
=
token_ids
.
tolist
()
...
...
src/transformers/models/whisper/tokenization_whisper_fast.py
View file @
a5c642fe
...
...
@@ -605,8 +605,11 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._convert_to_list
def
_convert_to_list
(
token_ids
):
# convert type to ndarray if necessary
if
"torch"
in
str
(
type
(
token_ids
))
or
"tensorflow"
in
str
(
type
(
token_ids
))
and
hasattr
(
token_ids
,
"numpy"
):
token_ids
=
token_ids
.
numpy
()
if
hasattr
(
token_ids
,
"numpy"
):
if
"torch"
in
str
(
type
(
token_ids
)):
token_ids
=
token_ids
.
cpu
().
numpy
()
elif
"tensorflow"
in
str
(
type
(
token_ids
)):
token_ids
=
token_ids
.
numpy
()
# now the token ids are either a numpy array, or a list of lists
if
isinstance
(
token_ids
,
np
.
ndarray
):
token_ids
=
token_ids
.
tolist
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment