Unverified Commit e77f162c authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Bugfix] Fix `Qwen3ASR` language asr tag in output (#33410)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent 8ecd213c
...@@ -518,7 +518,8 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -518,7 +518,8 @@ class OpenAISpeechToText(OpenAIServing):
total_segments.extend(segments) total_segments.extend(segments)
text_parts.extend([seg.text for seg in segments]) text_parts.extend([seg.text for seg in segments])
else: else:
text_parts.append(op.outputs[0].text) raw_text = op.outputs[0].text
text_parts.append(self.model_cls.post_process_output(raw_text))
text = "".join(text_parts) text = "".join(text_parts)
if self.task_type == "transcribe": if self.task_type == "transcribe":
final_response: ResponseType final_response: ResponseType
...@@ -607,6 +608,10 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -607,6 +608,10 @@ class OpenAISpeechToText(OpenAIServing):
assert len(res.outputs) == 1 assert len(res.outputs) == 1
output = res.outputs[0] output = res.outputs[0]
# TODO: For models that output structured formats (e.g.,
# Qwen3-ASR with "language X<asr_text>" prefix), streaming
# would need buffering to strip the prefix properly since
# deltas may split the tag across chunks.
delta_message = DeltaMessage(content=output.text) delta_message = DeltaMessage(content=output.text)
completion_tokens += len(output.token_ids) completion_tokens += len(output.token_ids)
......
...@@ -1145,6 +1145,22 @@ class SupportsTranscription(Protocol): ...@@ -1145,6 +1145,22 @@ class SupportsTranscription(Protocol):
""" """
return None return None
@classmethod
def post_process_output(cls, text: str) -> str:
"""
Post-process the raw model output text.
Some ASR models output structured formats (e.g., language tags,
special tokens) that need to be stripped before returning to the user.
Args:
text: Raw decoded text from the model.
Returns:
Cleaned transcription text.
"""
return text
@overload @overload
def supports_transcription( def supports_transcription(
......
...@@ -90,6 +90,7 @@ from vllm.transformers_utils.processors.qwen3_asr import ( ...@@ -90,6 +90,7 @@ from vllm.transformers_utils.processors.qwen3_asr import (
) )
logger = init_logger(__name__) logger = init_logger(__name__)
_ASR_TEXT_TAG = "<asr_text>"
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor): def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
...@@ -556,7 +557,7 @@ class Qwen3ASRForConditionalGeneration( ...@@ -556,7 +557,7 @@ class Qwen3ASRForConditionalGeneration(
else: else:
prompt = ( prompt = (
f"<|im_start|>user\n{audio_placeholder}<|im_end|>\n" f"<|im_start|>user\n{audio_placeholder}<|im_end|>\n"
f"<|im_start|>assistant\nlanguage {full_lang_name_to}<asr_text>" f"<|im_start|>assistant\nlanguage {full_lang_name_to}{_ASR_TEXT_TAG}"
) )
prompt_token_ids = tokenizer.encode(prompt) prompt_token_ids = tokenizer.encode(prompt)
...@@ -565,3 +566,21 @@ class Qwen3ASRForConditionalGeneration( ...@@ -565,3 +566,21 @@ class Qwen3ASRForConditionalGeneration(
"multi_modal_data": {"audio": audio}, "multi_modal_data": {"audio": audio},
} }
return cast(PromptType, prompt_dict) return cast(PromptType, prompt_dict)
@classmethod
def post_process_output(cls, text: str) -> str:
"""
Post-process Qwen3-ASR raw output to extract clean transcription.
The model outputs in format: "language {lang}<asr_text>{transcription}"
This method strips the language prefix and asr_text tags.
"""
if not text:
return ""
if _ASR_TEXT_TAG not in text:
return text
# Split on <asr_text> tag and take the transcription part
_, text_part = text.rsplit(_ASR_TEXT_TAG, 1)
return text_part
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment