Unverified Commit 0948c827 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Whisper] Strip prompt before finding common subsequence (#27836)

parent b1065aa0
...@@ -897,11 +897,15 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, ...@@ -897,11 +897,15 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
right_stride_start = None right_stride_start = None
all_special_ids = set(tokenizer.all_special_ids) all_special_ids = set(tokenizer.all_special_ids)
prompt_token_id = tokenizer.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = tokenizer.convert_tokens_to_ids("<|startoftranscript|>")
# - iterate over all outputs # - iterate over all outputs
for chunk_id, output in enumerate(model_outputs): for chunk_id, output in enumerate(model_outputs):
# We can drop everything to Python list, it's going to make # We can drop everything to Python list, it's going to make
# our lives easier # our lives easier
token_ids = output["tokens"][0].tolist() token_ids = output["tokens"][0].tolist()
# (possibly) remove the prompt from the token ids
token_ids = tokenizer._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
if return_timestamps == "word": if return_timestamps == "word":
token_timestamps = output["token_timestamps"][0].tolist() token_timestamps = output["token_timestamps"][0].tolist()
......
...@@ -1343,6 +1343,42 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -1343,6 +1343,42 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(output, [{"text": ANY(str)}]) self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "<s> <s") self.assertEqual(output[0]["text"][:6], "<s> <s")
@require_torch
@slow
def test_whisper_prompted(self):
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model = model.to("cuda")
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=16,
device="cuda:0",
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
# prompt the model to misspell "Mr Quilter" as "Mr Quillter"
whisper_prompt = "Mr. Quillter."
prompt_ids = pipe.tokenizer.get_prompt_ids(whisper_prompt, return_tensors="pt")
unprompted_result = pipe(sample.copy())["text"]
prompted_result = pipe(sample, generate_kwargs={"prompt_ids": prompt_ids})["text"]
# fmt: off
EXPECTED_UNPROMPTED_RESULT = " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of rocky Ithaca. Lennils, pictures are a sort of upguards and atom paintings and Mason's exquisite itals are as national as a jingo poem. Mr. Birkut Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says like a shampoo or a Turkish bath. Next man"
EXPECTED_PROMPTED_RESULT = " Mr. Quillter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quillter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins work is really great after all, and can discover in it but little of rocky Ithaca. Lennils, pictures are a sort of upguards and atom paintings, and Mason's exquisite itals are as national as a jingo poem. Mr. Birkut Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. Mr. John Collier gives his sitter a cheerful slap on the back before he says like a shampoo or a Turkish bath. Next man."
# fmt: on
self.assertEqual(unprompted_result, EXPECTED_UNPROMPTED_RESULT)
self.assertEqual(prompted_result, EXPECTED_PROMPTED_RESULT)
@require_torch @require_torch
@slow @slow
def test_whisper_longform(self): def test_whisper_longform(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment