"cacheflow/vscode:/vscode.git/clone" did not exist on "b7955ef17b8d899327b25564f20665ec3ffa71cb"
Unverified Commit 9047288b authored by AllenDou's avatar AllenDou Committed by GitHub
Browse files

support hotwords for FunASR model (#39674)


Signed-off-by: default avatarzixiao <shunli.dsl@alibaba-inc.com>
Co-authored-by: default avatarzixiao <shunli.dsl@alibaba-inc.com>
parent ed6d3037
...@@ -27,7 +27,12 @@ from vllm.assets.audio import AudioAsset ...@@ -27,7 +27,12 @@ from vllm.assets.audio import AudioAsset
def sync_openai( def sync_openai(
audio_path: str, client: OpenAI, model: str, *, repetition_penalty: float = 1.3 audio_path: str,
client: OpenAI,
model: str,
*,
repetition_penalty: float = 1.3,
hotwords: str = None,
): ):
""" """
Perform synchronous transcription using OpenAI-compatible API. Perform synchronous transcription using OpenAI-compatible API.
...@@ -43,12 +48,15 @@ def sync_openai( ...@@ -43,12 +48,15 @@ def sync_openai(
extra_body=dict( extra_body=dict(
seed=4419, seed=4419,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
hotwords=hotwords,
), ),
) )
print("transcription result [sync]:", transcription.text) print("transcription result [sync]:", transcription.text)
async def stream_openai_response(audio_path: str, client: AsyncOpenAI, model: str): async def stream_openai_response(
audio_path: str, client: AsyncOpenAI, model: str, hotwords: str = None
):
""" """
Perform asynchronous transcription using OpenAI-compatible API. Perform asynchronous transcription using OpenAI-compatible API.
""" """
...@@ -64,6 +72,7 @@ async def stream_openai_response(audio_path: str, client: AsyncOpenAI, model: st ...@@ -64,6 +72,7 @@ async def stream_openai_response(audio_path: str, client: AsyncOpenAI, model: st
extra_body=dict( extra_body=dict(
seed=420, seed=420,
top_p=0.6, top_p=0.6,
hotwords=hotwords,
), ),
stream=True, stream=True,
) )
...@@ -136,6 +145,7 @@ def main(args): ...@@ -136,6 +145,7 @@ def main(args):
client=client, client=client,
model=model, model=model,
repetition_penalty=args.repetition_penalty, repetition_penalty=args.repetition_penalty,
hotwords=args.hotwords,
) )
# Run the asynchronous function # Run the asynchronous function
...@@ -146,7 +156,10 @@ def main(args): ...@@ -146,7 +156,10 @@ def main(args):
) )
asyncio.run( asyncio.run(
stream_openai_response( stream_openai_response(
args.audio_path if args.audio_path else winning_call, client, model args.audio_path if args.audio_path else winning_call,
client,
model,
hotwords=args.hotwords,
) )
) )
else: else:
...@@ -174,5 +187,11 @@ if __name__ == "__main__": ...@@ -174,5 +187,11 @@ if __name__ == "__main__":
default=1.3, default=1.3,
help="repetition penalty", help="repetition penalty",
) )
parser.add_argument(
"--hotwords",
type=str,
default=None,
help="hotwords",
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -35,6 +35,12 @@ class SpeechToTextParams: ...@@ -35,6 +35,12 @@ class SpeechToTextParams:
language: str | None = None language: str | None = None
"""ISO 639-1 language code (validated / auto-detected).""" """ISO 639-1 language code (validated / auto-detected)."""
hotwords: str | None = None
"""
hotwords refers to a list of important words or phrases that the model
should pay extra attention to during transcription.
"""
task_type: str = "transcribe" task_type: str = "transcribe"
"""``"transcribe"`` or ``"translate"``.""" """``"transcribe"`` or ``"translate"``."""
......
...@@ -78,6 +78,12 @@ class TranscriptionRequest(OpenAIBaseModel): ...@@ -78,6 +78,12 @@ class TranscriptionRequest(OpenAIBaseModel):
will improve accuracy and latency. will improve accuracy and latency.
""" """
hotwords: str | None = None
"""
hotwords refers to a list of important words or phrases that the model
should pay extra attention to during transcription.
"""
prompt: str = Field(default="") prompt: str = Field(default="")
"""An optional text to guide the model's style or continue a previous audio """An optional text to guide the model's style or continue a previous audio
segment. segment.
...@@ -205,6 +211,7 @@ class TranscriptionRequest(OpenAIBaseModel): ...@@ -205,6 +211,7 @@ class TranscriptionRequest(OpenAIBaseModel):
task_type=task_type, task_type=task_type,
request_prompt=self.prompt, request_prompt=self.prompt,
to_language=self.to_language, to_language=self.to_language,
hotwords=self.hotwords,
) )
def to_beam_search_params( def to_beam_search_params(
...@@ -481,6 +488,12 @@ class TranslationRequest(OpenAIBaseModel): ...@@ -481,6 +488,12 @@ class TranslationRequest(OpenAIBaseModel):
will improve accuracy. will improve accuracy.
""" """
hotwords: str | None = None
"""
hotwords refers to a list of important words or phrases that the model
should pay extra attention to during transcription.
"""
to_language: str | None = None to_language: str | None = None
"""The language of the input audio we translate to. """The language of the input audio we translate to.
...@@ -522,6 +535,7 @@ class TranslationRequest(OpenAIBaseModel): ...@@ -522,6 +535,7 @@ class TranslationRequest(OpenAIBaseModel):
task_type=task_type, task_type=task_type,
request_prompt=self.prompt, request_prompt=self.prompt,
to_language=self.to_language, to_language=self.to_language,
hotwords=self.hotwords,
) )
def to_beam_search_params( def to_beam_search_params(
......
...@@ -881,13 +881,20 @@ class FunASRForConditionalGeneration( ...@@ -881,13 +881,20 @@ class FunASRForConditionalGeneration(
audio = stt_params.audio audio = stt_params.audio
stt_config = stt_params.stt_config stt_config = stt_params.stt_config
language = stt_params.language language = stt_params.language
hotwords = stt_params.hotwords
if language is None: if language is None:
raise ValueError( raise ValueError(
"Language must be specified when creating the funasr prompt" "Language must be specified when creating the funasr prompt"
) )
funasr_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n语音转写:<|AUDIO|><|im_end|>\n<|im_start|>assistant\n" # noqa: E501 if hotwords is not None:
funasr_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n请结合上下文信息,更加准确地完成语音转写任务。如果没有相关信息,我们会留空。\n\n\n**上下文信息:**\n\n\n热词列表:[{}]\n语音转写:<|AUDIO|><|im_end|>\n<|im_start|>assistant\n".format( # noqa: E501
hotwords
)
else:
funasr_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n语音转写:<|AUDIO|><|im_end|>\n<|im_start|>assistant\n" # noqa: E501
prompt = { prompt = {
"prompt": funasr_prompt, "prompt": funasr_prompt,
"multi_modal_data": { "multi_modal_data": {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment