Unverified Commit a5c642fe authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Whisper: move to tensor cpu before converting to np array at decode time (#31954)

parent df1c248a
......@@ -872,7 +872,10 @@ class WhisperTokenizer(PreTrainedTokenizer):
@staticmethod
def _convert_to_list(token_ids):
# convert type to ndarray if necessary
if "torch" in str(type(token_ids)) or "tensorflow" in str(type(token_ids)) and hasattr(token_ids, "numpy"):
if hasattr(token_ids, "numpy"):
if "torch" in str(type(token_ids)):
token_ids = token_ids.cpu().numpy()
elif "tensorflow" in str(type(token_ids)):
token_ids = token_ids.numpy()
# now the token ids are either a numpy array, or a list of lists
if isinstance(token_ids, np.ndarray):
......
......@@ -605,7 +605,10 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._convert_to_list
def _convert_to_list(token_ids):
# convert type to ndarray if necessary
if "torch" in str(type(token_ids)) or "tensorflow" in str(type(token_ids)) and hasattr(token_ids, "numpy"):
if hasattr(token_ids, "numpy"):
if "torch" in str(type(token_ids)):
token_ids = token_ids.cpu().numpy()
elif "tensorflow" in str(type(token_ids)):
token_ids = token_ids.numpy()
# now the token ids are either a numpy array, or a list of lists
if isinstance(token_ids, np.ndarray):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment