Unverified Commit 5a16fa61 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Model] Gemma3n MM (#20495)


Signed-off-by: default avatarShriKode <shrikode@gmail.com>
Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
Signed-off-by: default avatarRoger Wang <hey@rogerw.me>
Co-authored-by: default avatarShriKode <shrikode@gmail.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.me>
parent 2d18256e
...@@ -349,7 +349,7 @@ th { ...@@ -349,7 +349,7 @@ th {
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma3nForConditionalGeneration` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | | `Gemma3nForCausalLM` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
| `GlmForCausalLM` | GLM-4 | `zai-org/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GlmForCausalLM` | GLM-4 | `zai-org/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4ForCausalLM` | GLM-4-0414 | `zai-org/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Glm4ForCausalLM` | GLM-4-0414 | `zai-org/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4MoeForCausalLM` | GLM-4.5 | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Glm4MoeForCausalLM` | GLM-4.5 | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ |
...@@ -412,9 +412,6 @@ th { ...@@ -412,9 +412,6 @@ th {
!!! note !!! note
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
!!! note
Only text inputs are currently supported for `Gemma3nForConditionalGeneration`. To use this model, please upgrade Hugging Face Transformers to version 4.53.0.
### Pooling Models ### Pooling Models
See [this page](./pooling_models.md) for more information on how to use pooling models. See [this page](./pooling_models.md) for more information on how to use pooling models.
...@@ -608,6 +605,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen ...@@ -608,6 +605,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | | | `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ |
...@@ -677,6 +675,15 @@ Some models are supported only via the [Transformers backend](#transformers). Th ...@@ -677,6 +675,15 @@ Some models are supported only via the [Transformers backend](#transformers). Th
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends. This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
!!! note
`Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its
MobileNet-v5 vision backbone.
Performance is not yet fully optimized mainly due to:
- Both audio and vision MM encoders use `transformers.AutoModel` implementation.
- There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups.
!!! note !!! note
Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently. Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently.
......
...@@ -96,6 +96,25 @@ def run_voxtral(question: str, audio_count: int) -> ModelRequestData: ...@@ -96,6 +96,25 @@ def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
) )
# Gemma3N
def run_gemma3n(question: str, audio_count: int) -> ModelRequestData:
model_name = "google/gemma-3n-E2B-it"
engine_args = EngineArgs(
model=model_name,
max_model_len=2048,
max_num_batched_tokens=2048,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
enforce_eager=True,
)
prompt = f"<start_of_turn>user\n<audio_soft_token>{question}"
"<end_of_turn>\n<start_of_turn>model\n"
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Granite Speech # Granite Speech
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData: def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
# NOTE - the setting in this example are somehat different than what is # NOTE - the setting in this example are somehat different than what is
...@@ -331,6 +350,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData: ...@@ -331,6 +350,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
model_example_map = { model_example_map = {
"voxtral": run_voxtral, "voxtral": run_voxtral,
"gemma3n": run_gemma3n,
"granite_speech": run_granite_speech, "granite_speech": run_granite_speech,
"minicpmo": run_minicpmo, "minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm, "phi4_mm": run_phi4mm,
......
...@@ -211,7 +211,33 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData: ...@@ -211,7 +211,33 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
) )
for question in questions for question in questions
] ]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Gemma3N
def run_gemma3n(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "google/gemma-3n-E2B-it"
engine_args = EngineArgs(
model=model_name,
max_model_len=2048,
max_num_seqs=2,
limit_mm_per_prompt={modality: 1},
enforce_eager=True,
)
prompts = [
(
"<start_of_turn>user\n"
f"<image_soft_token>{question}<end_of_turn>\n"
"<start_of_turn>model\n"
)
for question in questions
]
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
prompts=prompts, prompts=prompts,
...@@ -1395,6 +1421,7 @@ model_example_map = { ...@@ -1395,6 +1421,7 @@ model_example_map = {
"florence2": run_florence2, "florence2": run_florence2,
"fuyu": run_fuyu, "fuyu": run_fuyu,
"gemma3": run_gemma3, "gemma3": run_gemma3,
"gemma3n": run_gemma3n,
"glm4v": run_glm4v, "glm4v": run_glm4v,
"glm4_1v": run_glm4_1v, "glm4_1v": run_glm4_1v,
"h2ovl_chat": run_h2ovl, "h2ovl_chat": run_h2ovl,
......
...@@ -21,7 +21,7 @@ ray[cgraph,default]>=2.48.0 # Ray Compiled Graph, required by pipeline paralleli ...@@ -21,7 +21,7 @@ ray[cgraph,default]>=2.48.0 # Ray Compiled Graph, required by pipeline paralleli
sentence-transformers # required for embedding tests sentence-transformers # required for embedding tests
soundfile # required for audio tests soundfile # required for audio tests
jiwer # required for audio tests jiwer # required for audio tests
timm # required for internvl test timm >=1.0.17 # required for internvl and gemma3n-mm test
torch==2.7.1 torch==2.7.1
torchaudio==2.7.1 torchaudio==2.7.1
torchvision==0.22.1 torchvision==0.22.1
......
...@@ -1051,7 +1051,7 @@ tiktoken==0.7.0 ...@@ -1051,7 +1051,7 @@ tiktoken==0.7.0
# via # via
# lm-eval # lm-eval
# mistral-common # mistral-common
timm==1.0.15 timm==1.0.17
# via # via
# -r requirements/test.in # -r requirements/test.in
# open-clip-torch # open-clip-torch
......
...@@ -271,6 +271,7 @@ def _test_processing_correctness_one( ...@@ -271,6 +271,7 @@ def _test_processing_correctness_one(
"microsoft/Florence-2-base", "microsoft/Florence-2-base",
"adept/fuyu-8b", "adept/fuyu-8b",
"google/gemma-3-4b-it", "google/gemma-3-4b-it",
"google/gemma-3n-E2B-it",
"zai-org/glm-4v-9b", "zai-org/glm-4v-9b",
"zai-org/GLM-4.1V-9B-Thinking", "zai-org/GLM-4.1V-9B-Thinking",
"ibm-granite/granite-speech-3.3-2b", "ibm-granite/granite-speech-3.3-2b",
...@@ -315,7 +316,7 @@ def _test_processing_correctness_one( ...@@ -315,7 +316,7 @@ def _test_processing_correctness_one(
"fixie-ai/ultravox-v0_5-llama-3_2-1b", "fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3", "openai/whisper-large-v3",
"omni-research/Tarsier-7b", "omni-research/Tarsier-7b",
"omni-research/Tarsier2-Recap-7b" "omni-research/Tarsier2-Recap-7b",
]) ])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32]) @pytest.mark.parametrize("num_batches", [32])
...@@ -327,6 +328,8 @@ def test_processing_correctness( ...@@ -327,6 +328,8 @@ def test_processing_correctness(
num_batches: int, num_batches: int,
simplify_rate: float, simplify_rate: float,
): ):
if model_id == "google/gemma-3n-E2B-it":
pytest.skip("Skipping gemma-3n-E2B-it due to transformers #39911 bug.")
_test_processing_correctness( _test_processing_correctness(
model_id, model_id,
hit_rate=hit_rate, hit_rate=hit_rate,
......
...@@ -186,7 +186,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -186,7 +186,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501 "Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it",
min_transformers_version="4.53"), min_transformers_version="4.53"),
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"), "GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
"Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"), "Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"),
...@@ -391,6 +391,8 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -391,6 +391,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), "Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
min_transformers_version="4.53"),
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-2b"), # noqa: E501 "GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-2b"), # noqa: E501
"GLM4VForCausalLM": _HfExamplesInfo("zai-org/glm-4v-9b", "GLM4VForCausalLM": _HfExamplesInfo("zai-org/glm-4v-9b",
trust_remote_code=True, trust_remote_code=True,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm import LLM, envs
from vllm.sampling_params import SamplingParams
if not envs.VLLM_USE_V1:
pytest.skip(
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
allow_module_level=True,
)
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
# TODO TPU will appear busy if we fan-out test params here
@pytest.mark.parametrize("n_prompts", [1])
def test_logprobs(model_name: str, n_prompts: int):
"""
Request top logprobs with different sampling settings and check
that results contains the requested number, ordered ascendingly.
"""
def check_num_logprobs(logprobs, expected_num: int):
for step in logprobs:
prev_logp = 1.0
# order by rank
sorted_step = dict(
sorted(step.items(), key=lambda item: item[1].rank))
if len(step) != expected_num:
print("watch out", sorted_step)
# check results are ordered by prob value
# assert len(step) == expected_num
for rankno, (tid, logp) in enumerate(sorted_step.items()):
assert logp.logprob <= prev_logp
prev_logp = logp.logprob
assert logp.rank == rankno + 1
llm = LLM(model_name,
enforce_eager=False,
max_num_seqs=1,
max_model_len=128,
max_num_batched_tokens=128)
prompts = [
"Write a short story about a robot that dreams for the first time."
] * n_prompts
greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64,\
logprobs=4)
regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
logprobs=4)
topkp_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
logprobs=4, top_k=12, top_p=0.5)
for sp in [greedy_sampling_params, regular_sampling_params, \
topkp_sampling_params]:
output = llm.generate(prompts, sp)
for o in output:
check_num_logprobs(o.outputs[0].logprobs, 4)
...@@ -331,14 +331,15 @@ class Gemma3nAttention(nn.Module): ...@@ -331,14 +331,15 @@ class Gemma3nAttention(nn.Module):
config.num_kv_shared_layers) config.num_kv_shared_layers)
self.is_kv_shared = layer_idx >= first_kv_shared_layer_idx self.is_kv_shared = layer_idx >= first_kv_shared_layer_idx
kv_sharing_target_layer_name = None
if self.is_kv_shared: if self.is_kv_shared:
# Last full attention layer is 1 before sharing # Last full attention layer is 1 before sharing
# Last sliding attention layer is 2 before sharing # Last sliding attention layer is 2 before sharing
offset = 2 if self.sliding_window is not None else 1 offset = 2 if self.sliding_window is not None else 1
kv_shared_layer_index = first_kv_shared_layer_idx - offset kv_shared_layer_index = first_kv_shared_layer_idx - offset
kv_sharing_target_layer_name = f"model.language_model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501 if kv_shared_layer_index >= 0:
else: # Only the greater layer is required to specify sharing.
kv_sharing_target_layer_name = None kv_sharing_target_layer_name = f"language_model.model.layers.{kv_shared_layer_index}.self_attn.attn" # noqa: E501
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -396,6 +397,7 @@ class Gemma3nDecoderLayer(nn.Module): ...@@ -396,6 +397,7 @@ class Gemma3nDecoderLayer(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
assert isinstance(config, Gemma3nTextConfig)
self.altup_active_idx = config.altup_active_idx self.altup_active_idx = config.altup_active_idx
assert config.altup_correct_scale assert config.altup_correct_scale
...@@ -537,7 +539,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): ...@@ -537,7 +539,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config.text_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
...@@ -553,6 +555,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): ...@@ -553,6 +555,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
config.hidden_size**0.5, config.hidden_size**0.5,
dtype=self.embed_tokens.weight.dtype, dtype=self.embed_tokens.weight.dtype,
) )
# Additional per-layer embeddings (PLE)
self.embed_tokens_per_layer = VocabParallelEmbedding( self.embed_tokens_per_layer = VocabParallelEmbedding(
config.vocab_size_per_layer_input, config.vocab_size_per_layer_input,
config.num_hidden_layers * config.hidden_size_per_layer_input, config.num_hidden_layers * config.hidden_size_per_layer_input,
...@@ -636,6 +639,8 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): ...@@ -636,6 +639,8 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
per_layer_inputs: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -644,13 +649,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): ...@@ -644,13 +649,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
else: else:
hidden_states_0 = self.get_input_embeddings(input_ids) hidden_states_0 = self.get_input_embeddings(input_ids)
# Per layer inputs.
if input_ids is None:
raise ValueError("Passing None for input ids is not supported.")
per_layer_inputs = self.get_per_layer_input_embeddings(input_ids)
per_layer_inputs = per_layer_inputs.reshape(
-1, self.config.num_hidden_layers,
self.config.hidden_size_per_layer_input)
per_layer_projection = self.per_layer_model_projection(hidden_states_0) per_layer_projection = self.per_layer_model_projection(hidden_states_0)
per_layer_projection = per_layer_projection.reshape( per_layer_projection = per_layer_projection.reshape(
*hidden_states_0.shape[:-1], *hidden_states_0.shape[:-1],
...@@ -659,8 +657,13 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): ...@@ -659,8 +657,13 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
) )
per_layer_projection = self.per_layer_projection_norm( per_layer_projection = self.per_layer_projection_norm(
per_layer_projection) per_layer_projection)
per_layer_inputs = per_layer_projection + per_layer_inputs
per_layer_inputs *= self.per_layer_input_scale if per_layer_inputs is not None:
# Profiling run does not compute per_layer_inputs
per_layer_inputs = per_layer_projection + per_layer_inputs
per_layer_inputs *= self.per_layer_input_scale
else:
per_layer_inputs = per_layer_projection
# Altup embed. # Altup embed.
hidden_states = [hidden_states_0] * self.config.altup_num_inputs hidden_states = [hidden_states_0] * self.config.altup_num_inputs
...@@ -760,29 +763,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant): ...@@ -760,29 +763,7 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
return loaded_params return loaded_params
class Gemma3nModel(nn.Module): class Gemma3nForCausalLM(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.language_model = Gemma3nTextModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "language_model"))
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
return self.language_model(input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
**kwargs)
class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -802,25 +783,33 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant): ...@@ -802,25 +783,33 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant):
super().__init__() super().__init__()
self.config = config self.config = config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.model = Gemma3nModel(vllm_config=vllm_config, self.model = Gemma3nTextModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
config.text_config.vocab_size, config.vocab_size, soft_cap=config.final_logit_softcapping)
soft_cap=config.text_config.final_logit_softcapping)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.language_model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
*,
per_layer_inputs: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds, **kwargs) hidden_states = self.model(
input_ids,
positions,
per_layer_inputs=per_layer_inputs,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**kwargs,
)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -828,8 +817,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant): ...@@ -828,8 +817,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsQuant):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: Optional[SamplingMetadata], sampling_metadata: Optional[SamplingMetadata],
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.model.language_model.embed_tokens, logits = self.logits_processor(self.model.embed_tokens, hidden_states,
hidden_states, sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
......
This diff is collapsed.
...@@ -69,8 +69,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -69,8 +69,7 @@ _TEXT_GENERATION_MODELS = {
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
#TODO(ywang96): Support multimodal gemma3n "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
"Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"), # noqa: E501
"GlmForCausalLM": ("glm", "GlmForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"),
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
"Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"), "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
...@@ -205,6 +204,7 @@ _MULTIMODAL_MODELS = { ...@@ -205,6 +204,7 @@ _MULTIMODAL_MODELS = {
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
"Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"), # noqa: E501
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501 "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501 "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment