Unverified Commit b10d05b8 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Use explicit types in `get_generation_prompt` (#33551)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent b398e5c8
......@@ -37,7 +37,8 @@ from vllm.entrypoints.openai.translations.protocol import (
TranslationStreamResponse,
)
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import PromptType
from vllm.inputs.data import ExplicitEncoderDecoderPrompt, PromptType
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.logprobs import FlatLogprobs, Logprob
from vllm.model_executor.models import SupportsTranscription, supports_transcription
......@@ -296,25 +297,36 @@ class OpenAISpeechToText(OpenAIServing):
to_language=to_language,
)
if request.response_format == "verbose_json":
if not isinstance(prompt, dict):
if not is_explicit_encoder_decoder_prompt(prompt):
raise VLLMValidationError(
"Expected prompt to be a dict",
"Expected prompt to be an encoder-decoder prompt",
parameter="prompt",
value=type(prompt).__name__,
)
prompt_dict = cast(dict, prompt)
decoder_prompt = prompt.get("decoder_prompt")
if not isinstance(decoder_prompt, str):
prompt = self._preprocess_verbose_prompt(prompt)
prompts.append(prompt)
return prompts, duration
def _repl_verbose_text(self, text: str):
return text.replace("<|notimestamps|>", "<|0.00|>")
def _preprocess_verbose_prompt(self, prompt: ExplicitEncoderDecoderPrompt):
dec_prompt = prompt["decoder_prompt"]
if isinstance(dec_prompt, str):
prompt["decoder_prompt"] = self._repl_verbose_text(dec_prompt)
elif isinstance(dec_prompt, dict) and "prompt" in dec_prompt:
dec_prompt["prompt"] = self._repl_verbose_text(dec_prompt["prompt"])
else:
raise VLLMValidationError(
"Expected decoder_prompt to be str",
"Expected decoder_prompt to contain text",
parameter="decoder_prompt",
value=type(decoder_prompt).__name__,
)
prompt_dict["decoder_prompt"] = decoder_prompt.replace(
"<|notimestamps|>", "<|0.00|>"
value=type(dec_prompt).__name__,
)
prompts.append(prompt)
return prompts, duration
return prompt
def _get_verbose_segments(
self,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, cast
from typing import Annotated, Any, Literal
import numpy as np
import torch
......@@ -19,7 +19,7 @@ from transformers.models.siglip import SiglipImageProcessorFast
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType
from vllm.inputs.data import PromptType, TextPrompt
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import RowParallelLinear
......@@ -807,9 +807,10 @@ class Gemma3nForConditionalGeneration(
prompt += ": <audio_soft_token><end_of_turn>\n<start_of_turn>model\n"
audio = (audio, stt_config.sample_rate)
prompts_dict = {"multi_modal_data": {"audio": audio}, "prompt": prompt}
return cast(PromptType, prompts_dict)
return TextPrompt(
prompt=prompt,
multi_modal_data={"audio": (audio, stt_config.sample_rate)},
)
@classmethod
def get_speech_to_text_config(
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias, cast
from typing import Annotated, Any, Literal, TypeAlias
import numpy as np
import torch
......@@ -14,7 +14,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.inputs.data import PromptType
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import MMEncoderAttention
from vllm.model_executor.layers.linear import (
......@@ -1159,8 +1159,8 @@ class GlmAsrForConditionalGeneration(
)
prompt_token_ids = tokenizer.encode(prompt)
prompt_dict = {
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": {"audio": audio},
}
return cast(PromptType, prompt_dict)
return TokensPrompt(
prompt_token_ids=prompt_token_ids,
multi_modal_data={"audio": audio},
)
......@@ -26,7 +26,7 @@
import math
from collections.abc import Iterable, Mapping
from typing import Annotated, Literal, cast
from typing import Annotated, Literal
import numpy as np
import torch
......@@ -36,7 +36,7 @@ from transformers import BatchFeature, PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys
......@@ -879,11 +879,11 @@ class GraniteSpeechForConditionalGeneration(
)
prompt_token_ids = tokenizer.encode(prompt)
prompt = {
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": {"audio": audio},
}
return cast(PromptType, prompt)
return TokensPrompt(
prompt_token_ids=prompt_token_ids,
multi_modal_data={"audio": audio},
)
# Adapted from https://github.com/huggingface/transformers/blob/v4.56.0/src/transformers/models/granite_speech/feature_extraction_granite_speech.py#L122 # noqa: E501
@classmethod
......
......@@ -23,7 +23,7 @@
"""Inference-only Qwen3-ASR model."""
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, cast
from typing import Any, Literal
import numpy as np
import torch
......@@ -33,7 +33,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings,
......@@ -561,11 +561,11 @@ class Qwen3ASRForConditionalGeneration(
)
prompt_token_ids = tokenizer.encode(prompt)
prompt_dict = {
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": {"audio": audio},
}
return cast(PromptType, prompt_dict)
return TokensPrompt(
prompt_token_ids=prompt_token_ids,
multi_modal_data={"audio": audio},
)
@classmethod
def post_process_output(cls, text: str) -> str:
......
......@@ -25,7 +25,7 @@ from transformers.tokenization_utils_base import TextInput
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -488,10 +488,13 @@ class VoxtralForConditionalGeneration(
)
tokenized = tokenizer.instruct.encode_transcription(req)
audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
prompts_dict = {"multi_modal_data": {"audio": audio}}
prompts_dict["prompt_token_ids"] = tokenized.tokens
return cast(PromptType, prompts_dict)
return TokensPrompt(
prompt_token_ids=tokenized.tokens,
multi_modal_data={
"audio": (tokenized.audios[0].audio_array, stt_config.sample_rate)
},
)
@classmethod
def get_num_audio_tokens(
......
......@@ -4,7 +4,7 @@
import asyncio
import math
from collections.abc import AsyncGenerator, Mapping
from typing import Literal, cast
from typing import Literal
import numpy as np
import torch
......@@ -453,7 +453,10 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
)
tokenized = tokenizer.instruct.encode_transcription(req)
audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
prompts_dict = {"multi_modal_data": {"audio": audio}}
prompts_dict["prompt_token_ids"] = tokenized.tokens
return cast(PromptType, prompts_dict)
return TokensPrompt(
prompt_token_ids=tokenized.tokens,
multi_modal_data={
"audio": (tokenized.audios[0].audio_array, stt_config.sample_rate)
},
)
......@@ -5,7 +5,7 @@ import enum
import math
from collections.abc import Iterable, Mapping, Sequence
from contextlib import nullcontext
from typing import Annotated, Literal, cast
from typing import Annotated, Literal
import numpy as np
import torch
......@@ -21,7 +21,7 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs.data import PromptType
from vllm.inputs.data import ExplicitEncoderDecoderPrompt, PromptType, TextPrompt
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import (
......@@ -815,21 +815,18 @@ class WhisperForConditionalGeneration(
raise ValueError(
"Language must be specified when creating the Whisper prompt"
)
prompt = {
"encoder_prompt": {
# Whisper does not support encoder prompt.
"prompt": "",
"multi_modal_data": {
"audio": (audio, stt_config.sample_rate),
},
},
"decoder_prompt": (
(f"<|prev|>{request_prompt}" if request_prompt else "")
+ f"<|startoftranscript|><|{language}|>"
+ f"<|{task_type}|><|notimestamps|>"
decoder_text = (
f"<|prev|>{request_prompt}" if request_prompt else ""
) + f"<|startoftranscript|><|{language}|><|{task_type}|><|notimestamps|>"
return ExplicitEncoderDecoderPrompt(
encoder_prompt=TextPrompt(
prompt="", # Whisper does not support encoder prompt.
multi_modal_data={"audio": (audio, stt_config.sample_rate)},
),
}
return cast(PromptType, prompt)
decoder_prompt=TextPrompt(prompt=decoder_text),
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment