"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7bc6d76396f4a603161539aefaa6207d61260f60"
Unverified Commit 619ecfe2 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Whisper Tok] Move token ids to CPU when computing offsets (#28485)

* move token ids to cpu

* check for torch attr
parent 0eaa5ea3
...@@ -553,6 +553,9 @@ class WhisperTokenizer(PreTrainedTokenizer): ...@@ -553,6 +553,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
The time ratio to convert from token to time. The time ratio to convert from token to time.
""" """
offsets = [] offsets = []
# ensure torch tensor of token ids is placed on cpu
if "torch" in str(type(token_ids)) and (hasattr(token_ids, "cpu") and callable(token_ids.cpu)):
token_ids = token_ids.cpu()
token_ids = np.array(token_ids) token_ids = np.array(token_ids)
if token_ids.shape[0] > 1 and len(token_ids.shape) > 1: if token_ids.shape[0] > 1 and len(token_ids.shape) > 1:
raise ValueError("Can only process a single input at a time") raise ValueError("Can only process a single input at a time")
......
...@@ -248,6 +248,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): ...@@ -248,6 +248,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
The time ratio to convert from token to time. The time ratio to convert from token to time.
""" """
offsets = [] offsets = []
# ensure torch tensor of token ids is placed on cpu
if "torch" in str(type(token_ids)) and (hasattr(token_ids, "cpu") and callable(token_ids.cpu)):
token_ids = token_ids.cpu()
token_ids = np.array(token_ids) token_ids = np.array(token_ids)
if token_ids.shape[0] > 1 and len(token_ids.shape) > 1: if token_ids.shape[0] > 1 and len(token_ids.shape) > 1:
raise ValueError("Can only process a single input at a time") raise ValueError("Can only process a single input at a time")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment