Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b7cbc254
Unverified
Commit
b7cbc254
authored
Nov 05, 2025
by
Alex Brooks
Committed by
GitHub
Nov 05, 2025
Browse files
[Model, Core] Support Granite Speech & LoRA for STT (#24455)
parent
d43ad5a7
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
169 additions
and
8 deletions
+169
-8
docs/models/supported_models.md
docs/models/supported_models.md
+1
-0
tests/entrypoints/openai/test_transcription_validation.py
tests/entrypoints/openai/test_transcription_validation.py
+35
-0
tests/entrypoints/openai/test_translation_validation.py
tests/entrypoints/openai/test_translation_validation.py
+34
-0
vllm/entrypoints/openai/speech_to_text.py
vllm/entrypoints/openai/speech_to_text.py
+2
-6
vllm/model_executor/models/granite_speech.py
vllm/model_executor/models/granite_speech.py
+97
-2
No files found.
docs/models/supported_models.md
View file @
b7cbc254
...
@@ -761,6 +761,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition.
...
@@ -761,6 +761,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition.
|
`WhisperForConditionalGeneration`
| Whisper |
`openai/whisper-small`
,
`openai/whisper-large-v3-turbo`
, etc. | | |
|
`WhisperForConditionalGeneration`
| Whisper |
`openai/whisper-small`
,
`openai/whisper-large-v3-turbo`
, etc. | | |
|
`VoxtralForConditionalGeneration`
| Voxtral (Mistral format) |
`mistralai/Voxtral-Mini-3B-2507`
,
`mistralai/Voxtral-Small-24B-2507`
, etc. | ✅︎ | ✅︎ |
|
`VoxtralForConditionalGeneration`
| Voxtral (Mistral format) |
`mistralai/Voxtral-Mini-3B-2507`
,
`mistralai/Voxtral-Small-24B-2507`
, etc. | ✅︎ | ✅︎ |
|
`Gemma3nForConditionalGeneration`
| Gemma3n |
`google/gemma-3n-E2B-it`
,
`google/gemma-3n-E4B-it`
, etc. | | |
|
`Gemma3nForConditionalGeneration`
| Gemma3n |
`google/gemma-3n-E2B-it`
,
`google/gemma-3n-E4B-it`
, etc. | | |
|
`GraniteSpeechForConditionalGeneration`
| Granite Speech |
`ibm-granite/granite-speech-3.3-2b`
,
`ibm-granite/granite-speech-3.3-8b`
, etc. | ✅︎ | ✅︎ |
### Pooling Models
### Pooling Models
...
...
tests/entrypoints/openai/test_transcription_validation.py
View file @
b7cbc254
...
@@ -65,6 +65,41 @@ async def test_basic_audio(mary_had_lamb, model_name):
...
@@ -65,6 +65,41 @@ async def test_basic_audio(mary_had_lamb, model_name):
assert
out_usage
[
"seconds"
]
==
16
,
out_usage
[
"seconds"
]
assert
out_usage
[
"seconds"
]
==
16
,
out_usage
[
"seconds"
]
@
pytest
.
mark
.
asyncio
async
def
test_basic_audio_with_lora
(
mary_had_lamb
):
"""Ensure STT (transcribe) requests can pass LoRA through to generate."""
model_name
=
"ibm-granite/granite-speech-3.3-2b"
lora_model_name
=
"speech"
server_args
=
[
"--enforce-eager"
,
"--enable-lora"
,
"--max-lora-rank"
,
"64"
,
"--lora-modules"
,
f
"
{
lora_model_name
}
=
{
model_name
}
"
,
"--max-model-len"
,
"2048"
,
"--max-num-seqs"
,
"1"
,
]
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
with
RemoteOpenAIServer
(
model_name
,
server_args
)
as
remote_server
:
client
=
remote_server
.
get_async_client
()
transcription
=
await
client
.
audio
.
transcriptions
.
create
(
model
=
lora_model_name
,
file
=
mary_had_lamb
,
language
=
"en"
,
response_format
=
"text"
,
temperature
=
0.0
,
)
out
=
json
.
loads
(
transcription
)
out_text
=
out
[
"text"
]
out_usage
=
out
[
"usage"
]
assert
"mary had a little lamb"
in
out_text
assert
out_usage
[
"seconds"
]
==
16
,
out_usage
[
"seconds"
]
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_basic_audio_gemma
(
foscolo
):
async
def
test_basic_audio_gemma
(
foscolo
):
# Gemma accuracy on some of the audio samples we use is particularly bad,
# Gemma accuracy on some of the audio samples we use is particularly bad,
...
...
tests/entrypoints/openai/test_translation_validation.py
View file @
b7cbc254
...
@@ -48,6 +48,40 @@ async def test_non_asr_model(foscolo):
...
@@ -48,6 +48,40 @@ async def test_non_asr_model(foscolo):
assert
err
[
"message"
]
==
"The model does not support Translations API"
assert
err
[
"message"
]
==
"The model does not support Translations API"
@
pytest
.
mark
.
asyncio
async
def
test_basic_audio_with_lora
(
mary_had_lamb
):
"""Ensure STT (translate) requests can pass LoRA through to generate."""
# NOTE - careful to call this test before the module scoped server
# fixture, otherwise it'll OOMkill the CI
model_name
=
"ibm-granite/granite-speech-3.3-2b"
lora_model_name
=
"speech"
server_args
=
[
"--enforce-eager"
,
"--enable-lora"
,
"--max-lora-rank"
,
"64"
,
"--lora-modules"
,
f
"
{
lora_model_name
}
=
{
model_name
}
"
,
"--max-model-len"
,
"2048"
,
"--max-num-seqs"
,
"1"
,
]
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
with
RemoteOpenAIServer
(
model_name
,
server_args
)
as
remote_server
:
client
=
remote_server
.
get_async_client
()
translation
=
await
client
.
audio
.
translations
.
create
(
model
=
lora_model_name
,
file
=
mary_had_lamb
,
extra_body
=
dict
(
language
=
"en"
,
to_language
=
"es"
),
response_format
=
"text"
,
temperature
=
0.0
,
)
out
=
json
.
loads
(
translation
)[
"text"
].
strip
().
lower
()
assert
"mary tenía un pequeño cordero"
in
out
# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation!
# NOTE: (NickLucche) the large-v3-turbo model was not trained on translation!
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_basic_audio
(
foscolo
,
client_and_model
):
async
def
test_basic_audio
(
foscolo
,
client_and_model
):
...
...
vllm/entrypoints/openai/speech_to_text.py
View file @
b7cbc254
...
@@ -170,11 +170,6 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -170,11 +170,6 @@ class OpenAISpeechToText(OpenAIServing):
try
:
try
:
lora_request
=
self
.
_maybe_get_adapters
(
request
)
lora_request
=
self
.
_maybe_get_adapters
(
request
)
if
lora_request
:
return
self
.
create_error_response
(
f
"Currently do not support LoRA for
{
self
.
task_type
.
title
()
}
."
)
prompts
,
duration_s
=
await
self
.
_preprocess_speech_to_text
(
prompts
,
duration_s
=
await
self
.
_preprocess_speech_to_text
(
request
=
request
,
request
=
request
,
audio_data
=
audio_data
,
audio_data
=
audio_data
,
...
@@ -199,7 +194,7 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -199,7 +194,7 @@ class OpenAISpeechToText(OpenAIServing):
# It will not display special tokens like <|startoftranscript|>
# It will not display special tokens like <|startoftranscript|>
request
.
prompt
,
request
.
prompt
,
params
=
sampling_params
,
params
=
sampling_params
,
lora_request
=
None
,
lora_request
=
lora_request
,
)
)
list_result_generator
=
[
list_result_generator
=
[
...
@@ -207,6 +202,7 @@ class OpenAISpeechToText(OpenAIServing):
...
@@ -207,6 +202,7 @@ class OpenAISpeechToText(OpenAIServing):
prompt
,
prompt
,
sampling_params
,
sampling_params
,
request_id
,
request_id
,
lora_request
=
lora_request
,
)
)
for
prompt
in
prompts
for
prompt
in
prompts
]
]
...
...
vllm/model_executor/models/granite_speech.py
View file @
b7cbc254
...
@@ -26,15 +26,17 @@
...
@@ -26,15 +26,17 @@
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
from
collections.abc
import
Iterable
,
Mapping
from
typing
import
Annotated
from
typing
import
Annotated
,
Literal
,
cast
import
numpy
as
np
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
transformers
import
BatchFeature
,
PretrainedConfig
from
transformers
import
BatchFeature
,
PretrainedConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SpeechToTextConfig
,
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.inputs.data
import
PromptType
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
,
RowParallelLinear
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
,
RowParallelLinear
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
...
@@ -57,6 +59,8 @@ from vllm.multimodal.processing import (
...
@@ -57,6 +59,8 @@ from vllm.multimodal.processing import (
)
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.processor
import
cached_get_processor
from
vllm.transformers_utils.tokenizer
import
cached_get_tokenizer
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.blip2
import
Blip2QFormerModel
from
.blip2
import
Blip2QFormerModel
...
@@ -65,9 +69,22 @@ from .interfaces import (
...
@@ -65,9 +69,22 @@ from .interfaces import (
SupportsLoRA
,
SupportsLoRA
,
SupportsMultiModal
,
SupportsMultiModal
,
SupportsPP
,
SupportsPP
,
SupportsTranscription
,
)
)
from
.utils
import
AutoWeightsLoader
,
init_vllm_registered_model
,
maybe_prefix
from
.utils
import
AutoWeightsLoader
,
init_vllm_registered_model
,
maybe_prefix
# NOTE lang support is based on what is written here:
# https://huggingface.co/ibm-granite/granite-speech-3.3-2b
# Though this may vary from model to model, and also many langs
# work pretty well with zero shot.
ISO639_1_SUPPORTED_LANGS
=
{
"en"
:
"English"
,
"fr"
:
"French"
,
"de"
:
"German"
,
"pt"
:
"Portuguese"
,
"es"
:
"Spanish"
,
}
### Audio Input
### Audio Input
class
GraniteSpeechAudioInputs
(
TensorSchema
):
class
GraniteSpeechAudioInputs
(
TensorSchema
):
...
@@ -545,8 +562,10 @@ class GraniteSpeechForConditionalGeneration(
...
@@ -545,8 +562,10 @@ class GraniteSpeechForConditionalGeneration(
SupportsMultiModal
,
SupportsMultiModal
,
SupportsPP
,
SupportsPP
,
SupportsLoRA
,
SupportsLoRA
,
SupportsTranscription
,
):
):
merge_by_field_config
=
True
merge_by_field_config
=
True
supported_languages
=
ISO639_1_SUPPORTED_LANGS
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
:
[
...
@@ -816,3 +835,79 @@ class GraniteSpeechForConditionalGeneration(
...
@@ -816,3 +835,79 @@ class GraniteSpeechForConditionalGeneration(
connector
=
"projector"
,
connector
=
"projector"
,
tower_model
=
"encoder"
,
tower_model
=
"encoder"
,
)
)
### Support for speech-to-text Transcription
@
classmethod
def
get_generation_prompt
(
cls
,
audio
:
np
.
ndarray
,
model_config
:
ModelConfig
,
stt_config
:
SpeechToTextConfig
,
language
:
str
|
None
,
task_type
:
Literal
[
"transcribe"
,
"translate"
],
request_prompt
:
str
,
to_language
:
str
|
None
,
)
->
PromptType
:
"""Get the generation prompt to be used for transcription requests."""
# Audio placeholders don't use an index, so value doesn't matter
audio_tok
=
cls
.
get_placeholder_str
(
"audio"
,
0
)
if
task_type
==
"translate"
:
full_lang_name_to
=
cls
.
supported_languages
.
get
(
to_language
,
to_language
)
user_prompt
=
f
"
{
audio_tok
}
translate the speech to
{
full_lang_name_to
}
"
# noqa: E501
elif
task_type
==
"transcribe"
:
user_prompt
=
(
f
"
{
audio_tok
}
can you transcribe the speech into a written format?"
# noqa: E501
)
else
:
raise
ValueError
(
f
"Unsupported task type
{
task_type
}
"
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
model
)
chat
=
[
dict
(
role
=
"user"
,
content
=
user_prompt
)]
prompt
=
tokenizer
.
apply_chat_template
(
chat
,
tokenize
=
False
,
add_generation_prompt
=
True
,
)
prompt_token_ids
=
tokenizer
.
encode
(
prompt
)
prompt
=
{
"prompt_token_ids"
:
prompt_token_ids
,
"multi_modal_data"
:
{
"audio"
:
audio
},
}
return
cast
(
PromptType
,
prompt
)
# Adapted from https://github.com/huggingface/transformers/blob/v4.56.0/src/transformers/models/granite_speech/feature_extraction_granite_speech.py#L122 # noqa: E501
@
classmethod
def
get_num_audio_tokens
(
cls
,
audio_duration_s
:
float
,
stt_config
:
SpeechToTextConfig
,
model_config
:
ModelConfig
,
)
->
int
|
None
:
"""Get the number of audio tokens for an audio duration in sec."""
processor
=
cached_get_processor
(
model_config
.
model
)
hop_length
=
processor
.
audio_processor
.
melspec_kwargs
[
"hop_length"
]
proj_win_size
=
processor
.
audio_processor
.
projector_window_size
ds_rate
=
processor
.
audio_processor
.
projector_downsample_rate
effective_window_size
=
proj_win_size
//
ds_rate
raw_length
=
audio_duration_s
*
stt_config
.
sample_rate
# mel sequence length computation
mel_length
=
raw_length
//
hop_length
+
1
# encoder frame takes two mel features
encoder_length
=
mel_length
//
2
nblocks
=
math
.
ceil
(
encoder_length
/
proj_win_size
)
# projector output length
return
nblocks
*
effective_window_size
@
classmethod
def
get_speech_to_text_config
(
cls
,
model_config
:
ModelConfig
,
task_type
:
str
)
->
SpeechToTextConfig
:
"""Get the stt config for this model."""
# Default settings are reasonable for this model and we don't currently
# expose this information in the model configs, but this may change in
# the future
return
SpeechToTextConfig
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment