Unverified Commit 5a16fa61 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Model] Gemma3n MM (#20495)


Signed-off-by: default avatarShriKode <shrikode@gmail.com>
Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
Signed-off-by: default avatarRoger Wang <hey@rogerw.me>
Co-authored-by: default avatarShriKode <shrikode@gmail.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.me>
parent 2d18256e
......@@ -349,7 +349,7 @@ th {
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma3nForConditionalGeneration` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
| `Gemma3nForCausalLM` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
| `GlmForCausalLM` | GLM-4 | `zai-org/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4ForCausalLM` | GLM-4-0414 | `zai-org/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4MoeForCausalLM` | GLM-4.5 | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ |
......@@ -412,9 +412,6 @@ th {
!!! note
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
!!! note
Only text inputs are currently supported for `Gemma3nForConditionalGeneration`. To use this model, please upgrade Hugging Face Transformers to version 4.53.0.
### Pooling Models
See [this page](./pooling_models.md) for more information on how to use pooling models.
......@@ -608,6 +605,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ |
......@@ -677,6 +675,15 @@ Some models are supported only via the [Transformers backend](#transformers). Th
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
!!! note
`Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its
MobileNet-v5 vision backbone.
Performance is not yet fully optimized mainly due to:
- Both audio and vision MM encoders use `transformers.AutoModel` implementation.
- There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups.
!!! note
Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently.
......
......@@ -96,6 +96,25 @@ def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
)
# Gemma3N
def run_gemma3n(question: str, audio_count: int) -> ModelRequestData:
model_name = "google/gemma-3n-E2B-it"
engine_args = EngineArgs(
model=model_name,
max_model_len=2048,
max_num_batched_tokens=2048,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
enforce_eager=True,
)
prompt = f"<start_of_turn>user\n<audio_soft_token>{question}"
"<end_of_turn>\n<start_of_turn>model\n"
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Granite Speech
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
# NOTE - the setting in this example are somehat different than what is
......@@ -331,6 +350,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
model_example_map = {
"voxtral": run_voxtral,
"gemma3n": run_gemma3n,
"granite_speech": run_granite_speech,
"minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm,
......
......@@ -211,7 +211,33 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
)
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Gemma3N
def run_gemma3n(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "google/gemma-3n-E2B-it"
engine_args = EngineArgs(
model=model_name,
max_model_len=2048,
max_num_seqs=2,
limit_mm_per_prompt={modality: 1},
enforce_eager=True,
)
prompts = [
(
"<start_of_turn>user\n"
f"<image_soft_token>{question}<end_of_turn>\n"
"<start_of_turn>model\n"
)
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
......@@ -1395,6 +1421,7 @@ model_example_map = {
"florence2": run_florence2,
"fuyu": run_fuyu,
"gemma3": run_gemma3,
"gemma3n": run_gemma3n,
"glm4v": run_glm4v,
"glm4_1v": run_glm4_1v,
"h2ovl_chat": run_h2ovl,
......
......@@ -21,7 +21,7 @@ ray[cgraph,default]>=2.48.0 # Ray Compiled Graph, required by pipeline paralleli
sentence-transformers # required for embedding tests
soundfile # required for audio tests
jiwer # required for audio tests
timm # required for internvl test
timm >=1.0.17 # required for internvl and gemma3n-mm test
torch==2.7.1
torchaudio==2.7.1
torchvision==0.22.1
......
......@@ -1051,7 +1051,7 @@ tiktoken==0.7.0
# via
# lm-eval
# mistral-common
timm==1.0.15
timm==1.0.17
# via
# -r requirements/test.in
# open-clip-torch
......
......@@ -271,6 +271,7 @@ def _test_processing_correctness_one(
"microsoft/Florence-2-base",
"adept/fuyu-8b",
"google/gemma-3-4b-it",
"google/gemma-3n-E2B-it",
"zai-org/glm-4v-9b",
"zai-org/GLM-4.1V-9B-Thinking",
"ibm-granite/granite-speech-3.3-2b",
......@@ -315,7 +316,7 @@ def _test_processing_correctness_one(
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3",
"omni-research/Tarsier-7b",
"omni-research/Tarsier2-Recap-7b"
"omni-research/Tarsier2-Recap-7b",
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
......@@ -327,6 +328,8 @@ def test_processing_correctness(
num_batches: int,
simplify_rate: float,
):
if model_id == "google/gemma-3n-E2B-it":
pytest.skip("Skipping gemma-3n-E2B-it due to transformers #39911 bug.")
_test_processing_correctness(
model_id,
hit_rate=hit_rate,
......
......@@ -186,7 +186,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
"Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it",
min_transformers_version="4.53"),
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
"Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"),
......@@ -391,6 +391,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
min_transformers_version="4.53"),
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-2b"), # noqa: E501
"GLM4VForCausalLM": _HfExamplesInfo("zai-org/glm-4v-9b",
trust_remote_code=True,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm import LLM, envs
from vllm.sampling_params import SamplingParams
if not envs.VLLM_USE_V1:
pytest.skip(
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
allow_module_level=True,
)
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
# TODO TPU will appear busy if we fan-out test params here
@pytest.mark.parametrize("n_prompts", [1])
def test_logprobs(model_name: str, n_prompts: int):
"""
Request top logprobs with different sampling settings and check
that results contains the requested number, ordered ascendingly.
"""
def check_num_logprobs(logprobs, expected_num: int):
for step in logprobs:
prev_logp = 1.0
# order by rank
sorted_step = dict(
sorted(step.items(), key=lambda item: item[1].rank))
if len(step) != expected_num:
print("watch out", sorted_step)
# check results are ordered by prob value
# assert len(step) == expected_num
for rankno, (tid, logp) in enumerate(sorted_step.items()):
assert logp.logprob <= prev_logp
prev_logp = logp.logprob
assert logp.rank == rankno + 1
llm = LLM(model_name,
enforce_eager=False,
max_num_seqs=1,
max_model_len=128,
max_num_batched_tokens=128)
prompts = [
"Write a short story about a robot that dreams for the first time."
] * n_prompts
greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64,\
logprobs=4)
regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
logprobs=4)
topkp_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
logprobs=4, top_k=12, top_p=0.5)
for sp in [greedy_sampling_params, regular_sampling_params, \
topkp_sampling_params]:
output = llm.generate(prompts, sp)
for o in output:
check_num_logprobs(o.outputs[0].logprobs, 4)
......@@ -331,14 +331,15 @@ class Gemma3nAttention(nn.Module):
config.num_kv_shared_layers)
self.is_kv_shared = layer_idx >= first_kv_shared_layer_idx
kv_sharing_target_layer_name = None
if self.is_kv_shared:
# Last full attention layer is 1 before sharing
# Last sliding attention layer is 2 before sharing
offset = 2 if self.sliding_window is not None else 1
kv_shared_layer_index = first_kv_shared_layer_idx - offset
kv_sharing_target_layer_name = f"model.language_model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501
else:
kv_sharing_target_layer_name = None
if kv_shared_layer_index >= 0:
# Only the greater layer is required to specify sharing.
kv_sharing_target_layer_name = f"language_model.model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501
self.rotary_emb = get_rope(
self.head_dim,
......@@ -396,6 +397,7 @@ class Gemma3nDecoderLayer(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
assert isinstance(config, Gemma3nTextConfig)
self.altup_active_idx = config.altup_active_idx
assert config.altup_correct_scale
......@@ -537,7 +539,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config.text_config
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
......@@ -553,6 +555,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
config.hidden_size**0.5,
dtype=self.embed_tokens.weight.dtype,
)
# Additional per-layer embeddings (PLE)
self.embed_tokens_per_layer = VocabParallelEmbedding(
config.vocab_size_per_layer_input,
config.num_hidden_layers * config.hidden_size_per_layer_input,
......@@ -636,6 +639,8 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
per_layer_inputs: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
......@@ -644,13 +649,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
else:
hidden_states_0 = self.get_input_embeddings(input_ids)
# Per layer inputs.
if input_ids is None:
raise ValueError("Passing None for input ids is not supported.")
per_layer_inputs = self.get_per_layer_input_embeddings(input_ids)
per_layer_inputs = per_layer_inputs.reshape(
-1, self.config.num_hidden_layers,
self.config.hidden_size_per_layer_input)
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
per_layer_projection = per_layer_projection.reshape(
*hidden_states_0.shape[:-1],
......@@ -659,8 +657,13 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
)
per_layer_projection = self.per_layer_projection_norm(
per_layer_projection)
if per_layer_inputs is not None:
# Profiling run does not compute per_layer_inputs
per_layer_inputs = per_layer_projection + per_layer_inputs
per_layer_inputs *= self.per_layer_input_scale
else:
per_layer_inputs = per_layer_projection
# Altup embed.
hidden_states = [hidden_states_0] * self.config.altup_num_inputs
......@@ -760,29 +763,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
return loaded_params
class Gemma3nModel(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.language_model = Gemma3nTextModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "language_model"))
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
return self.language_model(input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
**kwargs)
class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant):
class Gemma3nForCausalLM(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......@@ -802,25 +783,33 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant):
super().__init__()
self.config = config
self.cache_config = vllm_config.cache_config
self.model = Gemma3nModel(vllm_config=vllm_config,
self.model = Gemma3nTextModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.logits_processor = LogitsProcessor(
config.text_config.vocab_size,
soft_cap=config.text_config.final_logit_softcapping)
config.vocab_size, soft_cap=config.final_logit_softcapping)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.language_model.get_input_embeddings(input_ids)
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
*,
per_layer_inputs: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds, **kwargs)
hidden_states = self.model(
input_ids,
positions,
per_layer_inputs=per_layer_inputs,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**kwargs,
)
return hidden_states
def compute_logits(
......@@ -828,8 +817,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant):
hidden_states: torch.Tensor,
sampling_metadata: Optional[SamplingMetadata],
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.model.language_model.embed_tokens,
hidden_states, sampling_metadata)
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
......
This diff is collapsed.
......@@ -69,8 +69,7 @@ _TEXT_GENERATION_MODELS = {
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
#TODO(ywang96): Support multimodal gemma3n
"Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"), # noqa: E501
"Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
"Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
......@@ -205,6 +204,7 @@ _MULTIMODAL_MODELS = {
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
"Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"), # noqa: E501
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment