Commit 858bddce authored by luopl's avatar luopl
Browse files

feat:add gemma4

parent 40faaf0c
...@@ -13,6 +13,7 @@ from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding ...@@ -13,6 +13,7 @@ from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding
from .fope import FourierRotaryEmbedding from .fope import FourierRotaryEmbedding
from .linear_scaling_rope import LinearScalingRotaryEmbedding from .linear_scaling_rope import LinearScalingRotaryEmbedding
from .gemma4_rope import Gemma4RotaryEmbedding
from .llama3_rope import Llama3RotaryEmbedding from .llama3_rope import Llama3RotaryEmbedding
from .llama4_vision_rope import Llama4VisionRotaryEmbedding from .llama4_vision_rope import Llama4VisionRotaryEmbedding
from .mrope import MRotaryEmbedding from .mrope import MRotaryEmbedding
...@@ -134,6 +135,17 @@ def get_rope( ...@@ -134,6 +135,17 @@ def get_rope(
is_neox_style, is_neox_style,
dtype, dtype,
) )
elif scaling_type == "proportional":
# Proportional RoPE is used by Gemma4 for global (full) attention.
# Gemma4 uses a sparse/fractional RoPE with cross-mixing between halves.
rotary_emb = Gemma4RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
elif scaling_type == "llama3": elif scaling_type == "llama3":
scaling_factor = rope_parameters["factor"] scaling_factor = rope_parameters["factor"]
low_freq_factor = rope_parameters["low_freq_factor"] low_freq_factor = rope_parameters["low_freq_factor"]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Gemma4-specific Rotary Positional Embeddings (proportional scaling).
Gemma4 uses "proportional" RoPE which computes inv_freq frequencies scaled
by head_dim (not rotary_dim), and zero-pads for non-rotated dimensions when
partial_rotary_factor < 1. The actual rotation uses standard neox-style
rotate_half, matching HF transformers' apply_rotary_pos_emb.
"""
import torch
from .base import RotaryEmbedding
class Gemma4RotaryEmbedding(RotaryEmbedding):
"""Gemma4 proportional RoPE.
Extends RotaryEmbedding (which provides standard neox-style rotation
via ops.rotary_embedding CUDA kernel) but overrides the inv_freq
computation to match HF's _compute_proportional_rope_parameters:
- Frequency exponents use head_dim (not rotary_dim) as denominator
- Non-rotated dims are zero-padded (cos=1, sin=0 = identity rotation)
When partial_rotary_factor=1.0 (the default for some variants), ALL dims are
rotated and this is equivalent to standard RotaryEmbedding with
head_dim-scaled frequencies.
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
# Number of rotation angle pairs (from partial_rotary_factor)
self.rope_angles = rotary_dim // 2
# Non-rotated angle pairs per half
self.nope_angles = (head_size // 2) - self.rope_angles
# Important: set rotary_dim = head_size so the base class's
# forward_static applies rotation to ALL dims of the cos/sin cache.
# The non-rotated dims will have cos=1, sin=0 (identity) thanks
# to our _compute_inv_freq zero-padding.
super().__init__(
head_size,
head_size, # rotary_dim = head_size (full application)
max_position_embeddings,
base,
is_neox_style,
dtype,
)
def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute frequencies matching HF proportional RoPE.
Key difference from base: exponent denominator is head_size (not
rotary_dim), and non-rotated dims are zero-padded.
"""
# HF formula: base ** (arange(0, 2*rope_angles, 2) / head_dim)
freq_exponents = (
torch.arange(0, 2 * self.rope_angles, 2, dtype=torch.float) / self.head_size
)
inv_freq = 1.0 / (base**freq_exponents)
# Zero-pad for non-rotated dims (identity rotation: cos=1, sin=0)
if self.nope_angles > 0:
inv_freq = torch.cat(
[
inv_freq,
torch.zeros(self.nope_angles, dtype=torch.float),
]
)
return inv_freq
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", rope_angles={self.rope_angles}, nope_angles={self.nope_angles}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
return s
...@@ -56,6 +56,57 @@ class Gemma3TextModelConfig(VerifyAndUpdateConfig): ...@@ -56,6 +56,57 @@ class Gemma3TextModelConfig(VerifyAndUpdateConfig):
hf_config = model_config.hf_config hf_config = model_config.hf_config
hf_config.is_causal = not hf_config.use_bidirectional_attention hf_config.is_causal = not hf_config.use_bidirectional_attention
class Gemma4Config(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
"""Force unified attention backend for models with heterogeneous
head dimensions.
Some Gemma4 variants use different head dimensions for
sliding window (head_dim) vs full attention (global_head_dim) layers.
When global_head_dim > 256, FlashAttention rejects those layers
(head_size <= 256 kernel limit), causing vLLM to select a different
backend for each layer type. This mixed-backend execution produces
numerical divergence and output corruption.
The fix detects heterogeneous head dimensions from the model config
and forces TRITON_ATTN (which has no head_size ceiling) for all
layers when the user hasn't explicitly chosen a backend.
TODO: Heterogeneous head_sizes (head_dim != global_head_dim)
require NixlConnector changes to support per-layer KV transfer
with different head dimensions for prefill-decode disaggregation.
"""
hf_text_config = vllm_config.model_config.hf_text_config
head_dim = getattr(hf_text_config, "head_dim", None)
global_head_dim = getattr(hf_text_config, "global_head_dim", None)
# Only force Triton when head dimensions actually differ AND the
# larger one exceeds FlashAttention's kernel limit (head_size <= 256).
# This avoids unnecessary backend forcing on smaller models where
# the config carries global_head_dim but all layers can still use
# the same FA backend.
max_head_dim = max(head_dim or 0, global_head_dim or 0)
if (
head_dim is not None
and global_head_dim is not None
and head_dim != global_head_dim
and max_head_dim > 256
and vllm_config.attention_config.backend is None
):
from vllm.v1.attention.backends.registry import (
AttentionBackendEnum,
)
vllm_config.attention_config.backend = AttentionBackendEnum.TRITON_ATTN
logger.info(
"Gemma4 model has heterogeneous head dimensions "
"(head_dim=%d, global_head_dim=%d). Forcing TRITON_ATTN "
"backend to prevent mixed-backend numerical divergence.",
head_dim,
global_head_dim,
)
class GptOssForCausalLMConfig(VerifyAndUpdateConfig): class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod @staticmethod
...@@ -647,10 +698,13 @@ class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig): ...@@ -647,10 +698,13 @@ class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig):
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"ColBERTJinaRobertaModel": JinaRobertaModelConfig, "ColBERTJinaRobertaModel": JinaRobertaModelConfig,
"ColQwen3_5": Qwen3_5ForConditionalGenerationConfig,
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM, "DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
"Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501 "Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501
"FalconMambaForCausalLM": MambaModelConfig, "FalconMambaForCausalLM": MambaModelConfig,
"Gemma3TextModel": Gemma3TextModelConfig, "Gemma3TextModel": Gemma3TextModelConfig,
"Gemma4ForCausalLM": Gemma4Config,
"Gemma4ForConditionalGeneration": Gemma4Config,
"GptOssForCausalLM": GptOssForCausalLMConfig, "GptOssForCausalLM": GptOssForCausalLMConfig,
"GteModel": SnowflakeGteNewModelConfig, "GteModel": SnowflakeGteNewModelConfig,
"GteNewForSequenceClassification": GteNewModelConfig, "GteNewForSequenceClassification": GteNewModelConfig,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The vLLM team.
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gemma 4 model implementation for vLLM."""
from collections.abc import Iterable
from dataclasses import replace
from itertools import islice
import regex as re
import torch
from torch import nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
extract_layer_index,
is_pp_missing_parameter,
make_layers,
maybe_prefix,
)
logger = init_logger(__name__)
def _get_text_config(config):
"""Dereference text_config if config is a nested Gemma4Config.
Gemma4 checkpoints use architectures=["Gemma4ForConditionalGeneration"]
which yields a Gemma4Config with nested text_config. This function
transparently returns the text config regardless of nesting.
"""
if hasattr(config, "text_config"):
return config.text_config
return config
class Gemma4MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_activation: str,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_activation != "gelu_pytorch_tanh":
raise ValueError(
"Gemma4 uses `gelu_pytorch_tanh` as the hidden activation "
"function. Please set `hidden_act` and `hidden_activation` to "
"`gelu_pytorch_tanh`."
)
self.act_fn = GeluAndMul(approximate="tanh")
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Gemma4Router(nn.Module):
"""Router for Gemma4 MoE that preprocesses input before projection.
Applies RMSNorm (no learned weight), root_size scaling
(hidden_size^{-0.5}), then a learned per-dimension scale before
projecting to expert logits.
This preprocessing is applied ONLY to the router's input, not to
the expert MLPs' input.
"""
def __init__(
self,
config,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# RMSNorm without learned weight — pure normalization only
self.norm = RMSNorm(self.hidden_size, eps=config.rms_norm_eps, has_weight=False)
# Per-dimension learned scale, applied after norm + root_size
self.scale = nn.Parameter(torch.ones(self.hidden_size))
# Constant 1/sqrt(hidden_size) scaling factor
self.register_buffer(
"root_size",
torch.tensor(self.hidden_size**-0.5),
persistent=False,
)
# Project to expert logits; replicated across TP for consistent routing
# GateLinear supports bf16 W/A → fp32 output, which is important
# because the topk kernel often needs fp32 for stable routing.
self.proj = GateLinear(
self.hidden_size,
config.num_experts,
bias=False,
out_dtype=torch.float32,
prefix=f"{prefix}.proj",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Returns raw router logits [T, E]."""
x = self.norm(x)
x = x * self.root_size.to(x.dtype)
x = x * self.scale.to(x.dtype)
router_logits, _ = self.proj(x)
return router_logits
class Gemma4MoE(nn.Module):
"""Mixture of Experts for Gemma4 using vLLM's FusedMoE.
Wraps FusedMoE with custom routing. The router projection is
external (Gemma4Router) — this class only handles expert dispatch.
Gemma4 routing: softmax over ALL experts → top-k → renormalize.
per_expert_scale is folded into routing weights for mathematical
correctness with FusedMoE's fused kernel.
"""
def __init__(
self,
config,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.num_experts = config.num_experts
# Per-expert output scale folded into routing weights so that
# FusedMoE's fused kernel computes: Σ_e (expert_e * w_e * scale_e)
self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts))
# Gemma4 routing: softmax over ALL experts → top-k → renormalize.
# FusedMoE's built-in fused_topk scopes softmax differently, so
# a custom routing function is needed for numerical correctness.
per_expert_scale = self.per_expert_scale
def routing_function(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
_, topk_ids = torch.topk(gating_output, k=topk, dim=-1)
router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1)
indicator = torch.nn.functional.one_hot(
topk_ids, num_classes=gating_output.size(-1)
).sum(dim=-2)
gate_weights = indicator * router_probabilities
renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True)
renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0)
dispatch_weights = gate_weights / renorm_factor
topk_weights = dispatch_weights.gather(1, topk_ids)
# Fold per_expert_scale into routing weights
expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype)
topk_weights = topk_weights * expert_scales
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
# FusedMoE experts with custom Gemma4 routing
self.experts = FusedMoE(
num_experts=config.num_experts,
top_k=config.top_k_experts,
hidden_size=config.hidden_size,
intermediate_size=getattr(
config,
"moe_intermediate_size",
getattr(config, "expert_intermediate_size", None),
),
reduce_results=True,
renormalize=True,
quant_config=quant_config,
prefix=f"{prefix}.experts",
custom_routing_function=routing_function,
activation="gelu",
)
def forward(self, x: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor:
return self.experts(x, router_logits)
class Gemma4Attention(nn.Module):
def __init__(
self,
config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_position_embeddings: int,
use_k_eq_v: bool = False,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
attn_logits_soft_cap: float | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.hidden_size = hidden_size
self.use_k_eq_v = use_k_eq_v
tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
# Gemma4 uses scaling=1.0.
# Unlike Gemma2/3, query_pre_attn_scalar is NOT used here;
# Q/K norms with learnable weights handle scaling implicitly.
self.scaling = 1.0
# QKVParallelLinear handles GQA correctly for all layer types.
# k_eq_v layers load K weights into both K and V slots via
# _weight_iterator remapping — no structural difference needed.
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
# Q/K norms: output = norm(x) * weight (learnable per-head scale)
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
# V norm: no learnable scale (pure normalization only)
self.v_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, has_weight=False)
# Determine layer type and sliding window
layer_idx = extract_layer_index(prefix)
layer_type = config.layer_types[layer_idx]
self.is_sliding = layer_type == "sliding_attention"
sliding_window = config.sliding_window if self.is_sliding else None
# Initialize RoPE based on layer type.
# Gemma4 uses different RoPE parameters for sliding vs full attention.
if layer_type in config.rope_parameters:
# Per-layer-type rope config (dict format).
# rope_parameters already contains the correct
# partial_rotary_factor per layer type (1.0 for full
# attention, 1.0 for sliding). Do NOT override with
# global_partial_rotary_factor — that config key is
# not needed for Gemma4 — config uses per-layer rope_parameters.
rope_parameters = dict(config.rope_parameters[layer_type])
else:
# Legacy config format fallback.
rope_parameters = dict(config.rope_parameters.copy())
if self.is_sliding:
rope_parameters["rope_theta"] = getattr(
config, "rope_local_base_freq", 10000.0
)
# KV sharing: layers in the last `num_kv_shared_layers` share KV
# cache with earlier layers of the same type.
kv_sharing_target_layer_name = None
self.is_kv_shared_layer = False
num_kv_shared_layers = getattr(config, "num_kv_shared_layers", 0)
if num_kv_shared_layers > 0:
first_kv_shared_layer_idx = config.num_hidden_layers - num_kv_shared_layers
if layer_idx >= first_kv_shared_layer_idx:
self.is_kv_shared_layer = True
# Find the last non-shared layer of the same attention type
prev_layers = config.layer_types[:first_kv_shared_layer_idx]
current_layer_type = config.layer_types[layer_idx]
kv_shared_layer_index = (
len(prev_layers) - 1 - prev_layers[::-1].index(current_layer_type)
)
if kv_shared_layer_index >= 0:
if ".layers." in prefix:
param_name_before_layers = prefix.split(".layers.")[0]
else:
raise ValueError(
"Unexpected prefix format for Gemma4Attention: "
f"'{prefix}'. Expected to contain '.layers.'."
)
kv_sharing_target_layer_name = (
f"{param_name_before_layers}.layers."
f"{kv_shared_layer_index}.self_attn.attn"
)
self.rotary_emb = get_rope(
self.head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
is_neox_style=True,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap,
per_layer_sliding_window=sliding_window,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
prefix=f"{prefix}.attn",
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
**kwargs,
) -> torch.Tensor:
# Unified QKV path (works for both k_eq_v and standard layers).
# For k_eq_v, K weights are loaded into both K and V slots of
# qkv_proj, so V == K automatically.
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Q norm (always applied)
q = q.unflatten(-1, (self.num_heads, self.head_dim))
q = self.q_norm(q)
q = q.flatten(-2, -1)
if not self.is_kv_shared_layer:
# Non-shared: apply K norm + RoPE, V norm
k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
k = self.k_norm(k)
k = k.flatten(-2, -1)
q, k = self.rotary_emb(positions, q, k)
v = v.unflatten(-1, (self.num_kv_heads, self.head_dim))
v = self.v_norm(v)
v = v.flatten(-2, -1)
else:
# Shared: only apply RoPE to Q
q = self.rotary_emb(positions, q, k)[0]
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Gemma4DecoderLayer(nn.Module):
def __init__(
self,
config,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.hidden_size_per_layer_input = getattr(
config, "hidden_size_per_layer_input", 0
)
layer_idx = extract_layer_index(prefix)
self.layer_idx = layer_idx
# Gemma4 uses different head dimensions for sliding vs full attention
layer_type = config.layer_types[layer_idx]
self.is_full_attention = layer_type == "full_attention"
if self.is_full_attention:
head_dim = getattr(config, "global_head_dim", config.head_dim)
else:
head_dim = config.head_dim
# Determine if this full-attention layer uses k_eq_v
# (laptop variant: no v_proj, K reused as V on full attention layers)
use_k_eq_v = self.is_full_attention and getattr(
config, "attention_k_eq_v", False
)
# For k_eq_v full-attention layers, use num_global_key_value_heads
# as the KV head count when k_eq_v is enabled.
if use_k_eq_v:
num_kv_heads = getattr(
config, "num_global_key_value_heads", config.num_key_value_heads
)
else:
num_kv_heads = config.num_key_value_heads
self.self_attn = Gemma4Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
max_position_embeddings=config.max_position_embeddings,
use_k_eq_v=use_k_eq_v,
cache_config=cache_config,
quant_config=quant_config,
attn_logits_soft_cap=getattr(config, "attn_logit_softcapping", None),
prefix=f"{prefix}.self_attn",
)
# Compute per-layer intermediate_size from config.
# When use_double_wide_mlp is set, intermediate_size doubles for
# KV-shared layers (layers >= first_kv_shared_layer_idx).
first_kv_shared_layer_idx = config.num_hidden_layers - getattr(
config, "num_kv_shared_layers", 0
)
is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
use_double_wide_mlp = (
getattr(config, "use_double_wide_mlp", False) and is_kv_shared_layer
)
layer_intermediate_size = config.intermediate_size * (
2 if use_double_wide_mlp else 1
)
self.mlp = Gemma4MLP(
hidden_size=self.hidden_size,
intermediate_size=layer_intermediate_size,
hidden_activation=config.hidden_activation,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
# Layer norms: output = norm(x) * weight
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.pre_feedforward_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_feedforward_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
# MoE (Mixture of Experts) — router + expert block parallel to MLP
self.enable_moe_block = getattr(config, "enable_moe_block", False) or getattr(
config, "use_second_mlp_block", False
)
if self.enable_moe_block:
self.router = Gemma4Router(
config,
quant_config=quant_config,
prefix=f"{prefix}.router",
)
self.moe = Gemma4MoE(
config,
quant_config=quant_config,
prefix=f"{prefix}.moe",
)
self.post_feedforward_layernorm_1 = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_feedforward_layernorm_2 = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.pre_feedforward_layernorm_2 = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
else:
self.router = None
self.moe = None
self.post_feedforward_layernorm_1 = None
self.post_feedforward_layernorm_2 = None
self.pre_feedforward_layernorm_2 = None
# Per-Layer Embedding (PLE) components — present in each decoder layer
if (
self.hidden_size_per_layer_input is not None
and self.hidden_size_per_layer_input > 0
):
# Gate: projects hidden_states → per-layer dim for gating
self.per_layer_input_gate = ReplicatedLinear(
self.hidden_size,
self.hidden_size_per_layer_input,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.per_layer_input_gate",
return_bias=False,
)
# Projection: projects gated per-layer input back → hidden size
self.per_layer_projection = ReplicatedLinear(
self.hidden_size_per_layer_input,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.per_layer_projection",
return_bias=False,
)
# Post-PLE norm: output = norm(x) * weight
self.post_per_layer_input_norm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
else:
self.per_layer_input_gate = None
self.per_layer_projection = None
self.post_per_layer_input_norm = None
# Layer scalar (loaded from checkpoint) — applies to ALL text layers
self.register_buffer("layer_scalar", torch.ones(1))
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
per_layer_input: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
# Gemma4 residual pattern:
# 1. input_norm(x) → attn → post_attn_norm → ADD residual
# 2. pre_ff_norm → mlp → post_ff_norm → ADD residual
residual = hidden_states
hidden_states = self.input_layernorm(residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = hidden_states + residual
residual = hidden_states
# MLP runs unconditionally (same inputs for MoE and non-MoE)
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if self.enable_moe_block:
hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states)
# Router and MoE experts see the residual (pre-MLP state),
# matching the HF transformers forward path
router_logits = self.router(residual)
hidden_states_2 = self.pre_feedforward_layernorm_2(residual)
hidden_states_2 = self.moe(hidden_states_2, router_logits)
hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2)
# Combine MLP and MoE outputs
hidden_states = hidden_states_1 + hidden_states_2
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = hidden_states + residual
# Apply PLE (Per-Layer Embedding) if configured
if per_layer_input is not None and self.per_layer_input_gate is not None:
gate = self.per_layer_input_gate(hidden_states)
gate = torch.nn.functional.gelu(gate, approximate="tanh")
gated_per_layer = gate * per_layer_input
per_layer_contribution = self.per_layer_projection(gated_per_layer)
per_layer_contribution = self.post_per_layer_input_norm(
per_layer_contribution
)
hidden_states = hidden_states + per_layer_contribution
# Apply layer scalar for full-attention layers
# Apply per-layer scalar (all text layers)
hidden_states = hidden_states * self.layer_scalar
return hidden_states, None
def _run_decoder_layers(
decoder_layers: list[Gemma4DecoderLayer],
layer_idx_start: int,
positions: torch.Tensor,
hidden_states: torch.Tensor,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
"""Run a slice of decoder layers with PLE extraction."""
residual = None
for idx, layer in enumerate(decoder_layers):
layer_idx = idx + layer_idx_start
layer_per_input = (
per_layer_inputs[:, layer_idx, :] if per_layer_inputs is not None else None
)
hidden_states, residual = layer(
positions,
hidden_states,
residual,
per_layer_input=layer_per_input,
**kwargs,
)
return hidden_states
@support_torch_compile(
enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
)
class Gemma4SelfDecoderLayers(nn.Module):
"""Compiled wrapper: embedding + non-KV-shared layers (YOCO first half).
Owns the embedding and PLE modules so they are inside the compiled
graph. Gemma4Model delegates embedding methods here.
"""
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layers: list[Gemma4DecoderLayer],
layer_idx_start: int,
embed_tokens: VocabParallelEmbedding,
normalizer: torch.Tensor,
embed_tokens_per_layer: VocabParallelEmbedding | None,
embed_scale_per_layer: torch.Tensor | None,
per_layer_model_projection: ColumnParallelLinear | None,
per_layer_projection_norm: RMSNorm | None,
per_layer_input_scale: torch.Tensor | None,
per_layer_projection_scale: torch.Tensor | None,
):
super().__init__()
self.decoder_layers = decoder_layers
self.layer_idx_start = layer_idx_start
config = _get_text_config(vllm_config.model_config.hf_config)
self.config = config
self.hidden_size_per_layer_input = getattr(
config, "hidden_size_per_layer_input", 0
)
self.vocab_size_per_layer_input = getattr(
config, "vocab_size_per_layer_input", config.vocab_size
)
# Shared references to modules owned by Gemma4Model — must be
# inside this nn.Module so torch.compile captures them.
self.embed_tokens = embed_tokens
self.normalizer = normalizer
self.embed_tokens_per_layer = embed_tokens_per_layer
self.embed_scale_per_layer = embed_scale_per_layer
self.per_layer_model_projection = per_layer_model_projection
self.per_layer_projection_norm = per_layer_projection_norm
self.per_layer_input_scale = per_layer_input_scale
self.per_layer_projection_scale = per_layer_projection_scale
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.normalizer
def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None:
"""Get per-layer embeddings from embed_tokens_per_layer.
Returns:
Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input)
"""
if self.embed_tokens_per_layer is None:
return None
per_layer_inputs_mask = torch.logical_and(
input_ids >= 0,
input_ids < self.vocab_size_per_layer_input,
)
per_layer_inputs_tokens = torch.where(
per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
)
per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens)
per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer
return per_layer_embeds.reshape(
*input_ids.shape,
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
def project_per_layer_inputs(
self,
inputs_embeds: torch.Tensor,
per_layer_inputs: torch.Tensor | None,
) -> torch.Tensor | None:
"""Project inputs_embeds and combine with per_layer_inputs.
Steps:
1. Project inputs_embeds: hidden_size → total_ple_dim
2. Scale by hidden_size^{-0.5}
3. Reshape to (num_tokens, num_layers, per_layer_dim)
4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
"""
if self.per_layer_model_projection is None:
return None
per_layer_projection = self.per_layer_model_projection(inputs_embeds)
per_layer_projection = per_layer_projection * self.per_layer_projection_scale
per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1],
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
if per_layer_inputs is None:
return per_layer_projection
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if inputs_embeds is not None:
hidden_states = inputs_embeds
per_layer_inputs = self.project_per_layer_inputs(
hidden_states, per_layer_inputs
)
else:
hidden_states = self.embed_input_ids(input_ids)
per_layer_embeds = self.get_per_layer_inputs(input_ids)
per_layer_inputs = self.project_per_layer_inputs(
hidden_states, per_layer_embeds
)
hidden_states = _run_decoder_layers(
self.decoder_layers,
self.layer_idx_start,
positions,
hidden_states,
per_layer_inputs,
**kwargs,
)
return hidden_states, per_layer_inputs
@support_torch_compile(
enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
)
class Gemma4CrossDecoderLayers(nn.Module):
"""Cross-decoder layers (YOCO second half, KV-shared)."""
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layers: list[Gemma4DecoderLayer],
layer_idx_start: int,
):
super().__init__()
self.decoder_layers = decoder_layers
self.layer_idx_start = layer_idx_start
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
return _run_decoder_layers(
self.decoder_layers,
self.layer_idx_start,
positions,
hidden_states,
per_layer_inputs,
**kwargs,
)
@support_torch_compile(
enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill
)
class Gemma4Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = _get_text_config(vllm_config.model_config.hf_config)
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
# PLE config values (default to 0 if not present — disables PLE)
self.hidden_size_per_layer_input = getattr(
config, "hidden_size_per_layer_input", 0
)
self.vocab_size_per_layer_input = getattr(
config, "vocab_size_per_layer_input", config.vocab_size
)
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
# Per-Layer Embedding (PLE) components
if (
self.hidden_size_per_layer_input is not None
and self.hidden_size_per_layer_input > 0
):
total_ple_dim = self.hidden_size_per_layer_input * config.num_hidden_layers
self.embed_tokens_per_layer = VocabParallelEmbedding(
self.vocab_size_per_layer_input,
total_ple_dim,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens_per_layer",
)
# Scaled embedding factor (from config, not hardcoded)
# Register as buffer so it moves to GPU with the model
# and interacts correctly with torch.compile AOT caching.
self.register_buffer(
"embed_scale_per_layer",
torch.tensor(self.hidden_size_per_layer_input**0.5),
persistent=False,
)
# Projection: hidden_size → total_ple_dim
# ColumnParallelLinear with gather_output=True
self.per_layer_model_projection = ColumnParallelLinear(
config.hidden_size,
total_ple_dim,
bias=False,
gather_output=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.per_layer_model_projection",
)
# PLE projection norm: output = norm(x) * weight
self.per_layer_projection_norm = RMSNorm(
self.hidden_size_per_layer_input,
eps=config.rms_norm_eps,
)
# Scale factor for combining projection + per_layer_inputs
# Register as buffer so it moves to GPU with the model
# and interacts correctly with torch.compile AOT caching.
self.register_buffer(
"per_layer_input_scale",
torch.rsqrt(torch.tensor(2.0)),
persistent=False,
)
# Scaled projection: multiply output by hidden_size**-0.5.
# Register as buffer for GPU placement and torch.compile.
self.register_buffer(
"per_layer_projection_scale",
torch.tensor(config.hidden_size**-0.5),
persistent=False,
)
else:
self.embed_tokens_per_layer = None
self.embed_scale_per_layer = None
self.per_layer_model_projection = None
self.per_layer_projection_norm = None
self.per_layer_input_scale = None
self.per_layer_projection_scale = None
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Gemma4DecoderLayer(
config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
# Final norm: output = norm(x) * weight
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Embedding scale = sqrt(hidden_size)
# Downcast to model dtype (bfloat16 etc.) for numerical parity
self.register_buffer(
"normalizer",
torch.tensor(config.hidden_size**0.5),
persistent=False,
)
# --- You Only Cache Once (YOCO) split for fast prefill ---
first_kv_shared_layer_idx = config.num_hidden_layers - getattr(
config, "num_kv_shared_layers", 0
)
from vllm.compilation.backends import set_model_tag
# Layers 0..(K-1) are self-decoder layers in YOCO
with set_model_tag("self_decoder"):
self.self_decoder = Gemma4SelfDecoderLayers(
vllm_config=vllm_config,
prefix=f"{prefix}.self_decoder",
decoder_layers=self.layers[:first_kv_shared_layer_idx],
layer_idx_start=0,
embed_tokens=self.embed_tokens,
normalizer=self.normalizer,
embed_tokens_per_layer=getattr(self, "embed_tokens_per_layer", None),
embed_scale_per_layer=getattr(self, "embed_scale_per_layer", None),
per_layer_model_projection=getattr(
self, "per_layer_model_projection", None
),
per_layer_projection_norm=getattr(
self, "per_layer_projection_norm", None
),
per_layer_input_scale=getattr(self, "per_layer_input_scale", None),
per_layer_projection_scale=getattr(
self, "per_layer_projection_scale", None
),
)
# Layers K..(N-1) are cross-decoder layers in YOCO
with set_model_tag("cross_decoder"):
self.cross_decoder = Gemma4CrossDecoderLayers(
vllm_config=vllm_config,
prefix=f"{prefix}.cross_decoder",
decoder_layers=self.layers[first_kv_shared_layer_idx:],
layer_idx_start=first_kv_shared_layer_idx,
)
self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill
if self.fast_prefill_enabled:
# Allocate static buffers for CUDAGraph
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
device = next(self.parameters()).device
self.positions = torch.zeros(
max_num_tokens, dtype=torch.int64, device=device
)
self.hidden_states = torch.zeros(
(max_num_tokens, config.hidden_size),
dtype=self.embed_tokens.weight.dtype,
device=device,
)
if (
self.hidden_size_per_layer_input
and self.hidden_size_per_layer_input > 0
):
self.per_layer_inputs = torch.zeros(
(
max_num_tokens,
config.num_hidden_layers,
self.hidden_size_per_layer_input,
),
dtype=self.embed_tokens.weight.dtype,
device=device,
)
else:
self.per_layer_inputs = None
# Custom factory that includes per_layer_inputs for PLE-enabled PP.
# per_layer_inputs has shape (batch, num_layers, per_layer_dim),
# which differs from the standard (batch, hidden_size) shape,
# so we can't use the default factory.
ple_dim = self.hidden_size_per_layer_input
num_layers = config.num_hidden_layers
hidden_size = config.hidden_size
def _make_empty_intermediate_tensors(
batch_size: int,
dtype: torch.dtype,
device: torch.device,
) -> IntermediateTensors:
tensors: dict[str, torch.Tensor] = {
"hidden_states": torch.zeros(
(batch_size, hidden_size),
dtype=dtype,
device=device,
),
"residual": torch.zeros(
(batch_size, hidden_size),
dtype=dtype,
device=device,
),
}
if ple_dim and ple_dim > 0:
tensors["per_layer_inputs"] = torch.zeros(
(batch_size, num_layers, ple_dim),
dtype=dtype,
device=device,
)
return IntermediateTensors(tensors)
self.make_empty_intermediate_tensors = _make_empty_intermediate_tensors
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.self_decoder.embed_input_ids(input_ids)
def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None:
"""Get per-layer embeddings from embed_tokens_per_layer.
Returns:
Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input)
"""
return self.self_decoder.get_per_layer_inputs(input_ids)
def project_per_layer_inputs(
self,
inputs_embeds: torch.Tensor,
per_layer_inputs: torch.Tensor | None,
) -> torch.Tensor | None:
"""Project inputs_embeds and combine with per_layer_inputs.
Steps:
1. Project inputs_embeds: hidden_size → total_ple_dim
2. Scale by hidden_size^{-0.5}
3. Reshape to (num_tokens, num_layers, per_layer_dim)
4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
"""
return self.self_decoder.project_per_layer_inputs(
inputs_embeds, per_layer_inputs
)
def fast_prefill_forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
logits_indices_padded, num_logits_indices = None, None
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
layer_attn_metadata = attn_metadata[
self.layers[-1].self_attn.attn.layer_name
]
if isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata):
logits_indices_padded = layer_attn_metadata.logits_indices_padded
num_logits_indices = layer_attn_metadata.num_logits_indices
batch_size = positions.size(0)
self.positions[:batch_size].copy_(positions)
self_decoder_hidden_states, per_layer_inputs = self.self_decoder(
input_ids=input_ids,
positions=self.positions[:batch_size],
inputs_embeds=inputs_embeds,
per_layer_inputs=per_layer_inputs,
**kwargs,
)
if logits_indices_padded is None:
logits_indices_padded = torch.arange(
batch_size,
dtype=positions.dtype,
device=positions.device,
)
# NOTE: Keep .clone() until fix in
# https://github.com/vllm-project/vllm/pull/22282
hidden_states = self_decoder_hidden_states.clone()
num_padded = logits_indices_padded.size(0)
self.positions[:num_padded].copy_(positions[logits_indices_padded])
self.hidden_states[:num_padded].copy_(
self_decoder_hidden_states[logits_indices_padded]
)
if self.per_layer_inputs is not None and per_layer_inputs is not None:
self.per_layer_inputs[:num_padded].copy_(
per_layer_inputs[logits_indices_padded]
)
# Update batch_descriptor so the cross-decoder's piecewise
# CUDAGraphWrapper dispatches to the correct (reduced) batch size.
forward_context = get_forward_context()
orig_batch_desc = forward_context.batch_descriptor
if orig_batch_desc is not None:
forward_context.batch_descriptor = replace(
orig_batch_desc, num_tokens=num_padded
)
cross_per_layer = (
self.per_layer_inputs[:num_padded]
if self.per_layer_inputs is not None
else None
)
cross_hidden_states = self.cross_decoder(
self.positions[:num_padded],
self.hidden_states[:num_padded],
cross_per_layer,
**kwargs,
)
# Restore the original batch_descriptor
forward_context.batch_descriptor = orig_batch_desc
if num_logits_indices is not None:
assert num_logits_indices > 0
hidden_states[logits_indices_padded[:num_logits_indices]] = (
cross_hidden_states[:num_logits_indices]
)
else:
hidden_states = cross_hidden_states
return hidden_states
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | IntermediateTensors:
if self.fast_prefill_enabled:
hidden_states = self.fast_prefill_forward(
input_ids,
positions,
inputs_embeds,
per_layer_inputs,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return hidden_states
# Normal (non-fast-prefill) path with PP support
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
# When called from the multimodal wrapper, raw PLE
# embeddings are pre-computed and passed explicitly.
# Project them through per_layer_model_projection.
per_layer_inputs = self.project_per_layer_inputs(
hidden_states, per_layer_inputs
)
else:
hidden_states = self.embed_input_ids(input_ids)
# Compute per-layer inputs for PLE
per_layer_embeds = self.get_per_layer_inputs(input_ids)
per_layer_inputs = self.project_per_layer_inputs(
hidden_states, per_layer_embeds
)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
per_layer_inputs = intermediate_tensors.get("per_layer_inputs")
for layer_idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
# Extract the per-layer embedding for this specific layer
if per_layer_inputs is not None:
actual_layer_idx = self.start_layer + layer_idx
layer_per_input = per_layer_inputs[
:, actual_layer_idx, :
] # (num_tokens, per_layer_dim)
else:
layer_per_input = None
hidden_states, residual = layer(
positions,
hidden_states,
residual,
per_layer_input=layer_per_input,
**kwargs,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{
"hidden_states": hidden_states,
"residual": residual,
"per_layer_inputs": per_layer_inputs,
}
)
# Gemma4 incorporates residual into hidden_states directly
# Apply norm without residual fusion when possible.
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# MoE expert weight mapping: checkpoint can have either:
# 1. 3D packed tensors (exploded in _weight_iterator to per-expert 2D)
# 2. Already per-expert 2D weights (if quantized)
# Map to FusedMoE parameters:
# moe.experts.{id}.gate_proj → FusedMoE w1 (shard of w13)
# moe.experts.{id}.up_proj → FusedMoE w3 (shard of w13)
# moe.experts.{id}.down_proj → FusedMoE w2
#
# Use prefix matching to handle both weights and
# quantization scale parameters. The param_name is a prefix ending
# in underscore, and weight_name ends with a dot, so that:
# "experts.0.gate_proj.weight_scale" -> "experts.w13_weight_scale"
# "experts.0.gate_proj.weight" -> "experts.w13_weight"
num_experts = getattr(self.config, "num_experts", None) or 0
expert_params_mapping = [
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_"
if proj_name in ["gate_proj", "up_proj"]
else "experts.w2_",
f"experts.{expert_id}.{proj_name}.",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id, proj_name in [
("w1", "gate_proj"),
("w2", "down_proj"),
("w3", "up_proj"),
]
]
params_dict = dict(self.named_parameters())
# Include buffers (e.g. layer_scalar) so they can be loaded too
params_dict.update(dict(self.named_buffers()))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is not None and remapped_name in params_dict:
param = params_dict[remapped_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(remapped_name)
continue
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
stacked_name = name.replace(shard_name, param_name)
# k_eq_v layers use separate q_proj/k_proj instead of
# packed qkv_proj. If the stacked param doesn't exist,
# skip this mapping and fall through to direct load.
if stacked_name not in params_dict:
continue
if is_pp_missing_parameter(stacked_name, self):
continue
param = params_dict[stacked_name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(stacked_name)
break
else:
for (
param_name,
weight_name,
expert_id,
shard_id,
) in expert_params_mapping:
# Match both:
# - Bare weights: "experts.0.down_proj" (from 3D explosion)
# - With suffix: "experts.0.down_proj.weight_scale" (2D quantized)
# weight_name has trailing dot, so check with and without it
weight_name_base = weight_name.rstrip(".")
if weight_name in name:
# Has suffix (e.g., .weight_scale)
moe_name = name.replace(weight_name, param_name)
elif name.endswith(weight_name_base):
# Bare weight (no suffix)
moe_name = name.replace(
weight_name_base, param_name.rstrip("_") + "_weight"
)
else:
continue
if moe_name not in params_dict:
continue
if is_pp_missing_parameter(moe_name, self):
continue
param = params_dict[moe_name]
# Expert weights are already in the correct
# orientation for FusedMoE after _weight_iterator:
# gate/up: [I, H] → w1/w3 expects [I, H]
# down: [H, I] → w2 expects [H, I]
# Scales and other quantization params may be 1D or scalar.
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
moe_name, # Pass mapped name (handles both weights and scales)
shard_id=shard_id,
expert_id=expert_id,
)
loaded_params.add(moe_name)
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
# Note: qkv_proj packing applies to non-k_eq_v layers (sliding
# attention and full attention without k_eq_v). k_eq_v layers use
# separate q_proj + k_proj without packing.
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = _get_text_config(vllm_config.model_config.hf_config)
quant_config = vllm_config.quant_config
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = Gemma4Model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
self.logits_processor = LogitsProcessor(
config.vocab_size,
soft_cap=getattr(config, "final_logit_softcapping", None),
)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
# --- MixtureOfExperts protocol ---
self.expert_weights: list[list[torch.Tensor]] = []
self.moe_layers: list[nn.Module] = []
example_moe: Gemma4MoE | None = None
for layer in self.model.layers:
if hasattr(layer, "moe") and isinstance(layer.moe, Gemma4MoE):
example_moe = layer.moe
self.moe_layers.append(layer.moe.experts)
self.num_moe_layers = len(self.moe_layers)
if example_moe is not None:
self.num_logical_experts = example_moe.num_experts
self.num_physical_experts = example_moe.num_experts
self.num_local_physical_experts = example_moe.num_experts
self.num_routed_experts = example_moe.num_experts
else:
self.num_logical_experts = 0
self.num_physical_experts = 0
self.num_local_physical_experts = 0
self.num_routed_experts = 0
self.num_expert_groups = 1
self.num_shared_experts = 0
self.num_redundant_experts = 0
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.logits_processor(self.lm_head, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# Checkpoint weight names use "language_model." prefix (from the
# Gemma4ForConditionalGeneration wrapper). Strip it to map to our
# model tree which is just "model.*".
def _weight_iterator():
use_k_eq_v = getattr(self.config, "attention_k_eq_v", False)
# Build set of k_eq_v layer indices (full_attention layers
# when attention_k_eq_v is enabled). These layers have k_proj
# but no v_proj in checkpoint — we duplicate k_proj as v_proj.
k_eq_v_layer_indices: set[int] = set()
if use_k_eq_v:
for idx, lt in enumerate(self.config.layer_types):
if lt == "full_attention":
k_eq_v_layer_indices.add(idx)
for name, weight in weights:
# Remap "language_model." → "" to match our model tree.
# Checkpoint: model.language_model.layers.X.*
# Our model: model.layers.X.*
name = name.replace("language_model.", "")
# Remap new HF checkpoint naming to internal vLLM
# naming: HF moved per_expert_scale to router and
# renamed moe → experts in the MoE block.
name = name.replace(
".router.per_expert_scale",
".moe.per_expert_scale",
)
if ".experts.gate_up_proj" in name:
name = name.replace(
".experts.gate_up_proj",
".moe.gate_up_proj",
)
elif ".experts.down_proj" in name:
name = name.replace(
".experts.down_proj",
".moe.down_proj",
)
# Remap individual 2D expert weights:
# .experts.{id}.{proj} → .moe.experts.{id}.{proj}
# (This handles per-expert 2D quantized weights)
name = re.sub(r"\.experts\.(\d+)\.", r".moe.experts.\1.", name)
# MoE expert weights: checkpoint stores as 3D packed
# tensors. Explode into per-expert 2D weights for
# FusedMoE weight_loader.
#
# Checkpoint format:
# moe.gate_up_proj: [E, 2*I, H] (fused gate + up)
# moe.down_proj: [E, H, I]
#
# FusedMoE expects per-expert:
# w1 (gate): [I, H] — first half of gate_up
# w3 (up): [I, H] — second half of gate_up
# w2 (down): [H, I] — as-is from checkpoint
#
# No transpose needed: checkpoint orientation already
# matches FusedMoE's expected layout.
if "moe.gate_up_proj" in name and weight.dim() == 3:
num_experts = weight.size(0)
intermediate_size = weight.size(1) // 2
for expert_id in range(num_experts):
gate_weight = weight[expert_id, :intermediate_size, :]
up_weight = weight[expert_id, intermediate_size:, :]
base = name.replace("moe.", f"moe.experts.{expert_id}.")
yield base.replace("gate_up_proj", "gate_proj"), gate_weight
yield base.replace("gate_up_proj", "up_proj"), up_weight
continue
if "moe.down_proj" in name and weight.dim() == 3:
num_experts = weight.size(0)
for expert_id in range(num_experts):
expert_name = name.replace("moe.", f"moe.experts.{expert_id}.")
yield expert_name, weight[expert_id]
continue
# k_eq_v layers: checkpoint has k_proj but no v_proj.
# QKVParallelLinear expects both, so duplicate k_proj
# as v_proj so V gets identical weights to K.
# ONLY for full_attention layers — sliding layers have
# their own real v_proj weights.
if "self_attn.k_proj" in name and k_eq_v_layer_indices:
m = re.search(r"layers\.(\d+)\.", name)
if m and int(m.group(1)) in k_eq_v_layer_indices:
yield name, weight
yield name.replace("k_proj", "v_proj"), weight.clone()
continue
yield name, weight
# Skip multimodal weights — handled by the multimodal wrapper.
# Also skip lm_head when weights are tied.
skip = [
"audio_tower.",
"vision_tower.",
"embed_audio.",
"embed_vision.",
]
if self.config.tie_word_embeddings:
skip.append("lm_head.")
loader = AutoWeightsLoader(self, skip_substrs=skip)
return loader.load_weights(_weight_iterator())
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Gemma 4 multimodal model (image + audio + video support).
Adds vision tower, audio tower, and multimodal embedders on top of the
text-only Gemma4ForCausalLM. The vision/audio encoders are loaded via
AutoModel.from_config and run in eager mode while the language model uses
the vLLM-optimized path.
Video support: Gemma4 does **not** have a native video tower. Videos are
decomposed into timestamped image frames (up to 32 frames at 70 soft tokens
each) and fed through the same vision tower as regular images. The
processor inserts ``mm:ss`` timestamps between frames so the model can
reason about temporal order.
"""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal
import numpy as np
import torch
from PIL import Image as PILImage
from torch import nn
from transformers import AutoModel, BatchFeature
from transformers.models.gemma4 import (
Gemma4Config,
Gemma4Processor,
Gemma4VisionConfig,
)
from transformers.models.gemma4.configuration_gemma4 import (
Gemma4AudioConfig,
Gemma4TextConfig,
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
# from vllm.inputs import MultiModalDataDict
from vllm.multimodal.inputs import MultiModalDataDict
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.gemma4 import Gemma4ForCausalLM
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalFieldConfig,
MultiModalKwargsItems,
VideoItem,
)
from vllm.multimodal.parse import (
AudioProcessorItems,
ImageProcessorItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (
AutoWeightsLoader,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
)
logger = init_logger(__name__)
# Video constants — match transformers Gemma4VideoProcessor defaults.
_VIDEO_MAX_SOFT_TOKENS = 70 # soft tokens per video frame (vs 280 for images)
_VIDEO_MAX_FRAMES = 32 # max sampled frames per video
# ---------------------------------------------------------------------------
# Input schema
# ---------------------------------------------------------------------------
class Gemma4ImagePixelInputs(TensorSchema):
"""
Pre-patchified image inputs from the Gemma4 image processor.
Dimensions:
- bn: Batch size * number of images
- np: Number of patches (max_patches = max_soft_tokens * pooling_kernel_size²)
- pp: Patch pixels (patch_size² * 3)
The HF Gemma4ImageProcessor outputs pixel_values as
(batch, max_patches, patch_pixels) — already patchified with
zero-padding for patches beyond the real image content.
pixel_position_ids provides (x, y) coordinates per patch,
with (-1, -1) for padding patches.
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[
torch.Tensor,
TensorShape("bn", "np", "pp"),
]
pixel_position_ids: Annotated[
torch.Tensor,
TensorShape("bn", "np", 2),
]
class Gemma4AudioInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of audios
- s: Sequence length (MEL spectrogram frames)
- f: Number of features (MEL bins)
"""
type: Literal["audio"] = "audio"
input_features_padded: Annotated[torch.Tensor, TensorShape("bn", "s", "f")]
input_features_mask: Annotated[torch.Tensor, TensorShape("bn", "s")]
Gemma4ImageInputs = Gemma4ImagePixelInputs
class Gemma4VideoInputs(TensorSchema):
"""Video frame inputs — same tensor format as image inputs.
Gemma4 has no separate video tower; video frames are processed
through the vision tower at lower resolution (max_soft_tokens=70).
"""
type: Literal["pixel_values_videos"] = "pixel_values_videos"
pixel_values_videos: Annotated[
torch.Tensor,
TensorShape("bn", "np", "pp"),
]
pixel_position_ids_videos: Annotated[
torch.Tensor,
TensorShape("bn", "np", 2),
]
# ---------------------------------------------------------------------------
# Processing info
# ---------------------------------------------------------------------------
class Gemma4ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Gemma4Config)
def get_default_tok_params(self):
"""Gemma4's chat template already embeds a literal ``<bos>`` token in
the rendered text. If ``add_special_tokens=True`` (the base-class
default), the tokenizer prepends *another* BOS, producing a
``[2, 2, ...]`` double-BOS sequence that the model was not trained on.
Setting ``add_special_tokens=False`` here prevents the duplicate and
ensures both ``llm.generate()`` and the chat/completions API behave
correctly.
"""
params = super().get_default_tok_params()
params = params.with_kwargs(add_special_tokens=False)
return params
def get_hf_processor(self, **kwargs: object) -> Gemma4Processor:
return self.ctx.get_hf_processor(
Gemma4Processor,
**kwargs,
)
def validate_num_items(self, modality: str, num_items: int) -> None:
if (
modality == "audio"
and num_items > 0
and self.get_hf_config().audio_config is None
):
model = self.ctx.model_config.model
raise ValueError(
f"Audio input was provided but the model "
f"'{model}' does not have an audio tower. "
f"Audio inference is only supported for Gemma4 "
f"models that include an audio_config "
f"(i.e., models that include an audio_config)."
)
super().validate_num_items(modality, num_items)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
limits: dict[str, int | None] = {"image": None}
if self.get_hf_config().audio_config is not None:
limits["audio"] = None
limits["video"] = None
return limits
def get_mm_max_tokens_per_item(
self, seq_len: int, mm_counts: Mapping[str, int]
) -> Mapping[str, int] | None:
config = self.get_hf_config()
# Upper bound: the pooler outputs default_output_length slots
# per image (280). After padding is stripped the actual count
# is ≤ this value, but vLLM needs the max for memory planning.
tokens_per_image = config.vision_config.default_output_length
tokens: dict[str, int] = {"image": tokens_per_image}
if config.audio_config is not None:
# Audio max tokens from the processor's audio_seq_length.
processor = self.get_hf_processor()
tokens["audio"] = processor.audio_seq_length
# Video: each frame ≤ 70 soft tokens + boi + eoi + ~6 ts tokens.
tokens["video"] = _VIDEO_MAX_FRAMES * (_VIDEO_MAX_SOFT_TOKENS + 2 + 6)
return tokens
def get_data_parser(self) -> MultiModalDataParser:
config = self.get_hf_config()
kwargs: dict[str, Any] = {"video_needs_metadata": True}
if getattr(config, "audio_config", None) is not None:
processor = self.get_hf_processor()
kwargs["target_sr"] = processor.feature_extractor.sampling_rate
return MultiModalDataParser(**kwargs)
def _compute_num_soft_tokens(
self,
image_width: int,
image_height: int,
max_soft_tokens: int | None = None,
) -> int:
"""Compute the number of soft tokens the vision tower produces
for an image of the given dimensions, after padding is stripped.
Args:
max_soft_tokens: Override for the vision config's
``default_output_length``. When *None*, the value from
the model config is used.
"""
vision_cfg = self.get_hf_config().vision_config
patch_size = vision_cfg.patch_size
pooling_kernel_size = vision_cfg.pooling_kernel_size
if max_soft_tokens is None:
max_soft_tokens = vision_cfg.default_output_length
unit = patch_size * pooling_kernel_size
max_patches = max_soft_tokens * pooling_kernel_size**2
num_patches_orig = (image_height / patch_size) * (image_width / patch_size)
scale = math.sqrt(max_patches / num_patches_orig)
target_h = max(unit, int(math.floor(image_height * scale / unit)) * unit)
target_w = max(unit, int(math.floor(image_width * scale / unit)) * unit)
num_patches = (target_h // patch_size) * (target_w // patch_size)
return num_patches // (pooling_kernel_size**2)
def get_image_repl(
self,
*,
image_width: int,
image_height: int,
processor: Gemma4Processor | None,
max_soft_tokens: int | None = None,
) -> PromptUpdateDetails[list[int]]:
"""Return the dynamic image token sequence for this image.
Computes the exact number of soft tokens the vision tower will
produce after stripping padding.
Args:
max_soft_tokens: Override for the default token budget.
When *None*, falls back to the model config value.
"""
if processor is None:
processor = self.get_hf_processor()
num_soft = self._compute_num_soft_tokens(
image_width,
image_height,
max_soft_tokens=max_soft_tokens,
)
config = self.get_hf_config()
token_ids = (
[config.boi_token_id]
+ [processor.image_token_id] * num_soft
+ [config.eoi_token_id]
)
return PromptUpdateDetails.select_token_id(token_ids, processor.image_token_id)
def get_audio_repl(
self,
*,
audio_len: int,
processor: Gemma4Processor | None,
) -> PromptUpdateDetails[list[int]]:
"""Return the dynamic audio token sequence for this audio.
Computes the number of soft tokens from the audio waveform
length using ``ceil(duration_ms / audio_ms_per_token)``.
"""
if processor is None:
processor = self.get_hf_processor()
sampling_rate = processor.feature_extractor.sampling_rate
num_tokens = processor._compute_audio_num_tokens(
torch.zeros(audio_len), sampling_rate
)
config = self.get_hf_config()
token_ids = (
[config.boa_token_id]
+ [processor.audio_token_id] * num_tokens
+ [config.eoa_token_id]
)
return PromptUpdateDetails.select_token_id(token_ids, processor.audio_token_id)
def get_video_repl(
self,
*,
timestamps: list[float],
num_soft_tokens_per_frame: list[int],
processor: Gemma4Processor,
) -> PromptUpdateDetails[list[int]]:
"""Build the full token replacement for one video.
Produces the same interleaved sequence as the HF Gemma4Processor:
mm:ss <boi><|video|>*N<eoi> mm:ss <boi><|video|>*N<eoi> ...
"""
tokenizer = self.ctx.get_tokenizer()
config = self.get_hf_config()
boi_token_id = config.boi_token_id
eoi_token_id = config.eoi_token_id
video_token_id = processor.video_token_id
all_token_ids: list[int] = []
for i, (ts, n_tokens) in enumerate(zip(timestamps, num_soft_tokens_per_frame)):
# mm:ss timestamp — matches transformers: int-truncated,
# zero-padded.
minutes = int(ts // 60)
seconds = int(ts % 60)
ts_str = f"{minutes:02d}:{seconds:02d}"
prefix = f" {ts_str} " if i > 0 else f"{ts_str} "
ts_token_ids = tokenizer.encode(prefix, add_special_tokens=False)
all_token_ids.extend(ts_token_ids)
all_token_ids.append(boi_token_id)
all_token_ids.extend([video_token_id] * n_tokens)
all_token_ids.append(eoi_token_id)
return PromptUpdateDetails.select_token_id(all_token_ids, video_token_id)
# ---------------------------------------------------------------------------
# Dummy inputs builder
# ---------------------------------------------------------------------------
class Gemma4DummyInputsBuilder(BaseDummyInputsBuilder[Gemma4ProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
num_audios = mm_counts.get("audio", 0)
num_videos = mm_counts.get("video", 0)
processor = self.info.get_hf_processor()
# Use image_token (<|image|>) with tab prefix — this is what the
# Gemma4 chat template inserts per image (\t<|image|>).
# _get_prompt_updates targets image_token and expands it to the
# full_image_sequence.
text = ("\t" + processor.image_token) * num_images
if num_audios > 0 and processor.audio_token:
text += processor.audio_token * num_audios
if num_videos > 0:
text += processor.video_token * num_videos
return text
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_audios = mm_counts.get("audio", 0)
num_videos = mm_counts.get("video", 0)
processor = self.info.get_hf_processor()
image_processor = processor.image_processor
# Use processor's configured image size for dummies.
# Gemma4ImageProcessor sets size=None (it uses patch_size /
# max_soft_tokens instead of the standard size dict), so we
# guard against None with `or {}`.
size = getattr(image_processor, "size", None) or {}
img_width = size.get("width", 224)
img_height = size.get("height", 224)
image_overrides = mm_options.get("image") if mm_options else None
audio_overrides = mm_options.get("audio") if mm_options else None
video_overrides = mm_options.get("video") if mm_options else None
data: MultiModalDataDict = {
"image": self._get_dummy_images(
width=img_width,
height=img_height,
num_images=num_images,
overrides=image_overrides,
),
}
if num_audios > 0:
audio_len = processor.feature_extractor.fft_length
data["audio"] = self._get_dummy_audios(
length=audio_len,
num_audios=num_audios,
overrides=audio_overrides,
)
if num_videos > 0:
data["video"] = self._get_dummy_videos(
width=img_width,
height=img_height,
num_frames=_VIDEO_MAX_FRAMES,
num_videos=num_videos,
overrides=video_overrides,
)
return data
def _get_dummy_videos(
self,
*,
width: int,
height: int,
num_frames: int,
num_videos: int,
overrides: VideoDummyOptions | None = None,
) -> list[VideoItem]:
num_frames = max(num_frames, 2)
videos = super()._get_dummy_videos(
width=width,
height=height,
num_frames=num_frames,
num_videos=num_videos,
overrides=overrides,
)
videos = [v.copy() for v in videos]
video_items: list[VideoItem] = []
for video in videos:
video_num_frames = video.shape[0]
video_metadata = {
"fps": 2.0,
"duration": video_num_frames / 2.0,
"total_num_frames": video_num_frames,
"frames_indices": list(range(video_num_frames)),
"video_backend": "opencv",
"do_sample_frames": False,
}
video_items.append((video, video_metadata))
return video_items
# ---------------------------------------------------------------------------
# Multimodal processor
# ---------------------------------------------------------------------------
class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
# Validate max_soft_tokens early and exit cleanly on bad values.
_SUPPORTED_SOFT_TOKENS = (70, 140, 280, 560, 1120)
merged_kwargs = self.info.ctx.get_merged_mm_kwargs(mm_kwargs)
val = merged_kwargs.get("max_soft_tokens")
if val is None:
val = merged_kwargs.get("images_kwargs", {}).get("max_soft_tokens")
if val is not None and val not in _SUPPORTED_SOFT_TOKENS:
raise ValueError(
f"Unsupported max_soft_tokens value: {val}. "
f"Valid values are {_SUPPORTED_SOFT_TOKENS}."
)
mm_data = dict(mm_data)
# ---- VIDEO HANDLING ----
# Gemma4 decomposes video into timestamped image frames.
# Each frame is processed with max_soft_tokens=70 through the
# same vision tower, matching transformers processing_gemma4.py.
video_outputs: dict[str, Any] = {}
if videos := mm_data.pop("videos", []):
processor = self.info.get_hf_processor()
all_video_pixel_values: list[torch.Tensor] = []
all_video_position_ids: list[torch.Tensor] = []
video_num_soft_tokens_per_video: list[list[int]] = []
video_timestamps_per_video: list[list[float]] = []
video_frame_counts: list[int] = []
for item in videos:
video_array, metadata = item
# Convert frames to PIL images
if isinstance(video_array, np.ndarray):
frames = [
PILImage.fromarray(video_array[i])
for i in range(video_array.shape[0])
]
else:
frames = list(video_array)
# Compute timestamps from metadata (same as transformers)
fps = metadata.get("fps") or 24
frame_indices = metadata.get("frames_indices", list(range(len(frames))))
timestamps = [idx / fps for idx in frame_indices]
# Process frames as images with max_soft_tokens=70
video_mm_kwargs = dict(mm_kwargs)
video_mm_kwargs["max_soft_tokens"] = _VIDEO_MAX_SOFT_TOKENS
dummy_prompt = ("\t" + processor.image_token) * len(frames)
frame_outputs = super()._call_hf_processor(
prompt=dummy_prompt,
mm_data={"images": frames},
mm_kwargs=video_mm_kwargs,
tok_kwargs=tok_kwargs,
)
# Remap HF key name
if "image_position_ids" in frame_outputs:
frame_outputs["pixel_position_ids"] = frame_outputs.pop(
"image_position_ids"
)
all_video_pixel_values.append(frame_outputs["pixel_values"])
all_video_position_ids.append(frame_outputs["pixel_position_ids"])
# Compute soft tokens per frame
num_soft_per_frame = []
for img in frames:
w, h = img.size
n = self.info._compute_num_soft_tokens(
w, h, max_soft_tokens=_VIDEO_MAX_SOFT_TOKENS
)
num_soft_per_frame.append(n)
video_num_soft_tokens_per_video.append(num_soft_per_frame)
video_timestamps_per_video.append(timestamps)
video_frame_counts.append(len(frames))
# Build expanded replacement text and replace the
# <|video|> placeholder in the prompt.
# Use split(token, 1) to avoid collision — the
# replacement text itself contains <|video|> tokens.
ts_strs = [f"{int(s // 60):02d}:{int(s % 60):02d}" for s in timestamps]
replacement = " ".join(
f"{t} {processor.boi_token}"
f"{processor.video_token * n}"
f"{processor.eoi_token}"
for t, n in zip(ts_strs, num_soft_per_frame)
)
parts = prompt.split(processor.video_token, 1)
if len(parts) == 2:
prompt = parts[0] + replacement + parts[1]
video_outputs = {
"pixel_values_videos": torch.cat(all_video_pixel_values, dim=0),
"pixel_position_ids_videos": torch.cat(all_video_position_ids, dim=0),
"video_frame_counts": torch.tensor(video_frame_counts),
"video_num_soft_tokens": video_num_soft_tokens_per_video,
"video_timestamps": video_timestamps_per_video,
}
# The processor accepts 'audio' not 'audios'.
if "audios" in mm_data:
mm_data["audio"] = mm_data.pop("audios")
# Warn if any audio waveform exceeds the model's max duration.
if "audio" in mm_data:
processor = self.info.get_hf_processor()
sr = processor.feature_extractor.sampling_rate
max_tokens = processor.audio_seq_length
ms_per_tok = processor.audio_ms_per_token
max_duration_s = max_tokens * ms_per_tok / 1000.0
audios = mm_data["audio"]
if not isinstance(audios, (list, tuple)):
audios = [audios]
for i, waveform in enumerate(audios):
duration_s = len(waveform) / sr
if duration_s > max_duration_s:
logger.warning(
"Audio duration exceeds max: %f > %f seconds",
duration_s,
max_duration_s,
)
# vLLM's call_hf_processor (context.py) re-merges
# mm_processor_kwargs from the model config on every call via:
# config_kwargs | incoming_kwargs (right side wins)
#
# If we strip max_soft_tokens from incoming, the re-merge puts
# back the config's global default (e.g. 280), ignoring any
# per-prompt override. Instead, we keep it in the kwargs with
# the validated per-prompt value so it wins during the merge.
#
# NOTE: This requires a corresponding type annotation on the
# HF side (Gemma4ProcessorKwargs.images_kwargs) so that
# _merge_kwargs routes max_soft_tokens into images_kwargs.
patched_mm_kwargs = dict(mm_kwargs)
if val is not None:
patched_mm_kwargs["max_soft_tokens"] = val
processed_outputs = super()._call_hf_processor(
prompt,
mm_data,
patched_mm_kwargs,
tok_kwargs,
)
# HF uses 'image_position_ids'; vLLM uses 'pixel_position_ids'.
# Remap here to keep a single translation point.
if "image_position_ids" in processed_outputs:
processed_outputs["pixel_position_ids"] = processed_outputs.pop(
"image_position_ids"
)
if "input_features" in processed_outputs:
# Keep padded features for batched audio tower execution.
processed_outputs["input_features_padded"] = processed_outputs[
"input_features"
]
# Unpad per-item so each item's cache entry is self-contained.
unpadded_features = [
f[mask]
for f, mask in zip(
processed_outputs["input_features"],
processed_outputs["input_features_mask"],
)
]
processed_outputs["input_features"] = unpadded_features
# Merge video outputs into the final result
combined_outputs = dict(processed_outputs, **video_outputs)
return BatchFeature(combined_outputs)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
fields = dict(
pixel_values=MultiModalFieldConfig.batched("image"),
pixel_position_ids=MultiModalFieldConfig.batched("image"),
input_features_padded=MultiModalFieldConfig.batched("audio"),
input_features_mask=MultiModalFieldConfig.batched("audio"),
)
# Video fields: frames stored flat, split per video by
# video_frame_counts.
video_frame_counts = hf_inputs.get("video_frame_counts")
if video_frame_counts is not None:
vfc = video_frame_counts
if not isinstance(vfc, torch.Tensor):
vfc = torch.tensor(vfc)
fields.update(
pixel_values_videos=(
MultiModalFieldConfig.flat_from_sizes("video", vfc)
),
pixel_position_ids_videos=(
MultiModalFieldConfig.flat_from_sizes("video", vfc)
),
video_frame_counts=MultiModalFieldConfig.batched(
"video",
),
video_num_soft_tokens=MultiModalFieldConfig.batched(
"video", keep_on_cpu=True
),
video_timestamps=MultiModalFieldConfig.batched(
"video", keep_on_cpu=True
),
)
return fields
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
prompt_updates = []
if "image" in mm_items:
# Target image_token (<|image|>) — the single placeholder the
# Gemma4 chat template inserts once per image in the prompt.
# vLLM tokenizes the prompt without token expansion, so only
# one image_token exists per image in the token stream.
# The replacement expands it to the full image sequence
# (boi + N×image_token + eoi, where N = max_soft_tokens).
image_token = hf_processor.image_token
def get_replacement_image(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
# Resolve the effective max_soft_tokens by merging
# per-prompt kwargs with the config-level defaults,
# consistent with how _call_hf_processor resolves it.
# Without this merge, a missing per-prompt override
# would fall back to vision_cfg.default_output_length
# instead of the config's mm_processor_kwargs default.
merged_kwargs = self.info.ctx.get_merged_mm_kwargs(
hf_processor_mm_kwargs,
)
max_soft_tokens = merged_kwargs.get("max_soft_tokens")
return self.info.get_image_repl(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
max_soft_tokens=max_soft_tokens,
)
prompt_updates.append(
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement_image,
)
)
if "video" in mm_items:
video_token = hf_processor.video_token
def get_replacement_video(item_idx: int):
out_item = out_mm_kwargs["video"][item_idx]
timestamps = out_item["video_timestamps"].data
num_soft = out_item["video_num_soft_tokens"].data
return self.info.get_video_repl(
timestamps=timestamps,
num_soft_tokens_per_frame=num_soft,
processor=hf_processor,
)
prompt_updates.append(
PromptReplacement(
modality="video",
target=video_token,
replacement=get_replacement_video,
)
)
if "audio" in mm_items:
audio_token = hf_processor.audio_token
def get_replacement_audio(item_idx: int):
audios = mm_items.get_items("audio", AudioProcessorItems)
audio_len = audios.get_audio_length(item_idx)
return self.info.get_audio_repl(
audio_len=audio_len,
processor=hf_processor,
)
prompt_updates.append(
PromptReplacement(
modality="audio",
target=audio_token,
replacement=get_replacement_audio,
)
)
return prompt_updates
# NOTE: Gemma3/Gemma3n override _apply_token_matches and
# _find_mm_placeholders to merge adjacent newline tokens that arise
# when full_image_sequence contains "\n\n" wrappers. Gemma4's
# full_image_sequence has NO newlines (just BOI + 280×image_token +
# EOI), so the base class implementations work correctly as-is.
# ---------------------------------------------------------------------------
# Multimodal embedder
# ---------------------------------------------------------------------------
class Gemma4MultimodalEmbedder(nn.Module):
"""Projects vision/audio soft tokens into LM embedding space.
Architecture:
inputs_embeds → embedding_projection → embedding_post_projection_norm
Unlike Gemma3n which has separate hard/soft embedding paths with
per-path normalization and a learned embedding table, Gemma4 uses a
simplified 2-layer design: a linear projection followed by RMSNorm
(without learnable scale). The checkpoint confirms this — only
``embedding_projection.weight`` exists; there is no embedding table
or pre-projection norm weights.
"""
def __init__(
self,
multimodal_config: Gemma4VisionConfig | Gemma4AudioConfig,
text_config: Gemma4TextConfig,
):
super().__init__()
self.eps = multimodal_config.rms_norm_eps
self.text_hidden_size = text_config.hidden_size
# Audio tower uses output_proj_dims (1536) rather than hidden_size
# (1024); vision uses hidden_size (768) directly.
embedding_dim = (
getattr(multimodal_config, "output_proj_dims", None)
or multimodal_config.hidden_size
)
self.embedding_projection = ReplicatedLinear(
embedding_dim,
self.text_hidden_size,
bias=False,
)
self.embedding_post_projection_norm = RMSNorm(
self.text_hidden_size,
eps=self.eps,
has_weight=False,
)
def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
"""Project soft tokens from a multimodal tower into LM space."""
embs_proj, _ = self.embedding_projection(inputs_embeds)
return self.embedding_post_projection_norm(embs_proj)
# ---------------------------------------------------------------------------
# Main model
# ---------------------------------------------------------------------------
@MULTIMODAL_REGISTRY.register_processor(
Gemma4MultiModalProcessor,
info=Gemma4ProcessingInfo,
dummy_inputs=Gemma4DummyInputsBuilder,
)
class Gemma4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# Maps checkpoint prefixes to vLLM module paths.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.embed_audio.": "embed_audio.",
"model.embed_vision.": "embed_vision.",
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.audio_tower.": "audio_tower.",
"lm_head.": "language_model.lm_head.",
"model": "language_model.model",
}
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
# ---- Vision tower (shared by image and video) ----
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.embed_vision = Gemma4MultimodalEmbedder(
config.vision_config, config.text_config
)
# ---- Audio tower (variants with audio_config) ----
if config.audio_config is not None:
with self._mark_tower_model(vllm_config, "audio"):
self.audio_tower = AutoModel.from_config(config=config.audio_config)
# AutoModel.from_config does NOT call post_init(),
# which is needed to initialize buffers that are absent
# from the checkpoint (e.g. inv_timescales for relative
# position embeddings, softcap, gradient_clipping).
self.audio_tower.post_init()
self.embed_audio = Gemma4MultimodalEmbedder(
config.audio_config, config.text_config
)
else:
self.audio_tower = None
self.embed_audio = None
# ---- Language model (vLLM optimised) ----
with self._mark_language_model(vllm_config):
self.language_model: Gemma4ForCausalLM = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Gemma4ForCausalLM"],
)
# Pre-allocate PLE buffer for CUDA graph compatibility.
# Some variants have hidden_size_per_layer_input=None (no PLE).
ple_dim = config.text_config.hidden_size_per_layer_input
if ple_dim is not None:
self.per_layer_embeddings = torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.num_hidden_layers,
ple_dim,
device=(self.language_model.model.embed_tokens.weight.device),
dtype=(self.language_model.model.embed_tokens.weight.dtype),
)
else:
self.per_layer_embeddings = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
# --- MixtureOfExperts delegation to language_model ---
self.expert_weights = self.language_model.expert_weights
self.moe_layers = self.language_model.moe_layers
self.num_moe_layers = self.language_model.num_moe_layers
self.num_logical_experts = self.language_model.num_logical_experts
self.num_physical_experts = self.language_model.num_physical_experts
self.num_local_physical_experts = self.language_model.num_local_physical_experts
self.num_routed_experts = self.language_model.num_routed_experts
self.num_expert_groups = self.language_model.num_expert_groups
self.num_shared_experts = self.language_model.num_shared_experts
self.num_redundant_experts = self.language_model.num_redundant_experts
# ------------------------------------------------------------------ #
# Input parsing
# ------------------------------------------------------------------ #
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Gemma4ImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
pixel_position_ids = kwargs.pop("pixel_position_ids", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma4 does not support image_embeds."
if pixel_values is None:
return None
return Gemma4ImagePixelInputs(
pixel_values=pixel_values,
pixel_position_ids=pixel_position_ids,
)
def _parse_and_validate_audio_input(
self, **kwargs: object
) -> Gemma4AudioInputs | None:
input_features_padded = kwargs.pop("input_features_padded", None)
if input_features_padded is None:
return None
input_features_mask = kwargs.pop("input_features_mask", None)
if input_features_mask is None:
return None
return Gemma4AudioInputs(
input_features_padded=input_features_padded,
input_features_mask=input_features_mask,
)
def _parse_and_validate_video_input(
self, **kwargs: object
) -> dict[str, torch.Tensor] | None:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
pixel_position_ids_videos = kwargs.pop("pixel_position_ids_videos", None)
video_frame_counts = kwargs.pop("video_frame_counts", None)
if pixel_values_videos is None:
return None
return {
"pixel_values_videos": pixel_values_videos,
"pixel_position_ids_videos": pixel_position_ids_videos,
"video_frame_counts": video_frame_counts,
}
def _parse_and_validate_multimodal_inputs(
self, **kwargs: object
) -> dict[str, Gemma4ImageInputs | Gemma4AudioInputs | Gemma4VideoInputs | None]:
mm_input_by_modality = {}
for input_key in list(kwargs):
if (
input_key in ("pixel_values", "image_embeds")
and "image" not in mm_input_by_modality
):
mm_input_by_modality["image"] = self._parse_and_validate_image_input(
**kwargs
)
if (
input_key == "pixel_values_videos"
and "video" not in mm_input_by_modality
):
mm_input_by_modality["video"] = self._parse_and_validate_video_input(
**kwargs
)
if (
input_key == "input_features_padded"
and "audio" not in mm_input_by_modality
):
mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(
**kwargs
)
return mm_input_by_modality
# ------------------------------------------------------------------ #
# Image processing
# ------------------------------------------------------------------ #
def _process_image_input(
self,
image_input: Gemma4ImageInputs,
) -> list[torch.Tensor]:
pixel_values = image_input["pixel_values"]
pixel_position_ids = image_input["pixel_position_ids"]
# The HF image processor now outputs pre-patchified data:
# pixel_values: (num_images, max_patches, patch_pixels)
# pixel_position_ids: (num_images, max_patches, 2)
# We call the vision tower's forward() directly, which handles
# patch embedding, encoding, pooling, padding removal, and
# optional standardization internally.
vt = self.vision_tower
pooling_k2 = self.config.vision_config.pooling_kernel_size**2
# TODO: Move this per-image loop into the input processor to
# reduce dynamism at the model runner / engine core. This
# requires spatially padding all images to uniform (H_max,
# W_max) in _call_hf_processor() so they arrive as a single
# stacked tensor, tracking padded regions via image_sizes
# metadata, and validating numerical equivalence with the
# current per-image path.
#
# Process each image individually through the vision tower.
# The vision tower's forward() strips padding and returns a
# flat tensor of valid tokens. We process per-image to get
# variable-length outputs matching the dynamic token count
# from get_image_repl.
per_image_features = []
for i in range(pixel_values.shape[0]):
pv = pixel_values[i].unsqueeze(0) # (1, max_patches, patch_pixels)
pp = pixel_position_ids[i].unsqueeze(0) # (1, max_patches, 2)
# Derive the pooler's output_length from the total patch
# count (including padding). The vision tower encoder
# processes ALL patches — padding patches get zero hidden
# states but still occupy sequence positions. The pooler's
# _avg_pool_by_positions requires:
# input_seq_len / output_length == k²
# where k == pooling_kernel_size. The image processor
# allocates max_patches = max_soft_tokens * k² total slots,
# so output_length = max_patches / k² == max_soft_tokens.
# Without this, the pooler falls back to
# config.image_seq_length (e.g. 280), which fails when a
# different max_soft_tokens was used at preprocessing time.
max_patches = pv.shape[1]
output_length = max_patches // pooling_k2
vt_output = vt(pv, pp, output_length=output_length)
# last_hidden_state: (num_valid_tokens, hidden_size)
# — already flat with padding stripped by the vision tower
per_image_features.append(vt_output.last_hidden_state)
# Project each image's features into LM embedding space.
# Per-image loop is required because images have variable
# token counts after padding removal.
# Cast to match the projection layer's dtype (model may be
# bf16 while the vision tower outputs fp32).
target_dtype = self.embed_vision.embedding_projection.weight.dtype
return [
self.embed_vision(inputs_embeds=img.unsqueeze(0).to(target_dtype)).squeeze(
0
)
for img in per_image_features
]
# ------------------------------------------------------------------ #
# Video processing (frames through vision tower)
# ------------------------------------------------------------------ #
def _process_video_input(
self,
video_input: dict[str, torch.Tensor],
) -> list[torch.Tensor]:
"""Process video frames through the vision tower.
Reuses the image processing pipeline — Gemma4 has no separate
video tower; video frames are just images at lower resolution
(max_soft_tokens=70).
Returns one concatenated embedding tensor per video (not per
frame), because vLLM treats one video as one multimodal item.
The flat_from_sizes field config groups all frames of a video
together, so embed_multimodal must return one tensor per video.
"""
pixel_values = video_input["pixel_values_videos"]
pixel_position_ids = video_input["pixel_position_ids_videos"]
frame_counts = video_input["video_frame_counts"]
vt = self.vision_tower
pooling_k2 = self.config.vision_config.pooling_kernel_size**2
target_dtype = self.embed_vision.embedding_projection.weight.dtype
# Split flat tensors into per-video chunks
if isinstance(frame_counts, torch.Tensor):
fc_list = frame_counts.tolist()
else:
fc_list = list(frame_counts)
pv_per_video = torch.split(pixel_values, fc_list, dim=0)
pp_per_video = torch.split(pixel_position_ids, fc_list, dim=0)
per_video_embeddings = []
for pv_chunk, pp_chunk in zip(pv_per_video, pp_per_video):
frame_embs = []
for i in range(pv_chunk.shape[0]):
pv = pv_chunk[i].unsqueeze(0)
pp = pp_chunk[i].unsqueeze(0)
max_patches = pv.shape[1]
output_length = max_patches // pooling_k2
vt_output = vt(pv, pp, output_length=output_length)
frame_emb = self.embed_vision(
inputs_embeds=(
vt_output.last_hidden_state.unsqueeze(0).to(target_dtype)
)
).squeeze(0)
frame_embs.append(frame_emb)
# Concatenate all frames of this video into one tensor.
per_video_embeddings.append(torch.cat(frame_embs, dim=0))
return per_video_embeddings
# ------------------------------------------------------------------ #
# Audio processing
# ------------------------------------------------------------------ #
def _process_audio_input(
self,
audio_input: Gemma4AudioInputs,
) -> list[torch.Tensor]:
input_features = audio_input["input_features_padded"].squeeze(1)
input_features_mask = audio_input["input_features_mask"].squeeze(1)
# Run audio tower — mask uses standard HF convention
# (True=valid, False=padding).
audio_outputs = self.audio_tower(input_features, input_features_mask)
if isinstance(audio_outputs, tuple):
audio_encodings, audio_mask = audio_outputs
else:
audio_encodings = audio_outputs.last_hidden_state
audio_mask = audio_outputs.attention_mask
# Project into LM embedding space.
audio_features = self.embed_audio(inputs_embeds=audio_encodings)
# Strip padding per-batch element: only keep real (non-padding)
# tokens. audio_mask is True for valid positions (HF convention).
per_audio = []
for enc, mask in zip(audio_features, audio_mask, strict=True):
per_audio.append(enc[mask]) # [num_real, hidden_size]
return per_audio
# ------------------------------------------------------------------ #
# MultiModalEmbeddings interface
# ------------------------------------------------------------------ #
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
multimodal_embeddings: list[torch.Tensor] = []
for modality, multimodal_input in mm_input_by_modality.items():
if multimodal_input is None:
continue
if modality == "image":
multimodal_embeddings.extend(
self._process_image_input(multimodal_input)
)
elif modality == "video":
multimodal_embeddings.extend(
self._process_video_input(multimodal_input)
)
elif modality == "audio":
multimodal_embeddings.extend(
self._process_audio_input(multimodal_input)
)
return multimodal_embeddings
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
) -> torch.Tensor:
# Cache per-layer embeddings (PLE) for the language model's
# forward pass. During profiling embed_input_ids is not called,
# so the pre-allocated zeros are used instead.
if self.per_layer_embeddings is not None:
# Mask multimodal tokens (image/audio) to 0 for PLE
# computation (using token_type_ids == 0 as text_mask).
# Replicate this: map image token positions to token 0.
if is_multimodal is not None:
is_multimodal = is_multimodal.to(input_ids.device)
ple_input_ids = torch.where(
is_multimodal, torch.zeros_like(input_ids), input_ids
)
else:
ple_input_ids = input_ids
per_layer_inputs = self.language_model.model.get_per_layer_inputs(
ple_input_ids
)
if per_layer_inputs is not None:
per_layer_inputs = per_layer_inputs.reshape(
-1,
self.config.text_config.num_hidden_layers,
self.config.text_config.hidden_size_per_layer_input,
)
self.per_layer_embeddings[: per_layer_inputs.shape[0]].copy_(
per_layer_inputs
)
if multimodal_embeddings is None or is_multimodal is None:
return super().embed_input_ids(input_ids)
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
# ------------------------------------------------------------------ #
# Forward
# ------------------------------------------------------------------ #
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
# Select the pre-cached PLEs for this batch (None when PLE
# is disabled for variants without PLE).
per_layer_inputs = (
self.per_layer_embeddings[: inputs_embeds.shape[0]]
if self.per_layer_embeddings is not None and inputs_embeds is not None
else None
)
hidden_states = self.language_model.model(
input_ids,
positions,
per_layer_inputs=per_layer_inputs,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**kwargs,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
# ------------------------------------------------------------------ #
# Weight loading
# ------------------------------------------------------------------ #
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# Some checkpoints have vestigial embed_vision.embedding and
# embed_audio.embedding weights from the Gemma3n architecture
# that are not used by Gemma4's MultimodalEmbedder (which only
# has embedding_projection + embedding_post_projection_norm).
ignore_prefixes = [
"embed_vision.embedding.",
"embed_audio.embedding.",
]
# Models without audio tower should skip
# audio weights entirely.
if self.audio_tower is None:
ignore_prefixes.extend(
[
"audio_tower.",
"embed_audio.",
]
)
loader = AutoWeightsLoader(
self,
ignore_unexpected_prefixes=ignore_prefixes,
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
# ------------------------------------------------------------------ #
# LoRA / multimodal mapping
# ------------------------------------------------------------------ #
def get_mm_mapping(self) -> MultiModelKeys:
"""Get the module prefix mapping for multimodal models."""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector=["embed_vision", "embed_audio"],
tower_model=["vision_tower", "audio_tower"],
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality == "image":
return "<image_soft_token>"
if modality == "audio":
return "<audio_soft_token>"
if modality == "video":
return "<|video|>"
raise ValueError(f"Unsupported modality: {modality}")
...@@ -111,6 +111,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -111,6 +111,7 @@ _TEXT_GENERATION_MODELS = {
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
"Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"), "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
"Gemma4ForCausalLM": ("gemma4", "Gemma4ForCausalLM"),
"Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"), "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"),
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
...@@ -377,6 +378,7 @@ _MULTIMODAL_MODELS = { ...@@ -377,6 +378,7 @@ _MULTIMODAL_MODELS = {
"gemma3n_mm", "gemma3n_mm",
"Gemma3nForConditionalGeneration", "Gemma3nForConditionalGeneration",
), ),
"Gemma4ForConditionalGeneration": ("gemma4_mm", "Gemma4ForConditionalGeneration"),
"GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"), "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
......
...@@ -233,8 +233,15 @@ class AutoWeightsLoader: ...@@ -233,8 +233,15 @@ class AutoWeightsLoader:
): ):
""" """
Add tensor names that are not in the model params that may be in the Add tensor names that are not in the model params that may be in the
safetensors, e.g., batch normalization stats. safetensors, e.g., batch normalization stats and registered buffers.
""" """
# Add persistent registered buffers.
# Non-persistent buffers are excluded, matching PyTorch state_dict().
non_persistent = getattr(module, "_non_persistent_buffers_set", set())
for buf_name, buf in module.named_buffers(recurse=False):
if buf_name not in child_params and buf_name not in non_persistent:
child_params[buf_name] = buf
if isinstance( if isinstance(
module, module,
( (
......
...@@ -32,6 +32,10 @@ _REASONING_PARSERS_TO_REGISTER = { ...@@ -32,6 +32,10 @@ _REASONING_PARSERS_TO_REGISTER = {
"ernie45_reasoning_parser", "ernie45_reasoning_parser",
"Ernie45ReasoningParser", "Ernie45ReasoningParser",
), ),
"gemma4": (
"gemma4_reasoning_parser",
"Gemma4ReasoningParser",
),
"glm45": ( "glm45": (
"deepseek_v3_reasoning_parser", "deepseek_v3_reasoning_parser",
"DeepSeekV3ReasoningWithThinkingParser", "DeepSeekV3ReasoningWithThinkingParser",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.tokenizers import TokenizerLike
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
# Role label that Gemma4 emits at the start of the thinking channel.
# The model generates: <|channel>thought\n...reasoning...<channel|>
# This prefix must be stripped to expose only the actual reasoning content.
_THOUGHT_PREFIX = "thought\n"
class Gemma4ReasoningParser(BaseThinkingReasoningParser):
"""
Reasoning parser for Google Gemma4 thinking models.
Gemma4 uses <|channel>...<channel|> tokens to delimit reasoning/thinking
content within its output. Thinking mode is activated by passing
``enable_thinking=True`` in the chat template kwargs, which injects a
system turn containing <|think|> (token 98) to trigger chain-of-thought
reasoning.
Output pattern when thinking is enabled::
<|channel>thought
...chain of thought reasoning...<channel|>
Final answer text here.
The ``thought\\n`` role label inside the channel delimiters is a
structural artefact (analogous to ``user\\n`` in ``<|turn>user\\n...``).
This parser strips it so that downstream consumers see only the
actual reasoning text, consistent with the offline parser
(``vllm.reasoning.gemma4_utils._strip_thought_label``).
"""
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
# Instance state for streaming prefix stripping.
# Tracks only the reasoning text received from the base parser,
# independent of current_text (which may contain pre-reasoning
# content and lacks special token text due to
# skip_special_tokens=True).
self._reasoning_text: str = ""
self._prefix_stripped: bool = False
self.new_turn_token_id = self.vocab["<|turn>"]
self.tool_call_token_id = self.vocab["<|tool_call>"]
self.tool_response_token_id = self.vocab["<|tool_response>"]
def adjust_request(
self, request: "ChatCompletionRequest | ResponsesRequest"
) -> "ChatCompletionRequest | ResponsesRequest":
"""Disable special-token stripping to preserve boundary tokens."""
request.skip_special_tokens = False
return request
@property
def start_token(self) -> str:
"""The token that starts reasoning content."""
return "<|channel>"
@property
def end_token(self) -> str:
"""The token that ends reasoning content."""
return "<channel|>"
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
start_token_id = self.start_token_id
end_token_id = self.end_token_id
new_turn_token_id = self.new_turn_token_id
tool_call_token_id = self.tool_call_token_id
tool_response_token_id = self.tool_response_token_id
# Search from the end of input_ids to find the last match.
for i in range(len(input_ids) - 1, -1, -1):
if input_ids[i] == start_token_id:
return False
if input_ids[i] == tool_call_token_id:
# We're generating a tool call, so reasoning must be ended.
return True
if input_ids[i] in (new_turn_token_id, tool_response_token_id):
# We found a new turn or tool response token so don't consider
# reasoning ended yet, since the model starts new reasoning
# after these tokens.
return False
if input_ids[i] == end_token_id:
return True
return False
# ------------------------------------------------------------------
# Non-streaming path
# ------------------------------------------------------------------
def extract_reasoning(
self,
model_output: str,
request: "ChatCompletionRequest | ResponsesRequest",
) -> tuple[str | None, str | None]:
"""Extract reasoning, stripping the ``thought\\n`` role label."""
if self.start_token not in model_output and self.end_token not in model_output:
# Default to content history if no tags are present
# (or if they were stripped)
return None, model_output
reasoning, content = super().extract_reasoning(model_output, request)
if reasoning is not None:
reasoning = _strip_thought_label(reasoning)
return reasoning, content
# ------------------------------------------------------------------
# Streaming path
# ------------------------------------------------------------------
def extract_reasoning_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
"""Extract streaming reasoning, stripping ``thought\\n`` from the
first reasoning delta(s).
The ``thought\\n`` prefix may arrive as a single delta or split
across multiple deltas (e.g. ``"thought"`` then ``"\\n"``). We
buffer early reasoning tokens until we can determine whether the
prefix is present, then emit the buffered content minus the
prefix.
Unlike the previous implementation which reconstructed accumulated
reasoning from ``current_text``, this uses instance state
(``_reasoning_text``) to track only the reasoning content returned
by the base parser. This is necessary because
``skip_special_tokens=True`` (the vLLM default) causes the
``<|channel>`` delimiter to be invisible in ``current_text``,
making it impossible to separate pre-reasoning content from
reasoning content via string matching.
"""
result = super().extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
)
if result is None:
return None
if result.reasoning is None:
return result
# Accumulate ONLY the reasoning text from base parser results.
# This is immune to pre-reasoning content pollution.
self._reasoning_text += result.reasoning
# Once the prefix has been handled, all subsequent reasoning
# deltas pass through unchanged.
if self._prefix_stripped:
return result
# ---- Prefix stripping logic ----
# Case 1: We've accumulated enough to confirm the prefix is
# present. Strip it and pass through the remainder.
if self._reasoning_text.startswith(_THOUGHT_PREFIX):
prefix_len = len(_THOUGHT_PREFIX)
# How much reasoning was accumulated before this delta?
prev_reasoning_len = len(self._reasoning_text) - len(result.reasoning)
if prev_reasoning_len >= prefix_len:
# Prefix was already consumed by prior deltas; this
# delta is entirely real content — pass through.
self._prefix_stripped = True
return result
else:
# Part or all of the prefix is in this delta.
chars_of_prefix_in_delta = prefix_len - prev_reasoning_len
stripped = result.reasoning[chars_of_prefix_in_delta:]
if stripped:
self._prefix_stripped = True
result.reasoning = stripped
return result
else:
if len(self._reasoning_text) >= prefix_len:
self._prefix_stripped = True
result.reasoning = ""
return result
return None
# Case 2: Accumulated text is a strict prefix of
# _THOUGHT_PREFIX (e.g. we've only seen "thou" so far).
# Buffer by suppressing — we can't yet tell if this will
# become the full prefix or diverge.
if _THOUGHT_PREFIX.startswith(self._reasoning_text):
return None
# Case 3: Accumulated text doesn't match the thought prefix
# at all. This means prior deltas were buffered (suppressed
# by Case 2) but the text diverged. Re-emit the full
# accumulated text to avoid data loss.
self._prefix_stripped = True
result.reasoning = self._reasoning_text
return result
def _strip_thought_label(text: str) -> str:
"""Remove the ``thought\\n`` role label from the beginning of text.
Mirrors ``vllm.reasoning.gemma4_utils._strip_thought_label`` from the
offline parser.
"""
if text.startswith(_THOUGHT_PREFIX):
return text[len(_THOUGHT_PREFIX) :]
return text
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
"""Gemma4 thinking/reasoning output parsing utilities for offline inference.
Standalone functions that parse decoded model text to extract structured
thinking content from Gemma4 models. These are pure-Python utilities with
zero heavy dependencies — they work on raw decoded strings from any
inference backend (vLLM, HuggingFace, TGI, etc.).
For the OpenAI-compatible API reasoning parser (streaming +
non-streaming), see ``vllm.reasoning.gemma4_reasoning_parser``.
For tool call parsing, see ``vllm.tool_parsers.gemma4_utils``.
Usage with vLLM offline inference::
from vllm import LLM, SamplingParams
from vllm.reasoning.gemma4_utils import parse_thinking_output
llm = LLM(model="google/gemma-4-it")
outputs = llm.generate(prompt, SamplingParams(...))
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
# Extract thinking / answer (works with or without enable_thinking)
result = parse_thinking_output(text)
print(result["thinking"]) # chain-of-thought or None
print(result["answer"]) # final answer
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
do not need a transformers dependency for output parsing.
"""
# ---- Thinking Mode Utility ----
# Thinking delimiter tokens as they appear in decoded text.
# Gemma4 uses <|channel> (start) and <channel|> (end) as thinking delimiters.
_THINKING_START_TAG = "<|channel>"
_THINKING_END_TAG = "<channel|>"
# Sentinel tokens that may appear in decoded output.
_TURN_END_TAG = "<turn|>"
def parse_thinking_output(text: str) -> dict[str, str | None]:
"""Parse decoded Gemma4 model output.
Use this on **all** Gemma4 output regardless of whether thinking mode
was enabled. It handles three cases:
1. **Thinking enabled, tags present** — splits on ``<|channel>``/
``<channel|>`` to separate chain-of-thought from the answer and
strips the ``thought\\n`` role label.
2. **Thinking disabled, spurious label** — strips the bare
``thought\\n`` prefix that some Gemma4 models emit even
without thinking mode.
3. **Clean output** — returns the text unchanged.
The answer text is always cleaned of trailing sentinel tokens
(``<turn|>``, ``<eos>``, etc.).
Args:
text: Decoded model output text (from ``tokenizer.decode(...)``).
Returns:
A dict with keys:
- ``"thinking"``: The chain-of-thought text, or ``None`` if no
thinking delimiters were found.
- ``"answer"``: The final answer text.
Example::
>>> from vllm.reasoning.gemma4_utils import parse_thinking_output
>>> output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
>>> result = parse_thinking_output(output_text)
>>> print(result["thinking"]) # chain-of-thought reasoning or None
>>> print(result["answer"]) # final answer
"""
if _THINKING_END_TAG in text:
parts = text.split(_THINKING_END_TAG, 1)
thinking_block = parts[0]
answer = _clean_answer(parts[1])
# Extract thinking content: strip the start tag if present
if _THINKING_START_TAG in thinking_block:
thinking = thinking_block.split(_THINKING_START_TAG, 1)[1]
else:
thinking = thinking_block
# Strip the "thought\n" channel role label the model emits inside
# <|channel>thought\n...<channel|> (analogous to "user\n" in
# <|turn>user\n...<turn|>).
thinking = _strip_thought_label(thinking.strip())
thinking = thinking.strip()
return {"thinking": thinking, "answer": answer}
# No thinking delimiters found.
# Strip spurious "thought\n" role label that some Gemma4 models sometimes
# emit even without thinking mode enabled, then clean trailing tokens.
answer = _strip_thought_label(text)
answer = _clean_answer(answer)
return {"thinking": None, "answer": answer}
def _strip_thought_label(text: str) -> str:
"""Strip the spurious ``thought\\n`` label from the start of text.
Only strips when ``thought`` appears as the very first word followed by
a newline — preserving the word ``thought`` in any other context.
"""
if text.startswith("thought\n"):
return text[len("thought\n") :]
return text
def _clean_answer(text: str) -> str:
"""Clean trailing sentinel tokens from the answer text.
Strips ``<turn|>``, ``<eos>``, and surrounding whitespace that the
model appends at the end of its response.
"""
text = text.strip()
# Strip trailing <turn|> (Gemma4 turn-end marker)
if text.endswith(_TURN_END_TAG):
text = text[: -len(_TURN_END_TAG)].rstrip()
# Strip trailing <eos> if present
if text.endswith("<eos>"):
text = text[:-5].rstrip()
return text
...@@ -154,6 +154,10 @@ _TOOL_PARSERS_TO_REGISTER = { ...@@ -154,6 +154,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"functiongemma_tool_parser", "functiongemma_tool_parser",
"FunctionGemmaToolParser", "FunctionGemmaToolParser",
), ),
"gemma4": (
"gemma4_tool_parser",
"Gemma4ToolParser",
),
} }
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tool call parser for Google Gemma4 models.
Gemma4 uses a custom serialization format (not JSON) for tool calls::
<|tool_call>call:func_name{key:<|"|>value<|"|>,num:42}<tool_call|>
Strings are delimited by ``<|"|>`` (token 52), keys are unquoted, and
multiple tool calls are concatenated without separators.
Used when ``--enable-auto-tool-choice --tool-call-parser gemma4`` are set.
For offline inference tool call parsing (direct ``tokenizer.decode()`` output),
see ``vllm.tool_parsers.gemma4_utils.parse_tool_calls``.
"""
import json
from collections.abc import Sequence
import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
from vllm.tool_parsers.utils import find_common_prefix
logger = init_logger(__name__)
# Gemma4 special tokens for tool calls
TOOL_CALL_START = "<|tool_call>"
TOOL_CALL_END = "<tool_call|>"
STRING_DELIM = '<|"|>'
# ---------------------------------------------------------------------------
# Gemma4 argument parser (used by both streaming and non-streaming paths)
# ---------------------------------------------------------------------------
def _parse_gemma4_value(value_str: str) -> object:
"""Parse a single Gemma4 value (after key:) into a Python object."""
value_str = value_str.strip()
if not value_str:
return value_str
# Boolean
if value_str == "true":
return True
if value_str == "false":
return False
# Number (int or float)
try:
if "." in value_str:
return float(value_str)
return int(value_str)
except ValueError:
pass
# Bare string (no <|"|> delimiters — shouldn't happen but be safe)
return value_str
def _parse_gemma4_args(args_str: str, *, partial: bool = False) -> dict:
"""Parse Gemma4's custom key:value format into a Python dict.
Format examples::
location:<|"|>Tokyo<|"|>
location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>
count:42,flag:true
nested:{inner_key:<|"|>val<|"|>}
items:[<|"|>a<|"|>,<|"|>b<|"|>]
Args:
args_str: The raw Gemma4 argument string.
partial: When True (streaming), bare values at end of string are
omitted because they may be incomplete and type-unstable
(e.g. partial boolean parsed as bare string).
Returns a dict ready for ``json.dumps()``.
"""
if not args_str or not args_str.strip():
return {}
result: dict = {}
i = 0
n = len(args_str)
while i < n:
# Skip whitespace and commas
while i < n and args_str[i] in (" ", ",", "\n", "\t"):
i += 1
if i >= n:
break
# Parse key (unquoted, ends at ':')
key_start = i
while i < n and args_str[i] != ":":
i += 1
if i >= n:
break
key = args_str[key_start:i].strip()
i += 1 # skip ':'
# Parse value
if i >= n:
if not partial:
result[key] = ""
break
# Skip whitespace after ':'
while i < n and args_str[i] in (" ", "\n", "\t"):
i += 1
if i >= n:
if not partial:
result[key] = ""
break
# String value: <|"|>...<|"|>
if args_str[i:].startswith(STRING_DELIM):
i += len(STRING_DELIM)
val_start = i
end_pos = args_str.find(STRING_DELIM, i)
if end_pos == -1:
# Unterminated string — take rest
result[key] = args_str[val_start:]
break
result[key] = args_str[val_start:end_pos]
i = end_pos + len(STRING_DELIM)
# Nested object: {...}
elif args_str[i] == "{":
depth = 1
obj_start = i + 1
i += 1
while i < n and depth > 0:
if args_str[i:].startswith(STRING_DELIM):
# Skip over string contents to avoid counting { inside strings
i += len(STRING_DELIM)
next_delim = args_str.find(STRING_DELIM, i)
i = n if next_delim == -1 else next_delim + len(STRING_DELIM)
continue
if args_str[i] == "{":
depth += 1
elif args_str[i] == "}":
depth -= 1
i += 1
if depth > 0:
# Incomplete nested object — use i (not i-1) to avoid
# dropping the last char, and recurse as partial.
result[key] = _parse_gemma4_args(args_str[obj_start:i], partial=True)
else:
result[key] = _parse_gemma4_args(args_str[obj_start : i - 1])
# Array: [...]
elif args_str[i] == "[":
depth = 1
arr_start = i + 1
i += 1
while i < n and depth > 0:
if args_str[i:].startswith(STRING_DELIM):
i += len(STRING_DELIM)
next_delim = args_str.find(STRING_DELIM, i)
i = n if next_delim == -1 else next_delim + len(STRING_DELIM)
continue
if args_str[i] == "[":
depth += 1
elif args_str[i] == "]":
depth -= 1
i += 1
if depth > 0:
result[key] = _parse_gemma4_array(args_str[arr_start:i], partial=True)
else:
result[key] = _parse_gemma4_array(args_str[arr_start : i - 1])
# Bare value (number, boolean, etc.)
else:
val_start = i
while i < n and args_str[i] not in (",", "}", "]"):
i += 1
if partial and i >= n:
# Value may be incomplete (e.g. partial boolean) —
# withhold to avoid type instability during streaming.
break
result[key] = _parse_gemma4_value(args_str[val_start:i])
return result
def _parse_gemma4_array(arr_str: str, *, partial: bool = False) -> list:
"""Parse a Gemma4 array content string into a Python list."""
items: list = []
i = 0
n = len(arr_str)
while i < n:
while i < n and arr_str[i] in (" ", ",", "\n", "\t"):
i += 1
if i >= n:
break
# String element
if arr_str[i:].startswith(STRING_DELIM):
i += len(STRING_DELIM)
end_pos = arr_str.find(STRING_DELIM, i)
if end_pos == -1:
items.append(arr_str[i:])
break
items.append(arr_str[i:end_pos])
i = end_pos + len(STRING_DELIM)
# Nested object
elif arr_str[i] == "{":
depth = 1
obj_start = i + 1
i += 1
while i < n and depth > 0:
if arr_str[i:].startswith(STRING_DELIM):
i += len(STRING_DELIM)
nd = arr_str.find(STRING_DELIM, i)
i = nd + len(STRING_DELIM) if nd != -1 else n
continue
if arr_str[i] == "{":
depth += 1
elif arr_str[i] == "}":
depth -= 1
i += 1
if depth > 0:
items.append(_parse_gemma4_args(arr_str[obj_start:i], partial=True))
else:
items.append(_parse_gemma4_args(arr_str[obj_start : i - 1]))
# Nested array
elif arr_str[i] == "[":
depth = 1
sub_start = i + 1
i += 1
while i < n and depth > 0:
if arr_str[i] == "[":
depth += 1
elif arr_str[i] == "]":
depth -= 1
i += 1
if depth > 0:
items.append(_parse_gemma4_array(arr_str[sub_start:i], partial=True))
else:
items.append(_parse_gemma4_array(arr_str[sub_start : i - 1]))
# Bare value
else:
val_start = i
while i < n and arr_str[i] not in (",", "]"):
i += 1
if partial and i >= n:
break
items.append(_parse_gemma4_value(arr_str[val_start:i]))
return items
# ---------------------------------------------------------------------------
# Parser
# ---------------------------------------------------------------------------
class Gemma4ToolParser(ToolParser):
"""
Tool call parser for Google Gemma4 models.
Handles the Gemma4 function call format::
<|tool_call>call:func_name{key:<|"|>value<|"|>}<tool_call|>
Used when ``--enable-auto-tool-choice --tool-call-parser gemma4``
are set.
Streaming strategy: **accumulate-then-parse-then-diff**
Instead of trying to convert Gemma4's custom format to JSON
token-by-token (which fails because Gemma4 uses bare keys, custom
delimiters, and structural braces that differ from JSON), this parser:
1. Accumulates the raw Gemma4 argument string during streaming
2. Parses it with ``_parse_gemma4_args()`` into a Python dict
3. Converts to JSON with ``json.dumps()``
4. Diffs against the previously-streamed JSON string
5. Emits only the new JSON fragment as the delta
This follows the same pattern used by FunctionGemma, Hermes, and Llama
tool parsers.
"""
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
# Token strings
self.tool_call_start_token = TOOL_CALL_START
self.tool_call_end_token = TOOL_CALL_END
# Token IDs
self.tool_call_start_token_id = self.vocab.get(TOOL_CALL_START)
self.tool_call_end_token_id = self.vocab.get(TOOL_CALL_END)
if self.tool_call_start_token_id is None:
raise RuntimeError(
"Gemma4 ToolParser could not locate the tool call start "
f"token '{TOOL_CALL_START}' in the tokenizer!"
)
# Regex for non-streaming: extract complete tool calls.
# Supports function names with letters, digits, underscores,
# hyphens, and dots (e.g. "get-weather", "module.func").
self.tool_call_regex = re.compile(
r"<\|tool_call>call:([\w\-\.]+)\{(.*?)\}<tool_call\|>",
re.DOTALL,
)
# Streaming state — reset per-request via _reset_streaming_state()
self._reset_streaming_state()
# Delta buffer for handling multi-token special sequences
self.buffered_delta_text = ""
def _reset_streaming_state(self) -> None:
"""Reset all streaming state for a new request."""
self.current_tool_id = -1
self.current_tool_name_sent = False
self.prev_tool_call_arr: list[dict] = []
self.streamed_args_for_tool: list[str] = []
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if (
isinstance(request, ChatCompletionRequest)
and request.tools
and request.tool_choice != "none"
):
# Don't skip special tokens — <|tool_call> etc. are needed
request.skip_special_tokens = False
return request
# ------------------------------------------------------------------
# Delta buffering for multi-token special sequences
# ------------------------------------------------------------------
def _buffer_delta_text(self, delta_text: str) -> str:
"""Buffer incoming delta text to handle multi-token special sequences.
Accumulates partial tokens that could be the start of
``<|tool_call>`` or ``<tool_call|>`` and only flushes them
when the complete sequence is recognized or the sequence breaks.
This prevents partial special tokens (e.g., ``<|tool``) from being
emitted prematurely as content text.
"""
combined = self.buffered_delta_text + delta_text
# Check if combined ends with a complete special token
if combined.endswith(TOOL_CALL_START) or combined.endswith(TOOL_CALL_END):
self.buffered_delta_text = ""
return combined
# Check if combined ends with a partial prefix of a special token
for tag in [TOOL_CALL_START, TOOL_CALL_END]:
for i in range(1, len(tag)):
if combined.endswith(tag[:i]):
self.buffered_delta_text = combined[-i:]
return combined[:-i]
# No partial match — flush everything
self.buffered_delta_text = ""
return combined
# ------------------------------------------------------------------
# Non-streaming extraction
# ------------------------------------------------------------------
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
matches = self.tool_call_regex.findall(model_output)
if not matches:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
tool_calls: list[ToolCall] = []
for func_name, args_str in matches:
arguments = _parse_gemma4_args(args_str)
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=func_name,
arguments=json.dumps(arguments, ensure_ascii=False),
),
)
)
# Content = text before first tool call (if any)
content_end = model_output.find(self.tool_call_start_token)
content = model_output[:content_end].strip() if content_end > 0 else None
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error extracting tool calls from Gemma4 response")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# ------------------------------------------------------------------
# Streaming extraction — accumulate-then-parse-then-diff
# ------------------------------------------------------------------
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
# Buffer delta text to handle multi-token special sequences
delta_text = self._buffer_delta_text(delta_text)
# Keep current_text from the upstream stream state. The buffered delta
# is only for emission, and must not be stitched back into the
# accumulated model text or normal content like "<div>" can be
# duplicated into "<<div>" when a tool call just ended.
# If no tool call token seen yet, emit as content
if self.tool_call_start_token not in current_text:
if delta_text:
return DeltaMessage(content=delta_text)
return None
try:
return self._extract_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
)
except Exception:
logger.exception("Error in Gemma4 streaming tool call extraction")
return None
def _extract_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
) -> DeltaMessage | None:
"""Tag-counting streaming parser.
Uses the proven approach from FunctionGemma/Hermes: count start/end
tags in previous vs current text to determine phase, then
accumulate-parse-diff for arguments.
Format: ``<|tool_call>call:name{args}<tool_call|>``
"""
start_count = current_text.count(self.tool_call_start_token)
end_count = current_text.count(self.tool_call_end_token)
prev_start_count = previous_text.count(self.tool_call_start_token)
prev_end_count = previous_text.count(self.tool_call_end_token)
# Case 1: Not inside any tool call — emit as content
if (
start_count == end_count
and prev_end_count == end_count
and self.tool_call_end_token not in delta_text
):
if delta_text:
return DeltaMessage(content=delta_text)
return None
# Case 2: Starting a new tool call
if start_count > prev_start_count and start_count > end_count:
self.current_tool_id += 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
self.prev_tool_call_arr.append({})
logger.debug("Starting new tool call %d", self.current_tool_id)
# Don't return yet — fall through to try parsing if there's
# content after <|tool_call> in this same delta
# (but usually it's just the token itself, so return None)
if len(delta_text) <= len(self.tool_call_start_token):
return None
# Case 3: Tool call just ended
if end_count > prev_end_count:
return self._handle_tool_call_end(current_text)
# Case 4: In the middle of a tool call — parse partial content
if start_count > end_count:
return self._handle_tool_call_middle(current_text)
# Default: generate text outside tool calls
if delta_text:
text = delta_text.replace(self.tool_call_start_token, "")
text = text.replace(self.tool_call_end_token, "")
if text:
return DeltaMessage(content=text)
return None
def _extract_partial_call(self, current_text: str) -> tuple[str | None, str]:
"""Extract function name and raw argument string from partial text.
Returns (func_name, raw_args_str) or (None, "") if not parseable yet.
"""
# Get the text after the last <|tool_call> token
last_start = current_text.rfind(self.tool_call_start_token)
if last_start == -1:
return None, ""
partial_call = current_text[last_start + len(self.tool_call_start_token) :]
# Strip end token if present
if self.tool_call_end_token in partial_call:
partial_call = partial_call.split(self.tool_call_end_token)[0]
# Expect "call:name{args...}" or "call:name{args...}"
if not partial_call.startswith("call:"):
return None, ""
func_part = partial_call[5:] # skip "call:"
if "{" not in func_part:
# Still accumulating function name, not ready yet
return None, ""
func_name, _, args_part = func_part.partition("{")
func_name = func_name.strip()
# Strip trailing '}' if present (Gemma4 structural brace)
if args_part.endswith("}"):
args_part = args_part[:-1]
return func_name, args_part
def _handle_tool_call_middle(self, current_text: str) -> DeltaMessage | None:
"""Handle streaming when we're inside an active tool call.
Accumulates the raw Gemma4 arguments, parses them into JSON, and
diffs against the previously-streamed JSON to emit only the new
fragment.
"""
func_name, args_part = self._extract_partial_call(current_text)
if func_name is None:
return None
# Step 1: Send function name (once)
if not self.current_tool_name_sent and func_name:
self.current_tool_name_sent = True
self.prev_tool_call_arr[self.current_tool_id] = {
"name": func_name,
"arguments": {},
}
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=func_name,
arguments="",
).model_dump(exclude_none=True),
)
]
)
# Step 2: Parse and diff arguments
if self.current_tool_name_sent and args_part:
return self._emit_argument_diff(args_part)
return None
def _handle_tool_call_end(self, current_text: str) -> DeltaMessage | None:
"""Handle streaming when a tool call has just completed.
Performs a final parse of the complete tool call and flushes
any remaining un-streamed argument fragments.
"""
if self.current_tool_id < 0 or self.current_tool_id >= len(
self.prev_tool_call_arr
):
logger.debug(
"Tool call end detected but no active tool call (current_tool_id=%d)",
self.current_tool_id,
)
return None
# Parse the complete tool call using regex for accuracy
all_matches = self.tool_call_regex.findall(current_text)
if self.current_tool_id < len(all_matches):
_, args_str = all_matches[self.current_tool_id]
final_args = _parse_gemma4_args(args_str)
final_args_json = json.dumps(final_args, ensure_ascii=False)
prev_streamed = self.streamed_args_for_tool[self.current_tool_id]
if len(final_args_json) > len(prev_streamed):
diff = final_args_json[len(prev_streamed) :]
self.streamed_args_for_tool[self.current_tool_id] = final_args_json
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = final_args
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=diff).model_dump(
exclude_none=True
),
)
]
)
return None
def _emit_argument_diff(self, raw_args_str: str) -> DeltaMessage | None:
"""Parse raw Gemma4 arguments, convert to JSON, diff, and emit.
This is the core of the accumulate-then-parse-then-diff strategy:
1. Parse ``raw_args_str`` with ``_parse_gemma4_args()``
2. Convert to JSON string with ``json.dumps()``
3. Withhold trailing closing characters (``"}``) that may move
as more tokens arrive
4. Diff against previously streamed JSON and emit only new chars
**Why withholding is necessary:**
Gemma4's custom format produces *structurally incomplete* JSON
during streaming. For example, when ``<|"|>Paris`` arrives
without a closing delimiter, ``_parse_gemma4_args`` treats it
as a complete value and produces ``{"location": "Paris"}``. But
when ``, France<|"|>`` arrives next, the JSON becomes
``{"location": "Paris, France"}``. If we had sent the closing
``"}`` from the first parse, the concatenated client output
would be ``{"location": "Paris"}France"}``, which is garbage.
The solution: **never send trailing closing chars during
streaming**. They get flushed by ``_handle_tool_call_end()``
when the ``<tool_call|>`` end marker arrives.
Args:
raw_args_str: The raw Gemma4 argument text accumulated so far
(without the surrounding ``{`` ``}``).
Returns:
DeltaMessage with the argument diff, or None if no new content.
"""
try:
current_args = _parse_gemma4_args(raw_args_str, partial=True)
except Exception:
logger.debug(
"Could not parse partial Gemma4 args yet: %s",
raw_args_str[:100],
)
return None
if not current_args:
return None
current_args_json = json.dumps(current_args, ensure_ascii=False)
# Withhold trailing closing characters that may shift as more
# tokens arrive. Strip trailing '}', '"', ']' and partial
# STRING_DELIM fragments ('<', '|', '\\', '>') to get the
# "safe prefix".
safe_json = current_args_json
while safe_json and safe_json[-1] in ("}", '"', "]", "<", "|", "\\", ">"):
safe_json = safe_json[:-1]
prev_streamed = self.streamed_args_for_tool[self.current_tool_id]
if not safe_json or safe_json == prev_streamed:
return None
# Use find_common_prefix to handle cases where the value changed
# structurally (e.g., a string grew).
if prev_streamed:
prefix = find_common_prefix(prev_streamed, safe_json)
sent_len = len(prev_streamed)
prefix_len = len(prefix)
if prefix_len < sent_len:
# Structure changed — we sent too much. Truncate our
# tracking to the common prefix and wait for the final
# flush in _handle_tool_call_end.
self.streamed_args_for_tool[self.current_tool_id] = prefix
return None
# Stream the new stable portion
diff = safe_json[sent_len:]
else:
# First emission
diff = safe_json
if diff:
self.streamed_args_for_tool[self.current_tool_id] = safe_json
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = current_args
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=diff).model_dump(
exclude_none=True
),
)
]
)
return None
...@@ -300,6 +300,28 @@ class ModelArchConfigConvertorBase: ...@@ -300,6 +300,28 @@ class ModelArchConfigConvertorBase:
return model_arch_config return model_arch_config
class CohereAsrModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_total_num_attention_heads(self) -> int:
return self.hf_text_config.transf_decoder["config_dict"]["num_attention_heads"]
def get_head_size(self) -> int:
hidden_size = self.hf_text_config.transf_decoder["config_dict"]["hidden_size"]
num_attention_heads = self.hf_text_config.transf_decoder["config_dict"][
"num_attention_heads"
]
return hidden_size // num_attention_heads
def get_total_num_kv_heads(self) -> int:
enc_num_kv_heads = self.hf_text_config.encoder["n_heads"]
dec_num_kv_heads = self.hf_text_config.transf_decoder["config_dict"][
"num_attention_heads"
]
assert enc_num_kv_heads == dec_num_kv_heads, (
"Encoder and decoder must have the same number of kv heads"
)
return enc_num_kv_heads
class MambaModelArchConfigConvertor(ModelArchConfigConvertorBase): class MambaModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_head_size(self) -> int: def get_head_size(self) -> int:
return 0 return 0
...@@ -423,6 +445,16 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): ...@@ -423,6 +445,16 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
return getattr(self.hf_text_config, "num_nextn_predict_layers", 1) return getattr(self.hf_text_config, "num_nextn_predict_layers", 1)
class Gemma4ModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_head_size(self) -> int:
# Gemma4 uses dual head dimensions: head_dim (sliding attention)
# and global_head_dim (full attention). Return the largest so
# that attention backends allocate buffers large enough for both.
head_dim = getattr(self.hf_text_config, "head_dim", 0)
global_head_dim = getattr(self.hf_text_config, "global_head_dim", 0)
return max(head_dim, global_head_dim) or super().get_head_size()
# hf_config.model_type -> convertor class # hf_config.model_type -> convertor class
MODEL_ARCH_CONFIG_CONVERTORS = { MODEL_ARCH_CONFIG_CONVERTORS = {
"mamba": MambaModelArchConfigConvertor, "mamba": MambaModelArchConfigConvertor,
...@@ -433,6 +465,8 @@ MODEL_ARCH_CONFIG_CONVERTORS = { ...@@ -433,6 +465,8 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
"mpt": MPTModelArchConfigConvertor, "mpt": MPTModelArchConfigConvertor,
"dbrx": DbrxModelArchConfigConvertor, "dbrx": DbrxModelArchConfigConvertor,
"falcon": FalconModelArchConfigConvertor, "falcon": FalconModelArchConfigConvertor,
"gemma4": Gemma4ModelArchConfigConvertor,
"gemma4_text": Gemma4ModelArchConfigConvertor,
"RefinedWeb": FalconModelArchConfigConvertor, "RefinedWeb": FalconModelArchConfigConvertor,
"RefinedWebModel": FalconModelArchConfigConvertor, "RefinedWebModel": FalconModelArchConfigConvertor,
"nemotron-nas": NemotronNasModelArchConfigConvertor, "nemotron-nas": NemotronNasModelArchConfigConvertor,
......
...@@ -1040,6 +1040,8 @@ def unified_attention( ...@@ -1040,6 +1040,8 @@ def unified_attention(
num_seqs=num_seqs, num_seqs=num_seqs,
BLOCK_M=BLOCK_M, BLOCK_M=BLOCK_M,
USE_FP8=output_scale is not None, USE_FP8=output_scale is not None,
num_stages =1
) )
else: else:
kernel_unified_attention_3d[ kernel_unified_attention_3d[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment