Unverified Commit 4eafc729 authored by Ekagra Ranjan's avatar Ekagra Ranjan Committed by GitHub
Browse files

[Audio] Bundle `get_generation_prompt()` params into `SpeechToTextParams` (#36268)


Signed-off-by: default avatarEkagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 6d097697
......@@ -66,7 +66,7 @@ This is for controlling general behavior of the API when serving your model:
See [Audio preprocessing and chunking](#audio-preprocessing-and-chunking) for what each field controls.
Implement the prompt construction via [get_generation_prompt][vllm.model_executor.models.interfaces.SupportsTranscription.get_generation_prompt]. The server passes you the resampled waveform and task parameters; you return a valid [PromptType][vllm.inputs.llm.PromptType]. There are two common patterns:
Implement the prompt construction via [get_generation_prompt][vllm.model_executor.models.interfaces.SupportsTranscription.get_generation_prompt]. The server builds a [SpeechToTextParams][vllm.config.speech_to_text.SpeechToTextParams] object that bundles the resampled waveform, task parameters, and request-specific options. Your model receives this single object and returns a valid [PromptType][vllm.inputs.llm.PromptType]. There are two common patterns:
#### Multimodal LLM with audio embeddings (e.g., Voxtral, Gemma3n)
......@@ -75,21 +75,20 @@ Return a dict containing `multi_modal_data` with the audio, and either a `prompt
??? code "get_generation_prompt()"
```python
from vllm.config.speech_to_text import SpeechToTextParams
class YourASRModel(nn.Module, SupportsTranscription):
...
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
stt_params: SpeechToTextParams,
) -> PromptType:
# Example with a free-form instruction prompt
audio = stt_params.audio
stt_config = stt_params.stt_config
task_type = stt_params.task_type
task_word = "Transcribe" if task_type == "transcribe" else "Translate"
prompt = (
"<start_of_turn>user\n"
......@@ -112,20 +111,22 @@ Return a dict with separate `encoder_prompt` and `decoder_prompt` entries:
??? code "get_generation_prompt()"
```python
from vllm.config.speech_to_text import SpeechToTextParams
class YourASRModel(nn.Module, SupportsTranscription):
...
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
stt_params: SpeechToTextParams,
) -> PromptType:
audio = stt_params.audio
stt_config = stt_params.stt_config
language = stt_params.language
task_type = stt_params.task_type
request_prompt = stt_params.request_prompt
if language is None:
raise ValueError("Language must be specified")
......@@ -213,15 +214,13 @@ Relevant server logic:
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
prompts = []
for chunk in chunks:
prompt = self.model_cls.get_generation_prompt(
stt_params = request.build_stt_params(
audio=chunk,
stt_config=self.asr_config,
model_config=self.model_config,
language=language,
task_type=self.task_type,
request_prompt=request.prompt,
to_language=to_language,
)
prompt = self.model_cls.get_generation_prompt(stt_params)
prompts.append(prompt)
return prompts, duration
```
......
......@@ -37,7 +37,7 @@ from vllm.config.profiler import ProfilerConfig
from vllm.config.reasoning import ReasoningConfig
from vllm.config.scheduler import SchedulerConfig
from vllm.config.speculative import SpeculativeConfig
from vllm.config.speech_to_text import SpeechToTextConfig
from vllm.config.speech_to_text import SpeechToTextConfig, SpeechToTextParams
from vllm.config.structured_outputs import StructuredOutputsConfig
from vllm.config.utils import (
ConfigType,
......@@ -113,6 +113,7 @@ __all__ = [
"SpeculativeConfig",
# From vllm.config.speech_to_text
"SpeechToTextConfig",
"SpeechToTextParams",
# From vllm.config.structured_outputs
"StructuredOutputsConfig",
# From vllm.config.profiler
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
from vllm.config.utils import config
if TYPE_CHECKING:
import numpy as np
from vllm.config.model import ModelConfig
@dataclass
class SpeechToTextParams:
"""All parameters consumed by ``get_generation_prompt()``.
``TranscriptionRequest.build_stt_params()`` constructs this object,
mapping API-level fields into typed attributes. Models only receive
this object, so new parameters can be added here without changing the
``get_generation_prompt`` signature.
"""
audio: np.ndarray
"""Resampled audio waveform for a single chunk."""
stt_config: SpeechToTextConfig
"""Server-level speech-to-text configuration."""
model_config: ModelConfig
"""Model configuration."""
language: str | None = None
"""ISO 639-1 language code (validated / auto-detected)."""
task_type: str = "transcribe"
"""``"transcribe"`` or ``"translate"``."""
request_prompt: str = ""
"""Optional text prompt to guide the model."""
to_language: str | None = None
"""Target language for translation (model-dependent)."""
@config
class SpeechToTextConfig:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import time
from http import HTTPStatus
from typing import Literal, TypeAlias
from typing import TYPE_CHECKING, Literal, TypeAlias
import torch
from fastapi import HTTPException, UploadFile
......@@ -12,6 +13,7 @@ from pydantic import (
model_validator,
)
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage,
OpenAIBaseModel,
......@@ -26,6 +28,11 @@ from vllm.sampling_params import (
)
from vllm.utils import random_uuid
if TYPE_CHECKING:
import numpy as np
from vllm.config import ModelConfig, SpeechToTextConfig
logger = init_logger(__name__)
_LONG_INFO = torch.iinfo(torch.long)
......@@ -183,6 +190,23 @@ class TranscriptionRequest(OpenAIBaseModel):
"min_p": 0.0,
}
def build_stt_params(
self,
audio: "np.ndarray",
stt_config: "SpeechToTextConfig",
model_config: "ModelConfig",
task_type: str,
) -> SpeechToTextParams:
return SpeechToTextParams(
audio=audio,
stt_config=stt_config,
model_config=model_config,
language=self.language,
task_type=task_type,
request_prompt=self.prompt,
to_language=self.to_language,
)
def to_beam_search_params(
self,
default_max_tokens: int,
......@@ -277,6 +301,17 @@ class TranscriptionRequest(OpenAIBaseModel):
parameter=invalid_param,
)
# Parse vllm_xargs from JSON string (form data sends it as a string)
xargs = data.get("vllm_xargs")
if isinstance(xargs, str):
try:
data["vllm_xargs"] = json.loads(xargs)
except json.JSONDecodeError as e:
raise VLLMValidationError(
f"Failed to parse vllm_xargs. Must be valid JSON: {e}",
parameter="vllm_xargs",
) from e
return data
......@@ -472,6 +507,23 @@ class TranslationRequest(OpenAIBaseModel):
"temperature": 0,
}
def build_stt_params(
self,
audio: "np.ndarray",
stt_config: "SpeechToTextConfig",
model_config: "ModelConfig",
task_type: str,
) -> SpeechToTextParams:
return SpeechToTextParams(
audio=audio,
stt_config=stt_config,
model_config=model_config,
language=self.language,
task_type=task_type,
request_prompt=self.prompt,
to_language=self.to_language,
)
def to_beam_search_params(
self,
default_max_tokens: int,
......
......@@ -184,9 +184,8 @@ class OpenAISpeechToText(OpenAIServing):
request_id: str,
) -> tuple[list[EngineInput], float]:
# Validate request
language = self.model_cls.validate_language(request.language)
# Skip to_language validation to avoid extra logging for Whisper.
to_language = (
request.language = self.model_cls.validate_language(request.language)
request.to_language = (
self.model_cls.validate_language(request.to_language)
if request.to_language
else None
......@@ -229,28 +228,23 @@ class OpenAISpeechToText(OpenAIServing):
min_energy_window_size=self.asr_config.min_energy_split_window_size,
)
if language is None and getattr(
if request.language is None and getattr(
self.model_cls, "supports_explicit_language_detection", False
):
# Auto-detect language from the first chunk.
language = await self._detect_language(
request.language = await self._detect_language(
chunks[0], f"{request_id}-lang_detect"
)
request.language = language
parsed_prompts: list[DictPrompt] = []
for chunk in chunks:
# The model has control over the construction, as long as it
# returns a valid PromptType.
prompt = self.model_cls.get_generation_prompt(
stt_params = request.build_stt_params(
audio=chunk,
stt_config=self.asr_config,
model_config=self.model_config,
language=language,
task_type=self.task_type,
request_prompt=request.prompt,
to_language=to_language,
)
prompt = self.model_cls.get_generation_prompt(stt_params)
parsed_prompt: DictPrompt
if request.response_format == "verbose_json":
......
......@@ -3,9 +3,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
......@@ -13,6 +11,8 @@ from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import MultiModalDataDict, PromptType, TextPrompt
from vllm.logger import init_logger
......@@ -1900,7 +1900,7 @@ class CohereASRDummyInputsBuilder(BaseDummyInputsBuilder[CohereASRProcessingInfo
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options=None,
mm_options: Mapping[str, BaseDummyOptions],
mm_processor_kwargs=None,
) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor()
......@@ -2021,16 +2021,12 @@ class CohereAsrForConditionalGeneration(
return super().validate_language(language)
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
def get_generation_prompt(cls, stt_params: SpeechToTextParams) -> PromptType:
audio = stt_params.audio
stt_config = stt_params.stt_config
language = stt_params.language
request_prompt = stt_params.request_prompt
if language is None:
raise ValueError(
"Language must be specified when creating the CohereASR prompt"
......
......@@ -2,9 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal, cast
from typing import Annotated, cast
import numpy as np
import torch
from torch import nn
from transformers import (
......@@ -14,6 +13,7 @@ from transformers import (
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.inputs import MultiModalDataDict, PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
......@@ -356,14 +356,12 @@ class FireRedASR2ForConditionalGeneration(
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
stt_params: SpeechToTextParams,
) -> PromptType:
audio = stt_params.audio
stt_config = stt_params.stt_config
language = stt_params.language
if language is None:
raise ValueError(
"Language must be specified when creating the fireredasr2 prompt"
......
......@@ -3,9 +3,8 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal, cast
from typing import Annotated, cast
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
......@@ -16,6 +15,7 @@ from transformers import (
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import MultiModalDataDict, PromptType
from vllm.logger import init_logger
......@@ -876,14 +876,12 @@ class FunASRForConditionalGeneration(
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
stt_params: SpeechToTextParams,
) -> PromptType:
audio = stt_params.audio
stt_config = stt_params.stt_config
language = stt_params.language
if language is None:
raise ValueError(
"Language must be specified when creating the funasr prompt"
......
......@@ -3,7 +3,6 @@
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal
import numpy as np
import torch
from torch import nn
from transformers import AutoModel, BatchFeature
......@@ -19,6 +18,7 @@ from transformers.models.siglip import SiglipImageProcessorFast
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.inputs import MultiModalDataDict, PromptType, TextPrompt
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -769,21 +769,17 @@ class Gemma3nForConditionalGeneration(
raise ValueError(f"Unsupported modality: {modality}")
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
def get_generation_prompt(cls, stt_params: SpeechToTextParams) -> PromptType:
"""
Gemma3n supports "free-form" transcription.
We fix its prompt here to standardize transcriptions/translations
requests.
"""
audio = stt_params.audio
stt_config = stt_params.stt_config
language = stt_params.language
task_type = stt_params.task_type
to_language = stt_params.to_language
# Transcribe this audio [into <>] | for transcription
# Translate this audio [from <> into <>] | for translation
prompt = "<start_of_turn>user\n"
......
......@@ -4,7 +4,6 @@
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias
import numpy as np
import torch
import torch.nn as nn
from transformers import BatchFeature
......@@ -13,6 +12,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.inputs import ModalityData, MultiModalDataDict, PromptType, TokensPrompt
from vllm.model_executor.layers.activation import get_act_fn
......@@ -1131,17 +1131,12 @@ class GlmAsrForConditionalGeneration(
)
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
def get_generation_prompt(cls, stt_params: SpeechToTextParams) -> PromptType:
"""Get the generation prompt to be used for transcription requests."""
audio = stt_params.audio
model_config = stt_params.model_config
task_type = stt_params.task_type
to_language = stt_params.to_language
tokenizer = cached_tokenizer_from_config(model_config)
audio_token = cls._get_audio_token(model_config)
......
......@@ -26,9 +26,8 @@
import math
from collections.abc import Iterable, Mapping
from typing import Annotated, Literal
from typing import Annotated
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
......@@ -36,6 +35,7 @@ from transformers import BatchFeature, PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.inputs import MultiModalDataDict, PromptType, TokensPrompt
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
......@@ -852,15 +852,14 @@ class GraniteSpeechForConditionalGeneration(
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
stt_params: SpeechToTextParams,
) -> PromptType:
"""Get the generation prompt to be used for transcription requests."""
audio = stt_params.audio
model_config = stt_params.model_config
task_type = stt_params.task_type
to_language = stt_params.to_language
# Audio placeholders don't use an index, so value doesn't matter
audio_tok = cls.get_placeholder_str("audio", 0)
......
......@@ -29,7 +29,7 @@ from torch import Tensor
from transformers.models.whisper.tokenization_whisper import LANGUAGES
from typing_extensions import Self, TypeIs
from vllm.config import ModelConfig, SpeechToTextConfig
from vllm.config import ModelConfig, SpeechToTextConfig, SpeechToTextParams
from vllm.inputs import PromptType, TokensPrompt
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
......@@ -1119,13 +1119,7 @@ class SupportsTranscription(Protocol):
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
stt_params: SpeechToTextParams,
) -> PromptType:
"""Get the prompt for the ASR model.
The model has control over the construction, as long as it
......
......@@ -14,6 +14,7 @@ from transformers import WhisperConfig as HFWhisperConfig
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.inputs import PromptType, TokensPrompt
from vllm.model_executor.model_loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -626,16 +627,12 @@ class KimiAudioForConditionalGeneration(
)
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
def get_generation_prompt(cls, stt_params: SpeechToTextParams) -> PromptType:
audio = stt_params.audio
model_config = stt_params.model_config
task_type = stt_params.task_type
request_prompt = stt_params.request_prompt
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
tokenizer_cls=KimiAudioTokenizer,
......
......@@ -23,9 +23,8 @@
"""Inference-only Qwen3-ASR model."""
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from transformers.feature_extraction_utils import BatchFeature
......@@ -33,6 +32,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.inputs import ModalityData, MultiModalDataDict, PromptType, TokensPrompt
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import (
......@@ -549,17 +549,12 @@ class Qwen3ASRForConditionalGeneration(
)
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
def get_generation_prompt(cls, stt_params: SpeechToTextParams) -> PromptType:
"""Get the generation prompt to be used for transcription requests."""
audio = stt_params.audio
model_config = stt_params.model_config
task_type = stt_params.task_type
to_language = stt_params.to_language
tokenizer = cached_tokenizer_from_config(model_config)
audio_placeholder = cls.get_placeholder_str("audio", 0)
......
......@@ -24,7 +24,7 @@
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import partial
from typing import Any, Literal, cast
from typing import Any, cast
import numpy as np
import torch
......@@ -46,6 +46,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.compilation.decorators import support_torch_compile
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import PromptType
from vllm.logger import init_logger
......@@ -2201,19 +2202,17 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
def get_generation_prompt(cls, stt_params: SpeechToTextParams) -> PromptType:
"""
Construct a transcription/translation prompt for Qwen3-Omni.
"""
audio = stt_params.audio
stt_config = stt_params.stt_config
model_config = stt_params.model_config
language = stt_params.language
task_type = stt_params.task_type
to_language = stt_params.to_language
request_prompt = stt_params.request_prompt
# Transcribe this audio [into <language>] | for transcription
# Translate this audio [from <language> into <to_language>] | for translation
instruction = "Transcribe" if task_type == "transcribe" else "Translate"
......
......@@ -4,7 +4,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import Literal, cast
from typing import cast
import numpy as np
import regex as re
......@@ -19,6 +19,7 @@ from transformers import BatchFeature, WhisperConfig
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.inputs import MultiModalDataDict, PromptType, TokensPrompt
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
......@@ -446,14 +447,13 @@ class VoxtralForConditionalGeneration(
# for speech-to-text transcription
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
stt_params: SpeechToTextParams,
) -> PromptType:
audio = stt_params.audio
model_config = stt_params.model_config
stt_config = stt_params.stt_config
language = stt_params.language
tokenizer = cached_tokenizer_from_config(model_config)
audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless
req = TranscriptionRequest(
......
......@@ -4,7 +4,6 @@
import asyncio
import math
from collections.abc import AsyncGenerator, Iterable, Iterator, Mapping
from typing import Literal
import numpy as np
import torch
......@@ -18,6 +17,7 @@ from mistral_common.tokens.tokenizers.audio import AudioConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.engine.protocol import StreamingInput
from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S
from vllm.inputs import PromptType, TokensPrompt
......@@ -465,14 +465,13 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
# for speech-to-text transcription
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
stt_params: SpeechToTextParams,
) -> PromptType:
audio = stt_params.audio
model_config = stt_params.model_config
stt_config = stt_params.stt_config
language = stt_params.language
tokenizer = cached_tokenizer_from_config(model_config)
audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless
......
......@@ -5,7 +5,7 @@ import enum
import math
from collections.abc import Iterable, Mapping, Sequence
from contextlib import nullcontext
from typing import Annotated, Literal
from typing import Annotated
import numpy as np
import torch
......@@ -20,6 +20,7 @@ from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.speech_to_text import SpeechToTextParams
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import (
ExplicitEncoderDecoderPrompt,
......@@ -830,14 +831,14 @@ class WhisperForConditionalGeneration(
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
stt_params: SpeechToTextParams,
) -> PromptType:
audio = stt_params.audio
stt_config = stt_params.stt_config
language = stt_params.language
task_type = stt_params.task_type
request_prompt = stt_params.request_prompt
if language is None:
raise ValueError(
"Language must be specified when creating the Whisper prompt"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment