Unverified Commit b10d05b8 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Use explicit types in `get_generation_prompt` (#33551)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent b398e5c8
...@@ -37,7 +37,8 @@ from vllm.entrypoints.openai.translations.protocol import ( ...@@ -37,7 +37,8 @@ from vllm.entrypoints.openai.translations.protocol import (
TranslationStreamResponse, TranslationStreamResponse,
) )
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import PromptType from vllm.inputs.data import ExplicitEncoderDecoderPrompt, PromptType
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import FlatLogprobs, Logprob from vllm.logprobs import FlatLogprobs, Logprob
from vllm.model_executor.models import SupportsTranscription, supports_transcription from vllm.model_executor.models import SupportsTranscription, supports_transcription
...@@ -296,25 +297,36 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -296,25 +297,36 @@ class OpenAISpeechToText(OpenAIServing):
to_language=to_language, to_language=to_language,
) )
if request.response_format == "verbose_json": if request.response_format == "verbose_json":
if not isinstance(prompt, dict): if not is_explicit_encoder_decoder_prompt(prompt):
raise VLLMValidationError( raise VLLMValidationError(
"Expected prompt to be a dict", "Expected prompt to be an encoder-decoder prompt",
parameter="prompt", parameter="prompt",
value=type(prompt).__name__, value=type(prompt).__name__,
) )
prompt_dict = cast(dict, prompt)
decoder_prompt = prompt.get("decoder_prompt") prompt = self._preprocess_verbose_prompt(prompt)
if not isinstance(decoder_prompt, str):
prompts.append(prompt)
return prompts, duration
def _repl_verbose_text(self, text: str):
return text.replace("<|notimestamps|>", "<|0.00|>")
def _preprocess_verbose_prompt(self, prompt: ExplicitEncoderDecoderPrompt):
dec_prompt = prompt["decoder_prompt"]
if isinstance(dec_prompt, str):
prompt["decoder_prompt"] = self._repl_verbose_text(dec_prompt)
elif isinstance(dec_prompt, dict) and "prompt" in dec_prompt:
dec_prompt["prompt"] = self._repl_verbose_text(dec_prompt["prompt"])
else:
raise VLLMValidationError( raise VLLMValidationError(
"Expected decoder_prompt to be str", "Expected decoder_prompt to contain text",
parameter="decoder_prompt", parameter="decoder_prompt",
value=type(decoder_prompt).__name__, value=type(dec_prompt).__name__,
)
prompt_dict["decoder_prompt"] = decoder_prompt.replace(
"<|notimestamps|>", "<|0.00|>"
) )
prompts.append(prompt)
return prompts, duration return prompt
def _get_verbose_segments( def _get_verbose_segments(
self, self,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, cast from typing import Annotated, Any, Literal
import numpy as np import numpy as np
import torch import torch
...@@ -19,7 +19,7 @@ from transformers.models.siglip import SiglipImageProcessorFast ...@@ -19,7 +19,7 @@ from transformers.models.siglip import SiglipImageProcessorFast
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType, TextPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.linear import RowParallelLinear
...@@ -807,9 +807,10 @@ class Gemma3nForConditionalGeneration( ...@@ -807,9 +807,10 @@ class Gemma3nForConditionalGeneration(
prompt += ": <audio_soft_token><end_of_turn>\n<start_of_turn>model\n" prompt += ": <audio_soft_token><end_of_turn>\n<start_of_turn>model\n"
audio = (audio, stt_config.sample_rate) return TextPrompt(
prompts_dict = {"multi_modal_data": {"audio": audio}, "prompt": prompt} prompt=prompt,
return cast(PromptType, prompts_dict) multi_modal_data={"audio": (audio, stt_config.sample_rate)},
)
@classmethod @classmethod
def get_speech_to_text_config( def get_speech_to_text_config(
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias, cast from typing import Annotated, Any, Literal, TypeAlias
import numpy as np import numpy as np
import torch import torch
...@@ -14,7 +14,7 @@ from transformers.models.whisper import WhisperFeatureExtractor ...@@ -14,7 +14,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType, TokensPrompt
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import MMEncoderAttention from vllm.model_executor.layers.attention import MMEncoderAttention
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
...@@ -1159,8 +1159,8 @@ class GlmAsrForConditionalGeneration( ...@@ -1159,8 +1159,8 @@ class GlmAsrForConditionalGeneration(
) )
prompt_token_ids = tokenizer.encode(prompt) prompt_token_ids = tokenizer.encode(prompt)
prompt_dict = {
"prompt_token_ids": prompt_token_ids, return TokensPrompt(
"multi_modal_data": {"audio": audio}, prompt_token_ids=prompt_token_ids,
} multi_modal_data={"audio": audio},
return cast(PromptType, prompt_dict) )
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
import math import math
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from typing import Annotated, Literal, cast from typing import Annotated, Literal
import numpy as np import numpy as np
import torch import torch
...@@ -36,7 +36,7 @@ from transformers import BatchFeature, PretrainedConfig ...@@ -36,7 +36,7 @@ from transformers import BatchFeature, PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType, TokensPrompt
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
...@@ -879,11 +879,11 @@ class GraniteSpeechForConditionalGeneration( ...@@ -879,11 +879,11 @@ class GraniteSpeechForConditionalGeneration(
) )
prompt_token_ids = tokenizer.encode(prompt) prompt_token_ids = tokenizer.encode(prompt)
prompt = {
"prompt_token_ids": prompt_token_ids, return TokensPrompt(
"multi_modal_data": {"audio": audio}, prompt_token_ids=prompt_token_ids,
} multi_modal_data={"audio": audio},
return cast(PromptType, prompt) )
# Adapted from https://github.com/huggingface/transformers/blob/v4.56.0/src/transformers/models/granite_speech/feature_extraction_granite_speech.py#L122 # noqa: E501 # Adapted from https://github.com/huggingface/transformers/blob/v4.56.0/src/transformers/models/granite_speech/feature_extraction_granite_speech.py#L122 # noqa: E501
@classmethod @classmethod
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
"""Inference-only Qwen3-ASR model.""" """Inference-only Qwen3-ASR model."""
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, cast from typing import Any, Literal
import numpy as np import numpy as np
import torch import torch
...@@ -33,7 +33,7 @@ from transformers.models.whisper import WhisperFeatureExtractor ...@@ -33,7 +33,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import ( from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
...@@ -561,11 +561,11 @@ class Qwen3ASRForConditionalGeneration( ...@@ -561,11 +561,11 @@ class Qwen3ASRForConditionalGeneration(
) )
prompt_token_ids = tokenizer.encode(prompt) prompt_token_ids = tokenizer.encode(prompt)
prompt_dict = {
"prompt_token_ids": prompt_token_ids, return TokensPrompt(
"multi_modal_data": {"audio": audio}, prompt_token_ids=prompt_token_ids,
} multi_modal_data={"audio": audio},
return cast(PromptType, prompt_dict) )
@classmethod @classmethod
def post_process_output(cls, text: str) -> str: def post_process_output(cls, text: str) -> str:
......
...@@ -25,7 +25,7 @@ from transformers.tokenization_utils_base import TextInput ...@@ -25,7 +25,7 @@ from transformers.tokenization_utils_base import TextInput
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -488,10 +488,13 @@ class VoxtralForConditionalGeneration( ...@@ -488,10 +488,13 @@ class VoxtralForConditionalGeneration(
) )
tokenized = tokenizer.instruct.encode_transcription(req) tokenized = tokenizer.instruct.encode_transcription(req)
audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
prompts_dict = {"multi_modal_data": {"audio": audio}} return TokensPrompt(
prompts_dict["prompt_token_ids"] = tokenized.tokens prompt_token_ids=tokenized.tokens,
return cast(PromptType, prompts_dict) multi_modal_data={
"audio": (tokenized.audios[0].audio_array, stt_config.sample_rate)
},
)
@classmethod @classmethod
def get_num_audio_tokens( def get_num_audio_tokens(
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import asyncio import asyncio
import math import math
from collections.abc import AsyncGenerator, Mapping from collections.abc import AsyncGenerator, Mapping
from typing import Literal, cast from typing import Literal
import numpy as np import numpy as np
import torch import torch
...@@ -453,7 +453,10 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim ...@@ -453,7 +453,10 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
) )
tokenized = tokenizer.instruct.encode_transcription(req) tokenized = tokenizer.instruct.encode_transcription(req)
audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
prompts_dict = {"multi_modal_data": {"audio": audio}} return TokensPrompt(
prompts_dict["prompt_token_ids"] = tokenized.tokens prompt_token_ids=tokenized.tokens,
return cast(PromptType, prompts_dict) multi_modal_data={
"audio": (tokenized.audios[0].audio_array, stt_config.sample_rate)
},
)
...@@ -5,7 +5,7 @@ import enum ...@@ -5,7 +5,7 @@ import enum
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from contextlib import nullcontext from contextlib import nullcontext
from typing import Annotated, Literal, cast from typing import Annotated, Literal
import numpy as np import numpy as np
import torch import torch
...@@ -21,7 +21,7 @@ from vllm.compilation.decorators import support_torch_compile ...@@ -21,7 +21,7 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs.data import PromptType from vllm.inputs.data import ExplicitEncoderDecoderPrompt, PromptType, TextPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import ( from vllm.model_executor.layers.attention import (
...@@ -815,21 +815,18 @@ class WhisperForConditionalGeneration( ...@@ -815,21 +815,18 @@ class WhisperForConditionalGeneration(
raise ValueError( raise ValueError(
"Language must be specified when creating the Whisper prompt" "Language must be specified when creating the Whisper prompt"
) )
prompt = {
"encoder_prompt": { decoder_text = (
# Whisper does not support encoder prompt. f"<|prev|>{request_prompt}" if request_prompt else ""
"prompt": "", ) + f"<|startoftranscript|><|{language}|><|{task_type}|><|notimestamps|>"
"multi_modal_data": {
"audio": (audio, stt_config.sample_rate), return ExplicitEncoderDecoderPrompt(
}, encoder_prompt=TextPrompt(
}, prompt="", # Whisper does not support encoder prompt.
"decoder_prompt": ( multi_modal_data={"audio": (audio, stt_config.sample_rate)},
(f"<|prev|>{request_prompt}" if request_prompt else "")
+ f"<|startoftranscript|><|{language}|>"
+ f"<|{task_type}|><|notimestamps|>"
), ),
} decoder_prompt=TextPrompt(prompt=decoder_text),
return cast(PromptType, prompt) )
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None: def get_placeholder_str(cls, modality: str, i: int) -> str | None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment