"docs/source/performance/optimization.md" did not exist on "68d37809b9b52f4d012fa0dfbb187f0fe978bdbc"
Unverified Commit 07286ec5 authored by Jeremy Teboul's avatar Jeremy Teboul Committed by GitHub
Browse files

[Bugfix] Fix integer overflow in Gemma3n audio processing (#31657)


Signed-off-by: default avatarJeremy Teboul <jeremyte@meta.com>
parent 14fc7a68
......@@ -2,14 +2,154 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.model_executor.models.gemma3n_audio_utils import (
adjust_audio_features_to_expected_length,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import ImageTestAssets
from ...utils import build_model_context
# Gemma3 (image) model
GEMMA3_MODEL_ID = "google/gemma-3-4b-it"
@pytest.mark.parametrize("model_id", ["google/gemma-3-4b-it"])
# Gemma3n (multimodal with audio) model
GEMMA3N_MODEL_ID = "google/gemma-3n-E2B-it"
# Expected audio tokens for Gemma3n (audio_soft_tokens_per_image)
GEMMA3N_EXPECTED_AUDIO_TOKENS = 188
class TestGemma3nAudioTensorLogic:
"""CPU-based tests for Gemma3n audio feature tensor manipulation.
These tests validate the padding/truncation logic in
adjust_audio_features_to_expected_length() which fixes the
integer overflow in _process_audio_input when audio_seq_len > 188.
"""
def test_padding_when_audio_short(self):
"""Test that short audio is padded to expected length."""
batch_size, seq_len, embed_dim = 1, 100, 256
expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS
audio_features = torch.randn(batch_size, seq_len, embed_dim)
padding_embs = torch.zeros(1, 1, embed_dim)
result, tokens_truncated = adjust_audio_features_to_expected_length(
audio_features, expected_tokens, padding_embs
)
assert result.shape == (batch_size, expected_tokens, embed_dim)
assert tokens_truncated == 0
# First 100 tokens should be original, rest should be padding (zeros)
assert torch.allclose(result[:, :seq_len, :], audio_features)
assert torch.allclose(
result[:, seq_len:, :],
torch.zeros(batch_size, expected_tokens - seq_len, embed_dim),
)
def test_truncation_when_audio_long(self):
"""Test that long audio is truncated to expected length.
This is the key test for the overflow fix. Previously, when
audio_seq_len > expected_tokens, the code would compute a negative
padding value causing: RuntimeError: numel: integer multiplication overflow
"""
batch_size, seq_len, embed_dim = 1, 192, 256 # 192 > 188
expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS
audio_features = torch.randn(batch_size, seq_len, embed_dim)
padding_embs = torch.zeros(1, 1, embed_dim)
result, tokens_truncated = adjust_audio_features_to_expected_length(
audio_features, expected_tokens, padding_embs
)
assert result.shape == (batch_size, expected_tokens, embed_dim)
assert tokens_truncated == seq_len - expected_tokens # 192 - 188 = 4
# Result should be first 188 tokens of original
assert torch.allclose(result, audio_features[:, :expected_tokens, :])
def test_no_change_when_exact_length(self):
"""Test that exact-length audio passes through unchanged."""
batch_size, embed_dim = 1, 256
expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS
audio_features = torch.randn(batch_size, expected_tokens, embed_dim)
padding_embs = torch.zeros(1, 1, embed_dim)
result, tokens_truncated = adjust_audio_features_to_expected_length(
audio_features, expected_tokens, padding_embs
)
assert result.shape == audio_features.shape
assert tokens_truncated == 0
assert torch.allclose(result, audio_features)
def test_original_bug_would_fail(self):
"""Verify the original buggy implementation would cause overflow.
The original code always tried to pad, which fails when
audio_seq_len > expected_tokens because expand() gets negative size.
"""
batch_size, seq_len, embed_dim = 1, 192, 256
expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS
padding_embs = torch.zeros(1, 1, embed_dim)
# Original buggy logic (always pads, never truncates)
extra_padding_tokens = expected_tokens - seq_len # = -4 (negative!)
with pytest.raises(RuntimeError):
# This should fail with negative size error
padding_embs.expand(batch_size, extra_padding_tokens, embed_dim)
@pytest.mark.parametrize(
"seq_len",
[50, 100, 150, 187, 188, 189, 192, 200, 300],
)
def test_various_audio_lengths(self, seq_len: int):
"""Test padding/truncation with various audio lengths."""
batch_size, embed_dim = 1, 256
expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS
audio_features = torch.randn(batch_size, seq_len, embed_dim)
padding_embs = torch.zeros(1, 1, embed_dim)
# Should not raise any errors
result, tokens_truncated = adjust_audio_features_to_expected_length(
audio_features, expected_tokens, padding_embs
)
# Output should always be expected_tokens length
assert result.shape == (batch_size, expected_tokens, embed_dim)
# Verify truncation count is correct
if seq_len > expected_tokens:
assert tokens_truncated == seq_len - expected_tokens
else:
assert tokens_truncated == 0
def test_batch_processing(self):
"""Test that batch processing works correctly."""
batch_size, seq_len, embed_dim = 4, 192, 256
expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS
audio_features = torch.randn(batch_size, seq_len, embed_dim)
padding_embs = torch.zeros(1, 1, embed_dim)
result, tokens_truncated = adjust_audio_features_to_expected_length(
audio_features, expected_tokens, padding_embs
)
assert result.shape == (batch_size, expected_tokens, embed_dim)
assert tokens_truncated == seq_len - expected_tokens
@pytest.mark.parametrize("model_id", [GEMMA3_MODEL_ID])
def test_get_image_size_with_most_features(
image_assets: ImageTestAssets, model_id: str
):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Lightweight utility functions for Gemma3n audio processing.
This module is separate from gemma3n_mm.py to avoid heavy CUDA dependencies,
making it testable without a full vLLM build.
"""
import torch
def adjust_audio_features_to_expected_length(
audio_features: torch.Tensor,
expected_tokens: int,
audio_padding_embs: torch.Tensor,
) -> tuple[torch.Tensor, int]:
"""Adjust audio features to expected token length via padding or truncation.
The Gemma3nProcessor expects all audio will be ~30s in length and inserts
a fixed number of audio soft tokens into the text. However, the audio
preprocessing and encoder do not guarantee they will produce exactly that
many soft tokens; they may produce fewer tokens (for shorter audio) or more
tokens (for longer audio or due to BOA/EOA special tokens).
This function handles both cases:
- If fewer tokens: pad with the provided padding embeddings
- If more tokens: truncate to the expected count
Args:
audio_features: Audio embeddings tensor of shape
(batch_size, seq_len, embed_dim)
expected_tokens: The expected number of audio tokens (e.g., 188)
audio_padding_embs: Padding embeddings tensor of shape (1, 1, embed_dim)
Returns:
Tuple of:
- adjusted_features: Audio features adjusted to expected_tokens length
- tokens_truncated: Number of tokens truncated (0 if padding was applied)
"""
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
tokens_truncated = 0
if audio_seq_len < expected_tokens:
# Pad to expected length with padding embeddings
extra_padding_tokens = expected_tokens - audio_seq_len
extra_padding_features = audio_padding_embs.expand(
audio_batch_size, extra_padding_tokens, audio_embed_dim
)
audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
elif audio_seq_len > expected_tokens:
# Truncate to expected length (audio encoder produced more tokens
# than expected, e.g., due to longer audio or placeholder mismatch)
tokens_truncated = audio_seq_len - expected_tokens
audio_features = audio_features[:, :expected_tokens, :]
return audio_features, tokens_truncated
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, Optional, Union, cast
from typing import Annotated, Any, Literal, cast
import numpy as np
import torch
from torch import nn
from transformers import AutoModel, BatchFeature
from transformers.models.gemma3n import (
......@@ -26,6 +25,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM
from vllm.model_executor.models.gemma3n_audio_utils import (
adjust_audio_features_to_expected_length,
)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS
from vllm.multimodal import MULTIMODAL_REGISTRY
......@@ -105,12 +107,12 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(Gemma3nProcessor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None, "audio": None}
def get_max_tokens_per_item(
self, seq_len: int, mm_counts: Mapping[str, int]
) -> Optional[Mapping[str, int]]:
) -> Mapping[str, int] | None:
return {"image": TOKENS_PER_IMAGE, "audio": TOKENS_PER_AUDIO}
def get_image_repl(
......@@ -118,7 +120,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: Optional[Gemma3nProcessor],
processor: Gemma3nProcessor | None,
) -> str:
"""
Get the replacement text for image tokens.
......@@ -136,7 +138,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
def get_audio_repl(
self,
*,
processor: Optional[Gemma3nProcessor],
processor: Gemma3nProcessor | None,
) -> str:
"""
Get the replacement text for audio tokens.
......@@ -168,7 +170,7 @@ class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_audios = mm_counts.get("audio", 0)
......@@ -387,7 +389,7 @@ class Gemma3nMultimodalEmbedder(nn.Module):
def __init__(
self,
multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
multimodal_config: Gemma3nAudioConfig | Gemma3nVisionConfig,
text_config: Gemma3nTextConfig,
):
super().__init__()
......@@ -427,8 +429,8 @@ class Gemma3nMultimodalEmbedder(nn.Module):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
input_ids: torch.LongTensor | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
"""Embeds token ids or soft tokens for multimodal content into language model space.
......@@ -529,7 +531,7 @@ class Gemma3nForConditionalGeneration(
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Optional[Gemma3nImageInputs]:
) -> Gemma3nImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
# TODO is this the case?
......@@ -541,7 +543,7 @@ class Gemma3nForConditionalGeneration(
def _parse_and_validate_audio_input(
self, **kwargs: object
) -> Optional[Gemma3nAudioInputs]:
) -> Gemma3nAudioInputs | None:
input_features_padded = kwargs.pop("input_features_padded", None)
if input_features_padded is None:
return None
......@@ -616,12 +618,15 @@ class Gemma3nForConditionalGeneration(
)
audio_features = self.embed_audio(inputs_embeds=audio_outputs)
# ruff: noqa
# The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
# text to account for this. However, the audio preprocessing and encoder do not guarantee they will
# produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
# depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
# the audio feature out to 188 soft tokens with the embedding of the last token in the embed_audio vocab.
# The Gemma3nProcessor expects all audio will be 30s in length and
# inserts 188 audio soft tokens into the text to account for this.
# However, the audio preprocessing and encoder do not guarantee they
# will produce exactly 188 soft tokens; they may produce fewer tokens
# (for shorter audio) or more tokens (for longer audio or due to
# BOA/EOA special tokens in the placeholder sequence).
# We handle both cases:
# - If fewer tokens: pad with the embedding of the last vocab token
# - If more tokens: truncate to the expected count
# TODO precompute and cache padding
audio_padding_toks = torch.tensor(
[[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device
......@@ -631,13 +636,18 @@ class Gemma3nForConditionalGeneration(
audio_mask.unsqueeze(-1), audio_padding_embs, audio_features
)
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len # noqa: E501
extra_padding_features = audio_padding_embs.expand(
audio_batch_size, extra_padding_tokens, audio_embed_dim
expected_tokens = self.config.audio_soft_tokens_per_image
audio_features, tokens_truncated = adjust_audio_features_to_expected_length(
audio_features, expected_tokens, audio_padding_embs
)
if tokens_truncated > 0:
logger.warning(
"Gemma3n audio encoder produced %d extra tokens. "
"Truncating to match placeholder count of %d.",
tokens_truncated,
expected_tokens,
)
audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
# Return a list of embeddings instead of a batched tensor
return audio_features.unbind(0)
......@@ -666,9 +676,9 @@ class Gemma3nForConditionalGeneration(
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
......@@ -701,8 +711,8 @@ class Gemma3nForConditionalGeneration(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> IntermediateTensors:
if intermediate_tensors is not None:
......@@ -729,7 +739,7 @@ class Gemma3nForConditionalGeneration(
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
......@@ -747,7 +757,7 @@ class Gemma3nForConditionalGeneration(
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality == "image":
return "<image_soft_token>"
elif modality == "audio":
......@@ -761,10 +771,10 @@ class Gemma3nForConditionalGeneration(
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: Optional[str],
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str],
to_language: str | None,
) -> PromptType:
"""
Gemma3n supports "free-form" transcription.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment