"vscode:/vscode.git/clone" did not exist on "eb1493b15db2019c93e365219b517fb44e313aaf"
Unverified Commit 7adce8b5 authored by Connor Henderson's avatar Connor Henderson Committed by GitHub
Browse files

fix: Replace `add_prefix_space` in `get_prompt_ids` with manual space for...

fix: Replace `add_prefix_space` in `get_prompt_ids` with manual space for FastTokenizer compatibility (#23796)

* add ' ' replacement for add_prefix_space

* add fast tokenizer test
parent 84bac652
......@@ -721,7 +721,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
def get_prompt_ids(self, text: str, return_tensors="np"):
"""Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`]."""
batch_encoding = self("<|startofprev|>", text.strip(), add_prefix_space=True, add_special_tokens=False)
batch_encoding = self("<|startofprev|>", " " + text.strip(), add_special_tokens=False)
# Check for special tokens
prompt_text_ids = batch_encoding["input_ids"][1:]
......
......@@ -494,7 +494,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.get_prompt_ids
def get_prompt_ids(self, text: str, return_tensors="np"):
"""Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`]."""
batch_encoding = self("<|startofprev|>", text.strip(), add_prefix_space=True, add_special_tokens=False)
batch_encoding = self("<|startofprev|>", " " + text.strip(), add_special_tokens=False)
# Check for special tokens
prompt_text_ids = batch_encoding["input_ids"][1:]
......
......@@ -213,6 +213,16 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens
)
def test_fast_tokenizer_get_prompt_ids(self):
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
prompt = "This is test prompt text."
tokenizer_prompt_ids = tokenizer.get_prompt_ids(prompt)
fast_tokenizer_prompt_ids = rust_tokenizer.get_prompt_ids(prompt)
self.assertListEqual(tokenizer_prompt_ids.tolist(), fast_tokenizer_prompt_ids.tolist())
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
checkpoint_name = "openai/whisper-small.en"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment