Unverified Commit 9047288b authored by AllenDou's avatar AllenDou Committed by GitHub
Browse files

support hotwords for FunASR model (#39674)


Signed-off-by: default avatarzixiao <shunli.dsl@alibaba-inc.com>
Co-authored-by: default avatarzixiao <shunli.dsl@alibaba-inc.com>
parent ed6d3037
......@@ -27,7 +27,12 @@ from vllm.assets.audio import AudioAsset
def sync_openai(
audio_path: str, client: OpenAI, model: str, *, repetition_penalty: float = 1.3
audio_path: str,
client: OpenAI,
model: str,
*,
repetition_penalty: float = 1.3,
hotwords: str = None,
):
"""
Perform synchronous transcription using OpenAI-compatible API.
......@@ -43,12 +48,15 @@ def sync_openai(
extra_body=dict(
seed=4419,
repetition_penalty=repetition_penalty,
hotwords=hotwords,
),
)
print("transcription result [sync]:", transcription.text)
async def stream_openai_response(audio_path: str, client: AsyncOpenAI, model: str):
async def stream_openai_response(
audio_path: str, client: AsyncOpenAI, model: str, hotwords: str = None
):
"""
Perform asynchronous transcription using OpenAI-compatible API.
"""
......@@ -64,6 +72,7 @@ async def stream_openai_response(audio_path: str, client: AsyncOpenAI, model: st
extra_body=dict(
seed=420,
top_p=0.6,
hotwords=hotwords,
),
stream=True,
)
......@@ -136,6 +145,7 @@ def main(args):
client=client,
model=model,
repetition_penalty=args.repetition_penalty,
hotwords=args.hotwords,
)
# Run the asynchronous function
......@@ -146,7 +156,10 @@ def main(args):
)
asyncio.run(
stream_openai_response(
args.audio_path if args.audio_path else winning_call, client, model
args.audio_path if args.audio_path else winning_call,
client,
model,
hotwords=args.hotwords,
)
)
else:
......@@ -174,5 +187,11 @@ if __name__ == "__main__":
default=1.3,
help="repetition penalty",
)
parser.add_argument(
"--hotwords",
type=str,
default=None,
help="hotwords",
)
args = parser.parse_args()
main(args)
......@@ -35,6 +35,12 @@ class SpeechToTextParams:
language: str | None = None
"""ISO 639-1 language code (validated / auto-detected)."""
hotwords: str | None = None
"""
hotwords refers to a list of important words or phrases that the model
should pay extra attention to during transcription.
"""
task_type: str = "transcribe"
"""``"transcribe"`` or ``"translate"``."""
......
......@@ -78,6 +78,12 @@ class TranscriptionRequest(OpenAIBaseModel):
will improve accuracy and latency.
"""
hotwords: str | None = None
"""
hotwords refers to a list of important words or phrases that the model
should pay extra attention to during transcription.
"""
prompt: str = Field(default="")
"""An optional text to guide the model's style or continue a previous audio
segment.
......@@ -205,6 +211,7 @@ class TranscriptionRequest(OpenAIBaseModel):
task_type=task_type,
request_prompt=self.prompt,
to_language=self.to_language,
hotwords=self.hotwords,
)
def to_beam_search_params(
......@@ -481,6 +488,12 @@ class TranslationRequest(OpenAIBaseModel):
will improve accuracy.
"""
hotwords: str | None = None
"""
hotwords refers to a list of important words or phrases that the model
should pay extra attention to during transcription.
"""
to_language: str | None = None
"""The language of the input audio we translate to.
......@@ -522,6 +535,7 @@ class TranslationRequest(OpenAIBaseModel):
task_type=task_type,
request_prompt=self.prompt,
to_language=self.to_language,
hotwords=self.hotwords,
)
def to_beam_search_params(
......
......@@ -881,13 +881,20 @@ class FunASRForConditionalGeneration(
audio = stt_params.audio
stt_config = stt_params.stt_config
language = stt_params.language
hotwords = stt_params.hotwords
if language is None:
raise ValueError(
"Language must be specified when creating the funasr prompt"
)
funasr_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n语音转写:<|AUDIO|><|im_end|>\n<|im_start|>assistant\n" # noqa: E501
if hotwords is not None:
funasr_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n请结合上下文信息,更加准确地完成语音转写任务。如果没有相关信息,我们会留空。\n\n\n**上下文信息:**\n\n\n热词列表:[{}]\n语音转写:<|AUDIO|><|im_end|>\n<|im_start|>assistant\n".format( # noqa: E501
hotwords
)
else:
funasr_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n语音转写:<|AUDIO|><|im_end|>\n<|im_start|>assistant\n" # noqa: E501
prompt = {
"prompt": funasr_prompt,
"multi_modal_data": {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment