"vscode:/vscode.git/clone" did not exist on "5e65d6b2adf64d1548c3b25ccc52ad55fcfe8044"
Unverified Commit 9e5452ee authored by sangbumlikeagod's avatar sangbumlikeagod Committed by GitHub
Browse files

[Bug][Frontend] Fix structure of transcription's decoder_prompt (#18809)


Signed-off-by: default avatarsangbumlikeagod <oironese@naver.com>
parent 0e3fe896
......@@ -37,7 +37,6 @@ async def test_basic_audio(mary_had_lamb):
model_name = "openai/whisper-large-v3-turbo"
server_args = ["--enforce-eager"]
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
prompt = "THE FIRST WORDS I SPOKE"
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
transcription = await client.audio.transcriptions.create(
......@@ -48,16 +47,6 @@ async def test_basic_audio(mary_had_lamb):
temperature=0.0)
out = json.loads(transcription)['text']
assert "Mary had a little lamb," in out
# This should "force" whisper to continue prompt in all caps
transcription_wprompt = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
language="en",
response_format="text",
prompt=prompt,
temperature=0.0)
out_capital = json.loads(transcription_wprompt)['text']
assert prompt not in out_capital
@pytest.mark.asyncio
......@@ -238,3 +227,31 @@ async def test_sampling_params(mary_had_lamb):
extra_body=dict(seed=42))
assert greedy_transcription.text != transcription.text
@pytest.mark.asyncio
async def test_audio_prompt(mary_had_lamb):
model_name = "openai/whisper-large-v3-turbo"
server_args = ["--enforce-eager"]
prompt = "This is a speech, recorded in a phonograph."
with RemoteOpenAIServer(model_name, server_args) as remote_server:
#Prompts should not omit the part of original prompt while transcribing.
prefix = "The first words I spoke in the original phonograph"
client = remote_server.get_async_client()
transcription = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
language="en",
response_format="text",
temperature=0.0)
out = json.loads(transcription)['text']
assert prefix in out
transcription_wprompt = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
language="en",
response_format="text",
prompt=prompt,
temperature=0.0)
out_prompt = json.loads(transcription_wprompt)['text']
assert prefix in out_prompt
......@@ -780,8 +780,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
@classmethod
def get_decoder_prompt(cls, language: str, task_type: str,
prompt: str) -> str:
return (f"<|startoftranscript|><|{language}|><|{task_type}|>"
f"<|notimestamps|>{prompt}")
return ((f"<|prev|>{prompt}" if prompt else "") +
f"<|startoftranscript|><|{language}|>" +
f"<|{task_type}|><|notimestamps|>")
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment