Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
daceac57
Unverified
Commit
daceac57
authored
Jun 28, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Jun 28, 2025
Browse files
[Frontend] Generalize `v1/audio/transcriptions` endpoint (#20179)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
8615d977
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
154 additions
and
128 deletions
+154
-128
vllm/entrypoints/openai/speech_to_text.py
vllm/entrypoints/openai/speech_to_text.py
+14
-128
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+11
-0
vllm/model_executor/models/whisper.py
vllm/model_executor/models/whisper.py
+129
-0
No files found.
vllm/entrypoints/openai/speech_to_text.py
View file @
daceac57
...
@@ -24,6 +24,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
...
@@ -24,6 +24,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.inputs.data
import
PromptType
from
vllm.inputs.data
import
PromptType
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.utils
import
get_model_architecture
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.transformers_utils.processor
import
cached_get_processor
from
vllm.transformers_utils.processor
import
cached_get_processor
from
vllm.utils
import
PlaceholderModule
from
vllm.utils
import
PlaceholderModule
...
@@ -38,118 +39,10 @@ T = TypeVar("T", bound=SpeechToTextResponse)
...
@@ -38,118 +39,10 @@ T = TypeVar("T", bound=SpeechToTextResponse)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
# TODO these configs should live somewhere with the model so we can support
# additional ones
ISO639_1_SUPPORTED_LANGS
=
{
"af"
:
"Afrikaans"
,
"ar"
:
"Arabic"
,
"hy"
:
"Armenian"
,
"az"
:
"Azerbaijani"
,
"be"
:
"Belarusian"
,
"bs"
:
"Bosnian"
,
"bg"
:
"Bulgarian"
,
"ca"
:
"Catalan"
,
"zh"
:
"Chinese"
,
"hr"
:
"Croatian"
,
"cs"
:
"Czech"
,
"da"
:
"Danish"
,
"nl"
:
"Dutch"
,
"en"
:
"English"
,
"et"
:
"Estonian"
,
"fi"
:
"Finnish"
,
"fr"
:
"French"
,
"gl"
:
"Galician"
,
"de"
:
"German"
,
"el"
:
"Greek"
,
"he"
:
"Hebrew"
,
"hi"
:
"Hindi"
,
"hu"
:
"Hungarian"
,
"is"
:
"Icelandic"
,
"id"
:
"Indonesian"
,
"it"
:
"Italian"
,
"ja"
:
"Japanese"
,
"kn"
:
"Kannada"
,
"kk"
:
"Kazakh"
,
"ko"
:
"Korean"
,
"lv"
:
"Latvian"
,
"lt"
:
"Lithuanian"
,
"mk"
:
"Macedonian"
,
"ms"
:
"Malay"
,
"mr"
:
"Marathi"
,
"mi"
:
"Maori"
,
"ne"
:
"Nepali"
,
"no"
:
"Norwegian"
,
"fa"
:
"Persian"
,
"pl"
:
"Polish"
,
"pt"
:
"Portuguese"
,
"ro"
:
"Romanian"
,
"ru"
:
"Russian"
,
"sr"
:
"Serbian"
,
"sk"
:
"Slovak"
,
"sl"
:
"Slovenian"
,
"es"
:
"Spanish"
,
"sw"
:
"Swahili"
,
"sv"
:
"Swedish"
,
"tl"
:
"Tagalog"
,
"ta"
:
"Tamil"
,
"th"
:
"Thai"
,
"tr"
:
"Turkish"
,
"uk"
:
"Ukrainian"
,
"ur"
:
"Urdu"
,
"vi"
:
"Vietnamese"
,
"cy"
:
"Welsh"
}
ISO639_1_OTHER_LANGS
=
{
"lo"
:
"Lao"
,
"jw"
:
"Javanese"
,
"tk"
:
"Turkmen"
,
"yi"
:
"Yiddish"
,
"so"
:
"Somali"
,
"bn"
:
"Bengali"
,
"nn"
:
"Norwegian Nynorsk"
,
"si"
:
"Sinhala"
,
"yo"
:
"Yoruba"
,
"sa"
:
"Sanskrit"
,
"mi"
:
"Māori"
,
"fo"
:
"Faroese"
,
# codespell:ignore
"mt"
:
"Maltese"
,
"tg"
:
"Tajik"
,
"mg"
:
"Malagasy"
,
"haw"
:
"Hawaiian"
,
"km"
:
"Khmer"
,
"br"
:
"Breton"
,
"ps"
:
"Pashto"
,
"ln"
:
"Lingala"
,
"la"
:
"Latin"
,
"ml"
:
"Malayalam"
,
"sq"
:
"Albanian"
,
"su"
:
"Sundanese"
,
"eu"
:
"Basque"
,
"ka"
:
"Georgian"
,
"uz"
:
"Uzbek"
,
"sn"
:
"Shona"
,
"ht"
:
"Haitian"
,
"as"
:
"Assamese"
,
"mn"
:
"Mongolian"
,
"te"
:
"Telugu"
,
"pa"
:
"Panjabi"
,
"tt"
:
"Tatar"
,
"gu"
:
"Gujarati"
,
"oc"
:
"Occitan"
,
"ha"
:
"Hausa"
,
"ba"
:
"Bashkir"
,
"my"
:
"Burmese"
,
"sd"
:
"Sindhi"
,
"am"
:
"Amharic"
,
"lb"
:
"Luxembourgish"
,
"bo"
:
"Tibetan"
}
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
# TODO configurable
# TODO configurable
MAX_AUDIO_CLIP_FILESIZE_MB
=
25
MAX_AUDIO_CLIP_FILESIZE_MB
=
25
MAX_AUDIO_CLIP_SECONDS
=
30
OVERLAP_CHUNK_SECOND
=
1
OVERLAP_CHUNK_SECOND
=
1
MIN_ENERGY_WINDOW_SIZE
=
1600
# 1600 ~ 100ms for 16000 Hz audio
MIN_ENERGY_WINDOW_SIZE
=
1600
# 1600 ~ 100ms for 16000 Hz audio
...
@@ -177,10 +70,13 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -177,10 +70,13 @@ class OpenAISpeechToText(OpenAIServing):
self
.
default_sampling_params
=
(
self
.
default_sampling_params
=
(
self
.
model_config
.
get_diff_sampling_param
())
self
.
model_config
.
get_diff_sampling_param
())
processor
=
cached_get_processor
(
model_config
.
model
)
processor
=
cached_get_processor
(
model_config
.
model
)
self
.
max_audio_clip_s
=
processor
.
feature_extractor
.
chunk_length
self
.
max_audio_clip_s
=
processor
.
feature_extractor
.
chunk_length
\
if
hasattr
(
processor
.
feature_extractor
,
'chunk_length'
)
\
else
MAX_AUDIO_CLIP_SECONDS
self
.
model_sr
=
processor
.
feature_extractor
.
sampling_rate
self
.
model_sr
=
processor
.
feature_extractor
.
sampling_rate
self
.
hop_length
=
processor
.
feature_extractor
.
hop_length
self
.
hop_length
=
processor
.
feature_extractor
.
hop_length
self
.
task_type
=
task_type
self
.
task_type
=
task_type
self
.
model_cls
,
_
=
get_model_architecture
(
model_config
)
if
self
.
default_sampling_params
:
if
self
.
default_sampling_params
:
logger
.
info
(
logger
.
info
(
...
@@ -196,21 +92,8 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -196,21 +92,8 @@ class OpenAISpeechToText(OpenAIServing):
# TODO language should be optional and can be guessed.
# TODO language should be optional and can be guessed.
# For now we default to en. See
# For now we default to en. See
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
lang_token
=
f
"<|
{
request
.
language
}
|>"
if
request
.
language
else
"<|en|>"
lang
=
request
.
language
or
"en"
if
request
.
language
:
self
.
model_cls
.
validate_language
(
lang
)
# type: ignore[attr-defined]
if
request
.
language
in
ISO639_1_SUPPORTED_LANGS
:
pass
elif
request
.
language
in
ISO639_1_OTHER_LANGS
:
logger
.
warning
(
"The selected language %s has limited accuracy with"
" reported WER>=0.5. Results may be less accurate "
"for this choice."
,
request
.
language
)
else
:
raise
ValueError
(
f
"Unsupported language:
{
request
.
language
}
."
"Language should be one of:"
+
f
"
{
list
(
ISO639_1_SUPPORTED_LANGS
.
values
())
}
"
+
f
"or
{
list
(
ISO639_1_OTHER_LANGS
.
values
())
}
"
)
if
len
(
audio_data
)
/
1024
**
2
>
MAX_AUDIO_CLIP_FILESIZE_MB
:
if
len
(
audio_data
)
/
1024
**
2
>
MAX_AUDIO_CLIP_FILESIZE_MB
:
raise
ValueError
(
"Maximum file size exceeded."
)
raise
ValueError
(
"Maximum file size exceeded."
)
...
@@ -221,7 +104,9 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -221,7 +104,9 @@ class OpenAISpeechToText(OpenAIServing):
y
,
sr
=
librosa
.
load
(
bytes_
,
sr
=
self
.
model_sr
)
y
,
sr
=
librosa
.
load
(
bytes_
,
sr
=
self
.
model_sr
)
duration
=
librosa
.
get_duration
(
y
=
y
,
sr
=
sr
)
duration
=
librosa
.
get_duration
(
y
=
y
,
sr
=
sr
)
chunks
=
[
y
]
if
duration
<
30
else
self
.
_split_audio
(
y
,
int
(
sr
))
chunks
=
[
y
]
if
duration
<
self
.
max_audio_clip_s
else
self
.
_split_audio
(
y
,
int
(
sr
))
prompts
=
[]
prompts
=
[]
for
chunk
in
chunks
:
for
chunk
in
chunks
:
prompt
=
{
prompt
=
{
...
@@ -232,8 +117,9 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -232,8 +117,9 @@ class OpenAISpeechToText(OpenAIServing):
},
},
},
},
"decoder_prompt"
:
"decoder_prompt"
:
(
f
"<|startoftranscript|>
{
lang_token
}
"
self
.
model_cls
.
f
"<|
{
self
.
task_type
}
|><|notimestamps|>
{
request
.
prompt
}
"
)
get_decoder_prompt
(
# type: ignore[attr-defined]
lang
,
self
.
task_type
,
request
.
prompt
)
}
}
prompts
.
append
(
cast
(
PromptType
,
prompt
))
prompts
.
append
(
cast
(
PromptType
,
prompt
))
return
prompts
,
duration
return
prompts
,
duration
...
...
vllm/model_executor/models/interfaces.py
View file @
daceac57
...
@@ -599,6 +599,17 @@ class SupportsTranscription(Protocol):
...
@@ -599,6 +599,17 @@ class SupportsTranscription(Protocol):
supports_transcription
:
ClassVar
[
Literal
[
True
]]
=
True
supports_transcription
:
ClassVar
[
Literal
[
True
]]
=
True
@
classmethod
def
get_decoder_prompt
(
cls
,
language
:
str
,
task_type
:
str
,
prompt
:
str
)
->
str
:
"""Get the decoder prompt for the ASR model."""
...
@
classmethod
def
validate_language
(
cls
,
language
:
str
)
->
bool
:
"""Check if the model supports a specific ISO639_1 language."""
...
@
overload
@
overload
def
supports_transcription
(
def
supports_transcription
(
...
...
vllm/model_executor/models/whisper.py
View file @
daceac57
...
@@ -41,6 +41,113 @@ from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
...
@@ -41,6 +41,113 @@ from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
ISO639_1_SUPPORTED_LANGS
=
{
"af"
:
"Afrikaans"
,
"ar"
:
"Arabic"
,
"hy"
:
"Armenian"
,
"az"
:
"Azerbaijani"
,
"be"
:
"Belarusian"
,
"bs"
:
"Bosnian"
,
"bg"
:
"Bulgarian"
,
"ca"
:
"Catalan"
,
"zh"
:
"Chinese"
,
"hr"
:
"Croatian"
,
"cs"
:
"Czech"
,
"da"
:
"Danish"
,
"nl"
:
"Dutch"
,
"en"
:
"English"
,
"et"
:
"Estonian"
,
"fi"
:
"Finnish"
,
"fr"
:
"French"
,
"gl"
:
"Galician"
,
"de"
:
"German"
,
"el"
:
"Greek"
,
"he"
:
"Hebrew"
,
"hi"
:
"Hindi"
,
"hu"
:
"Hungarian"
,
"is"
:
"Icelandic"
,
"id"
:
"Indonesian"
,
"it"
:
"Italian"
,
"ja"
:
"Japanese"
,
"kn"
:
"Kannada"
,
"kk"
:
"Kazakh"
,
"ko"
:
"Korean"
,
"lv"
:
"Latvian"
,
"lt"
:
"Lithuanian"
,
"mk"
:
"Macedonian"
,
"ms"
:
"Malay"
,
"mr"
:
"Marathi"
,
"mi"
:
"Maori"
,
"ne"
:
"Nepali"
,
"no"
:
"Norwegian"
,
"fa"
:
"Persian"
,
"pl"
:
"Polish"
,
"pt"
:
"Portuguese"
,
"ro"
:
"Romanian"
,
"ru"
:
"Russian"
,
"sr"
:
"Serbian"
,
"sk"
:
"Slovak"
,
"sl"
:
"Slovenian"
,
"es"
:
"Spanish"
,
"sw"
:
"Swahili"
,
"sv"
:
"Swedish"
,
"tl"
:
"Tagalog"
,
"ta"
:
"Tamil"
,
"th"
:
"Thai"
,
"tr"
:
"Turkish"
,
"uk"
:
"Ukrainian"
,
"ur"
:
"Urdu"
,
"vi"
:
"Vietnamese"
,
"cy"
:
"Welsh"
}
ISO639_1_OTHER_LANGS
=
{
"lo"
:
"Lao"
,
"jw"
:
"Javanese"
,
"tk"
:
"Turkmen"
,
"yi"
:
"Yiddish"
,
"so"
:
"Somali"
,
"bn"
:
"Bengali"
,
"nn"
:
"Norwegian Nynorsk"
,
"si"
:
"Sinhala"
,
"yo"
:
"Yoruba"
,
"sa"
:
"Sanskrit"
,
"mi"
:
"Māori"
,
"fo"
:
"Faroese"
,
# codespell:ignore
"mt"
:
"Maltese"
,
"tg"
:
"Tajik"
,
"mg"
:
"Malagasy"
,
"haw"
:
"Hawaiian"
,
"km"
:
"Khmer"
,
"br"
:
"Breton"
,
"ps"
:
"Pashto"
,
"ln"
:
"Lingala"
,
"la"
:
"Latin"
,
"ml"
:
"Malayalam"
,
"sq"
:
"Albanian"
,
"su"
:
"Sundanese"
,
"eu"
:
"Basque"
,
"ka"
:
"Georgian"
,
"uz"
:
"Uzbek"
,
"sn"
:
"Shona"
,
"ht"
:
"Haitian"
,
"as"
:
"Assamese"
,
"mn"
:
"Mongolian"
,
"te"
:
"Telugu"
,
"pa"
:
"Panjabi"
,
"tt"
:
"Tatar"
,
"gu"
:
"Gujarati"
,
"oc"
:
"Occitan"
,
"ha"
:
"Hausa"
,
"ba"
:
"Bashkir"
,
"my"
:
"Burmese"
,
"sd"
:
"Sindhi"
,
"am"
:
"Amharic"
,
"lb"
:
"Luxembourgish"
,
"bo"
:
"Tibetan"
}
class
WhisperAudioInputs
(
TypedDict
):
class
WhisperAudioInputs
(
TypedDict
):
input_features
:
NestedTensors
input_features
:
NestedTensors
...
@@ -731,6 +838,28 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
...
@@ -731,6 +838,28 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
weights
=
_create_fake_bias_for_k_proj
(
weights
)
weights
=
_create_fake_bias_for_k_proj
(
weights
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
@
classmethod
def
validate_language
(
cls
,
language
:
str
)
->
bool
:
if
language
in
ISO639_1_SUPPORTED_LANGS
:
return
True
elif
language
in
ISO639_1_OTHER_LANGS
:
logger
.
warning
(
"The selected language %s has limited accuracy with"
" reported WER>=0.5. Results may be less accurate "
"for this choice."
,
language
)
return
True
else
:
raise
ValueError
(
f
"Unsupported language:
{
language
}
."
"Language should be one of:"
+
f
"
{
list
(
ISO639_1_SUPPORTED_LANGS
.
values
())
}
"
+
f
"or
{
list
(
ISO639_1_OTHER_LANGS
.
values
())
}
"
)
@
classmethod
def
get_decoder_prompt
(
cls
,
language
:
str
,
task_type
:
str
,
prompt
:
str
)
->
str
:
return
(
f
"<|startoftranscript|><|
{
language
}
|><|
{
task_type
}
|>"
f
"<|notimestamps|>
{
prompt
}
"
)
def
_create_fake_bias_for_k_proj
(
def
_create_fake_bias_for_k_proj
(
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment