Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
07286ec5
Unverified
Commit
07286ec5
authored
Jan 10, 2026
by
Jeremy Teboul
Committed by
GitHub
Jan 10, 2026
Browse files
[Bugfix] Fix integer overflow in Gemma3n audio processing (#31657)
Signed-off-by:
Jeremy Teboul
<
jeremyte@meta.com
>
parent
14fc7a68
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
239 additions
and
32 deletions
+239
-32
tests/models/multimodal/processing/test_gemma3.py
tests/models/multimodal/processing/test_gemma3.py
+141
-1
vllm/model_executor/models/gemma3n_audio_utils.py
vllm/model_executor/models/gemma3n_audio_utils.py
+57
-0
vllm/model_executor/models/gemma3n_mm.py
vllm/model_executor/models/gemma3n_mm.py
+41
-31
No files found.
tests/models/multimodal/processing/test_gemma3.py
View file @
07286ec5
...
...
@@ -2,14 +2,154 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
vllm.model_executor.models.gemma3n_audio_utils
import
(
adjust_audio_features_to_expected_length
,
)
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
....conftest
import
ImageTestAssets
from
...utils
import
build_model_context
# Gemma3 (image) model
GEMMA3_MODEL_ID
=
"google/gemma-3-4b-it"
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"google/gemma-3-4b-it"
])
# Gemma3n (multimodal with audio) model
GEMMA3N_MODEL_ID
=
"google/gemma-3n-E2B-it"
# Expected audio tokens for Gemma3n (audio_soft_tokens_per_image)
GEMMA3N_EXPECTED_AUDIO_TOKENS
=
188
class
TestGemma3nAudioTensorLogic
:
"""CPU-based tests for Gemma3n audio feature tensor manipulation.
These tests validate the padding/truncation logic in
adjust_audio_features_to_expected_length() which fixes the
integer overflow in _process_audio_input when audio_seq_len > 188.
"""
def
test_padding_when_audio_short
(
self
):
"""Test that short audio is padded to expected length."""
batch_size
,
seq_len
,
embed_dim
=
1
,
100
,
256
expected_tokens
=
GEMMA3N_EXPECTED_AUDIO_TOKENS
audio_features
=
torch
.
randn
(
batch_size
,
seq_len
,
embed_dim
)
padding_embs
=
torch
.
zeros
(
1
,
1
,
embed_dim
)
result
,
tokens_truncated
=
adjust_audio_features_to_expected_length
(
audio_features
,
expected_tokens
,
padding_embs
)
assert
result
.
shape
==
(
batch_size
,
expected_tokens
,
embed_dim
)
assert
tokens_truncated
==
0
# First 100 tokens should be original, rest should be padding (zeros)
assert
torch
.
allclose
(
result
[:,
:
seq_len
,
:],
audio_features
)
assert
torch
.
allclose
(
result
[:,
seq_len
:,
:],
torch
.
zeros
(
batch_size
,
expected_tokens
-
seq_len
,
embed_dim
),
)
def
test_truncation_when_audio_long
(
self
):
"""Test that long audio is truncated to expected length.
This is the key test for the overflow fix. Previously, when
audio_seq_len > expected_tokens, the code would compute a negative
padding value causing: RuntimeError: numel: integer multiplication overflow
"""
batch_size
,
seq_len
,
embed_dim
=
1
,
192
,
256
# 192 > 188
expected_tokens
=
GEMMA3N_EXPECTED_AUDIO_TOKENS
audio_features
=
torch
.
randn
(
batch_size
,
seq_len
,
embed_dim
)
padding_embs
=
torch
.
zeros
(
1
,
1
,
embed_dim
)
result
,
tokens_truncated
=
adjust_audio_features_to_expected_length
(
audio_features
,
expected_tokens
,
padding_embs
)
assert
result
.
shape
==
(
batch_size
,
expected_tokens
,
embed_dim
)
assert
tokens_truncated
==
seq_len
-
expected_tokens
# 192 - 188 = 4
# Result should be first 188 tokens of original
assert
torch
.
allclose
(
result
,
audio_features
[:,
:
expected_tokens
,
:])
def
test_no_change_when_exact_length
(
self
):
"""Test that exact-length audio passes through unchanged."""
batch_size
,
embed_dim
=
1
,
256
expected_tokens
=
GEMMA3N_EXPECTED_AUDIO_TOKENS
audio_features
=
torch
.
randn
(
batch_size
,
expected_tokens
,
embed_dim
)
padding_embs
=
torch
.
zeros
(
1
,
1
,
embed_dim
)
result
,
tokens_truncated
=
adjust_audio_features_to_expected_length
(
audio_features
,
expected_tokens
,
padding_embs
)
assert
result
.
shape
==
audio_features
.
shape
assert
tokens_truncated
==
0
assert
torch
.
allclose
(
result
,
audio_features
)
def
test_original_bug_would_fail
(
self
):
"""Verify the original buggy implementation would cause overflow.
The original code always tried to pad, which fails when
audio_seq_len > expected_tokens because expand() gets negative size.
"""
batch_size
,
seq_len
,
embed_dim
=
1
,
192
,
256
expected_tokens
=
GEMMA3N_EXPECTED_AUDIO_TOKENS
padding_embs
=
torch
.
zeros
(
1
,
1
,
embed_dim
)
# Original buggy logic (always pads, never truncates)
extra_padding_tokens
=
expected_tokens
-
seq_len
# = -4 (negative!)
with
pytest
.
raises
(
RuntimeError
):
# This should fail with negative size error
padding_embs
.
expand
(
batch_size
,
extra_padding_tokens
,
embed_dim
)
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
50
,
100
,
150
,
187
,
188
,
189
,
192
,
200
,
300
],
)
def
test_various_audio_lengths
(
self
,
seq_len
:
int
):
"""Test padding/truncation with various audio lengths."""
batch_size
,
embed_dim
=
1
,
256
expected_tokens
=
GEMMA3N_EXPECTED_AUDIO_TOKENS
audio_features
=
torch
.
randn
(
batch_size
,
seq_len
,
embed_dim
)
padding_embs
=
torch
.
zeros
(
1
,
1
,
embed_dim
)
# Should not raise any errors
result
,
tokens_truncated
=
adjust_audio_features_to_expected_length
(
audio_features
,
expected_tokens
,
padding_embs
)
# Output should always be expected_tokens length
assert
result
.
shape
==
(
batch_size
,
expected_tokens
,
embed_dim
)
# Verify truncation count is correct
if
seq_len
>
expected_tokens
:
assert
tokens_truncated
==
seq_len
-
expected_tokens
else
:
assert
tokens_truncated
==
0
def
test_batch_processing
(
self
):
"""Test that batch processing works correctly."""
batch_size
,
seq_len
,
embed_dim
=
4
,
192
,
256
expected_tokens
=
GEMMA3N_EXPECTED_AUDIO_TOKENS
audio_features
=
torch
.
randn
(
batch_size
,
seq_len
,
embed_dim
)
padding_embs
=
torch
.
zeros
(
1
,
1
,
embed_dim
)
result
,
tokens_truncated
=
adjust_audio_features_to_expected_length
(
audio_features
,
expected_tokens
,
padding_embs
)
assert
result
.
shape
==
(
batch_size
,
expected_tokens
,
embed_dim
)
assert
tokens_truncated
==
seq_len
-
expected_tokens
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
GEMMA3_MODEL_ID
])
def
test_get_image_size_with_most_features
(
image_assets
:
ImageTestAssets
,
model_id
:
str
):
...
...
vllm/model_executor/models/gemma3n_audio_utils.py
0 → 100644
View file @
07286ec5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Lightweight utility functions for Gemma3n audio processing.
This module is separate from gemma3n_mm.py to avoid heavy CUDA dependencies,
making it testable without a full vLLM build.
"""
import
torch
def
adjust_audio_features_to_expected_length
(
audio_features
:
torch
.
Tensor
,
expected_tokens
:
int
,
audio_padding_embs
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
"""Adjust audio features to expected token length via padding or truncation.
The Gemma3nProcessor expects all audio will be ~30s in length and inserts
a fixed number of audio soft tokens into the text. However, the audio
preprocessing and encoder do not guarantee they will produce exactly that
many soft tokens; they may produce fewer tokens (for shorter audio) or more
tokens (for longer audio or due to BOA/EOA special tokens).
This function handles both cases:
- If fewer tokens: pad with the provided padding embeddings
- If more tokens: truncate to the expected count
Args:
audio_features: Audio embeddings tensor of shape
(batch_size, seq_len, embed_dim)
expected_tokens: The expected number of audio tokens (e.g., 188)
audio_padding_embs: Padding embeddings tensor of shape (1, 1, embed_dim)
Returns:
Tuple of:
- adjusted_features: Audio features adjusted to expected_tokens length
- tokens_truncated: Number of tokens truncated (0 if padding was applied)
"""
audio_batch_size
,
audio_seq_len
,
audio_embed_dim
=
audio_features
.
shape
tokens_truncated
=
0
if
audio_seq_len
<
expected_tokens
:
# Pad to expected length with padding embeddings
extra_padding_tokens
=
expected_tokens
-
audio_seq_len
extra_padding_features
=
audio_padding_embs
.
expand
(
audio_batch_size
,
extra_padding_tokens
,
audio_embed_dim
)
audio_features
=
torch
.
cat
((
audio_features
,
extra_padding_features
),
dim
=
1
)
elif
audio_seq_len
>
expected_tokens
:
# Truncate to expected length (audio encoder produced more tokens
# than expected, e.g., due to longer audio or placeholder mismatch)
tokens_truncated
=
audio_seq_len
-
expected_tokens
audio_features
=
audio_features
[:,
:
expected_tokens
,
:]
return
audio_features
,
tokens_truncated
vllm/model_executor/models/gemma3n_mm.py
View file @
07286ec5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Annotated
,
Any
,
Literal
,
Optional
,
Union
,
cast
from
typing
import
Annotated
,
Any
,
Literal
,
cast
import
numpy
as
np
import
torch
from
torch
import
nn
from
transformers
import
AutoModel
,
BatchFeature
from
transformers.models.gemma3n
import
(
...
...
@@ -26,6 +25,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from
vllm.model_executor.layers.linear
import
RowParallelLinear
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.models.gemma3n
import
Gemma3nForCausalLM
from
vllm.model_executor.models.gemma3n_audio_utils
import
(
adjust_audio_features_to_expected_length
,
)
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.whisper
import
ISO639_1_SUPPORTED_LANGS
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
...
@@ -105,12 +107,12 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
def
get_hf_processor
(
self
,
**
kwargs
:
object
):
return
self
.
ctx
.
get_hf_processor
(
Gemma3nProcessor
,
**
kwargs
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]
]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
return
{
"image"
:
None
,
"audio"
:
None
}
def
get_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]
)
->
Optional
[
Mapping
[
str
,
int
]
]
:
)
->
Mapping
[
str
,
int
]
|
None
:
return
{
"image"
:
TOKENS_PER_IMAGE
,
"audio"
:
TOKENS_PER_AUDIO
}
def
get_image_repl
(
...
...
@@ -118,7 +120,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
Optional
[
Gemma3nProcessor
]
,
processor
:
Gemma3nProcessor
|
None
,
)
->
str
:
"""
Get the replacement text for image tokens.
...
...
@@ -136,7 +138,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
def
get_audio_repl
(
self
,
*
,
processor
:
Optional
[
Gemma3nProcessor
]
,
processor
:
Gemma3nProcessor
|
None
,
)
->
str
:
"""
Get the replacement text for audio tokens.
...
...
@@ -168,7 +170,7 @@ class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_options
:
Optional
[
Mapping
[
str
,
BaseDummyOptions
]
]
=
None
,
mm_options
:
Mapping
[
str
,
BaseDummyOptions
]
|
None
=
None
,
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_audios
=
mm_counts
.
get
(
"audio"
,
0
)
...
...
@@ -387,7 +389,7 @@ class Gemma3nMultimodalEmbedder(nn.Module):
def
__init__
(
self
,
multimodal_config
:
Union
[
Gemma3nAudioConfig
,
Gemma3nVisionConfig
]
,
multimodal_config
:
Gemma3nAudioConfig
|
Gemma3nVisionConfig
,
text_config
:
Gemma3nTextConfig
,
):
super
().
__init__
()
...
...
@@ -427,8 +429,8 @@ class Gemma3nMultimodalEmbedder(nn.Module):
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
input_ids
:
torch
.
LongTensor
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
"""Embeds token ids or soft tokens for multimodal content into language model space.
...
...
@@ -529,7 +531,7 @@ class Gemma3nForConditionalGeneration(
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
Gemma3nImageInputs
]
:
)
->
Gemma3nImageInputs
|
None
:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
# TODO is this the case?
...
...
@@ -541,7 +543,7 @@ class Gemma3nForConditionalGeneration(
def
_parse_and_validate_audio_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
Gemma3nAudioInputs
]
:
)
->
Gemma3nAudioInputs
|
None
:
input_features_padded
=
kwargs
.
pop
(
"input_features_padded"
,
None
)
if
input_features_padded
is
None
:
return
None
...
...
@@ -616,12 +618,15 @@ class Gemma3nForConditionalGeneration(
)
audio_features
=
self
.
embed_audio
(
inputs_embeds
=
audio_outputs
)
# ruff: noqa
# The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
# text to account for this. However, the audio preprocessing and encoder do not guarantee they will
# produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
# depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
# the audio feature out to 188 soft tokens with the embedding of the last token in the embed_audio vocab.
# The Gemma3nProcessor expects all audio will be 30s in length and
# inserts 188 audio soft tokens into the text to account for this.
# However, the audio preprocessing and encoder do not guarantee they
# will produce exactly 188 soft tokens; they may produce fewer tokens
# (for shorter audio) or more tokens (for longer audio or due to
# BOA/EOA special tokens in the placeholder sequence).
# We handle both cases:
# - If fewer tokens: pad with the embedding of the last vocab token
# - If more tokens: truncate to the expected count
# TODO precompute and cache padding
audio_padding_toks
=
torch
.
tensor
(
[[
self
.
vocab_size
-
1
]],
dtype
=
torch
.
long
,
device
=
audio_features
.
device
...
...
@@ -631,13 +636,18 @@ class Gemma3nForConditionalGeneration(
audio_mask
.
unsqueeze
(
-
1
),
audio_padding_embs
,
audio_features
)
audio_batch_size
,
audio_seq_len
,
audio_embed_dim
=
audio_features
.
shape
extra_padding_tokens
=
self
.
config
.
audio_soft_tokens_per_image
-
audio_seq_len
# noqa: E501
extra_padding_features
=
audio_padding_embs
.
expand
(
audio_batch_size
,
extra_padding_tokens
,
audio_embed_dim
expected_tokens
=
self
.
config
.
audio_soft_tokens_per_image
audio_features
,
tokens_truncated
=
adjust_audio_features_to_expected_length
(
audio_features
,
expected_tokens
,
audio_padding_embs
)
if
tokens_truncated
>
0
:
logger
.
warning
(
"Gemma3n audio encoder produced %d extra tokens. "
"Truncating to match placeholder count of %d."
,
tokens_truncated
,
expected_tokens
,
)
audio_features
=
torch
.
cat
((
audio_features
,
extra_padding_features
),
dim
=
1
)
# Return a list of embeddings instead of a batched tensor
return
audio_features
.
unbind
(
0
)
...
...
@@ -666,9 +676,9 @@ class Gemma3nForConditionalGeneration(
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
is_multimodal
:
Optional
[
torch
.
Tensor
]
=
None
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
...
...
@@ -701,8 +711,8 @@ class Gemma3nForConditionalGeneration(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
:
object
,
)
->
IntermediateTensors
:
if
intermediate_tensors
is
not
None
:
...
...
@@ -729,7 +739,7 @@ class Gemma3nForConditionalGeneration(
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
Optional
[
torch
.
Tensor
]
:
)
->
torch
.
Tensor
|
None
:
return
self
.
language_model
.
compute_logits
(
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
...
...
@@ -747,7 +757,7 @@ class Gemma3nForConditionalGeneration(
)
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
Optional
[
str
]
:
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
str
|
None
:
if
modality
==
"image"
:
return
"<image_soft_token>"
elif
modality
==
"audio"
:
...
...
@@ -761,10 +771,10 @@ class Gemma3nForConditionalGeneration(
audio
:
np
.
ndarray
,
stt_config
:
SpeechToTextConfig
,
model_config
:
ModelConfig
,
language
:
Optional
[
str
]
,
language
:
str
|
None
,
task_type
:
Literal
[
"transcribe"
,
"translate"
],
request_prompt
:
str
,
to_language
:
Optional
[
str
]
,
to_language
:
str
|
None
,
)
->
PromptType
:
"""
Gemma3n supports "free-form" transcription.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment