Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0a74e9d0
Unverified
Commit
0a74e9d0
authored
Sep 02, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Sep 02, 2025
Browse files
[Gemma3n] Fix audio batching (#24052)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
8bd58449
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
7 deletions
+63
-7
examples/online_serving/openai_chat_completion_client_for_multimodal.py
...e_serving/openai_chat_completion_client_for_multimodal.py
+42
-0
vllm/model_executor/models/gemma3n_mm.py
vllm/model_executor/models/gemma3n_mm.py
+21
-7
No files found.
examples/online_serving/openai_chat_completion_client_for_multimodal.py
View file @
0a74e9d0
...
...
@@ -266,10 +266,52 @@ def run_audio(model: str) -> None:
print
(
"Chat completion output from base64 encoded audio:"
,
result
)
def
run_multi_audio
(
model
:
str
)
->
None
:
from
vllm.assets.audio
import
AudioAsset
# Two different audios to showcase batched inference.
audio_url
=
AudioAsset
(
"winning_call"
).
url
audio_base64
=
encode_base64_content_from_url
(
audio_url
)
audio_url2
=
AudioAsset
(
"azacinto_foscolo"
).
url
audio_base64_2
=
encode_base64_content_from_url
(
audio_url2
)
# OpenAI-compatible schema (`input_audio`)
chat_completion_from_base64
=
client
.
chat
.
completions
.
create
(
messages
=
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
"Are these two audios the same?"
},
{
"type"
:
"input_audio"
,
"input_audio"
:
{
"data"
:
audio_base64
,
"format"
:
"wav"
,
},
},
{
"type"
:
"input_audio"
,
"input_audio"
:
{
"data"
:
audio_base64_2
,
"format"
:
"wav"
,
},
},
],
}
],
model
=
model
,
max_completion_tokens
=
64
,
)
result
=
chat_completion_from_base64
.
choices
[
0
].
message
.
content
print
(
"Chat completion output from input audio:"
,
result
)
example_function_map
=
{
"text-only"
:
run_text_only
,
"single-image"
:
run_single_image
,
"multi-image"
:
run_multi_image
,
"multi-audio"
:
run_multi_audio
,
"video"
:
run_video
,
"audio"
:
run_audio
,
}
...
...
vllm/model_executor/models/gemma3n_mm.py
View file @
0a74e9d0
...
...
@@ -5,6 +5,7 @@ from typing import Any, Literal, Optional, TypedDict, Union, cast
import
numpy
as
np
import
torch
# yapf: disable
from
torch
import
nn
from
transformers
import
AutoModel
,
BatchFeature
from
transformers.models.gemma3n
import
(
Gemma3nAudioConfig
,
...
...
@@ -30,7 +31,6 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems
)
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
MultiModalDataItems
,
MultiModalDataParser
)
# yapf: disable
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
MultiModalPromptUpdates
,
...
...
@@ -62,7 +62,8 @@ class Gemma3nImagePixelInputs(TypedDict):
class
Gemma3nAudioInputs
(
TypedDict
):
input_features
:
torch
.
Tensor
input_features
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
input_features_padded
:
torch
.
Tensor
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
input_features_mask
:
torch
.
Tensor
"""Shape: `(batch_size * num_audio, seq_length)`"""
...
...
@@ -188,8 +189,13 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
mm_kwargs
,
tok_kwargs
,
)
if
'input_features'
in
processed_outputs
:
# Avoid padding since we need the output of each item to be
# Padding enables audio_tower to run in batched mode
processed_outputs
[
"input_features_padded"
]
=
\
processed_outputs
[
"input_features"
]
# Unpad features here since we need the output of each item to be
# independent of other items for the cache to work correctly
unpadded_features
=
[
f
[
mask
]
for
f
,
mask
in
zip
(
...
...
@@ -206,9 +212,11 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
input_features
=
MultiModalFieldConfig
.
batched
(
"audio"
),
input_features_mask
=
MultiModalFieldConfig
.
batched
(
"audio"
))
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
input_features
=
MultiModalFieldConfig
.
batched
(
"audio"
),
input_features_padded
=
MultiModalFieldConfig
.
batched
(
"audio"
),
input_features_mask
=
MultiModalFieldConfig
.
batched
(
"audio"
))
def
_get_prompt_updates
(
self
,
...
...
@@ -516,9 +524,14 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
if
input_features_mask
is
None
:
return
None
input_features_padded
=
kwargs
.
pop
(
"input_features_padded"
,
None
)
if
input_features_padded
is
None
:
return
None
return
Gemma3nAudioInputs
(
input_features
=
input_features
,
input_features_mask
=
input_features_mask
,
input_features_padded
=
input_features_padded
,
)
def
_parse_and_validate_multimodal_inputs
(
self
,
**
kwargs
:
object
)
->
dict
:
...
...
@@ -564,7 +577,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
audio_input
:
Gemma3nAudioInputs
,
)
->
list
[
torch
.
Tensor
]:
assert
self
.
audio_tower
is
not
None
input_features
=
audio_input
[
"input_features"
].
squeeze
(
1
)
# Run on padded features to enable batching
input_features
=
audio_input
[
"input_features_padded"
].
squeeze
(
1
)
input_features_mask
=
audio_input
[
"input_features_mask"
].
squeeze
(
1
)
audio_outputs
,
audio_mask
=
self
.
audio_tower
(
input_features
,
~
input_features_mask
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment