Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3c7d942d
Unverified
Commit
3c7d942d
authored
Jul 12, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Jul 11, 2025
Browse files
[Frontend] Abstract prompt and SpeechToTextConfig for transcriptions models (#20637)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
890323dc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
141 additions
and
60 deletions
+141
-60
vllm/config.py
vllm/config.py
+31
-0
vllm/entrypoints/openai/speech_to_text.py
vllm/entrypoints/openai/speech_to_text.py
+33
-50
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+29
-3
vllm/model_executor/models/whisper.py
vllm/model_executor/models/whisper.py
+48
-7
No files found.
vllm/config.py
View file @
3c7d942d
...
@@ -4958,3 +4958,34 @@ def get_layers_from_vllm_config(vllm_config: VllmConfig,
...
@@ -4958,3 +4958,34 @@ def get_layers_from_vllm_config(vllm_config: VllmConfig,
vllm_config
.
compilation_config
.
static_forward_context
.
items
()
vllm_config
.
compilation_config
.
static_forward_context
.
items
()
if
isinstance
(
layer
,
layer_type
)
if
isinstance
(
layer
,
layer_type
)
}
}
@
config
@
dataclass
class
SpeechToTextConfig
:
"""Configuration for speech-to-text models."""
sample_rate
:
float
=
16_000
"""Sample rate (Hz) to resample input audio to. Most speech models expect
16kHz audio input. The input audio will be automatically resampled to this
rate before processing."""
max_audio_clip_s
:
int
=
30
"""Maximum duration in seconds for a single audio clip without chunking.
Audio longer than this will be split into smaller chunks if
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
overlap_chunk_second
:
int
=
1
"""Overlap duration in seconds between consecutive audio chunks when
splitting long audio. This helps maintain context across chunk boundaries
and improves transcription quality at split points."""
min_energy_split_window_size
:
Optional
[
int
]
=
1600
"""Window size in samples for finding low-energy (quiet) regions to split
audio chunks. The algorithm looks for the quietest moment within this
window to minimize cutting through speech. Default 1600 samples ≈ 100ms
at 16kHz. If None, no chunking will be done."""
@
property
def
allow_audio_chunking
(
self
)
->
bool
:
return
self
.
min_energy_split_window_size
is
not
None
\ No newline at end of file
vllm/entrypoints/openai/speech_to_text.py
View file @
3c7d942d
...
@@ -6,7 +6,6 @@ import math
...
@@ -6,7 +6,6 @@ import math
import
time
import
time
from
collections.abc
import
AsyncGenerator
from
collections.abc
import
AsyncGenerator
from
functools
import
cached_property
from
functools
import
cached_property
from
math
import
ceil
from
typing
import
Callable
,
Literal
,
Optional
,
TypeVar
,
Union
,
cast
from
typing
import
Callable
,
Literal
,
Optional
,
TypeVar
,
Union
,
cast
import
numpy
as
np
import
numpy
as
np
...
@@ -28,7 +27,6 @@ from vllm.logger import init_logger
...
@@ -28,7 +27,6 @@ from vllm.logger import init_logger
from
vllm.model_executor.model_loader
import
get_model_cls
from
vllm.model_executor.model_loader
import
get_model_cls
from
vllm.model_executor.models
import
SupportsTranscription
from
vllm.model_executor.models
import
SupportsTranscription
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.transformers_utils.processor
import
cached_get_processor
from
vllm.utils
import
PlaceholderModule
from
vllm.utils
import
PlaceholderModule
try
:
try
:
...
@@ -44,9 +42,6 @@ logger = init_logger(__name__)
...
@@ -44,9 +42,6 @@ logger = init_logger(__name__)
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
# TODO configurable
# TODO configurable
MAX_AUDIO_CLIP_FILESIZE_MB
=
25
MAX_AUDIO_CLIP_FILESIZE_MB
=
25
MAX_AUDIO_CLIP_SECONDS
=
30
OVERLAP_CHUNK_SECOND
=
1
MIN_ENERGY_WINDOW_SIZE
=
1600
# 1600 ~ 100ms for 16000 Hz audio
class
OpenAISpeechToText
(
OpenAIServing
):
class
OpenAISpeechToText
(
OpenAIServing
):
...
@@ -71,36 +66,32 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -71,36 +66,32 @@ class OpenAISpeechToText(OpenAIServing):
self
.
default_sampling_params
=
(
self
.
default_sampling_params
=
(
self
.
model_config
.
get_diff_sampling_param
())
self
.
model_config
.
get_diff_sampling_param
())
processor
=
cached_get_processor
(
model_config
.
model
)
self
.
max_audio_clip_s
=
processor
.
feature_extractor
.
chunk_length
\
if
hasattr
(
processor
.
feature_extractor
,
'chunk_length'
)
\
else
MAX_AUDIO_CLIP_SECONDS
self
.
model_sr
=
processor
.
feature_extractor
.
sampling_rate
self
.
hop_length
=
processor
.
feature_extractor
.
hop_length
self
.
task_type
=
task_type
self
.
task_type
=
task_type
self
.
asr_config
=
self
.
model_cls
.
get_speech_to_text_config
(
model_config
,
task_type
)
if
self
.
default_sampling_params
:
if
self
.
default_sampling_params
:
logger
.
info
(
logger
.
info
(
"Overwriting default completion sampling param with: %s"
,
"Overwriting default completion sampling param with: %s"
,
self
.
default_sampling_params
)
self
.
default_sampling_params
)
@
cached_property
@
cached_property
def
model_cls
(
self
):
def
model_cls
(
self
)
->
type
[
SupportsTranscription
]:
return
get_model_cls
(
self
.
model_config
)
model_cls
=
get_model_cls
(
self
.
model_config
)
return
cast
(
type
[
SupportsTranscription
],
model_cls
)
async
def
_preprocess_speech_to_text
(
async
def
_preprocess_speech_to_text
(
self
,
self
,
request
:
SpeechToTextRequest
,
request
:
SpeechToTextRequest
,
audio_data
:
bytes
,
audio_data
:
bytes
,
)
->
tuple
[
list
[
PromptType
],
float
]:
)
->
tuple
[
list
[
PromptType
],
float
]:
model_cls
=
cast
(
SupportsTranscription
,
self
.
model_cls
)
# Validate request
# Validate request
# TODO language should be optional and can be guessed.
# TODO language should be optional and can be guessed.
# For now we default to en. See
# For now we default to en. See
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
lang
=
request
.
language
or
"en"
lang
=
request
.
language
or
"en"
model_cls
.
validate_language
(
lang
)
self
.
model_cls
.
validate_language
(
lang
)
if
len
(
audio_data
)
/
1024
**
2
>
MAX_AUDIO_CLIP_FILESIZE_MB
:
if
len
(
audio_data
)
/
1024
**
2
>
MAX_AUDIO_CLIP_FILESIZE_MB
:
raise
ValueError
(
"Maximum file size exceeded."
)
raise
ValueError
(
"Maximum file size exceeded."
)
...
@@ -108,26 +99,23 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -108,26 +99,23 @@ class OpenAISpeechToText(OpenAIServing):
with
io
.
BytesIO
(
audio_data
)
as
bytes_
:
with
io
.
BytesIO
(
audio_data
)
as
bytes_
:
# NOTE resample to model SR here for efficiency. This is also a
# NOTE resample to model SR here for efficiency. This is also a
# pre-requisite for chunking, as it assumes Whisper SR.
# pre-requisite for chunking, as it assumes Whisper SR.
y
,
sr
=
librosa
.
load
(
bytes_
,
sr
=
self
.
model_sr
)
y
,
sr
=
librosa
.
load
(
bytes_
,
sr
=
self
.
asr_config
.
sample_rate
)
duration
=
librosa
.
get_duration
(
y
=
y
,
sr
=
sr
)
duration
=
librosa
.
get_duration
(
y
=
y
,
sr
=
sr
)
chunks
=
[
y
do_split_audio
=
(
self
.
asr_config
.
allow_audio_chunking
]
if
duration
<
self
.
max_audio_clip_s
else
self
.
_split_audio
(
and
duration
>
self
.
asr_config
.
max_audio_clip_s
)
y
,
int
(
sr
))
chunks
=
[
y
]
if
not
do_split_audio
else
self
.
_split_audio
(
y
,
int
(
sr
))
prompts
=
[]
prompts
=
[]
for
chunk
in
chunks
:
for
chunk
in
chunks
:
prompt
=
{
# The model has control over the construction, as long as it
"encoder_prompt"
:
{
# returns a valid PromptType.
"prompt"
:
""
,
prompt
=
self
.
model_cls
.
get_generation_prompt
(
"multi_modal_data"
:
{
audio
=
chunk
,
"audio"
:
(
chunk
,
sr
),
stt_config
=
self
.
asr_config
,
},
language
=
lang
,
},
task_type
=
self
.
task_type
,
"decoder_prompt"
:
request_prompt
=
request
.
prompt
)
model_cls
.
get_decoder_prompt
(
lang
,
self
.
task_type
,
prompts
.
append
(
prompt
)
request
.
prompt
)
}
prompts
.
append
(
cast
(
PromptType
,
prompt
))
return
prompts
,
duration
return
prompts
,
duration
async
def
_create_speech_to_text
(
async
def
_create_speech_to_text
(
...
@@ -196,7 +184,8 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -196,7 +184,8 @@ class OpenAISpeechToText(OpenAIServing):
self
.
_log_inputs
(
self
.
_log_inputs
(
request_id
,
request_id
,
prompts
[
0
][
'decoder_prompt'
],
# type: ignore
# It will not display special tokens like <|startoftranscript|>
request
.
prompt
,
params
=
sampling_params
,
params
=
sampling_params
,
lora_request
=
None
,
lora_request
=
None
,
prompt_adapter_request
=
None
)
prompt_adapter_request
=
None
)
...
@@ -261,17 +250,11 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -261,17 +250,11 @@ class OpenAISpeechToText(OpenAIServing):
async
for
res
in
result_generator
:
async
for
res
in
result_generator
:
# On first result.
# On first result.
if
res
.
prompt_token_ids
is
not
None
:
if
res
.
prompt_token_ids
is
not
None
:
# Do not account the 4-tokens `<|startoftranscript|>..`
num_prompt_tokens
=
len
(
res
.
prompt_token_ids
)
# Could be negative when language token
if
audio_tokens
:
=
self
.
model_cls
.
get_num_audio_tokens
(
# is not specified.
audio_duration_s
,
self
.
asr_config
,
num_prompt_tokens
=
max
(
self
.
model_config
):
len
(
res
.
prompt_token_ids
)
-
4
,
0
)
num_prompt_tokens
+=
audio_tokens
# NOTE(NickLucche) user can't pass encoder
# prompts directly at least not to Whisper.
# One indicator of the encoder amount of processing
# is the log-mel spectogram length.
num_prompt_tokens
+=
ceil
(
audio_duration_s
*
self
.
model_sr
/
self
.
hop_length
)
# We need to do it here, because if there are exceptions in
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# the result_generator, it needs to be sent as the FIRST
...
@@ -347,8 +330,8 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -347,8 +330,8 @@ class OpenAISpeechToText(OpenAIServing):
def
_split_audio
(
self
,
audio_data
:
np
.
ndarray
,
def
_split_audio
(
self
,
audio_data
:
np
.
ndarray
,
sample_rate
:
int
)
->
list
[
np
.
ndarray
]:
sample_rate
:
int
)
->
list
[
np
.
ndarray
]:
chunk_size
=
sample_rate
*
self
.
max_audio_clip_s
chunk_size
=
sample_rate
*
self
.
asr_config
.
max_audio_clip_s
overlap_size
=
sample_rate
*
OVERLAP_CHUNK_SECOND
overlap_size
=
sample_rate
*
self
.
asr_config
.
overlap_chunk_second
chunks
=
[]
chunks
=
[]
i
=
0
i
=
0
while
i
<
audio_data
.
shape
[
-
1
]:
while
i
<
audio_data
.
shape
[
-
1
]:
...
@@ -384,10 +367,10 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -384,10 +367,10 @@ class OpenAISpeechToText(OpenAIServing):
# Calculate RMS energy in small windows
# Calculate RMS energy in small windows
min_energy
=
math
.
inf
min_energy
=
math
.
inf
quietest_idx
=
0
quietest_idx
=
0
for
i
in
range
(
0
,
min_energy_window
=
self
.
asr_config
.
min_energy_split_window_size
len
(
segment
)
-
MIN_ENERGY_WINDOW_SIZE
,
assert
min_energy_window
is
not
None
MIN_ENERGY_WINDOW_SIZE
):
for
i
in
range
(
0
,
len
(
segment
)
-
min_energy_window
,
min_energy_window
):
window
=
segment
[
i
:
i
+
MIN_ENERGY_WINDOW_SIZE
]
window
=
segment
[
i
:
i
+
min_energy_window
]
energy
=
(
window
**
2
).
mean
()
**
0.5
energy
=
(
window
**
2
).
mean
()
**
0.5
if
energy
<
min_energy
:
if
energy
<
min_energy
:
quietest_idx
=
i
+
start_idx
quietest_idx
=
i
+
start_idx
...
...
vllm/model_executor/models/interfaces.py
View file @
3c7d942d
...
@@ -5,11 +5,14 @@ from collections.abc import Iterable, MutableSequence
...
@@ -5,11 +5,14 @@ from collections.abc import Iterable, MutableSequence
from
typing
import
(
TYPE_CHECKING
,
ClassVar
,
Literal
,
Optional
,
Protocol
,
from
typing
import
(
TYPE_CHECKING
,
ClassVar
,
Literal
,
Optional
,
Protocol
,
Union
,
overload
,
runtime_checkable
)
Union
,
overload
,
runtime_checkable
)
import
numpy
as
np
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
typing_extensions
import
Self
,
TypeIs
from
typing_extensions
import
Self
,
TypeIs
from
vllm.config
import
ModelConfig
,
SpeechToTextConfig
from
vllm.inputs
import
TokensPrompt
from
vllm.inputs
import
TokensPrompt
from
vllm.inputs.data
import
PromptType
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
...
@@ -692,9 +695,13 @@ class SupportsTranscription(Protocol):
...
@@ -692,9 +695,13 @@ class SupportsTranscription(Protocol):
supports_transcription
:
ClassVar
[
Literal
[
True
]]
=
True
supports_transcription
:
ClassVar
[
Literal
[
True
]]
=
True
@
classmethod
@
classmethod
def
get_decoder_prompt
(
cls
,
language
:
str
,
task_type
:
str
,
def
get_generation_prompt
(
cls
,
audio
:
np
.
ndarray
,
prompt
:
str
)
->
str
:
stt_config
:
SpeechToTextConfig
,
language
:
str
,
"""Get the decoder prompt for the ASR model."""
task_type
:
str
,
request_prompt
:
str
)
->
PromptType
:
"""Get the prompt for the ASR model.
The model has control over the construction, as long as it
returns a valid PromptType."""
...
...
@
classmethod
@
classmethod
...
@@ -702,6 +709,25 @@ class SupportsTranscription(Protocol):
...
@@ -702,6 +709,25 @@ class SupportsTranscription(Protocol):
"""Check if the model supports a specific ISO639_1 language."""
"""Check if the model supports a specific ISO639_1 language."""
...
...
@
classmethod
def
get_speech_to_text_config
(
cls
,
model_config
:
ModelConfig
,
task_type
:
Literal
[
"transcribe"
,
"translate"
])
->
SpeechToTextConfig
:
"""Get the speech to text config for the ASR model."""
...
@
classmethod
def
get_num_audio_tokens
(
cls
,
audio_duration_s
:
float
,
stt_config
:
SpeechToTextConfig
,
model_config
:
ModelConfig
)
->
Optional
[
int
]:
"""
Map from audio duration to number of audio tokens produced by the ASR
model, without running a forward pass.
This is used for estimating the amount of processing for this audio.
"""
return
None
@
overload
@
overload
def
supports_transcription
(
def
supports_transcription
(
...
...
vllm/model_executor/models/whisper.py
View file @
3c7d942d
...
@@ -3,8 +3,9 @@
...
@@ -3,8 +3,9 @@
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Optional
,
TypedDict
,
Union
from
typing
import
Optional
,
TypedDict
,
Union
,
cast
import
numpy
as
np
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
(
BatchFeature
,
WhisperConfig
,
WhisperFeatureExtractor
,
from
transformers
import
(
BatchFeature
,
WhisperConfig
,
WhisperFeatureExtractor
,
...
@@ -12,8 +13,10 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
...
@@ -12,8 +13,10 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
from
transformers.models.whisper.modeling_whisper
import
sinusoids
from
transformers.models.whisper.modeling_whisper
import
sinusoids
from
vllm.attention
import
Attention
,
AttentionType
from
vllm.attention
import
Attention
,
AttentionType
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
SpeechToTextConfig
,
VllmConfig
)
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.inputs.data
import
PromptType
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
@@ -33,6 +36,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
...
@@ -33,6 +36,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor
,
EncDecMultiModalProcessor
,
PromptReplacement
,
PromptUpdate
)
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.transformers_utils.processor
import
cached_get_processor
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsMultiModal
,
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsTranscription
,
SupportsV0Only
)
SupportsTranscription
,
SupportsV0Only
)
...
@@ -785,11 +789,24 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
...
@@ -785,11 +789,24 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
f
"or
{
list
(
ISO639_1_OTHER_LANGS
.
values
())
}
"
)
f
"or
{
list
(
ISO639_1_OTHER_LANGS
.
values
())
}
"
)
@
classmethod
@
classmethod
def
get_decoder_prompt
(
cls
,
language
:
str
,
task_type
:
str
,
def
get_generation_prompt
(
cls
,
audio
:
np
.
ndarray
,
prompt
:
str
)
->
str
:
stt_config
:
SpeechToTextConfig
,
language
:
str
,
return
((
f
"<|prev|>
{
prompt
}
"
if
prompt
else
""
)
+
task_type
:
str
,
f
"<|startoftranscript|><|
{
language
}
|>"
+
request_prompt
:
str
)
->
PromptType
:
f
"<|
{
task_type
}
|><|notimestamps|>"
)
prompt
=
{
"encoder_prompt"
:
{
# Whisper does not support encoder prompt.
"prompt"
:
""
,
"multi_modal_data"
:
{
"audio"
:
(
audio
,
stt_config
.
sample_rate
),
},
},
"decoder_prompt"
:
((
f
"<|prev|>
{
request_prompt
}
"
if
request_prompt
else
""
)
+
f
"<|startoftranscript|><|
{
language
}
|>"
+
f
"<|
{
task_type
}
|><|notimestamps|>"
)
}
return
cast
(
PromptType
,
prompt
)
@
classmethod
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
Optional
[
str
]:
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
Optional
[
str
]:
...
@@ -798,6 +815,30 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
...
@@ -798,6 +815,30 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
raise
ValueError
(
"Only audio modality is supported"
)
raise
ValueError
(
"Only audio modality is supported"
)
@
classmethod
def
get_speech_to_text_config
(
cls
,
model_config
:
ModelConfig
,
task_type
:
str
)
->
SpeechToTextConfig
:
processor
=
cached_get_processor
(
model_config
.
model
)
return
SpeechToTextConfig
(
max_audio_clip_s
=
processor
.
feature_extractor
.
chunk_length
,
sample_rate
=
processor
.
feature_extractor
.
sampling_rate
,
)
@
classmethod
def
get_num_audio_tokens
(
cls
,
audio_duration_s
:
float
,
stt_config
:
SpeechToTextConfig
,
model_config
:
ModelConfig
)
->
Optional
[
int
]:
processor
=
cached_get_processor
(
model_config
.
model
)
hop_length
=
processor
.
feature_extractor
.
hop_length
assert
hop_length
is
not
None
# NOTE(NickLucche) user can't pass encoder
# prompts directly at least not to Whisper.
# One indicator of the encoder amount of processing
# is the log-mel spectogram length.
return
math
.
ceil
(
audio_duration_s
*
stt_config
.
sample_rate
/
hop_length
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment