Unverified Commit 4eafc729 authored by Ekagra Ranjan's avatar Ekagra Ranjan Committed by GitHub
Browse files

[Audio] Bundle `get_generation_prompt()` params into `SpeechToTextParams` (#36268)


Signed-off-by: default avatarEkagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 6d097697
...@@ -66,7 +66,7 @@ This is for controlling general behavior of the API when serving your model: ...@@ -66,7 +66,7 @@ This is for controlling general behavior of the API when serving your model:
See [Audio preprocessing and chunking](#audio-preprocessing-and-chunking) for what each field controls. See [Audio preprocessing and chunking](#audio-preprocessing-and-chunking) for what each field controls.
Implement the prompt construction via [get_generation_prompt][vllm.model_executor.models.interfaces.SupportsTranscription.get_generation_prompt]. The server passes you the resampled waveform and task parameters; you return a valid [PromptType][vllm.inputs.llm.PromptType]. There are two common patterns: Implement the prompt construction via [get_generation_prompt][vllm.model_executor.models.interfaces.SupportsTranscription.get_generation_prompt]. The server builds a [SpeechToTextParams][vllm.config.speech_to_text.SpeechToTextParams] object that bundles the resampled waveform, task parameters, and request-specific options. Your model receives this single object and returns a valid [PromptType][vllm.inputs.llm.PromptType]. There are two common patterns:
#### Multimodal LLM with audio embeddings (e.g., Voxtral, Gemma3n) #### Multimodal LLM with audio embeddings (e.g., Voxtral, Gemma3n)
...@@ -75,21 +75,20 @@ Return a dict containing `multi_modal_data` with the audio, and either a `prompt ...@@ -75,21 +75,20 @@ Return a dict containing `multi_modal_data` with the audio, and either a `prompt
??? code "get_generation_prompt()" ??? code "get_generation_prompt()"
```python ```python
from vllm.config.speech_to_text import SpeechToTextParams
class YourASRModel(nn.Module, SupportsTranscription): class YourASRModel(nn.Module, SupportsTranscription):
... ...
@classmethod @classmethod
def get_generation_prompt( def get_generation_prompt(
cls, cls,
audio: np.ndarray, stt_params: SpeechToTextParams,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType: ) -> PromptType:
# Example with a free-form instruction prompt audio = stt_params.audio
stt_config = stt_params.stt_config
task_type = stt_params.task_type
task_word = "Transcribe" if task_type == "transcribe" else "Translate" task_word = "Transcribe" if task_type == "transcribe" else "Translate"
prompt = ( prompt = (
"<start_of_turn>user\n" "<start_of_turn>user\n"
...@@ -112,20 +111,22 @@ Return a dict with separate `encoder_prompt` and `decoder_prompt` entries: ...@@ -112,20 +111,22 @@ Return a dict with separate `encoder_prompt` and `decoder_prompt` entries:
??? code "get_generation_prompt()" ??? code "get_generation_prompt()"
```python ```python
from vllm.config.speech_to_text import SpeechToTextParams
class YourASRModel(nn.Module, SupportsTranscription): class YourASRModel(nn.Module, SupportsTranscription):
... ...
@classmethod @classmethod
def get_generation_prompt( def get_generation_prompt(
cls, cls,
audio: np.ndarray, stt_params: SpeechToTextParams,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType: ) -> PromptType:
audio = stt_params.audio
stt_config = stt_params.stt_config
language = stt_params.language
task_type = stt_params.task_type
request_prompt = stt_params.request_prompt
if language is None: if language is None:
raise ValueError("Language must be specified") raise ValueError("Language must be specified")
...@@ -213,15 +214,13 @@ Relevant server logic: ...@@ -213,15 +214,13 @@ Relevant server logic:
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr)) chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
prompts = [] prompts = []
for chunk in chunks: for chunk in chunks:
prompt = self.model_cls.get_generation_prompt( stt_params = request.build_stt_params(
audio=chunk, audio=chunk,
stt_config=self.asr_config, stt_config=self.asr_config,
model_config=self.model_config, model_config=self.model_config,
language=language,
task_type=self.task_type, task_type=self.task_type,
request_prompt=request.prompt,
to_language=to_language,
) )
prompt = self.model_cls.get_generation_prompt(stt_params)
prompts.append(prompt) prompts.append(prompt)
return prompts, duration return prompts, duration
``` ```
......
...@@ -37,7 +37,7 @@ from vllm.config.profiler import ProfilerConfig ...@@ -37,7 +37,7 @@ from vllm.config.profiler import ProfilerConfig
from vllm.config.reasoning import ReasoningConfig from vllm.config.reasoning import ReasoningConfig
from vllm.config.scheduler import SchedulerConfig from vllm.config.scheduler import SchedulerConfig
from vllm.config.speculative import SpeculativeConfig from vllm.config.speculative import SpeculativeConfig
from vllm.config.speech_to_text import SpeechToTextConfig from vllm.config.speech_to_text import SpeechToTextConfig, SpeechToTextParams
from vllm.config.structured_outputs import StructuredOutputsConfig from vllm.config.structured_outputs import StructuredOutputsConfig
from vllm.config.utils import ( from vllm.config.utils import (
ConfigType, ConfigType,
...@@ -113,6 +113,7 @@ __all__ = [ ...@@ -113,6 +113,7 @@ __all__ = [
"SpeculativeConfig", "SpeculativeConfig",
# From vllm.config.speech_to_text # From vllm.config.speech_to_text
"SpeechToTextConfig", "SpeechToTextConfig",
"SpeechToTextParams",
# From vllm.config.structured_outputs # From vllm.config.structured_outputs
"StructuredOutputsConfig", "StructuredOutputsConfig",
# From vllm.config.profiler # From vllm.config.profiler
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
from vllm.config.utils import config from vllm.config.utils import config
if TYPE_CHECKING:
import numpy as np
from vllm.config.model import ModelConfig
@dataclass
class SpeechToTextParams:
"""All parameters consumed by ``get_generation_prompt()``.
``TranscriptionRequest.build_stt_params()`` constructs this object,
mapping API-level fields into typed attributes. Models only receive
this object, so new parameters can be added here without changing the
``get_generation_prompt`` signature.
"""
audio: np.ndarray
"""Resampled audio waveform for a single chunk."""
stt_config: SpeechToTextConfig
"""Server-level speech-to-text configuration."""
model_config: ModelConfig
"""Model configuration."""
language: str | None = None
"""ISO 639-1 language code (validated / auto-detected)."""
task_type: str = "transcribe"
"""``"transcribe"`` or ``"translate"``."""
request_prompt: str = ""
"""Optional text prompt to guide the model."""
to_language: str | None = None
"""Target language for translation (model-dependent)."""
@config @config
class SpeechToTextConfig: class SpeechToTextConfig:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import time import time
from http import HTTPStatus from http import HTTPStatus
from typing import Literal, TypeAlias from typing import TYPE_CHECKING, Literal, TypeAlias
import torch import torch
from fastapi import HTTPException, UploadFile from fastapi import HTTPException, UploadFile
...@@ -12,6 +13,7 @@ from pydantic import ( ...@@ -12,6 +13,7 @@ from pydantic import (
model_validator, model_validator,
) )
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage, DeltaMessage,
OpenAIBaseModel, OpenAIBaseModel,
...@@ -26,6 +28,11 @@ from vllm.sampling_params import ( ...@@ -26,6 +28,11 @@ from vllm.sampling_params import (
) )
from vllm.utils import random_uuid from vllm.utils import random_uuid
if TYPE_CHECKING:
import numpy as np
from vllm.config import ModelConfig, SpeechToTextConfig
logger = init_logger(__name__) logger = init_logger(__name__)
_LONG_INFO = torch.iinfo(torch.long) _LONG_INFO = torch.iinfo(torch.long)
...@@ -183,6 +190,23 @@ class TranscriptionRequest(OpenAIBaseModel): ...@@ -183,6 +190,23 @@ class TranscriptionRequest(OpenAIBaseModel):
"min_p": 0.0, "min_p": 0.0,
} }
def build_stt_params(
self,
audio: "np.ndarray",
stt_config: "SpeechToTextConfig",
model_config: "ModelConfig",
task_type: str,
) -> SpeechToTextParams:
return SpeechToTextParams(
audio=audio,
stt_config=stt_config,
model_config=model_config,
language=self.language,
task_type=task_type,
request_prompt=self.prompt,
to_language=self.to_language,
)
def to_beam_search_params( def to_beam_search_params(
self, self,
default_max_tokens: int, default_max_tokens: int,
...@@ -277,6 +301,17 @@ class TranscriptionRequest(OpenAIBaseModel): ...@@ -277,6 +301,17 @@ class TranscriptionRequest(OpenAIBaseModel):
parameter=invalid_param, parameter=invalid_param,
) )
# Parse vllm_xargs from JSON string (form data sends it as a string)
xargs = data.get("vllm_xargs")
if isinstance(xargs, str):
try:
data["vllm_xargs"] = json.loads(xargs)
except json.JSONDecodeError as e:
raise VLLMValidationError(
f"Failed to parse vllm_xargs. Must be valid JSON: {e}",
parameter="vllm_xargs",
) from e
return data return data
...@@ -472,6 +507,23 @@ class TranslationRequest(OpenAIBaseModel): ...@@ -472,6 +507,23 @@ class TranslationRequest(OpenAIBaseModel):
"temperature": 0, "temperature": 0,
} }
def build_stt_params(
self,
audio: "np.ndarray",
stt_config: "SpeechToTextConfig",
model_config: "ModelConfig",
task_type: str,
) -> SpeechToTextParams:
return SpeechToTextParams(
audio=audio,
stt_config=stt_config,
model_config=model_config,
language=self.language,
task_type=task_type,
request_prompt=self.prompt,
to_language=self.to_language,
)
def to_beam_search_params( def to_beam_search_params(
self, self,
default_max_tokens: int, default_max_tokens: int,
......
...@@ -184,9 +184,8 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -184,9 +184,8 @@ class OpenAISpeechToText(OpenAIServing):
request_id: str, request_id: str,
) -> tuple[list[EngineInput], float]: ) -> tuple[list[EngineInput], float]:
# Validate request # Validate request
language = self.model_cls.validate_language(request.language) request.language = self.model_cls.validate_language(request.language)
# Skip to_language validation to avoid extra logging for Whisper. request.to_language = (
to_language = (
self.model_cls.validate_language(request.to_language) self.model_cls.validate_language(request.to_language)
if request.to_language if request.to_language
else None else None
...@@ -229,28 +228,23 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -229,28 +228,23 @@ class OpenAISpeechToText(OpenAIServing):
min_energy_window_size=self.asr_config.min_energy_split_window_size, min_energy_window_size=self.asr_config.min_energy_split_window_size,
) )
if language is None and getattr( if request.language is None and getattr(
self.model_cls, "supports_explicit_language_detection", False self.model_cls, "supports_explicit_language_detection", False
): ):
# Auto-detect language from the first chunk. # Auto-detect language from the first chunk.
language = await self._detect_language( request.language = await self._detect_language(
chunks[0], f"{request_id}-lang_detect" chunks[0], f"{request_id}-lang_detect"
) )
request.language = language
parsed_prompts: list[DictPrompt] = [] parsed_prompts: list[DictPrompt] = []
for chunk in chunks: for chunk in chunks:
# The model has control over the construction, as long as it stt_params = request.build_stt_params(
# returns a valid PromptType.
prompt = self.model_cls.get_generation_prompt(
audio=chunk, audio=chunk,
stt_config=self.asr_config, stt_config=self.asr_config,
model_config=self.model_config, model_config=self.model_config,
language=language,
task_type=self.task_type, task_type=self.task_type,
request_prompt=request.prompt,
to_language=to_language,
) )
prompt = self.model_cls.get_generation_prompt(stt_params)
parsed_prompt: DictPrompt parsed_prompt: DictPrompt
if request.response_format == "verbose_json": if request.response_format == "verbose_json":
......
...@@ -3,9 +3,7 @@ ...@@ -3,9 +3,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
...@@ -13,6 +11,8 @@ from transformers import PretrainedConfig ...@@ -13,6 +11,8 @@ from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import MultiModalDataDict, PromptType, TextPrompt from vllm.inputs import MultiModalDataDict, PromptType, TextPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -1900,7 +1900,7 @@ class CohereASRDummyInputsBuilder(BaseDummyInputsBuilder[CohereASRProcessingInfo ...@@ -1900,7 +1900,7 @@ class CohereASRDummyInputsBuilder(BaseDummyInputsBuilder[CohereASRProcessingInfo
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options=None, mm_options: Mapping[str, BaseDummyOptions],
mm_processor_kwargs=None, mm_processor_kwargs=None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor() feature_extractor = self.info.get_feature_extractor()
...@@ -2021,16 +2021,12 @@ class CohereAsrForConditionalGeneration( ...@@ -2021,16 +2021,12 @@ class CohereAsrForConditionalGeneration(
return super().validate_language(language) return super().validate_language(language)
@classmethod @classmethod
def get_generation_prompt( def get_generation_prompt(cls, stt_params: SpeechToTextParams) -> PromptType:
cls, audio = stt_params.audio
audio: np.ndarray, stt_config = stt_params.stt_config
model_config: ModelConfig, # not needed here language = stt_params.language
stt_config: SpeechToTextConfig, request_prompt = stt_params.request_prompt
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
if language is None: if language is None:
raise ValueError( raise ValueError(
"Language must be specified when creating the CohereASR prompt" "Language must be specified when creating the CohereASR prompt"
......
...@@ -2,9 +2,8 @@ ...@@ -2,9 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal, cast from typing import Annotated, cast
import numpy as np
import torch import torch
from torch import nn from torch import nn
from transformers import ( from transformers import (
...@@ -14,6 +13,7 @@ from transformers import ( ...@@ -14,6 +13,7 @@ from transformers import (
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.inputs import MultiModalDataDict, PromptType from vllm.inputs import MultiModalDataDict, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
...@@ -356,14 +356,12 @@ class FireRedASR2ForConditionalGeneration( ...@@ -356,14 +356,12 @@ class FireRedASR2ForConditionalGeneration(
@classmethod @classmethod
def get_generation_prompt( def get_generation_prompt(
cls, cls,
audio: np.ndarray, stt_params: SpeechToTextParams,
model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType: ) -> PromptType:
audio = stt_params.audio
stt_config = stt_params.stt_config
language = stt_params.language
if language is None: if language is None:
raise ValueError( raise ValueError(
"Language must be specified when creating the fireredasr2 prompt" "Language must be specified when creating the fireredasr2 prompt"
......
...@@ -3,9 +3,8 @@ ...@@ -3,9 +3,8 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal, cast from typing import Annotated, cast
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
...@@ -16,6 +15,7 @@ from transformers import ( ...@@ -16,6 +15,7 @@ from transformers import (
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import MultiModalDataDict, PromptType from vllm.inputs import MultiModalDataDict, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -876,14 +876,12 @@ class FunASRForConditionalGeneration( ...@@ -876,14 +876,12 @@ class FunASRForConditionalGeneration(
@classmethod @classmethod
def get_generation_prompt( def get_generation_prompt(
cls, cls,
audio: np.ndarray, stt_params: SpeechToTextParams,
model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType: ) -> PromptType:
audio = stt_params.audio
stt_config = stt_params.stt_config
language = stt_params.language
if language is None: if language is None:
raise ValueError( raise ValueError(
"Language must be specified when creating the funasr prompt" "Language must be specified when creating the funasr prompt"
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
import numpy as np
import torch import torch
from torch import nn from torch import nn
from transformers import AutoModel, BatchFeature from transformers import AutoModel, BatchFeature
...@@ -19,6 +18,7 @@ from transformers.models.siglip import SiglipImageProcessorFast ...@@ -19,6 +18,7 @@ from transformers.models.siglip import SiglipImageProcessorFast
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.inputs import MultiModalDataDict, PromptType, TextPrompt from vllm.inputs import MultiModalDataDict, PromptType, TextPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -769,21 +769,17 @@ class Gemma3nForConditionalGeneration( ...@@ -769,21 +769,17 @@ class Gemma3nForConditionalGeneration(
raise ValueError(f"Unsupported modality: {modality}") raise ValueError(f"Unsupported modality: {modality}")
@classmethod @classmethod
def get_generation_prompt( def get_generation_prompt(cls, stt_params: SpeechToTextParams) -> PromptType:
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
""" """
Gemma3n supports "free-form" transcription. Gemma3n supports "free-form" transcription.
We fix its prompt here to standardize transcriptions/translations We fix its prompt here to standardize transcriptions/translations
requests. requests.
""" """
audio = stt_params.audio
stt_config = stt_params.stt_config
language = stt_params.language
task_type = stt_params.task_type
to_language = stt_params.to_language
# Transcribe this audio [into <>] | for transcription # Transcribe this audio [into <>] | for transcription
# Translate this audio [from <> into <>] | for translation # Translate this audio [from <> into <>] | for translation
prompt = "<start_of_turn>user\n" prompt = "<start_of_turn>user\n"
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import BatchFeature from transformers import BatchFeature
...@@ -13,6 +12,7 @@ from transformers.models.whisper import WhisperFeatureExtractor ...@@ -13,6 +12,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.inputs import ModalityData, MultiModalDataDict, PromptType, TokensPrompt from vllm.inputs import ModalityData, MultiModalDataDict, PromptType, TokensPrompt
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -1131,17 +1131,12 @@ class GlmAsrForConditionalGeneration( ...@@ -1131,17 +1131,12 @@ class GlmAsrForConditionalGeneration(
) )
@classmethod @classmethod
def get_generation_prompt( def get_generation_prompt(cls, stt_params: SpeechToTextParams) -> PromptType:
cls,
audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
"""Get the generation prompt to be used for transcription requests.""" """Get the generation prompt to be used for transcription requests."""
audio = stt_params.audio
model_config = stt_params.model_config
task_type = stt_params.task_type
to_language = stt_params.to_language
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(model_config)
audio_token = cls._get_audio_token(model_config) audio_token = cls._get_audio_token(model_config)
......
...@@ -26,9 +26,8 @@ ...@@ -26,9 +26,8 @@
import math import math
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from typing import Annotated, Literal from typing import Annotated
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
...@@ -36,6 +35,7 @@ from transformers import BatchFeature, PretrainedConfig ...@@ -36,6 +35,7 @@ from transformers import BatchFeature, PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.inputs import MultiModalDataDict, PromptType, TokensPrompt from vllm.inputs import MultiModalDataDict, PromptType, TokensPrompt
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -852,15 +852,14 @@ class GraniteSpeechForConditionalGeneration( ...@@ -852,15 +852,14 @@ class GraniteSpeechForConditionalGeneration(
@classmethod @classmethod
def get_generation_prompt( def get_generation_prompt(
cls, cls,
audio: np.ndarray, stt_params: SpeechToTextParams,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType: ) -> PromptType:
"""Get the generation prompt to be used for transcription requests.""" """Get the generation prompt to be used for transcription requests."""
audio = stt_params.audio
model_config = stt_params.model_config
task_type = stt_params.task_type
to_language = stt_params.to_language
# Audio placeholders don't use an index, so value doesn't matter # Audio placeholders don't use an index, so value doesn't matter
audio_tok = cls.get_placeholder_str("audio", 0) audio_tok = cls.get_placeholder_str("audio", 0)
......
...@@ -29,7 +29,7 @@ from torch import Tensor ...@@ -29,7 +29,7 @@ from torch import Tensor
from transformers.models.whisper.tokenization_whisper import LANGUAGES from transformers.models.whisper.tokenization_whisper import LANGUAGES
from typing_extensions import Self, TypeIs from typing_extensions import Self, TypeIs
from vllm.config import ModelConfig, SpeechToTextConfig from vllm.config import ModelConfig, SpeechToTextConfig, SpeechToTextParams
from vllm.inputs import PromptType, TokensPrompt from vllm.inputs import PromptType, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
...@@ -1119,13 +1119,7 @@ class SupportsTranscription(Protocol): ...@@ -1119,13 +1119,7 @@ class SupportsTranscription(Protocol):
@classmethod @classmethod
def get_generation_prompt( def get_generation_prompt(
cls, cls,
audio: np.ndarray, stt_params: SpeechToTextParams,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType: ) -> PromptType:
"""Get the prompt for the ASR model. """Get the prompt for the ASR model.
The model has control over the construction, as long as it The model has control over the construction, as long as it
......
...@@ -14,6 +14,7 @@ from transformers import WhisperConfig as HFWhisperConfig ...@@ -14,6 +14,7 @@ from transformers import WhisperConfig as HFWhisperConfig
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.inputs import PromptType, TokensPrompt from vllm.inputs import PromptType, TokensPrompt
from vllm.model_executor.model_loader import DefaultModelLoader from vllm.model_executor.model_loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -626,16 +627,12 @@ class KimiAudioForConditionalGeneration( ...@@ -626,16 +627,12 @@ class KimiAudioForConditionalGeneration(
) )
@classmethod @classmethod
def get_generation_prompt( def get_generation_prompt(cls, stt_params: SpeechToTextParams) -> PromptType:
cls, audio = stt_params.audio
audio: np.ndarray, model_config = stt_params.model_config
model_config: ModelConfig, task_type = stt_params.task_type
stt_config: SpeechToTextConfig, request_prompt = stt_params.request_prompt
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
tokenizer = cached_get_tokenizer( tokenizer = cached_get_tokenizer(
model_config.tokenizer, model_config.tokenizer,
tokenizer_cls=KimiAudioTokenizer, tokenizer_cls=KimiAudioTokenizer,
......
...@@ -23,9 +23,8 @@ ...@@ -23,9 +23,8 @@
"""Inference-only Qwen3-ASR model.""" """Inference-only Qwen3-ASR model."""
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal from typing import Any
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers.feature_extraction_utils import BatchFeature from transformers.feature_extraction_utils import BatchFeature
...@@ -33,6 +32,7 @@ from transformers.models.whisper import WhisperFeatureExtractor ...@@ -33,6 +32,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.inputs import ModalityData, MultiModalDataDict, PromptType, TokensPrompt from vllm.inputs import ModalityData, MultiModalDataDict, PromptType, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import ( from vllm.model_executor.models.interfaces import (
...@@ -549,17 +549,12 @@ class Qwen3ASRForConditionalGeneration( ...@@ -549,17 +549,12 @@ class Qwen3ASRForConditionalGeneration(
) )
@classmethod @classmethod
def get_generation_prompt( def get_generation_prompt(cls, stt_params: SpeechToTextParams) -> PromptType:
cls,
audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
"""Get the generation prompt to be used for transcription requests.""" """Get the generation prompt to be used for transcription requests."""
audio = stt_params.audio
model_config = stt_params.model_config
task_type = stt_params.task_type
to_language = stt_params.to_language
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(model_config)
audio_placeholder = cls.get_placeholder_str("audio", 0) audio_placeholder = cls.get_placeholder_str("audio", 0)
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import partial from functools import partial
from typing import Any, Literal, cast from typing import Any, cast
import numpy as np import numpy as np
import torch import torch
...@@ -46,6 +46,7 @@ from transformers.models.whisper import WhisperFeatureExtractor ...@@ -46,6 +46,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -2201,19 +2202,17 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -2201,19 +2202,17 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
) )
@classmethod @classmethod
def get_generation_prompt( def get_generation_prompt(cls, stt_params: SpeechToTextParams) -> PromptType:
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
""" """
Construct a transcription/translation prompt for Qwen3-Omni. Construct a transcription/translation prompt for Qwen3-Omni.
""" """
audio = stt_params.audio
stt_config = stt_params.stt_config
model_config = stt_params.model_config
language = stt_params.language
task_type = stt_params.task_type
to_language = stt_params.to_language
request_prompt = stt_params.request_prompt
# Transcribe this audio [into <language>] | for transcription # Transcribe this audio [into <language>] | for transcription
# Translate this audio [from <language> into <to_language>] | for translation # Translate this audio [from <language> into <to_language>] | for translation
instruction = "Transcribe" if task_type == "transcribe" else "Translate" instruction = "Transcribe" if task_type == "transcribe" else "Translate"
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import Literal, cast from typing import cast
import numpy as np import numpy as np
import regex as re import regex as re
...@@ -19,6 +19,7 @@ from transformers import BatchFeature, WhisperConfig ...@@ -19,6 +19,7 @@ from transformers import BatchFeature, WhisperConfig
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.inputs import MultiModalDataDict, PromptType, TokensPrompt from vllm.inputs import MultiModalDataDict, PromptType, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -446,14 +447,13 @@ class VoxtralForConditionalGeneration( ...@@ -446,14 +447,13 @@ class VoxtralForConditionalGeneration(
# for speech-to-text transcription # for speech-to-text transcription
def get_generation_prompt( def get_generation_prompt(
cls, cls,
audio: np.ndarray, stt_params: SpeechToTextParams,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType: ) -> PromptType:
audio = stt_params.audio
model_config = stt_params.model_config
stt_config = stt_params.stt_config
language = stt_params.language
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(model_config)
audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless
req = TranscriptionRequest( req = TranscriptionRequest(
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import asyncio import asyncio
import math import math
from collections.abc import AsyncGenerator, Iterable, Iterator, Mapping from collections.abc import AsyncGenerator, Iterable, Iterator, Mapping
from typing import Literal
import numpy as np import numpy as np
import torch import torch
...@@ -18,6 +17,7 @@ from mistral_common.tokens.tokenizers.audio import AudioConfig ...@@ -18,6 +17,7 @@ from mistral_common.tokens.tokenizers.audio import AudioConfig
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.engine.protocol import StreamingInput from vllm.engine.protocol import StreamingInput
from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S
from vllm.inputs import PromptType, TokensPrompt from vllm.inputs import PromptType, TokensPrompt
...@@ -465,14 +465,13 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim ...@@ -465,14 +465,13 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
# for speech-to-text transcription # for speech-to-text transcription
def get_generation_prompt( def get_generation_prompt(
cls, cls,
audio: np.ndarray, stt_params: SpeechToTextParams,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType: ) -> PromptType:
audio = stt_params.audio
model_config = stt_params.model_config
stt_config = stt_params.stt_config
language = stt_params.language
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(model_config)
audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless
......
...@@ -5,7 +5,7 @@ import enum ...@@ -5,7 +5,7 @@ import enum
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from contextlib import nullcontext from contextlib import nullcontext
from typing import Annotated, Literal from typing import Annotated
import numpy as np import numpy as np
import torch import torch
...@@ -20,6 +20,7 @@ from transformers.models.whisper.modeling_whisper import sinusoids ...@@ -20,6 +20,7 @@ from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import ( from vllm.inputs import (
ExplicitEncoderDecoderPrompt, ExplicitEncoderDecoderPrompt,
...@@ -830,14 +831,14 @@ class WhisperForConditionalGeneration( ...@@ -830,14 +831,14 @@ class WhisperForConditionalGeneration(
@classmethod @classmethod
def get_generation_prompt( def get_generation_prompt(
cls, cls,
audio: np.ndarray, stt_params: SpeechToTextParams,
model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType: ) -> PromptType:
audio = stt_params.audio
stt_config = stt_params.stt_config
language = stt_params.language
task_type = stt_params.task_type
request_prompt = stt_params.request_prompt
if language is None: if language is None:
raise ValueError( raise ValueError(
"Language must be specified when creating the Whisper prompt" "Language must be specified when creating the Whisper prompt"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment