Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
858bddce
Commit
858bddce
authored
Apr 09, 2026
by
luopl
Browse files
feat:add gemma4
parent
40faaf0c
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
4233 additions
and
1 deletion
+4233
-1
vllm/model_executor/layers/rotary_embedding/__init__.py
vllm/model_executor/layers/rotary_embedding/__init__.py
+12
-0
vllm/model_executor/layers/rotary_embedding/gemma4_rope.py
vllm/model_executor/layers/rotary_embedding/gemma4_rope.py
+84
-0
vllm/model_executor/models/config.py
vllm/model_executor/models/config.py
+54
-0
vllm/model_executor/models/gemma4.py
vllm/model_executor/models/gemma4.py
+1581
-0
vllm/model_executor/models/gemma4_mm.py
vllm/model_executor/models/gemma4_mm.py
+1339
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+2
-0
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+8
-1
vllm/reasoning/__init__.py
vllm/reasoning/__init__.py
+4
-0
vllm/reasoning/gemma4_reasoning_parser.py
vllm/reasoning/gemma4_reasoning_parser.py
+225
-0
vllm/reasoning/gemma4_utils.py
vllm/reasoning/gemma4_utils.py
+130
-0
vllm/tool_parsers/__init__.py
vllm/tool_parsers/__init__.py
+4
-0
vllm/tool_parsers/gemma4_tool_parser.py
vllm/tool_parsers/gemma4_tool_parser.py
+754
-0
vllm/transformers_utils/model_arch_config_convertor.py
vllm/transformers_utils/model_arch_config_convertor.py
+34
-0
vllm/v1/attention/ops/triton_unified_attention.py
vllm/v1/attention/ops/triton_unified_attention.py
+2
-0
No files found.
vllm/model_executor/layers/rotary_embedding/__init__.py
View file @
858bddce
...
...
@@ -13,6 +13,7 @@ from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
from
.dynamic_ntk_scaling_rope
import
DynamicNTKScalingRotaryEmbedding
from
.fope
import
FourierRotaryEmbedding
from
.linear_scaling_rope
import
LinearScalingRotaryEmbedding
from
.gemma4_rope
import
Gemma4RotaryEmbedding
from
.llama3_rope
import
Llama3RotaryEmbedding
from
.llama4_vision_rope
import
Llama4VisionRotaryEmbedding
from
.mrope
import
MRotaryEmbedding
...
...
@@ -134,6 +135,17 @@ def get_rope(
is_neox_style
,
dtype
,
)
elif
scaling_type
==
"proportional"
:
# Proportional RoPE is used by Gemma4 for global (full) attention.
# Gemma4 uses a sparse/fractional RoPE with cross-mixing between halves.
rotary_emb
=
Gemma4RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
)
elif
scaling_type
==
"llama3"
:
scaling_factor
=
rope_parameters
[
"factor"
]
low_freq_factor
=
rope_parameters
[
"low_freq_factor"
]
...
...
vllm/model_executor/layers/rotary_embedding/gemma4_rope.py
0 → 100644
View file @
858bddce
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Gemma4-specific Rotary Positional Embeddings (proportional scaling).
Gemma4 uses "proportional" RoPE which computes inv_freq frequencies scaled
by head_dim (not rotary_dim), and zero-pads for non-rotated dimensions when
partial_rotary_factor < 1. The actual rotation uses standard neox-style
rotate_half, matching HF transformers' apply_rotary_pos_emb.
"""
import
torch
from
.base
import
RotaryEmbedding
class
Gemma4RotaryEmbedding
(
RotaryEmbedding
):
"""Gemma4 proportional RoPE.
Extends RotaryEmbedding (which provides standard neox-style rotation
via ops.rotary_embedding CUDA kernel) but overrides the inv_freq
computation to match HF's _compute_proportional_rope_parameters:
- Frequency exponents use head_dim (not rotary_dim) as denominator
- Non-rotated dims are zero-padded (cos=1, sin=0 = identity rotation)
When partial_rotary_factor=1.0 (the default for some variants), ALL dims are
rotated and this is equivalent to standard RotaryEmbedding with
head_dim-scaled frequencies.
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
)
->
None
:
# Number of rotation angle pairs (from partial_rotary_factor)
self
.
rope_angles
=
rotary_dim
//
2
# Non-rotated angle pairs per half
self
.
nope_angles
=
(
head_size
//
2
)
-
self
.
rope_angles
# Important: set rotary_dim = head_size so the base class's
# forward_static applies rotation to ALL dims of the cos/sin cache.
# The non-rotated dims will have cos=1, sin=0 (identity) thanks
# to our _compute_inv_freq zero-padding.
super
().
__init__
(
head_size
,
head_size
,
# rotary_dim = head_size (full application)
max_position_embeddings
,
base
,
is_neox_style
,
dtype
,
)
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
"""Compute frequencies matching HF proportional RoPE.
Key difference from base: exponent denominator is head_size (not
rotary_dim), and non-rotated dims are zero-padded.
"""
# HF formula: base ** (arange(0, 2*rope_angles, 2) / head_dim)
freq_exponents
=
(
torch
.
arange
(
0
,
2
*
self
.
rope_angles
,
2
,
dtype
=
torch
.
float
)
/
self
.
head_size
)
inv_freq
=
1.0
/
(
base
**
freq_exponents
)
# Zero-pad for non-rotated dims (identity rotation: cos=1, sin=0)
if
self
.
nope_angles
>
0
:
inv_freq
=
torch
.
cat
(
[
inv_freq
,
torch
.
zeros
(
self
.
nope_angles
,
dtype
=
torch
.
float
),
]
)
return
inv_freq
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", rope_angles=
{
self
.
rope_angles
}
, nope_angles=
{
self
.
nope_angles
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
s
+=
f
", base=
{
self
.
base
}
, is_neox_style=
{
self
.
is_neox_style
}
"
return
s
vllm/model_executor/models/config.py
View file @
858bddce
...
...
@@ -56,6 +56,57 @@ class Gemma3TextModelConfig(VerifyAndUpdateConfig):
hf_config
=
model_config
.
hf_config
hf_config
.
is_causal
=
not
hf_config
.
use_bidirectional_attention
class
Gemma4Config
(
VerifyAndUpdateConfig
):
@
staticmethod
def
verify_and_update_config
(
vllm_config
:
"VllmConfig"
)
->
None
:
"""Force unified attention backend for models with heterogeneous
head dimensions.
Some Gemma4 variants use different head dimensions for
sliding window (head_dim) vs full attention (global_head_dim) layers.
When global_head_dim > 256, FlashAttention rejects those layers
(head_size <= 256 kernel limit), causing vLLM to select a different
backend for each layer type. This mixed-backend execution produces
numerical divergence and output corruption.
The fix detects heterogeneous head dimensions from the model config
and forces TRITON_ATTN (which has no head_size ceiling) for all
layers when the user hasn't explicitly chosen a backend.
TODO: Heterogeneous head_sizes (head_dim != global_head_dim)
require NixlConnector changes to support per-layer KV transfer
with different head dimensions for prefill-decode disaggregation.
"""
hf_text_config
=
vllm_config
.
model_config
.
hf_text_config
head_dim
=
getattr
(
hf_text_config
,
"head_dim"
,
None
)
global_head_dim
=
getattr
(
hf_text_config
,
"global_head_dim"
,
None
)
# Only force Triton when head dimensions actually differ AND the
# larger one exceeds FlashAttention's kernel limit (head_size <= 256).
# This avoids unnecessary backend forcing on smaller models where
# the config carries global_head_dim but all layers can still use
# the same FA backend.
max_head_dim
=
max
(
head_dim
or
0
,
global_head_dim
or
0
)
if
(
head_dim
is
not
None
and
global_head_dim
is
not
None
and
head_dim
!=
global_head_dim
and
max_head_dim
>
256
and
vllm_config
.
attention_config
.
backend
is
None
):
from
vllm.v1.attention.backends.registry
import
(
AttentionBackendEnum
,
)
vllm_config
.
attention_config
.
backend
=
AttentionBackendEnum
.
TRITON_ATTN
logger
.
info
(
"Gemma4 model has heterogeneous head dimensions "
"(head_dim=%d, global_head_dim=%d). Forcing TRITON_ATTN "
"backend to prevent mixed-backend numerical divergence."
,
head_dim
,
global_head_dim
,
)
class
GptOssForCausalLMConfig
(
VerifyAndUpdateConfig
):
@
staticmethod
...
...
@@ -647,10 +698,13 @@ class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig):
MODELS_CONFIG_MAP
:
dict
[
str
,
type
[
VerifyAndUpdateConfig
]]
=
{
"ColBERTJinaRobertaModel"
:
JinaRobertaModelConfig
,
"ColQwen3_5"
:
Qwen3_5ForConditionalGenerationConfig
,
"DeepseekV32ForCausalLM"
:
DeepseekV32ForCausalLM
,
"Ernie4_5_VLMoeForConditionalGeneration"
:
Ernie4_5_VLMoeForConditionalGenerationConfig
,
# noqa: E501
"FalconMambaForCausalLM"
:
MambaModelConfig
,
"Gemma3TextModel"
:
Gemma3TextModelConfig
,
"Gemma4ForCausalLM"
:
Gemma4Config
,
"Gemma4ForConditionalGeneration"
:
Gemma4Config
,
"GptOssForCausalLM"
:
GptOssForCausalLMConfig
,
"GteModel"
:
SnowflakeGteNewModelConfig
,
"GteNewForSequenceClassification"
:
GteNewModelConfig
,
...
...
vllm/model_executor/models/gemma4.py
0 → 100644
View file @
858bddce
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The vLLM team.
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gemma 4 model implementation for vLLM."""
from
collections.abc
import
Iterable
from
dataclasses
import
replace
from
itertools
import
islice
import
regex
as
re
import
torch
from
torch
import
nn
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
GateLinear
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
,
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.v1.attention.backends.utils
import
KVSharingFastPrefillMetadata
from
.interfaces
import
MixtureOfExperts
,
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
is_pp_missing_parameter
,
make_layers
,
maybe_prefix
,
)
logger
=
init_logger
(
__name__
)
def
_get_text_config
(
config
):
"""Dereference text_config if config is a nested Gemma4Config.
Gemma4 checkpoints use architectures=["Gemma4ForConditionalGeneration"]
which yields a Gemma4Config with nested text_config. This function
transparently returns the text config regardless of nesting.
"""
if
hasattr
(
config
,
"text_config"
):
return
config
.
text_config
return
config
class
Gemma4MLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_activation
:
str
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
if
hidden_activation
!=
"gelu_pytorch_tanh"
:
raise
ValueError
(
"Gemma4 uses `gelu_pytorch_tanh` as the hidden activation "
"function. Please set `hidden_act` and `hidden_activation` to "
"`gelu_pytorch_tanh`."
)
self
.
act_fn
=
GeluAndMul
(
approximate
=
"tanh"
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
Gemma4Router
(
nn
.
Module
):
"""Router for Gemma4 MoE that preprocesses input before projection.
Applies RMSNorm (no learned weight), root_size scaling
(hidden_size^{-0.5}), then a learned per-dimension scale before
projecting to expert logits.
This preprocessing is applied ONLY to the router's input, not to
the expert MLPs' input.
"""
def
__init__
(
self
,
config
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
# RMSNorm without learned weight — pure normalization only
self
.
norm
=
RMSNorm
(
self
.
hidden_size
,
eps
=
config
.
rms_norm_eps
,
has_weight
=
False
)
# Per-dimension learned scale, applied after norm + root_size
self
.
scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
hidden_size
))
# Constant 1/sqrt(hidden_size) scaling factor
self
.
register_buffer
(
"root_size"
,
torch
.
tensor
(
self
.
hidden_size
**-
0.5
),
persistent
=
False
,
)
# Project to expert logits; replicated across TP for consistent routing
# GateLinear supports bf16 W/A → fp32 output, which is important
# because the topk kernel often needs fp32 for stable routing.
self
.
proj
=
GateLinear
(
self
.
hidden_size
,
config
.
num_experts
,
bias
=
False
,
out_dtype
=
torch
.
float32
,
prefix
=
f
"
{
prefix
}
.proj"
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Returns raw router logits [T, E]."""
x
=
self
.
norm
(
x
)
x
=
x
*
self
.
root_size
.
to
(
x
.
dtype
)
x
=
x
*
self
.
scale
.
to
(
x
.
dtype
)
router_logits
,
_
=
self
.
proj
(
x
)
return
router_logits
class
Gemma4MoE
(
nn
.
Module
):
"""Mixture of Experts for Gemma4 using vLLM's FusedMoE.
Wraps FusedMoE with custom routing. The router projection is
external (Gemma4Router) — this class only handles expert dispatch.
Gemma4 routing: softmax over ALL experts → top-k → renormalize.
per_expert_scale is folded into routing weights for mathematical
correctness with FusedMoE's fused kernel.
"""
def
__init__
(
self
,
config
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
num_experts
=
config
.
num_experts
# Per-expert output scale folded into routing weights so that
# FusedMoE's fused kernel computes: Σ_e (expert_e * w_e * scale_e)
self
.
per_expert_scale
=
nn
.
Parameter
(
torch
.
ones
(
config
.
num_experts
))
# Gemma4 routing: softmax over ALL experts → top-k → renormalize.
# FusedMoE's built-in fused_topk scopes softmax differently, so
# a custom routing function is needed for numerical correctness.
per_expert_scale
=
self
.
per_expert_scale
def
routing_function
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
_
,
topk_ids
=
torch
.
topk
(
gating_output
,
k
=
topk
,
dim
=-
1
)
router_probabilities
=
torch
.
nn
.
functional
.
softmax
(
gating_output
,
dim
=-
1
)
indicator
=
torch
.
nn
.
functional
.
one_hot
(
topk_ids
,
num_classes
=
gating_output
.
size
(
-
1
)
).
sum
(
dim
=-
2
)
gate_weights
=
indicator
*
router_probabilities
renorm_factor
=
torch
.
sum
(
gate_weights
,
dim
=-
1
,
keepdim
=
True
)
renorm_factor
=
torch
.
where
(
renorm_factor
>
0.0
,
renorm_factor
,
1.0
)
dispatch_weights
=
gate_weights
/
renorm_factor
topk_weights
=
dispatch_weights
.
gather
(
1
,
topk_ids
)
# Fold per_expert_scale into routing weights
expert_scales
=
per_expert_scale
[
topk_ids
].
to
(
topk_weights
.
dtype
)
topk_weights
=
topk_weights
*
expert_scales
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
# FusedMoE experts with custom Gemma4 routing
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
num_experts
,
top_k
=
config
.
top_k_experts
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
getattr
(
config
,
"moe_intermediate_size"
,
getattr
(
config
,
"expert_intermediate_size"
,
None
),
),
reduce_results
=
True
,
renormalize
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
custom_routing_function
=
routing_function
,
activation
=
"gelu"
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
experts
(
x
,
router_logits
)
class
Gemma4Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
head_dim
:
int
,
max_position_embeddings
:
int
,
use_k_eq_v
:
bool
=
False
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
attn_logits_soft_cap
:
float
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
hidden_size
self
.
use_k_eq_v
=
use_k_eq_v
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
head_dim
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
# Gemma4 uses scaling=1.0.
# Unlike Gemma2/3, query_pre_attn_scalar is NOT used here;
# Q/K norms with learnable weights handle scaling implicitly.
self
.
scaling
=
1.0
# QKVParallelLinear handles GQA correctly for all layer types.
# k_eq_v layers load K weights into both K and V slots via
# _weight_iterator remapping — no structural difference needed.
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
config
.
attention_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
config
.
attention_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
# Q/K norms: output = norm(x) * weight (learnable per-head scale)
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
# V norm: no learnable scale (pure normalization only)
self
.
v_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
,
has_weight
=
False
)
# Determine layer type and sliding window
layer_idx
=
extract_layer_index
(
prefix
)
layer_type
=
config
.
layer_types
[
layer_idx
]
self
.
is_sliding
=
layer_type
==
"sliding_attention"
sliding_window
=
config
.
sliding_window
if
self
.
is_sliding
else
None
# Initialize RoPE based on layer type.
# Gemma4 uses different RoPE parameters for sliding vs full attention.
if
layer_type
in
config
.
rope_parameters
:
# Per-layer-type rope config (dict format).
# rope_parameters already contains the correct
# partial_rotary_factor per layer type (1.0 for full
# attention, 1.0 for sliding). Do NOT override with
# global_partial_rotary_factor — that config key is
# not needed for Gemma4 — config uses per-layer rope_parameters.
rope_parameters
=
dict
(
config
.
rope_parameters
[
layer_type
])
else
:
# Legacy config format fallback.
rope_parameters
=
dict
(
config
.
rope_parameters
.
copy
())
if
self
.
is_sliding
:
rope_parameters
[
"rope_theta"
]
=
getattr
(
config
,
"rope_local_base_freq"
,
10000.0
)
# KV sharing: layers in the last `num_kv_shared_layers` share KV
# cache with earlier layers of the same type.
kv_sharing_target_layer_name
=
None
self
.
is_kv_shared_layer
=
False
num_kv_shared_layers
=
getattr
(
config
,
"num_kv_shared_layers"
,
0
)
if
num_kv_shared_layers
>
0
:
first_kv_shared_layer_idx
=
config
.
num_hidden_layers
-
num_kv_shared_layers
if
layer_idx
>=
first_kv_shared_layer_idx
:
self
.
is_kv_shared_layer
=
True
# Find the last non-shared layer of the same attention type
prev_layers
=
config
.
layer_types
[:
first_kv_shared_layer_idx
]
current_layer_type
=
config
.
layer_types
[
layer_idx
]
kv_shared_layer_index
=
(
len
(
prev_layers
)
-
1
-
prev_layers
[::
-
1
].
index
(
current_layer_type
)
)
if
kv_shared_layer_index
>=
0
:
if
".layers."
in
prefix
:
param_name_before_layers
=
prefix
.
split
(
".layers."
)[
0
]
else
:
raise
ValueError
(
"Unexpected prefix format for Gemma4Attention: "
f
"'
{
prefix
}
'. Expected to contain '.layers.'."
)
kv_sharing_target_layer_name
=
(
f
"
{
param_name_before_layers
}
.layers."
f
"
{
kv_shared_layer_index
}
.self_attn.attn"
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
max_position
=
max_position_embeddings
,
rope_parameters
=
rope_parameters
,
is_neox_style
=
True
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
logits_soft_cap
=
attn_logits_soft_cap
,
per_layer_sliding_window
=
sliding_window
,
kv_sharing_target_layer_name
=
kv_sharing_target_layer_name
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
**
kwargs
,
)
->
torch
.
Tensor
:
# Unified QKV path (works for both k_eq_v and standard layers).
# For k_eq_v, K weights are loaded into both K and V slots of
# qkv_proj, so V == K automatically.
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
# Q norm (always applied)
q
=
q
.
unflatten
(
-
1
,
(
self
.
num_heads
,
self
.
head_dim
))
q
=
self
.
q_norm
(
q
)
q
=
q
.
flatten
(
-
2
,
-
1
)
if
not
self
.
is_kv_shared_layer
:
# Non-shared: apply K norm + RoPE, V norm
k
=
k
.
unflatten
(
-
1
,
(
self
.
num_kv_heads
,
self
.
head_dim
))
k
=
self
.
k_norm
(
k
)
k
=
k
.
flatten
(
-
2
,
-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
v
=
v
.
unflatten
(
-
1
,
(
self
.
num_kv_heads
,
self
.
head_dim
))
v
=
self
.
v_norm
(
v
)
v
=
v
.
flatten
(
-
2
,
-
1
)
else
:
# Shared: only apply RoPE to Q
q
=
self
.
rotary_emb
(
positions
,
q
,
k
)[
0
]
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
Gemma4DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size_per_layer_input
=
getattr
(
config
,
"hidden_size_per_layer_input"
,
0
)
layer_idx
=
extract_layer_index
(
prefix
)
self
.
layer_idx
=
layer_idx
# Gemma4 uses different head dimensions for sliding vs full attention
layer_type
=
config
.
layer_types
[
layer_idx
]
self
.
is_full_attention
=
layer_type
==
"full_attention"
if
self
.
is_full_attention
:
head_dim
=
getattr
(
config
,
"global_head_dim"
,
config
.
head_dim
)
else
:
head_dim
=
config
.
head_dim
# Determine if this full-attention layer uses k_eq_v
# (laptop variant: no v_proj, K reused as V on full attention layers)
use_k_eq_v
=
self
.
is_full_attention
and
getattr
(
config
,
"attention_k_eq_v"
,
False
)
# For k_eq_v full-attention layers, use num_global_key_value_heads
# as the KV head count when k_eq_v is enabled.
if
use_k_eq_v
:
num_kv_heads
=
getattr
(
config
,
"num_global_key_value_heads"
,
config
.
num_key_value_heads
)
else
:
num_kv_heads
=
config
.
num_key_value_heads
self
.
self_attn
=
Gemma4Attention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
num_kv_heads
,
head_dim
=
head_dim
,
max_position_embeddings
=
config
.
max_position_embeddings
,
use_k_eq_v
=
use_k_eq_v
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
attn_logits_soft_cap
=
getattr
(
config
,
"attn_logit_softcapping"
,
None
),
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
# Compute per-layer intermediate_size from config.
# When use_double_wide_mlp is set, intermediate_size doubles for
# KV-shared layers (layers >= first_kv_shared_layer_idx).
first_kv_shared_layer_idx
=
config
.
num_hidden_layers
-
getattr
(
config
,
"num_kv_shared_layers"
,
0
)
is_kv_shared_layer
=
layer_idx
>=
first_kv_shared_layer_idx
>
0
use_double_wide_mlp
=
(
getattr
(
config
,
"use_double_wide_mlp"
,
False
)
and
is_kv_shared_layer
)
layer_intermediate_size
=
config
.
intermediate_size
*
(
2
if
use_double_wide_mlp
else
1
)
self
.
mlp
=
Gemma4MLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
layer_intermediate_size
,
hidden_activation
=
config
.
hidden_activation
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
# Layer norms: output = norm(x) * weight
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
pre_feedforward_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_feedforward_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
# MoE (Mixture of Experts) — router + expert block parallel to MLP
self
.
enable_moe_block
=
getattr
(
config
,
"enable_moe_block"
,
False
)
or
getattr
(
config
,
"use_second_mlp_block"
,
False
)
if
self
.
enable_moe_block
:
self
.
router
=
Gemma4Router
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.router"
,
)
self
.
moe
=
Gemma4MoE
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.moe"
,
)
self
.
post_feedforward_layernorm_1
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_feedforward_layernorm_2
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
pre_feedforward_layernorm_2
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
router
=
None
self
.
moe
=
None
self
.
post_feedforward_layernorm_1
=
None
self
.
post_feedforward_layernorm_2
=
None
self
.
pre_feedforward_layernorm_2
=
None
# Per-Layer Embedding (PLE) components — present in each decoder layer
if
(
self
.
hidden_size_per_layer_input
is
not
None
and
self
.
hidden_size_per_layer_input
>
0
):
# Gate: projects hidden_states → per-layer dim for gating
self
.
per_layer_input_gate
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
hidden_size_per_layer_input
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.per_layer_input_gate"
,
return_bias
=
False
,
)
# Projection: projects gated per-layer input back → hidden size
self
.
per_layer_projection
=
ReplicatedLinear
(
self
.
hidden_size_per_layer_input
,
self
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.per_layer_projection"
,
return_bias
=
False
,
)
# Post-PLE norm: output = norm(x) * weight
self
.
post_per_layer_input_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
per_layer_input_gate
=
None
self
.
per_layer_projection
=
None
self
.
post_per_layer_input_norm
=
None
# Layer scalar (loaded from checkpoint) — applies to ALL text layers
self
.
register_buffer
(
"layer_scalar"
,
torch
.
ones
(
1
))
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
per_layer_input
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Gemma4 residual pattern:
# 1. input_norm(x) → attn → post_attn_norm → ADD residual
# 2. pre_ff_norm → mlp → post_ff_norm → ADD residual
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
**
kwargs
,
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
hidden_states
+
residual
residual
=
hidden_states
# MLP runs unconditionally (same inputs for MoE and non-MoE)
hidden_states
=
self
.
pre_feedforward_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
if
self
.
enable_moe_block
:
hidden_states_1
=
self
.
post_feedforward_layernorm_1
(
hidden_states
)
# Router and MoE experts see the residual (pre-MLP state),
# matching the HF transformers forward path
router_logits
=
self
.
router
(
residual
)
hidden_states_2
=
self
.
pre_feedforward_layernorm_2
(
residual
)
hidden_states_2
=
self
.
moe
(
hidden_states_2
,
router_logits
)
hidden_states_2
=
self
.
post_feedforward_layernorm_2
(
hidden_states_2
)
# Combine MLP and MoE outputs
hidden_states
=
hidden_states_1
+
hidden_states_2
hidden_states
=
self
.
post_feedforward_layernorm
(
hidden_states
)
hidden_states
=
hidden_states
+
residual
# Apply PLE (Per-Layer Embedding) if configured
if
per_layer_input
is
not
None
and
self
.
per_layer_input_gate
is
not
None
:
gate
=
self
.
per_layer_input_gate
(
hidden_states
)
gate
=
torch
.
nn
.
functional
.
gelu
(
gate
,
approximate
=
"tanh"
)
gated_per_layer
=
gate
*
per_layer_input
per_layer_contribution
=
self
.
per_layer_projection
(
gated_per_layer
)
per_layer_contribution
=
self
.
post_per_layer_input_norm
(
per_layer_contribution
)
hidden_states
=
hidden_states
+
per_layer_contribution
# Apply layer scalar for full-attention layers
# Apply per-layer scalar (all text layers)
hidden_states
=
hidden_states
*
self
.
layer_scalar
return
hidden_states
,
None
def
_run_decoder_layers
(
decoder_layers
:
list
[
Gemma4DecoderLayer
],
layer_idx_start
:
int
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""Run a slice of decoder layers with PLE extraction."""
residual
=
None
for
idx
,
layer
in
enumerate
(
decoder_layers
):
layer_idx
=
idx
+
layer_idx_start
layer_per_input
=
(
per_layer_inputs
[:,
layer_idx
,
:]
if
per_layer_inputs
is
not
None
else
None
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
,
per_layer_input
=
layer_per_input
,
**
kwargs
,
)
return
hidden_states
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma4SelfDecoderLayers
(
nn
.
Module
):
"""Compiled wrapper: embedding + non-KV-shared layers (YOCO first half).
Owns the embedding and PLE modules so they are inside the compiled
graph. Gemma4Model delegates embedding methods here.
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
decoder_layers
:
list
[
Gemma4DecoderLayer
],
layer_idx_start
:
int
,
embed_tokens
:
VocabParallelEmbedding
,
normalizer
:
torch
.
Tensor
,
embed_tokens_per_layer
:
VocabParallelEmbedding
|
None
,
embed_scale_per_layer
:
torch
.
Tensor
|
None
,
per_layer_model_projection
:
ColumnParallelLinear
|
None
,
per_layer_projection_norm
:
RMSNorm
|
None
,
per_layer_input_scale
:
torch
.
Tensor
|
None
,
per_layer_projection_scale
:
torch
.
Tensor
|
None
,
):
super
().
__init__
()
self
.
decoder_layers
=
decoder_layers
self
.
layer_idx_start
=
layer_idx_start
config
=
_get_text_config
(
vllm_config
.
model_config
.
hf_config
)
self
.
config
=
config
self
.
hidden_size_per_layer_input
=
getattr
(
config
,
"hidden_size_per_layer_input"
,
0
)
self
.
vocab_size_per_layer_input
=
getattr
(
config
,
"vocab_size_per_layer_input"
,
config
.
vocab_size
)
# Shared references to modules owned by Gemma4Model — must be
# inside this nn.Module so torch.compile captures them.
self
.
embed_tokens
=
embed_tokens
self
.
normalizer
=
normalizer
self
.
embed_tokens_per_layer
=
embed_tokens_per_layer
self
.
embed_scale_per_layer
=
embed_scale_per_layer
self
.
per_layer_model_projection
=
per_layer_model_projection
self
.
per_layer_projection_norm
=
per_layer_projection_norm
self
.
per_layer_input_scale
=
per_layer_input_scale
self
.
per_layer_projection_scale
=
per_layer_projection_scale
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
*
self
.
normalizer
def
get_per_layer_inputs
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
|
None
:
"""Get per-layer embeddings from embed_tokens_per_layer.
Returns:
Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input)
"""
if
self
.
embed_tokens_per_layer
is
None
:
return
None
per_layer_inputs_mask
=
torch
.
logical_and
(
input_ids
>=
0
,
input_ids
<
self
.
vocab_size_per_layer_input
,
)
per_layer_inputs_tokens
=
torch
.
where
(
per_layer_inputs_mask
,
input_ids
,
torch
.
zeros_like
(
input_ids
)
)
per_layer_embeds
=
self
.
embed_tokens_per_layer
(
per_layer_inputs_tokens
)
per_layer_embeds
=
per_layer_embeds
*
self
.
embed_scale_per_layer
return
per_layer_embeds
.
reshape
(
*
input_ids
.
shape
,
self
.
config
.
num_hidden_layers
,
self
.
hidden_size_per_layer_input
,
)
def
project_per_layer_inputs
(
self
,
inputs_embeds
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
None
:
"""Project inputs_embeds and combine with per_layer_inputs.
Steps:
1. Project inputs_embeds: hidden_size → total_ple_dim
2. Scale by hidden_size^{-0.5}
3. Reshape to (num_tokens, num_layers, per_layer_dim)
4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
"""
if
self
.
per_layer_model_projection
is
None
:
return
None
per_layer_projection
=
self
.
per_layer_model_projection
(
inputs_embeds
)
per_layer_projection
=
per_layer_projection
*
self
.
per_layer_projection_scale
per_layer_projection
=
per_layer_projection
.
reshape
(
*
inputs_embeds
.
shape
[:
-
1
],
self
.
config
.
num_hidden_layers
,
self
.
hidden_size_per_layer_input
,
)
per_layer_projection
=
self
.
per_layer_projection_norm
(
per_layer_projection
)
if
per_layer_inputs
is
None
:
return
per_layer_projection
return
(
per_layer_projection
+
per_layer_inputs
)
*
self
.
per_layer_input_scale
def
forward
(
self
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
per_layer_inputs
=
self
.
project_per_layer_inputs
(
hidden_states
,
per_layer_inputs
)
else
:
hidden_states
=
self
.
embed_input_ids
(
input_ids
)
per_layer_embeds
=
self
.
get_per_layer_inputs
(
input_ids
)
per_layer_inputs
=
self
.
project_per_layer_inputs
(
hidden_states
,
per_layer_embeds
)
hidden_states
=
_run_decoder_layers
(
self
.
decoder_layers
,
self
.
layer_idx_start
,
positions
,
hidden_states
,
per_layer_inputs
,
**
kwargs
,
)
return
hidden_states
,
per_layer_inputs
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma4CrossDecoderLayers
(
nn
.
Module
):
"""Cross-decoder layers (YOCO second half, KV-shared)."""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
decoder_layers
:
list
[
Gemma4DecoderLayer
],
layer_idx_start
:
int
,
):
super
().
__init__
()
self
.
decoder_layers
=
decoder_layers
self
.
layer_idx_start
=
layer_idx_start
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
return
_run_decoder_layers
(
self
.
decoder_layers
,
self
.
layer_idx_start
,
positions
,
hidden_states
,
per_layer_inputs
,
**
kwargs
,
)
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
not
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma4Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
_get_text_config
(
vllm_config
.
model_config
.
hf_config
)
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
# PLE config values (default to 0 if not present — disables PLE)
self
.
hidden_size_per_layer_input
=
getattr
(
config
,
"hidden_size_per_layer_input"
,
0
)
self
.
vocab_size_per_layer_input
=
getattr
(
config
,
"vocab_size_per_layer_input"
,
config
.
vocab_size
)
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.embed_tokens"
,
)
# Per-Layer Embedding (PLE) components
if
(
self
.
hidden_size_per_layer_input
is
not
None
and
self
.
hidden_size_per_layer_input
>
0
):
total_ple_dim
=
self
.
hidden_size_per_layer_input
*
config
.
num_hidden_layers
self
.
embed_tokens_per_layer
=
VocabParallelEmbedding
(
self
.
vocab_size_per_layer_input
,
total_ple_dim
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.embed_tokens_per_layer"
,
)
# Scaled embedding factor (from config, not hardcoded)
# Register as buffer so it moves to GPU with the model
# and interacts correctly with torch.compile AOT caching.
self
.
register_buffer
(
"embed_scale_per_layer"
,
torch
.
tensor
(
self
.
hidden_size_per_layer_input
**
0.5
),
persistent
=
False
,
)
# Projection: hidden_size → total_ple_dim
# ColumnParallelLinear with gather_output=True
self
.
per_layer_model_projection
=
ColumnParallelLinear
(
config
.
hidden_size
,
total_ple_dim
,
bias
=
False
,
gather_output
=
True
,
return_bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.per_layer_model_projection"
,
)
# PLE projection norm: output = norm(x) * weight
self
.
per_layer_projection_norm
=
RMSNorm
(
self
.
hidden_size_per_layer_input
,
eps
=
config
.
rms_norm_eps
,
)
# Scale factor for combining projection + per_layer_inputs
# Register as buffer so it moves to GPU with the model
# and interacts correctly with torch.compile AOT caching.
self
.
register_buffer
(
"per_layer_input_scale"
,
torch
.
rsqrt
(
torch
.
tensor
(
2.0
)),
persistent
=
False
,
)
# Scaled projection: multiply output by hidden_size**-0.5.
# Register as buffer for GPU placement and torch.compile.
self
.
register_buffer
(
"per_layer_projection_scale"
,
torch
.
tensor
(
config
.
hidden_size
**-
0.5
),
persistent
=
False
,
)
else
:
self
.
embed_tokens_per_layer
=
None
self
.
embed_scale_per_layer
=
None
self
.
per_layer_model_projection
=
None
self
.
per_layer_projection_norm
=
None
self
.
per_layer_input_scale
=
None
self
.
per_layer_projection_scale
=
None
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
Gemma4DecoderLayer
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
# Final norm: output = norm(x) * weight
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
# Embedding scale = sqrt(hidden_size)
# Downcast to model dtype (bfloat16 etc.) for numerical parity
self
.
register_buffer
(
"normalizer"
,
torch
.
tensor
(
config
.
hidden_size
**
0.5
),
persistent
=
False
,
)
# --- You Only Cache Once (YOCO) split for fast prefill ---
first_kv_shared_layer_idx
=
config
.
num_hidden_layers
-
getattr
(
config
,
"num_kv_shared_layers"
,
0
)
from
vllm.compilation.backends
import
set_model_tag
# Layers 0..(K-1) are self-decoder layers in YOCO
with
set_model_tag
(
"self_decoder"
):
self
.
self_decoder
=
Gemma4SelfDecoderLayers
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.self_decoder"
,
decoder_layers
=
self
.
layers
[:
first_kv_shared_layer_idx
],
layer_idx_start
=
0
,
embed_tokens
=
self
.
embed_tokens
,
normalizer
=
self
.
normalizer
,
embed_tokens_per_layer
=
getattr
(
self
,
"embed_tokens_per_layer"
,
None
),
embed_scale_per_layer
=
getattr
(
self
,
"embed_scale_per_layer"
,
None
),
per_layer_model_projection
=
getattr
(
self
,
"per_layer_model_projection"
,
None
),
per_layer_projection_norm
=
getattr
(
self
,
"per_layer_projection_norm"
,
None
),
per_layer_input_scale
=
getattr
(
self
,
"per_layer_input_scale"
,
None
),
per_layer_projection_scale
=
getattr
(
self
,
"per_layer_projection_scale"
,
None
),
)
# Layers K..(N-1) are cross-decoder layers in YOCO
with
set_model_tag
(
"cross_decoder"
):
self
.
cross_decoder
=
Gemma4CrossDecoderLayers
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.cross_decoder"
,
decoder_layers
=
self
.
layers
[
first_kv_shared_layer_idx
:],
layer_idx_start
=
first_kv_shared_layer_idx
,
)
self
.
fast_prefill_enabled
=
cache_config
.
kv_sharing_fast_prefill
if
self
.
fast_prefill_enabled
:
# Allocate static buffers for CUDAGraph
max_num_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
device
=
next
(
self
.
parameters
()).
device
self
.
positions
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
self
.
hidden_states
=
torch
.
zeros
(
(
max_num_tokens
,
config
.
hidden_size
),
dtype
=
self
.
embed_tokens
.
weight
.
dtype
,
device
=
device
,
)
if
(
self
.
hidden_size_per_layer_input
and
self
.
hidden_size_per_layer_input
>
0
):
self
.
per_layer_inputs
=
torch
.
zeros
(
(
max_num_tokens
,
config
.
num_hidden_layers
,
self
.
hidden_size_per_layer_input
,
),
dtype
=
self
.
embed_tokens
.
weight
.
dtype
,
device
=
device
,
)
else
:
self
.
per_layer_inputs
=
None
# Custom factory that includes per_layer_inputs for PLE-enabled PP.
# per_layer_inputs has shape (batch, num_layers, per_layer_dim),
# which differs from the standard (batch, hidden_size) shape,
# so we can't use the default factory.
ple_dim
=
self
.
hidden_size_per_layer_input
num_layers
=
config
.
num_hidden_layers
hidden_size
=
config
.
hidden_size
def
_make_empty_intermediate_tensors
(
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
)
->
IntermediateTensors
:
tensors
:
dict
[
str
,
torch
.
Tensor
]
=
{
"hidden_states"
:
torch
.
zeros
(
(
batch_size
,
hidden_size
),
dtype
=
dtype
,
device
=
device
,
),
"residual"
:
torch
.
zeros
(
(
batch_size
,
hidden_size
),
dtype
=
dtype
,
device
=
device
,
),
}
if
ple_dim
and
ple_dim
>
0
:
tensors
[
"per_layer_inputs"
]
=
torch
.
zeros
(
(
batch_size
,
num_layers
,
ple_dim
),
dtype
=
dtype
,
device
=
device
,
)
return
IntermediateTensors
(
tensors
)
self
.
make_empty_intermediate_tensors
=
_make_empty_intermediate_tensors
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
self_decoder
.
embed_input_ids
(
input_ids
)
def
get_per_layer_inputs
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
|
None
:
"""Get per-layer embeddings from embed_tokens_per_layer.
Returns:
Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input)
"""
return
self
.
self_decoder
.
get_per_layer_inputs
(
input_ids
)
def
project_per_layer_inputs
(
self
,
inputs_embeds
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
None
:
"""Project inputs_embeds and combine with per_layer_inputs.
Steps:
1. Project inputs_embeds: hidden_size → total_ple_dim
2. Scale by hidden_size^{-0.5}
3. Reshape to (num_tokens, num_layers, per_layer_dim)
4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
"""
return
self
.
self_decoder
.
project_per_layer_inputs
(
inputs_embeds
,
per_layer_inputs
)
def
fast_prefill_forward
(
self
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
logits_indices_padded
,
num_logits_indices
=
None
,
None
attn_metadata
=
get_forward_context
().
attn_metadata
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
layer_attn_metadata
=
attn_metadata
[
self
.
layers
[
-
1
].
self_attn
.
attn
.
layer_name
]
if
isinstance
(
layer_attn_metadata
,
KVSharingFastPrefillMetadata
):
logits_indices_padded
=
layer_attn_metadata
.
logits_indices_padded
num_logits_indices
=
layer_attn_metadata
.
num_logits_indices
batch_size
=
positions
.
size
(
0
)
self
.
positions
[:
batch_size
].
copy_
(
positions
)
self_decoder_hidden_states
,
per_layer_inputs
=
self
.
self_decoder
(
input_ids
=
input_ids
,
positions
=
self
.
positions
[:
batch_size
],
inputs_embeds
=
inputs_embeds
,
per_layer_inputs
=
per_layer_inputs
,
**
kwargs
,
)
if
logits_indices_padded
is
None
:
logits_indices_padded
=
torch
.
arange
(
batch_size
,
dtype
=
positions
.
dtype
,
device
=
positions
.
device
,
)
# NOTE: Keep .clone() until fix in
# https://github.com/vllm-project/vllm/pull/22282
hidden_states
=
self_decoder_hidden_states
.
clone
()
num_padded
=
logits_indices_padded
.
size
(
0
)
self
.
positions
[:
num_padded
].
copy_
(
positions
[
logits_indices_padded
])
self
.
hidden_states
[:
num_padded
].
copy_
(
self_decoder_hidden_states
[
logits_indices_padded
]
)
if
self
.
per_layer_inputs
is
not
None
and
per_layer_inputs
is
not
None
:
self
.
per_layer_inputs
[:
num_padded
].
copy_
(
per_layer_inputs
[
logits_indices_padded
]
)
# Update batch_descriptor so the cross-decoder's piecewise
# CUDAGraphWrapper dispatches to the correct (reduced) batch size.
forward_context
=
get_forward_context
()
orig_batch_desc
=
forward_context
.
batch_descriptor
if
orig_batch_desc
is
not
None
:
forward_context
.
batch_descriptor
=
replace
(
orig_batch_desc
,
num_tokens
=
num_padded
)
cross_per_layer
=
(
self
.
per_layer_inputs
[:
num_padded
]
if
self
.
per_layer_inputs
is
not
None
else
None
)
cross_hidden_states
=
self
.
cross_decoder
(
self
.
positions
[:
num_padded
],
self
.
hidden_states
[:
num_padded
],
cross_per_layer
,
**
kwargs
,
)
# Restore the original batch_descriptor
forward_context
.
batch_descriptor
=
orig_batch_desc
if
num_logits_indices
is
not
None
:
assert
num_logits_indices
>
0
hidden_states
[
logits_indices_padded
[:
num_logits_indices
]]
=
(
cross_hidden_states
[:
num_logits_indices
]
)
else
:
hidden_states
=
cross_hidden_states
return
hidden_states
def
forward
(
self
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
|
IntermediateTensors
:
if
self
.
fast_prefill_enabled
:
hidden_states
=
self
.
fast_prefill_forward
(
input_ids
,
positions
,
inputs_embeds
,
per_layer_inputs
,
**
kwargs
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
# Normal (non-fast-prefill) path with PP support
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
# When called from the multimodal wrapper, raw PLE
# embeddings are pre-computed and passed explicitly.
# Project them through per_layer_model_projection.
per_layer_inputs
=
self
.
project_per_layer_inputs
(
hidden_states
,
per_layer_inputs
)
else
:
hidden_states
=
self
.
embed_input_ids
(
input_ids
)
# Compute per-layer inputs for PLE
per_layer_embeds
=
self
.
get_per_layer_inputs
(
input_ids
)
per_layer_inputs
=
self
.
project_per_layer_inputs
(
hidden_states
,
per_layer_embeds
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
per_layer_inputs
=
intermediate_tensors
.
get
(
"per_layer_inputs"
)
for
layer_idx
,
layer
in
enumerate
(
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
)
):
# Extract the per-layer embedding for this specific layer
if
per_layer_inputs
is
not
None
:
actual_layer_idx
=
self
.
start_layer
+
layer_idx
layer_per_input
=
per_layer_inputs
[
:,
actual_layer_idx
,
:
]
# (num_tokens, per_layer_dim)
else
:
layer_per_input
=
None
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
,
per_layer_input
=
layer_per_input
,
**
kwargs
,
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
,
"per_layer_inputs"
:
per_layer_inputs
,
}
)
# Gemma4 incorporates residual into hidden_states directly
# Apply norm without residual fusion when possible.
if
residual
is
None
:
hidden_states
=
self
.
norm
(
hidden_states
)
else
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
# MoE expert weight mapping: checkpoint can have either:
# 1. 3D packed tensors (exploded in _weight_iterator to per-expert 2D)
# 2. Already per-expert 2D weights (if quantized)
# Map to FusedMoE parameters:
# moe.experts.{id}.gate_proj → FusedMoE w1 (shard of w13)
# moe.experts.{id}.up_proj → FusedMoE w3 (shard of w13)
# moe.experts.{id}.down_proj → FusedMoE w2
#
# Use prefix matching to handle both weights and
# quantization scale parameters. The param_name is a prefix ending
# in underscore, and weight_name ends with a dot, so that:
# "experts.0.gate_proj.weight_scale" -> "experts.w13_weight_scale"
# "experts.0.gate_proj.weight" -> "experts.w13_weight"
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
None
)
or
0
expert_params_mapping
=
[
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_"
if
proj_name
in
[
"gate_proj"
,
"up_proj"
]
else
"experts.w2_"
,
f
"experts.
{
expert_id
}
.
{
proj_name
}
."
,
expert_id
,
shard_id
,
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
proj_name
in
[
(
"w1"
,
"gate_proj"
),
(
"w2"
,
"down_proj"
),
(
"w3"
,
"up_proj"
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
# Include buffers (e.g. layer_scalar) so they can be loaded too
params_dict
.
update
(
dict
(
self
.
named_buffers
()))
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
)
):
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
loaded_weight
[
0
]
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
scale_name
)
continue
if
name
.
endswith
((
".k_scale"
,
".v_scale"
,
".q_scale"
,
".prob_scale"
)):
remapped_name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
remapped_name
is
not
None
and
remapped_name
in
params_dict
:
param
=
params_dict
[
remapped_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
remapped_name
)
continue
for
param_name
,
shard_name
,
shard_id
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
continue
stacked_name
=
name
.
replace
(
shard_name
,
param_name
)
# k_eq_v layers use separate q_proj/k_proj instead of
# packed qkv_proj. If the stacked param doesn't exist,
# skip this mapping and fall through to direct load.
if
stacked_name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
stacked_name
,
self
):
continue
param
=
params_dict
[
stacked_name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
stacked_name
)
break
else
:
for
(
param_name
,
weight_name
,
expert_id
,
shard_id
,
)
in
expert_params_mapping
:
# Match both:
# - Bare weights: "experts.0.down_proj" (from 3D explosion)
# - With suffix: "experts.0.down_proj.weight_scale" (2D quantized)
# weight_name has trailing dot, so check with and without it
weight_name_base
=
weight_name
.
rstrip
(
"."
)
if
weight_name
in
name
:
# Has suffix (e.g., .weight_scale)
moe_name
=
name
.
replace
(
weight_name
,
param_name
)
elif
name
.
endswith
(
weight_name_base
):
# Bare weight (no suffix)
moe_name
=
name
.
replace
(
weight_name_base
,
param_name
.
rstrip
(
"_"
)
+
"_weight"
)
else
:
continue
if
moe_name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
moe_name
,
self
):
continue
param
=
params_dict
[
moe_name
]
# Expert weights are already in the correct
# orientation for FusedMoE after _weight_iterator:
# gate/up: [I, H] → w1/w3 expects [I, H]
# down: [H, I] → w2 expects [H, I]
# Scales and other quantization params may be 1D or scalar.
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
moe_name
,
# Pass mapped name (handles both weights and scales)
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
loaded_params
.
add
(
moe_name
)
break
else
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
Gemma4ForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
,
MixtureOfExperts
):
# Note: qkv_proj packing applies to non-k_eq_v layers (sliding
# attention and full attention without k_eq_v). k_eq_v layers use
# separate q_proj + k_proj without packing.
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
=
_get_text_config
(
vllm_config
.
model_config
.
hf_config
)
quant_config
=
vllm_config
.
quant_config
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
Gemma4Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
),
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
lm_head
.
tie_weights
(
self
.
model
.
embed_tokens
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
soft_cap
=
getattr
(
config
,
"final_logit_softcapping"
,
None
),
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
# --- MixtureOfExperts protocol ---
self
.
expert_weights
:
list
[
list
[
torch
.
Tensor
]]
=
[]
self
.
moe_layers
:
list
[
nn
.
Module
]
=
[]
example_moe
:
Gemma4MoE
|
None
=
None
for
layer
in
self
.
model
.
layers
:
if
hasattr
(
layer
,
"moe"
)
and
isinstance
(
layer
.
moe
,
Gemma4MoE
):
example_moe
=
layer
.
moe
self
.
moe_layers
.
append
(
layer
.
moe
.
experts
)
self
.
num_moe_layers
=
len
(
self
.
moe_layers
)
if
example_moe
is
not
None
:
self
.
num_logical_experts
=
example_moe
.
num_experts
self
.
num_physical_experts
=
example_moe
.
num_experts
self
.
num_local_physical_experts
=
example_moe
.
num_experts
self
.
num_routed_experts
=
example_moe
.
num_experts
else
:
self
.
num_logical_experts
=
0
self
.
num_physical_experts
=
0
self
.
num_local_physical_experts
=
0
self
.
num_routed_experts
=
0
self
.
num_expert_groups
=
1
self
.
num_shared_experts
=
0
self
.
num_redundant_experts
=
0
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
embed_input_ids
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
|
IntermediateTensors
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
,
**
kwargs
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
None
:
return
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
# Checkpoint weight names use "language_model." prefix (from the
# Gemma4ForConditionalGeneration wrapper). Strip it to map to our
# model tree which is just "model.*".
def
_weight_iterator
():
use_k_eq_v
=
getattr
(
self
.
config
,
"attention_k_eq_v"
,
False
)
# Build set of k_eq_v layer indices (full_attention layers
# when attention_k_eq_v is enabled). These layers have k_proj
# but no v_proj in checkpoint — we duplicate k_proj as v_proj.
k_eq_v_layer_indices
:
set
[
int
]
=
set
()
if
use_k_eq_v
:
for
idx
,
lt
in
enumerate
(
self
.
config
.
layer_types
):
if
lt
==
"full_attention"
:
k_eq_v_layer_indices
.
add
(
idx
)
for
name
,
weight
in
weights
:
# Remap "language_model." → "" to match our model tree.
# Checkpoint: model.language_model.layers.X.*
# Our model: model.layers.X.*
name
=
name
.
replace
(
"language_model."
,
""
)
# Remap new HF checkpoint naming to internal vLLM
# naming: HF moved per_expert_scale to router and
# renamed moe → experts in the MoE block.
name
=
name
.
replace
(
".router.per_expert_scale"
,
".moe.per_expert_scale"
,
)
if
".experts.gate_up_proj"
in
name
:
name
=
name
.
replace
(
".experts.gate_up_proj"
,
".moe.gate_up_proj"
,
)
elif
".experts.down_proj"
in
name
:
name
=
name
.
replace
(
".experts.down_proj"
,
".moe.down_proj"
,
)
# Remap individual 2D expert weights:
# .experts.{id}.{proj} → .moe.experts.{id}.{proj}
# (This handles per-expert 2D quantized weights)
name
=
re
.
sub
(
r
"\.experts\.(\d+)\."
,
r
".moe.experts.\1."
,
name
)
# MoE expert weights: checkpoint stores as 3D packed
# tensors. Explode into per-expert 2D weights for
# FusedMoE weight_loader.
#
# Checkpoint format:
# moe.gate_up_proj: [E, 2*I, H] (fused gate + up)
# moe.down_proj: [E, H, I]
#
# FusedMoE expects per-expert:
# w1 (gate): [I, H] — first half of gate_up
# w3 (up): [I, H] — second half of gate_up
# w2 (down): [H, I] — as-is from checkpoint
#
# No transpose needed: checkpoint orientation already
# matches FusedMoE's expected layout.
if
"moe.gate_up_proj"
in
name
and
weight
.
dim
()
==
3
:
num_experts
=
weight
.
size
(
0
)
intermediate_size
=
weight
.
size
(
1
)
//
2
for
expert_id
in
range
(
num_experts
):
gate_weight
=
weight
[
expert_id
,
:
intermediate_size
,
:]
up_weight
=
weight
[
expert_id
,
intermediate_size
:,
:]
base
=
name
.
replace
(
"moe."
,
f
"moe.experts.
{
expert_id
}
."
)
yield
base
.
replace
(
"gate_up_proj"
,
"gate_proj"
),
gate_weight
yield
base
.
replace
(
"gate_up_proj"
,
"up_proj"
),
up_weight
continue
if
"moe.down_proj"
in
name
and
weight
.
dim
()
==
3
:
num_experts
=
weight
.
size
(
0
)
for
expert_id
in
range
(
num_experts
):
expert_name
=
name
.
replace
(
"moe."
,
f
"moe.experts.
{
expert_id
}
."
)
yield
expert_name
,
weight
[
expert_id
]
continue
# k_eq_v layers: checkpoint has k_proj but no v_proj.
# QKVParallelLinear expects both, so duplicate k_proj
# as v_proj so V gets identical weights to K.
# ONLY for full_attention layers — sliding layers have
# their own real v_proj weights.
if
"self_attn.k_proj"
in
name
and
k_eq_v_layer_indices
:
m
=
re
.
search
(
r
"layers\.(\d+)\."
,
name
)
if
m
and
int
(
m
.
group
(
1
))
in
k_eq_v_layer_indices
:
yield
name
,
weight
yield
name
.
replace
(
"k_proj"
,
"v_proj"
),
weight
.
clone
()
continue
yield
name
,
weight
# Skip multimodal weights — handled by the multimodal wrapper.
# Also skip lm_head when weights are tied.
skip
=
[
"audio_tower."
,
"vision_tower."
,
"embed_audio."
,
"embed_vision."
,
]
if
self
.
config
.
tie_word_embeddings
:
skip
.
append
(
"lm_head."
)
loader
=
AutoWeightsLoader
(
self
,
skip_substrs
=
skip
)
return
loader
.
load_weights
(
_weight_iterator
())
vllm/model_executor/models/gemma4_mm.py
0 → 100644
View file @
858bddce
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Gemma 4 multimodal model (image + audio + video support).
Adds vision tower, audio tower, and multimodal embedders on top of the
text-only Gemma4ForCausalLM. The vision/audio encoders are loaded via
AutoModel.from_config and run in eager mode while the language model uses
the vLLM-optimized path.
Video support: Gemma4 does **not** have a native video tower. Videos are
decomposed into timestamped image frames (up to 32 frames at 70 soft tokens
each) and fed through the same vision tower as regular images. The
processor inserts ``mm:ss`` timestamps between frames so the model can
reason about temporal order.
"""
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Annotated
,
Any
,
Literal
import
numpy
as
np
import
torch
from
PIL
import
Image
as
PILImage
from
torch
import
nn
from
transformers
import
AutoModel
,
BatchFeature
from
transformers.models.gemma4
import
(
Gemma4Config
,
Gemma4Processor
,
Gemma4VisionConfig
,
)
from
transformers.models.gemma4.configuration_gemma4
import
(
Gemma4AudioConfig
,
Gemma4TextConfig
,
)
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
,
VideoDummyOptions
# from vllm.inputs import MultiModalDataDict
from
vllm.multimodal.inputs
import
MultiModalDataDict
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.models.gemma4
import
Gemma4ForCausalLM
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalFieldConfig
,
MultiModalKwargsItems
,
VideoItem
,
)
from
vllm.multimodal.parse
import
(
AudioProcessorItems
,
ImageProcessorItems
,
MultiModalDataItems
,
MultiModalDataParser
,
)
from
vllm.multimodal.processing
import
BaseDummyInputsBuilder
from
vllm.multimodal.processing.processor
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
,
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
init_vllm_registered_model
,
maybe_prefix
,
)
logger
=
init_logger
(
__name__
)
# Video constants — match transformers Gemma4VideoProcessor defaults.
_VIDEO_MAX_SOFT_TOKENS
=
70
# soft tokens per video frame (vs 280 for images)
_VIDEO_MAX_FRAMES
=
32
# max sampled frames per video
# ---------------------------------------------------------------------------
# Input schema
# ---------------------------------------------------------------------------
class
Gemma4ImagePixelInputs
(
TensorSchema
):
"""
Pre-patchified image inputs from the Gemma4 image processor.
Dimensions:
- bn: Batch size * number of images
- np: Number of patches (max_patches = max_soft_tokens * pooling_kernel_size²)
- pp: Patch pixels (patch_size² * 3)
The HF Gemma4ImageProcessor outputs pixel_values as
(batch, max_patches, patch_pixels) — already patchified with
zero-padding for patches beyond the real image content.
pixel_position_ids provides (x, y) coordinates per patch,
with (-1, -1) for padding patches.
"""
type
:
Literal
[
"pixel_values"
]
=
"pixel_values"
pixel_values
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
"np"
,
"pp"
),
]
pixel_position_ids
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
"np"
,
2
),
]
class
Gemma4AudioInputs
(
TensorSchema
):
"""
Dimensions:
- bn: Batch size * number of audios
- s: Sequence length (MEL spectrogram frames)
- f: Number of features (MEL bins)
"""
type
:
Literal
[
"audio"
]
=
"audio"
input_features_padded
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
"s"
,
"f"
)]
input_features_mask
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
"s"
)]
Gemma4ImageInputs
=
Gemma4ImagePixelInputs
class
Gemma4VideoInputs
(
TensorSchema
):
"""Video frame inputs — same tensor format as image inputs.
Gemma4 has no separate video tower; video frames are processed
through the vision tower at lower resolution (max_soft_tokens=70).
"""
type
:
Literal
[
"pixel_values_videos"
]
=
"pixel_values_videos"
pixel_values_videos
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
"np"
,
"pp"
),
]
pixel_position_ids_videos
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
"np"
,
2
),
]
# ---------------------------------------------------------------------------
# Processing info
# ---------------------------------------------------------------------------
class
Gemma4ProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_config
(
self
):
return
self
.
ctx
.
get_hf_config
(
Gemma4Config
)
def
get_default_tok_params
(
self
):
"""Gemma4's chat template already embeds a literal ``<bos>`` token in
the rendered text. If ``add_special_tokens=True`` (the base-class
default), the tokenizer prepends *another* BOS, producing a
``[2, 2, ...]`` double-BOS sequence that the model was not trained on.
Setting ``add_special_tokens=False`` here prevents the duplicate and
ensures both ``llm.generate()`` and the chat/completions API behave
correctly.
"""
params
=
super
().
get_default_tok_params
()
params
=
params
.
with_kwargs
(
add_special_tokens
=
False
)
return
params
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
Gemma4Processor
:
return
self
.
ctx
.
get_hf_processor
(
Gemma4Processor
,
**
kwargs
,
)
def
validate_num_items
(
self
,
modality
:
str
,
num_items
:
int
)
->
None
:
if
(
modality
==
"audio"
and
num_items
>
0
and
self
.
get_hf_config
().
audio_config
is
None
):
model
=
self
.
ctx
.
model_config
.
model
raise
ValueError
(
f
"Audio input was provided but the model "
f
"'
{
model
}
' does not have an audio tower. "
f
"Audio inference is only supported for Gemma4 "
f
"models that include an audio_config "
f
"(i.e., models that include an audio_config)."
)
super
().
validate_num_items
(
modality
,
num_items
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
limits
:
dict
[
str
,
int
|
None
]
=
{
"image"
:
None
}
if
self
.
get_hf_config
().
audio_config
is
not
None
:
limits
[
"audio"
]
=
None
limits
[
"video"
]
=
None
return
limits
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]
)
->
Mapping
[
str
,
int
]
|
None
:
config
=
self
.
get_hf_config
()
# Upper bound: the pooler outputs default_output_length slots
# per image (280). After padding is stripped the actual count
# is ≤ this value, but vLLM needs the max for memory planning.
tokens_per_image
=
config
.
vision_config
.
default_output_length
tokens
:
dict
[
str
,
int
]
=
{
"image"
:
tokens_per_image
}
if
config
.
audio_config
is
not
None
:
# Audio max tokens from the processor's audio_seq_length.
processor
=
self
.
get_hf_processor
()
tokens
[
"audio"
]
=
processor
.
audio_seq_length
# Video: each frame ≤ 70 soft tokens + boi + eoi + ~6 ts tokens.
tokens
[
"video"
]
=
_VIDEO_MAX_FRAMES
*
(
_VIDEO_MAX_SOFT_TOKENS
+
2
+
6
)
return
tokens
def
get_data_parser
(
self
)
->
MultiModalDataParser
:
config
=
self
.
get_hf_config
()
kwargs
:
dict
[
str
,
Any
]
=
{
"video_needs_metadata"
:
True
}
if
getattr
(
config
,
"audio_config"
,
None
)
is
not
None
:
processor
=
self
.
get_hf_processor
()
kwargs
[
"target_sr"
]
=
processor
.
feature_extractor
.
sampling_rate
return
MultiModalDataParser
(
**
kwargs
)
def
_compute_num_soft_tokens
(
self
,
image_width
:
int
,
image_height
:
int
,
max_soft_tokens
:
int
|
None
=
None
,
)
->
int
:
"""Compute the number of soft tokens the vision tower produces
for an image of the given dimensions, after padding is stripped.
Args:
max_soft_tokens: Override for the vision config's
``default_output_length``. When *None*, the value from
the model config is used.
"""
vision_cfg
=
self
.
get_hf_config
().
vision_config
patch_size
=
vision_cfg
.
patch_size
pooling_kernel_size
=
vision_cfg
.
pooling_kernel_size
if
max_soft_tokens
is
None
:
max_soft_tokens
=
vision_cfg
.
default_output_length
unit
=
patch_size
*
pooling_kernel_size
max_patches
=
max_soft_tokens
*
pooling_kernel_size
**
2
num_patches_orig
=
(
image_height
/
patch_size
)
*
(
image_width
/
patch_size
)
scale
=
math
.
sqrt
(
max_patches
/
num_patches_orig
)
target_h
=
max
(
unit
,
int
(
math
.
floor
(
image_height
*
scale
/
unit
))
*
unit
)
target_w
=
max
(
unit
,
int
(
math
.
floor
(
image_width
*
scale
/
unit
))
*
unit
)
num_patches
=
(
target_h
//
patch_size
)
*
(
target_w
//
patch_size
)
return
num_patches
//
(
pooling_kernel_size
**
2
)
def
get_image_repl
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
Gemma4Processor
|
None
,
max_soft_tokens
:
int
|
None
=
None
,
)
->
PromptUpdateDetails
[
list
[
int
]]:
"""Return the dynamic image token sequence for this image.
Computes the exact number of soft tokens the vision tower will
produce after stripping padding.
Args:
max_soft_tokens: Override for the default token budget.
When *None*, falls back to the model config value.
"""
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
num_soft
=
self
.
_compute_num_soft_tokens
(
image_width
,
image_height
,
max_soft_tokens
=
max_soft_tokens
,
)
config
=
self
.
get_hf_config
()
token_ids
=
(
[
config
.
boi_token_id
]
+
[
processor
.
image_token_id
]
*
num_soft
+
[
config
.
eoi_token_id
]
)
return
PromptUpdateDetails
.
select_token_id
(
token_ids
,
processor
.
image_token_id
)
def
get_audio_repl
(
self
,
*
,
audio_len
:
int
,
processor
:
Gemma4Processor
|
None
,
)
->
PromptUpdateDetails
[
list
[
int
]]:
"""Return the dynamic audio token sequence for this audio.
Computes the number of soft tokens from the audio waveform
length using ``ceil(duration_ms / audio_ms_per_token)``.
"""
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
sampling_rate
=
processor
.
feature_extractor
.
sampling_rate
num_tokens
=
processor
.
_compute_audio_num_tokens
(
torch
.
zeros
(
audio_len
),
sampling_rate
)
config
=
self
.
get_hf_config
()
token_ids
=
(
[
config
.
boa_token_id
]
+
[
processor
.
audio_token_id
]
*
num_tokens
+
[
config
.
eoa_token_id
]
)
return
PromptUpdateDetails
.
select_token_id
(
token_ids
,
processor
.
audio_token_id
)
def
get_video_repl
(
self
,
*
,
timestamps
:
list
[
float
],
num_soft_tokens_per_frame
:
list
[
int
],
processor
:
Gemma4Processor
,
)
->
PromptUpdateDetails
[
list
[
int
]]:
"""Build the full token replacement for one video.
Produces the same interleaved sequence as the HF Gemma4Processor:
mm:ss <boi><|video|>*N<eoi> mm:ss <boi><|video|>*N<eoi> ...
"""
tokenizer
=
self
.
ctx
.
get_tokenizer
()
config
=
self
.
get_hf_config
()
boi_token_id
=
config
.
boi_token_id
eoi_token_id
=
config
.
eoi_token_id
video_token_id
=
processor
.
video_token_id
all_token_ids
:
list
[
int
]
=
[]
for
i
,
(
ts
,
n_tokens
)
in
enumerate
(
zip
(
timestamps
,
num_soft_tokens_per_frame
)):
# mm:ss timestamp — matches transformers: int-truncated,
# zero-padded.
minutes
=
int
(
ts
//
60
)
seconds
=
int
(
ts
%
60
)
ts_str
=
f
"
{
minutes
:
02
d
}
:
{
seconds
:
02
d
}
"
prefix
=
f
"
{
ts_str
}
"
if
i
>
0
else
f
"
{
ts_str
}
"
ts_token_ids
=
tokenizer
.
encode
(
prefix
,
add_special_tokens
=
False
)
all_token_ids
.
extend
(
ts_token_ids
)
all_token_ids
.
append
(
boi_token_id
)
all_token_ids
.
extend
([
video_token_id
]
*
n_tokens
)
all_token_ids
.
append
(
eoi_token_id
)
return
PromptUpdateDetails
.
select_token_id
(
all_token_ids
,
video_token_id
)
# ---------------------------------------------------------------------------
# Dummy inputs builder
# ---------------------------------------------------------------------------
class
Gemma4DummyInputsBuilder
(
BaseDummyInputsBuilder
[
Gemma4ProcessingInfo
]):
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_audios
=
mm_counts
.
get
(
"audio"
,
0
)
num_videos
=
mm_counts
.
get
(
"video"
,
0
)
processor
=
self
.
info
.
get_hf_processor
()
# Use image_token (<|image|>) with tab prefix — this is what the
# Gemma4 chat template inserts per image (\t<|image|>).
# _get_prompt_updates targets image_token and expands it to the
# full_image_sequence.
text
=
(
"
\t
"
+
processor
.
image_token
)
*
num_images
if
num_audios
>
0
and
processor
.
audio_token
:
text
+=
processor
.
audio_token
*
num_audios
if
num_videos
>
0
:
text
+=
processor
.
video_token
*
num_videos
return
text
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_options
:
Mapping
[
str
,
BaseDummyOptions
]
|
None
=
None
,
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_audios
=
mm_counts
.
get
(
"audio"
,
0
)
num_videos
=
mm_counts
.
get
(
"video"
,
0
)
processor
=
self
.
info
.
get_hf_processor
()
image_processor
=
processor
.
image_processor
# Use processor's configured image size for dummies.
# Gemma4ImageProcessor sets size=None (it uses patch_size /
# max_soft_tokens instead of the standard size dict), so we
# guard against None with `or {}`.
size
=
getattr
(
image_processor
,
"size"
,
None
)
or
{}
img_width
=
size
.
get
(
"width"
,
224
)
img_height
=
size
.
get
(
"height"
,
224
)
image_overrides
=
mm_options
.
get
(
"image"
)
if
mm_options
else
None
audio_overrides
=
mm_options
.
get
(
"audio"
)
if
mm_options
else
None
video_overrides
=
mm_options
.
get
(
"video"
)
if
mm_options
else
None
data
:
MultiModalDataDict
=
{
"image"
:
self
.
_get_dummy_images
(
width
=
img_width
,
height
=
img_height
,
num_images
=
num_images
,
overrides
=
image_overrides
,
),
}
if
num_audios
>
0
:
audio_len
=
processor
.
feature_extractor
.
fft_length
data
[
"audio"
]
=
self
.
_get_dummy_audios
(
length
=
audio_len
,
num_audios
=
num_audios
,
overrides
=
audio_overrides
,
)
if
num_videos
>
0
:
data
[
"video"
]
=
self
.
_get_dummy_videos
(
width
=
img_width
,
height
=
img_height
,
num_frames
=
_VIDEO_MAX_FRAMES
,
num_videos
=
num_videos
,
overrides
=
video_overrides
,
)
return
data
def
_get_dummy_videos
(
self
,
*
,
width
:
int
,
height
:
int
,
num_frames
:
int
,
num_videos
:
int
,
overrides
:
VideoDummyOptions
|
None
=
None
,
)
->
list
[
VideoItem
]:
num_frames
=
max
(
num_frames
,
2
)
videos
=
super
().
_get_dummy_videos
(
width
=
width
,
height
=
height
,
num_frames
=
num_frames
,
num_videos
=
num_videos
,
overrides
=
overrides
,
)
videos
=
[
v
.
copy
()
for
v
in
videos
]
video_items
:
list
[
VideoItem
]
=
[]
for
video
in
videos
:
video_num_frames
=
video
.
shape
[
0
]
video_metadata
=
{
"fps"
:
2.0
,
"duration"
:
video_num_frames
/
2.0
,
"total_num_frames"
:
video_num_frames
,
"frames_indices"
:
list
(
range
(
video_num_frames
)),
"video_backend"
:
"opencv"
,
"do_sample_frames"
:
False
,
}
video_items
.
append
((
video
,
video_metadata
))
return
video_items
# ---------------------------------------------------------------------------
# Multimodal processor
# ---------------------------------------------------------------------------
class
Gemma4MultiModalProcessor
(
BaseMultiModalProcessor
[
Gemma4ProcessingInfo
]):
def
_call_hf_processor
(
self
,
prompt
:
str
,
mm_data
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
tok_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
# Validate max_soft_tokens early and exit cleanly on bad values.
_SUPPORTED_SOFT_TOKENS
=
(
70
,
140
,
280
,
560
,
1120
)
merged_kwargs
=
self
.
info
.
ctx
.
get_merged_mm_kwargs
(
mm_kwargs
)
val
=
merged_kwargs
.
get
(
"max_soft_tokens"
)
if
val
is
None
:
val
=
merged_kwargs
.
get
(
"images_kwargs"
,
{}).
get
(
"max_soft_tokens"
)
if
val
is
not
None
and
val
not
in
_SUPPORTED_SOFT_TOKENS
:
raise
ValueError
(
f
"Unsupported max_soft_tokens value:
{
val
}
. "
f
"Valid values are
{
_SUPPORTED_SOFT_TOKENS
}
."
)
mm_data
=
dict
(
mm_data
)
# ---- VIDEO HANDLING ----
# Gemma4 decomposes video into timestamped image frames.
# Each frame is processed with max_soft_tokens=70 through the
# same vision tower, matching transformers processing_gemma4.py.
video_outputs
:
dict
[
str
,
Any
]
=
{}
if
videos
:
=
mm_data
.
pop
(
"videos"
,
[]):
processor
=
self
.
info
.
get_hf_processor
()
all_video_pixel_values
:
list
[
torch
.
Tensor
]
=
[]
all_video_position_ids
:
list
[
torch
.
Tensor
]
=
[]
video_num_soft_tokens_per_video
:
list
[
list
[
int
]]
=
[]
video_timestamps_per_video
:
list
[
list
[
float
]]
=
[]
video_frame_counts
:
list
[
int
]
=
[]
for
item
in
videos
:
video_array
,
metadata
=
item
# Convert frames to PIL images
if
isinstance
(
video_array
,
np
.
ndarray
):
frames
=
[
PILImage
.
fromarray
(
video_array
[
i
])
for
i
in
range
(
video_array
.
shape
[
0
])
]
else
:
frames
=
list
(
video_array
)
# Compute timestamps from metadata (same as transformers)
fps
=
metadata
.
get
(
"fps"
)
or
24
frame_indices
=
metadata
.
get
(
"frames_indices"
,
list
(
range
(
len
(
frames
))))
timestamps
=
[
idx
/
fps
for
idx
in
frame_indices
]
# Process frames as images with max_soft_tokens=70
video_mm_kwargs
=
dict
(
mm_kwargs
)
video_mm_kwargs
[
"max_soft_tokens"
]
=
_VIDEO_MAX_SOFT_TOKENS
dummy_prompt
=
(
"
\t
"
+
processor
.
image_token
)
*
len
(
frames
)
frame_outputs
=
super
().
_call_hf_processor
(
prompt
=
dummy_prompt
,
mm_data
=
{
"images"
:
frames
},
mm_kwargs
=
video_mm_kwargs
,
tok_kwargs
=
tok_kwargs
,
)
# Remap HF key name
if
"image_position_ids"
in
frame_outputs
:
frame_outputs
[
"pixel_position_ids"
]
=
frame_outputs
.
pop
(
"image_position_ids"
)
all_video_pixel_values
.
append
(
frame_outputs
[
"pixel_values"
])
all_video_position_ids
.
append
(
frame_outputs
[
"pixel_position_ids"
])
# Compute soft tokens per frame
num_soft_per_frame
=
[]
for
img
in
frames
:
w
,
h
=
img
.
size
n
=
self
.
info
.
_compute_num_soft_tokens
(
w
,
h
,
max_soft_tokens
=
_VIDEO_MAX_SOFT_TOKENS
)
num_soft_per_frame
.
append
(
n
)
video_num_soft_tokens_per_video
.
append
(
num_soft_per_frame
)
video_timestamps_per_video
.
append
(
timestamps
)
video_frame_counts
.
append
(
len
(
frames
))
# Build expanded replacement text and replace the
# <|video|> placeholder in the prompt.
# Use split(token, 1) to avoid collision — the
# replacement text itself contains <|video|> tokens.
ts_strs
=
[
f
"
{
int
(
s
//
60
):
02
d
}
:
{
int
(
s
%
60
):
02
d
}
"
for
s
in
timestamps
]
replacement
=
" "
.
join
(
f
"
{
t
}
{
processor
.
boi_token
}
"
f
"
{
processor
.
video_token
*
n
}
"
f
"
{
processor
.
eoi_token
}
"
for
t
,
n
in
zip
(
ts_strs
,
num_soft_per_frame
)
)
parts
=
prompt
.
split
(
processor
.
video_token
,
1
)
if
len
(
parts
)
==
2
:
prompt
=
parts
[
0
]
+
replacement
+
parts
[
1
]
video_outputs
=
{
"pixel_values_videos"
:
torch
.
cat
(
all_video_pixel_values
,
dim
=
0
),
"pixel_position_ids_videos"
:
torch
.
cat
(
all_video_position_ids
,
dim
=
0
),
"video_frame_counts"
:
torch
.
tensor
(
video_frame_counts
),
"video_num_soft_tokens"
:
video_num_soft_tokens_per_video
,
"video_timestamps"
:
video_timestamps_per_video
,
}
# The processor accepts 'audio' not 'audios'.
if
"audios"
in
mm_data
:
mm_data
[
"audio"
]
=
mm_data
.
pop
(
"audios"
)
# Warn if any audio waveform exceeds the model's max duration.
if
"audio"
in
mm_data
:
processor
=
self
.
info
.
get_hf_processor
()
sr
=
processor
.
feature_extractor
.
sampling_rate
max_tokens
=
processor
.
audio_seq_length
ms_per_tok
=
processor
.
audio_ms_per_token
max_duration_s
=
max_tokens
*
ms_per_tok
/
1000.0
audios
=
mm_data
[
"audio"
]
if
not
isinstance
(
audios
,
(
list
,
tuple
)):
audios
=
[
audios
]
for
i
,
waveform
in
enumerate
(
audios
):
duration_s
=
len
(
waveform
)
/
sr
if
duration_s
>
max_duration_s
:
logger
.
warning
(
"Audio duration exceeds max: %f > %f seconds"
,
duration_s
,
max_duration_s
,
)
# vLLM's call_hf_processor (context.py) re-merges
# mm_processor_kwargs from the model config on every call via:
# config_kwargs | incoming_kwargs (right side wins)
#
# If we strip max_soft_tokens from incoming, the re-merge puts
# back the config's global default (e.g. 280), ignoring any
# per-prompt override. Instead, we keep it in the kwargs with
# the validated per-prompt value so it wins during the merge.
#
# NOTE: This requires a corresponding type annotation on the
# HF side (Gemma4ProcessorKwargs.images_kwargs) so that
# _merge_kwargs routes max_soft_tokens into images_kwargs.
patched_mm_kwargs
=
dict
(
mm_kwargs
)
if
val
is
not
None
:
patched_mm_kwargs
[
"max_soft_tokens"
]
=
val
processed_outputs
=
super
().
_call_hf_processor
(
prompt
,
mm_data
,
patched_mm_kwargs
,
tok_kwargs
,
)
# HF uses 'image_position_ids'; vLLM uses 'pixel_position_ids'.
# Remap here to keep a single translation point.
if
"image_position_ids"
in
processed_outputs
:
processed_outputs
[
"pixel_position_ids"
]
=
processed_outputs
.
pop
(
"image_position_ids"
)
if
"input_features"
in
processed_outputs
:
# Keep padded features for batched audio tower execution.
processed_outputs
[
"input_features_padded"
]
=
processed_outputs
[
"input_features"
]
# Unpad per-item so each item's cache entry is self-contained.
unpadded_features
=
[
f
[
mask
]
for
f
,
mask
in
zip
(
processed_outputs
[
"input_features"
],
processed_outputs
[
"input_features_mask"
],
)
]
processed_outputs
[
"input_features"
]
=
unpadded_features
# Merge video outputs into the final result
combined_outputs
=
dict
(
processed_outputs
,
**
video_outputs
)
return
BatchFeature
(
combined_outputs
)
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
fields
=
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
pixel_position_ids
=
MultiModalFieldConfig
.
batched
(
"image"
),
input_features_padded
=
MultiModalFieldConfig
.
batched
(
"audio"
),
input_features_mask
=
MultiModalFieldConfig
.
batched
(
"audio"
),
)
# Video fields: frames stored flat, split per video by
# video_frame_counts.
video_frame_counts
=
hf_inputs
.
get
(
"video_frame_counts"
)
if
video_frame_counts
is
not
None
:
vfc
=
video_frame_counts
if
not
isinstance
(
vfc
,
torch
.
Tensor
):
vfc
=
torch
.
tensor
(
vfc
)
fields
.
update
(
pixel_values_videos
=
(
MultiModalFieldConfig
.
flat_from_sizes
(
"video"
,
vfc
)
),
pixel_position_ids_videos
=
(
MultiModalFieldConfig
.
flat_from_sizes
(
"video"
,
vfc
)
),
video_frame_counts
=
MultiModalFieldConfig
.
batched
(
"video"
,
),
video_num_soft_tokens
=
MultiModalFieldConfig
.
batched
(
"video"
,
keep_on_cpu
=
True
),
video_timestamps
=
MultiModalFieldConfig
.
batched
(
"video"
,
keep_on_cpu
=
True
),
)
return
fields
def
_get_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
out_mm_kwargs
:
MultiModalKwargsItems
,
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
prompt_updates
=
[]
if
"image"
in
mm_items
:
# Target image_token (<|image|>) — the single placeholder the
# Gemma4 chat template inserts once per image in the prompt.
# vLLM tokenizes the prompt without token expansion, so only
# one image_token exists per image in the token stream.
# The replacement expands it to the full image sequence
# (boi + N×image_token + eoi, where N = max_soft_tokens).
image_token
=
hf_processor
.
image_token
def
get_replacement_image
(
item_idx
:
int
):
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
image_size
=
images
.
get_image_size
(
item_idx
)
# Resolve the effective max_soft_tokens by merging
# per-prompt kwargs with the config-level defaults,
# consistent with how _call_hf_processor resolves it.
# Without this merge, a missing per-prompt override
# would fall back to vision_cfg.default_output_length
# instead of the config's mm_processor_kwargs default.
merged_kwargs
=
self
.
info
.
ctx
.
get_merged_mm_kwargs
(
hf_processor_mm_kwargs
,
)
max_soft_tokens
=
merged_kwargs
.
get
(
"max_soft_tokens"
)
return
self
.
info
.
get_image_repl
(
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
processor
=
hf_processor
,
max_soft_tokens
=
max_soft_tokens
,
)
prompt_updates
.
append
(
PromptReplacement
(
modality
=
"image"
,
target
=
image_token
,
replacement
=
get_replacement_image
,
)
)
if
"video"
in
mm_items
:
video_token
=
hf_processor
.
video_token
def
get_replacement_video
(
item_idx
:
int
):
out_item
=
out_mm_kwargs
[
"video"
][
item_idx
]
timestamps
=
out_item
[
"video_timestamps"
].
data
num_soft
=
out_item
[
"video_num_soft_tokens"
].
data
return
self
.
info
.
get_video_repl
(
timestamps
=
timestamps
,
num_soft_tokens_per_frame
=
num_soft
,
processor
=
hf_processor
,
)
prompt_updates
.
append
(
PromptReplacement
(
modality
=
"video"
,
target
=
video_token
,
replacement
=
get_replacement_video
,
)
)
if
"audio"
in
mm_items
:
audio_token
=
hf_processor
.
audio_token
def
get_replacement_audio
(
item_idx
:
int
):
audios
=
mm_items
.
get_items
(
"audio"
,
AudioProcessorItems
)
audio_len
=
audios
.
get_audio_length
(
item_idx
)
return
self
.
info
.
get_audio_repl
(
audio_len
=
audio_len
,
processor
=
hf_processor
,
)
prompt_updates
.
append
(
PromptReplacement
(
modality
=
"audio"
,
target
=
audio_token
,
replacement
=
get_replacement_audio
,
)
)
return
prompt_updates
# NOTE: Gemma3/Gemma3n override _apply_token_matches and
# _find_mm_placeholders to merge adjacent newline tokens that arise
# when full_image_sequence contains "\n\n" wrappers. Gemma4's
# full_image_sequence has NO newlines (just BOI + 280×image_token +
# EOI), so the base class implementations work correctly as-is.
# ---------------------------------------------------------------------------
# Multimodal embedder
# ---------------------------------------------------------------------------
class
Gemma4MultimodalEmbedder
(
nn
.
Module
):
"""Projects vision/audio soft tokens into LM embedding space.
Architecture:
inputs_embeds → embedding_projection → embedding_post_projection_norm
Unlike Gemma3n which has separate hard/soft embedding paths with
per-path normalization and a learned embedding table, Gemma4 uses a
simplified 2-layer design: a linear projection followed by RMSNorm
(without learnable scale). The checkpoint confirms this — only
``embedding_projection.weight`` exists; there is no embedding table
or pre-projection norm weights.
"""
def
__init__
(
self
,
multimodal_config
:
Gemma4VisionConfig
|
Gemma4AudioConfig
,
text_config
:
Gemma4TextConfig
,
):
super
().
__init__
()
self
.
eps
=
multimodal_config
.
rms_norm_eps
self
.
text_hidden_size
=
text_config
.
hidden_size
# Audio tower uses output_proj_dims (1536) rather than hidden_size
# (1024); vision uses hidden_size (768) directly.
embedding_dim
=
(
getattr
(
multimodal_config
,
"output_proj_dims"
,
None
)
or
multimodal_config
.
hidden_size
)
self
.
embedding_projection
=
ReplicatedLinear
(
embedding_dim
,
self
.
text_hidden_size
,
bias
=
False
,
)
self
.
embedding_post_projection_norm
=
RMSNorm
(
self
.
text_hidden_size
,
eps
=
self
.
eps
,
has_weight
=
False
,
)
def
forward
(
self
,
inputs_embeds
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Project soft tokens from a multimodal tower into LM space."""
embs_proj
,
_
=
self
.
embedding_projection
(
inputs_embeds
)
return
self
.
embedding_post_projection_norm
(
embs_proj
)
# ---------------------------------------------------------------------------
# Main model
# ---------------------------------------------------------------------------
@
MULTIMODAL_REGISTRY
.
register_processor
(
Gemma4MultiModalProcessor
,
info
=
Gemma4ProcessingInfo
,
dummy_inputs
=
Gemma4DummyInputsBuilder
,
)
class
Gemma4ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# Maps checkpoint prefixes to vLLM module paths.
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model.embed_audio."
:
"embed_audio."
,
"model.embed_vision."
:
"embed_vision."
,
"model.language_model."
:
"language_model.model."
,
"model.vision_tower."
:
"vision_tower."
,
"model.audio_tower."
:
"audio_tower."
,
"lm_head."
:
"language_model.lm_head."
,
"model"
:
"language_model.model"
,
}
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
multimodal_config
=
multimodal_config
# ---- Vision tower (shared by image and video) ----
with
self
.
_mark_tower_model
(
vllm_config
,
{
"image"
,
"video"
}):
self
.
vision_tower
=
AutoModel
.
from_config
(
config
=
config
.
vision_config
)
self
.
embed_vision
=
Gemma4MultimodalEmbedder
(
config
.
vision_config
,
config
.
text_config
)
# ---- Audio tower (variants with audio_config) ----
if
config
.
audio_config
is
not
None
:
with
self
.
_mark_tower_model
(
vllm_config
,
"audio"
):
self
.
audio_tower
=
AutoModel
.
from_config
(
config
=
config
.
audio_config
)
# AutoModel.from_config does NOT call post_init(),
# which is needed to initialize buffers that are absent
# from the checkpoint (e.g. inv_timescales for relative
# position embeddings, softcap, gradient_clipping).
self
.
audio_tower
.
post_init
()
self
.
embed_audio
=
Gemma4MultimodalEmbedder
(
config
.
audio_config
,
config
.
text_config
)
else
:
self
.
audio_tower
=
None
self
.
embed_audio
=
None
# ---- Language model (vLLM optimised) ----
with
self
.
_mark_language_model
(
vllm_config
):
self
.
language_model
:
Gemma4ForCausalLM
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
architectures
=
[
"Gemma4ForCausalLM"
],
)
# Pre-allocate PLE buffer for CUDA graph compatibility.
# Some variants have hidden_size_per_layer_input=None (no PLE).
ple_dim
=
config
.
text_config
.
hidden_size_per_layer_input
if
ple_dim
is
not
None
:
self
.
per_layer_embeddings
=
torch
.
zeros
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
,
config
.
text_config
.
num_hidden_layers
,
ple_dim
,
device
=
(
self
.
language_model
.
model
.
embed_tokens
.
weight
.
device
),
dtype
=
(
self
.
language_model
.
model
.
embed_tokens
.
weight
.
dtype
),
)
else
:
self
.
per_layer_embeddings
=
None
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
# --- MixtureOfExperts delegation to language_model ---
self
.
expert_weights
=
self
.
language_model
.
expert_weights
self
.
moe_layers
=
self
.
language_model
.
moe_layers
self
.
num_moe_layers
=
self
.
language_model
.
num_moe_layers
self
.
num_logical_experts
=
self
.
language_model
.
num_logical_experts
self
.
num_physical_experts
=
self
.
language_model
.
num_physical_experts
self
.
num_local_physical_experts
=
self
.
language_model
.
num_local_physical_experts
self
.
num_routed_experts
=
self
.
language_model
.
num_routed_experts
self
.
num_expert_groups
=
self
.
language_model
.
num_expert_groups
self
.
num_shared_experts
=
self
.
language_model
.
num_shared_experts
self
.
num_redundant_experts
=
self
.
language_model
.
num_redundant_experts
# ------------------------------------------------------------------ #
# Input parsing
# ------------------------------------------------------------------ #
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Gemma4ImageInputs
|
None
:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
pixel_position_ids
=
kwargs
.
pop
(
"pixel_position_ids"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
assert
image_embeds
is
None
,
"Gemma4 does not support image_embeds."
if
pixel_values
is
None
:
return
None
return
Gemma4ImagePixelInputs
(
pixel_values
=
pixel_values
,
pixel_position_ids
=
pixel_position_ids
,
)
def
_parse_and_validate_audio_input
(
self
,
**
kwargs
:
object
)
->
Gemma4AudioInputs
|
None
:
input_features_padded
=
kwargs
.
pop
(
"input_features_padded"
,
None
)
if
input_features_padded
is
None
:
return
None
input_features_mask
=
kwargs
.
pop
(
"input_features_mask"
,
None
)
if
input_features_mask
is
None
:
return
None
return
Gemma4AudioInputs
(
input_features_padded
=
input_features_padded
,
input_features_mask
=
input_features_mask
,
)
def
_parse_and_validate_video_input
(
self
,
**
kwargs
:
object
)
->
dict
[
str
,
torch
.
Tensor
]
|
None
:
pixel_values_videos
=
kwargs
.
pop
(
"pixel_values_videos"
,
None
)
pixel_position_ids_videos
=
kwargs
.
pop
(
"pixel_position_ids_videos"
,
None
)
video_frame_counts
=
kwargs
.
pop
(
"video_frame_counts"
,
None
)
if
pixel_values_videos
is
None
:
return
None
return
{
"pixel_values_videos"
:
pixel_values_videos
,
"pixel_position_ids_videos"
:
pixel_position_ids_videos
,
"video_frame_counts"
:
video_frame_counts
,
}
def
_parse_and_validate_multimodal_inputs
(
self
,
**
kwargs
:
object
)
->
dict
[
str
,
Gemma4ImageInputs
|
Gemma4AudioInputs
|
Gemma4VideoInputs
|
None
]:
mm_input_by_modality
=
{}
for
input_key
in
list
(
kwargs
):
if
(
input_key
in
(
"pixel_values"
,
"image_embeds"
)
and
"image"
not
in
mm_input_by_modality
):
mm_input_by_modality
[
"image"
]
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
(
input_key
==
"pixel_values_videos"
and
"video"
not
in
mm_input_by_modality
):
mm_input_by_modality
[
"video"
]
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
if
(
input_key
==
"input_features_padded"
and
"audio"
not
in
mm_input_by_modality
):
mm_input_by_modality
[
"audio"
]
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
return
mm_input_by_modality
# ------------------------------------------------------------------ #
# Image processing
# ------------------------------------------------------------------ #
def
_process_image_input
(
self
,
image_input
:
Gemma4ImageInputs
,
)
->
list
[
torch
.
Tensor
]:
pixel_values
=
image_input
[
"pixel_values"
]
pixel_position_ids
=
image_input
[
"pixel_position_ids"
]
# The HF image processor now outputs pre-patchified data:
# pixel_values: (num_images, max_patches, patch_pixels)
# pixel_position_ids: (num_images, max_patches, 2)
# We call the vision tower's forward() directly, which handles
# patch embedding, encoding, pooling, padding removal, and
# optional standardization internally.
vt
=
self
.
vision_tower
pooling_k2
=
self
.
config
.
vision_config
.
pooling_kernel_size
**
2
# TODO: Move this per-image loop into the input processor to
# reduce dynamism at the model runner / engine core. This
# requires spatially padding all images to uniform (H_max,
# W_max) in _call_hf_processor() so they arrive as a single
# stacked tensor, tracking padded regions via image_sizes
# metadata, and validating numerical equivalence with the
# current per-image path.
#
# Process each image individually through the vision tower.
# The vision tower's forward() strips padding and returns a
# flat tensor of valid tokens. We process per-image to get
# variable-length outputs matching the dynamic token count
# from get_image_repl.
per_image_features
=
[]
for
i
in
range
(
pixel_values
.
shape
[
0
]):
pv
=
pixel_values
[
i
].
unsqueeze
(
0
)
# (1, max_patches, patch_pixels)
pp
=
pixel_position_ids
[
i
].
unsqueeze
(
0
)
# (1, max_patches, 2)
# Derive the pooler's output_length from the total patch
# count (including padding). The vision tower encoder
# processes ALL patches — padding patches get zero hidden
# states but still occupy sequence positions. The pooler's
# _avg_pool_by_positions requires:
# input_seq_len / output_length == k²
# where k == pooling_kernel_size. The image processor
# allocates max_patches = max_soft_tokens * k² total slots,
# so output_length = max_patches / k² == max_soft_tokens.
# Without this, the pooler falls back to
# config.image_seq_length (e.g. 280), which fails when a
# different max_soft_tokens was used at preprocessing time.
max_patches
=
pv
.
shape
[
1
]
output_length
=
max_patches
//
pooling_k2
vt_output
=
vt
(
pv
,
pp
,
output_length
=
output_length
)
# last_hidden_state: (num_valid_tokens, hidden_size)
# — already flat with padding stripped by the vision tower
per_image_features
.
append
(
vt_output
.
last_hidden_state
)
# Project each image's features into LM embedding space.
# Per-image loop is required because images have variable
# token counts after padding removal.
# Cast to match the projection layer's dtype (model may be
# bf16 while the vision tower outputs fp32).
target_dtype
=
self
.
embed_vision
.
embedding_projection
.
weight
.
dtype
return
[
self
.
embed_vision
(
inputs_embeds
=
img
.
unsqueeze
(
0
).
to
(
target_dtype
)).
squeeze
(
0
)
for
img
in
per_image_features
]
# ------------------------------------------------------------------ #
# Video processing (frames through vision tower)
# ------------------------------------------------------------------ #
def
_process_video_input
(
self
,
video_input
:
dict
[
str
,
torch
.
Tensor
],
)
->
list
[
torch
.
Tensor
]:
"""Process video frames through the vision tower.
Reuses the image processing pipeline — Gemma4 has no separate
video tower; video frames are just images at lower resolution
(max_soft_tokens=70).
Returns one concatenated embedding tensor per video (not per
frame), because vLLM treats one video as one multimodal item.
The flat_from_sizes field config groups all frames of a video
together, so embed_multimodal must return one tensor per video.
"""
pixel_values
=
video_input
[
"pixel_values_videos"
]
pixel_position_ids
=
video_input
[
"pixel_position_ids_videos"
]
frame_counts
=
video_input
[
"video_frame_counts"
]
vt
=
self
.
vision_tower
pooling_k2
=
self
.
config
.
vision_config
.
pooling_kernel_size
**
2
target_dtype
=
self
.
embed_vision
.
embedding_projection
.
weight
.
dtype
# Split flat tensors into per-video chunks
if
isinstance
(
frame_counts
,
torch
.
Tensor
):
fc_list
=
frame_counts
.
tolist
()
else
:
fc_list
=
list
(
frame_counts
)
pv_per_video
=
torch
.
split
(
pixel_values
,
fc_list
,
dim
=
0
)
pp_per_video
=
torch
.
split
(
pixel_position_ids
,
fc_list
,
dim
=
0
)
per_video_embeddings
=
[]
for
pv_chunk
,
pp_chunk
in
zip
(
pv_per_video
,
pp_per_video
):
frame_embs
=
[]
for
i
in
range
(
pv_chunk
.
shape
[
0
]):
pv
=
pv_chunk
[
i
].
unsqueeze
(
0
)
pp
=
pp_chunk
[
i
].
unsqueeze
(
0
)
max_patches
=
pv
.
shape
[
1
]
output_length
=
max_patches
//
pooling_k2
vt_output
=
vt
(
pv
,
pp
,
output_length
=
output_length
)
frame_emb
=
self
.
embed_vision
(
inputs_embeds
=
(
vt_output
.
last_hidden_state
.
unsqueeze
(
0
).
to
(
target_dtype
)
)
).
squeeze
(
0
)
frame_embs
.
append
(
frame_emb
)
# Concatenate all frames of this video into one tensor.
per_video_embeddings
.
append
(
torch
.
cat
(
frame_embs
,
dim
=
0
))
return
per_video_embeddings
# ------------------------------------------------------------------ #
# Audio processing
# ------------------------------------------------------------------ #
def
_process_audio_input
(
self
,
audio_input
:
Gemma4AudioInputs
,
)
->
list
[
torch
.
Tensor
]:
input_features
=
audio_input
[
"input_features_padded"
].
squeeze
(
1
)
input_features_mask
=
audio_input
[
"input_features_mask"
].
squeeze
(
1
)
# Run audio tower — mask uses standard HF convention
# (True=valid, False=padding).
audio_outputs
=
self
.
audio_tower
(
input_features
,
input_features_mask
)
if
isinstance
(
audio_outputs
,
tuple
):
audio_encodings
,
audio_mask
=
audio_outputs
else
:
audio_encodings
=
audio_outputs
.
last_hidden_state
audio_mask
=
audio_outputs
.
attention_mask
# Project into LM embedding space.
audio_features
=
self
.
embed_audio
(
inputs_embeds
=
audio_encodings
)
# Strip padding per-batch element: only keep real (non-padding)
# tokens. audio_mask is True for valid positions (HF convention).
per_audio
=
[]
for
enc
,
mask
in
zip
(
audio_features
,
audio_mask
,
strict
=
True
):
per_audio
.
append
(
enc
[
mask
])
# [num_real, hidden_size]
return
per_audio
# ------------------------------------------------------------------ #
# MultiModalEmbeddings interface
# ------------------------------------------------------------------ #
def
embed_multimodal
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
mm_input_by_modality
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
multimodal_embeddings
:
list
[
torch
.
Tensor
]
=
[]
for
modality
,
multimodal_input
in
mm_input_by_modality
.
items
():
if
multimodal_input
is
None
:
continue
if
modality
==
"image"
:
multimodal_embeddings
.
extend
(
self
.
_process_image_input
(
multimodal_input
)
)
elif
modality
==
"video"
:
multimodal_embeddings
.
extend
(
self
.
_process_video_input
(
multimodal_input
)
)
elif
modality
==
"audio"
:
multimodal_embeddings
.
extend
(
self
.
_process_audio_input
(
multimodal_input
)
)
return
multimodal_embeddings
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
# Cache per-layer embeddings (PLE) for the language model's
# forward pass. During profiling embed_input_ids is not called,
# so the pre-allocated zeros are used instead.
if
self
.
per_layer_embeddings
is
not
None
:
# Mask multimodal tokens (image/audio) to 0 for PLE
# computation (using token_type_ids == 0 as text_mask).
# Replicate this: map image token positions to token 0.
if
is_multimodal
is
not
None
:
is_multimodal
=
is_multimodal
.
to
(
input_ids
.
device
)
ple_input_ids
=
torch
.
where
(
is_multimodal
,
torch
.
zeros_like
(
input_ids
),
input_ids
)
else
:
ple_input_ids
=
input_ids
per_layer_inputs
=
self
.
language_model
.
model
.
get_per_layer_inputs
(
ple_input_ids
)
if
per_layer_inputs
is
not
None
:
per_layer_inputs
=
per_layer_inputs
.
reshape
(
-
1
,
self
.
config
.
text_config
.
num_hidden_layers
,
self
.
config
.
text_config
.
hidden_size_per_layer_input
,
)
self
.
per_layer_embeddings
[:
per_layer_inputs
.
shape
[
0
]].
copy_
(
per_layer_inputs
)
if
multimodal_embeddings
is
None
or
is_multimodal
is
None
:
return
super
().
embed_input_ids
(
input_ids
)
return
super
().
embed_input_ids
(
input_ids
,
multimodal_embeddings
=
multimodal_embeddings
,
is_multimodal
=
is_multimodal
,
)
# ------------------------------------------------------------------ #
# Forward
# ------------------------------------------------------------------ #
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
:
object
,
)
->
IntermediateTensors
:
if
intermediate_tensors
is
not
None
:
inputs_embeds
=
None
# Select the pre-cached PLEs for this batch (None when PLE
# is disabled for variants without PLE).
per_layer_inputs
=
(
self
.
per_layer_embeddings
[:
inputs_embeds
.
shape
[
0
]]
if
self
.
per_layer_embeddings
is
not
None
and
inputs_embeds
is
not
None
else
None
)
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
per_layer_inputs
=
per_layer_inputs
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
**
kwargs
,
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
None
:
return
self
.
language_model
.
compute_logits
(
hidden_states
)
# ------------------------------------------------------------------ #
# Weight loading
# ------------------------------------------------------------------ #
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
# Some checkpoints have vestigial embed_vision.embedding and
# embed_audio.embedding weights from the Gemma3n architecture
# that are not used by Gemma4's MultimodalEmbedder (which only
# has embedding_projection + embedding_post_projection_norm).
ignore_prefixes
=
[
"embed_vision.embedding."
,
"embed_audio.embedding."
,
]
# Models without audio tower should skip
# audio weights entirely.
if
self
.
audio_tower
is
None
:
ignore_prefixes
.
extend
(
[
"audio_tower."
,
"embed_audio."
,
]
)
loader
=
AutoWeightsLoader
(
self
,
ignore_unexpected_prefixes
=
ignore_prefixes
,
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
# ------------------------------------------------------------------ #
# LoRA / multimodal mapping
# ------------------------------------------------------------------ #
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
"""Get the module prefix mapping for multimodal models."""
return
MultiModelKeys
.
from_string_field
(
language_model
=
"language_model"
,
connector
=
[
"embed_vision"
,
"embed_audio"
],
tower_model
=
[
"vision_tower"
,
"audio_tower"
],
)
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
str
|
None
:
if
modality
==
"image"
:
return
"<image_soft_token>"
if
modality
==
"audio"
:
return
"<audio_soft_token>"
if
modality
==
"video"
:
return
"<|video|>"
raise
ValueError
(
f
"Unsupported modality:
{
modality
}
"
)
vllm/model_executor/models/registry.py
View file @
858bddce
...
...
@@ -111,6 +111,7 @@ _TEXT_GENERATION_MODELS = {
"Gemma2ForCausalLM"
:
(
"gemma2"
,
"Gemma2ForCausalLM"
),
"Gemma3ForCausalLM"
:
(
"gemma3"
,
"Gemma3ForCausalLM"
),
"Gemma3nForCausalLM"
:
(
"gemma3n"
,
"Gemma3nForCausalLM"
),
"Gemma4ForCausalLM"
:
(
"gemma4"
,
"Gemma4ForCausalLM"
),
"Qwen3NextForCausalLM"
:
(
"qwen3_next"
,
"Qwen3NextForCausalLM"
),
"GlmForCausalLM"
:
(
"glm"
,
"GlmForCausalLM"
),
"Glm4ForCausalLM"
:
(
"glm4"
,
"Glm4ForCausalLM"
),
...
...
@@ -377,6 +378,7 @@ _MULTIMODAL_MODELS = {
"gemma3n_mm"
,
"Gemma3nForConditionalGeneration"
,
),
"Gemma4ForConditionalGeneration"
:
(
"gemma4_mm"
,
"Gemma4ForConditionalGeneration"
),
"GlmAsrForConditionalGeneration"
:
(
"glmasr"
,
"GlmAsrForConditionalGeneration"
),
"GLM4VForCausalLM"
:
(
"glm4v"
,
"GLM4VForCausalLM"
),
"Glm4vForConditionalGeneration"
:
(
"glm4_1v"
,
"Glm4vForConditionalGeneration"
),
...
...
vllm/model_executor/models/utils.py
View file @
858bddce
...
...
@@ -233,8 +233,15 @@ class AutoWeightsLoader:
):
"""
Add tensor names that are not in the model params that may be in the
safetensors, e.g., batch normalization stats.
safetensors, e.g., batch normalization stats
and registered buffers
.
"""
# Add persistent registered buffers.
# Non-persistent buffers are excluded, matching PyTorch state_dict().
non_persistent
=
getattr
(
module
,
"_non_persistent_buffers_set"
,
set
())
for
buf_name
,
buf
in
module
.
named_buffers
(
recurse
=
False
):
if
buf_name
not
in
child_params
and
buf_name
not
in
non_persistent
:
child_params
[
buf_name
]
=
buf
if
isinstance
(
module
,
(
...
...
vllm/reasoning/__init__.py
View file @
858bddce
...
...
@@ -32,6 +32,10 @@ _REASONING_PARSERS_TO_REGISTER = {
"ernie45_reasoning_parser"
,
"Ernie45ReasoningParser"
,
),
"gemma4"
:
(
"gemma4_reasoning_parser"
,
"Gemma4ReasoningParser"
,
),
"glm45"
:
(
"deepseek_v3_reasoning_parser"
,
"DeepSeekV3ReasoningWithThinkingParser"
,
...
...
vllm/reasoning/gemma4_reasoning_parser.py
0 → 100644
View file @
858bddce
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Sequence
from
typing
import
TYPE_CHECKING
from
vllm.entrypoints.openai.engine.protocol
import
DeltaMessage
from
vllm.reasoning.basic_parsers
import
BaseThinkingReasoningParser
from
vllm.tokenizers
import
TokenizerLike
if
TYPE_CHECKING
:
from
vllm.entrypoints.openai.chat_completion.protocol
import
(
ChatCompletionRequest
,
)
from
vllm.entrypoints.openai.responses.protocol
import
ResponsesRequest
# Role label that Gemma4 emits at the start of the thinking channel.
# The model generates: <|channel>thought\n...reasoning...<channel|>
# This prefix must be stripped to expose only the actual reasoning content.
_THOUGHT_PREFIX
=
"thought
\n
"
class
Gemma4ReasoningParser
(
BaseThinkingReasoningParser
):
"""
Reasoning parser for Google Gemma4 thinking models.
Gemma4 uses <|channel>...<channel|> tokens to delimit reasoning/thinking
content within its output. Thinking mode is activated by passing
``enable_thinking=True`` in the chat template kwargs, which injects a
system turn containing <|think|> (token 98) to trigger chain-of-thought
reasoning.
Output pattern when thinking is enabled::
<|channel>thought
...chain of thought reasoning...<channel|>
Final answer text here.
The ``thought
\\
n`` role label inside the channel delimiters is a
structural artefact (analogous to ``user
\\
n`` in ``<|turn>user
\\
n...``).
This parser strips it so that downstream consumers see only the
actual reasoning text, consistent with the offline parser
(``vllm.reasoning.gemma4_utils._strip_thought_label``).
"""
def
__init__
(
self
,
tokenizer
:
TokenizerLike
,
*
args
,
**
kwargs
):
super
().
__init__
(
tokenizer
,
*
args
,
**
kwargs
)
# Instance state for streaming prefix stripping.
# Tracks only the reasoning text received from the base parser,
# independent of current_text (which may contain pre-reasoning
# content and lacks special token text due to
# skip_special_tokens=True).
self
.
_reasoning_text
:
str
=
""
self
.
_prefix_stripped
:
bool
=
False
self
.
new_turn_token_id
=
self
.
vocab
[
"<|turn>"
]
self
.
tool_call_token_id
=
self
.
vocab
[
"<|tool_call>"
]
self
.
tool_response_token_id
=
self
.
vocab
[
"<|tool_response>"
]
def
adjust_request
(
self
,
request
:
"ChatCompletionRequest | ResponsesRequest"
)
->
"ChatCompletionRequest | ResponsesRequest"
:
"""Disable special-token stripping to preserve boundary tokens."""
request
.
skip_special_tokens
=
False
return
request
@
property
def
start_token
(
self
)
->
str
:
"""The token that starts reasoning content."""
return
"<|channel>"
@
property
def
end_token
(
self
)
->
str
:
"""The token that ends reasoning content."""
return
"<channel|>"
def
is_reasoning_end
(
self
,
input_ids
:
Sequence
[
int
])
->
bool
:
start_token_id
=
self
.
start_token_id
end_token_id
=
self
.
end_token_id
new_turn_token_id
=
self
.
new_turn_token_id
tool_call_token_id
=
self
.
tool_call_token_id
tool_response_token_id
=
self
.
tool_response_token_id
# Search from the end of input_ids to find the last match.
for
i
in
range
(
len
(
input_ids
)
-
1
,
-
1
,
-
1
):
if
input_ids
[
i
]
==
start_token_id
:
return
False
if
input_ids
[
i
]
==
tool_call_token_id
:
# We're generating a tool call, so reasoning must be ended.
return
True
if
input_ids
[
i
]
in
(
new_turn_token_id
,
tool_response_token_id
):
# We found a new turn or tool response token so don't consider
# reasoning ended yet, since the model starts new reasoning
# after these tokens.
return
False
if
input_ids
[
i
]
==
end_token_id
:
return
True
return
False
# ------------------------------------------------------------------
# Non-streaming path
# ------------------------------------------------------------------
def
extract_reasoning
(
self
,
model_output
:
str
,
request
:
"ChatCompletionRequest | ResponsesRequest"
,
)
->
tuple
[
str
|
None
,
str
|
None
]:
"""Extract reasoning, stripping the ``thought
\\
n`` role label."""
if
self
.
start_token
not
in
model_output
and
self
.
end_token
not
in
model_output
:
# Default to content history if no tags are present
# (or if they were stripped)
return
None
,
model_output
reasoning
,
content
=
super
().
extract_reasoning
(
model_output
,
request
)
if
reasoning
is
not
None
:
reasoning
=
_strip_thought_label
(
reasoning
)
return
reasoning
,
content
# ------------------------------------------------------------------
# Streaming path
# ------------------------------------------------------------------
def
extract_reasoning_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
)
->
DeltaMessage
|
None
:
"""Extract streaming reasoning, stripping ``thought
\\
n`` from the
first reasoning delta(s).
The ``thought
\\
n`` prefix may arrive as a single delta or split
across multiple deltas (e.g. ``"thought"`` then ``"
\\
n"``). We
buffer early reasoning tokens until we can determine whether the
prefix is present, then emit the buffered content minus the
prefix.
Unlike the previous implementation which reconstructed accumulated
reasoning from ``current_text``, this uses instance state
(``_reasoning_text``) to track only the reasoning content returned
by the base parser. This is necessary because
``skip_special_tokens=True`` (the vLLM default) causes the
``<|channel>`` delimiter to be invisible in ``current_text``,
making it impossible to separate pre-reasoning content from
reasoning content via string matching.
"""
result
=
super
().
extract_reasoning_streaming
(
previous_text
,
current_text
,
delta_text
,
previous_token_ids
,
current_token_ids
,
delta_token_ids
,
)
if
result
is
None
:
return
None
if
result
.
reasoning
is
None
:
return
result
# Accumulate ONLY the reasoning text from base parser results.
# This is immune to pre-reasoning content pollution.
self
.
_reasoning_text
+=
result
.
reasoning
# Once the prefix has been handled, all subsequent reasoning
# deltas pass through unchanged.
if
self
.
_prefix_stripped
:
return
result
# ---- Prefix stripping logic ----
# Case 1: We've accumulated enough to confirm the prefix is
# present. Strip it and pass through the remainder.
if
self
.
_reasoning_text
.
startswith
(
_THOUGHT_PREFIX
):
prefix_len
=
len
(
_THOUGHT_PREFIX
)
# How much reasoning was accumulated before this delta?
prev_reasoning_len
=
len
(
self
.
_reasoning_text
)
-
len
(
result
.
reasoning
)
if
prev_reasoning_len
>=
prefix_len
:
# Prefix was already consumed by prior deltas; this
# delta is entirely real content — pass through.
self
.
_prefix_stripped
=
True
return
result
else
:
# Part or all of the prefix is in this delta.
chars_of_prefix_in_delta
=
prefix_len
-
prev_reasoning_len
stripped
=
result
.
reasoning
[
chars_of_prefix_in_delta
:]
if
stripped
:
self
.
_prefix_stripped
=
True
result
.
reasoning
=
stripped
return
result
else
:
if
len
(
self
.
_reasoning_text
)
>=
prefix_len
:
self
.
_prefix_stripped
=
True
result
.
reasoning
=
""
return
result
return
None
# Case 2: Accumulated text is a strict prefix of
# _THOUGHT_PREFIX (e.g. we've only seen "thou" so far).
# Buffer by suppressing — we can't yet tell if this will
# become the full prefix or diverge.
if
_THOUGHT_PREFIX
.
startswith
(
self
.
_reasoning_text
):
return
None
# Case 3: Accumulated text doesn't match the thought prefix
# at all. This means prior deltas were buffered (suppressed
# by Case 2) but the text diverged. Re-emit the full
# accumulated text to avoid data loss.
self
.
_prefix_stripped
=
True
result
.
reasoning
=
self
.
_reasoning_text
return
result
def
_strip_thought_label
(
text
:
str
)
->
str
:
"""Remove the ``thought
\\
n`` role label from the beginning of text.
Mirrors ``vllm.reasoning.gemma4_utils._strip_thought_label`` from the
offline parser.
"""
if
text
.
startswith
(
_THOUGHT_PREFIX
):
return
text
[
len
(
_THOUGHT_PREFIX
)
:]
return
text
vllm/reasoning/gemma4_utils.py
0 → 100644
View file @
858bddce
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
"""Gemma4 thinking/reasoning output parsing utilities for offline inference.
Standalone functions that parse decoded model text to extract structured
thinking content from Gemma4 models. These are pure-Python utilities with
zero heavy dependencies — they work on raw decoded strings from any
inference backend (vLLM, HuggingFace, TGI, etc.).
For the OpenAI-compatible API reasoning parser (streaming +
non-streaming), see ``vllm.reasoning.gemma4_reasoning_parser``.
For tool call parsing, see ``vllm.tool_parsers.gemma4_utils``.
Usage with vLLM offline inference::
from vllm import LLM, SamplingParams
from vllm.reasoning.gemma4_utils import parse_thinking_output
llm = LLM(model="google/gemma-4-it")
outputs = llm.generate(prompt, SamplingParams(...))
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
# Extract thinking / answer (works with or without enable_thinking)
result = parse_thinking_output(text)
print(result["thinking"]) # chain-of-thought or None
print(result["answer"]) # final answer
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
do not need a transformers dependency for output parsing.
"""
# ---- Thinking Mode Utility ----
# Thinking delimiter tokens as they appear in decoded text.
# Gemma4 uses <|channel> (start) and <channel|> (end) as thinking delimiters.
_THINKING_START_TAG
=
"<|channel>"
_THINKING_END_TAG
=
"<channel|>"
# Sentinel tokens that may appear in decoded output.
_TURN_END_TAG
=
"<turn|>"
def
parse_thinking_output
(
text
:
str
)
->
dict
[
str
,
str
|
None
]:
"""Parse decoded Gemma4 model output.
Use this on **all** Gemma4 output regardless of whether thinking mode
was enabled. It handles three cases:
1. **Thinking enabled, tags present** — splits on ``<|channel>``/
``<channel|>`` to separate chain-of-thought from the answer and
strips the ``thought
\\
n`` role label.
2. **Thinking disabled, spurious label** — strips the bare
``thought
\\
n`` prefix that some Gemma4 models emit even
without thinking mode.
3. **Clean output** — returns the text unchanged.
The answer text is always cleaned of trailing sentinel tokens
(``<turn|>``, ``<eos>``, etc.).
Args:
text: Decoded model output text (from ``tokenizer.decode(...)``).
Returns:
A dict with keys:
- ``"thinking"``: The chain-of-thought text, or ``None`` if no
thinking delimiters were found.
- ``"answer"``: The final answer text.
Example::
>>> from vllm.reasoning.gemma4_utils import parse_thinking_output
>>> output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
>>> result = parse_thinking_output(output_text)
>>> print(result["thinking"]) # chain-of-thought reasoning or None
>>> print(result["answer"]) # final answer
"""
if
_THINKING_END_TAG
in
text
:
parts
=
text
.
split
(
_THINKING_END_TAG
,
1
)
thinking_block
=
parts
[
0
]
answer
=
_clean_answer
(
parts
[
1
])
# Extract thinking content: strip the start tag if present
if
_THINKING_START_TAG
in
thinking_block
:
thinking
=
thinking_block
.
split
(
_THINKING_START_TAG
,
1
)[
1
]
else
:
thinking
=
thinking_block
# Strip the "thought\n" channel role label the model emits inside
# <|channel>thought\n...<channel|> (analogous to "user\n" in
# <|turn>user\n...<turn|>).
thinking
=
_strip_thought_label
(
thinking
.
strip
())
thinking
=
thinking
.
strip
()
return
{
"thinking"
:
thinking
,
"answer"
:
answer
}
# No thinking delimiters found.
# Strip spurious "thought\n" role label that some Gemma4 models sometimes
# emit even without thinking mode enabled, then clean trailing tokens.
answer
=
_strip_thought_label
(
text
)
answer
=
_clean_answer
(
answer
)
return
{
"thinking"
:
None
,
"answer"
:
answer
}
def
_strip_thought_label
(
text
:
str
)
->
str
:
"""Strip the spurious ``thought
\\
n`` label from the start of text.
Only strips when ``thought`` appears as the very first word followed by
a newline — preserving the word ``thought`` in any other context.
"""
if
text
.
startswith
(
"thought
\n
"
):
return
text
[
len
(
"thought
\n
"
)
:]
return
text
def
_clean_answer
(
text
:
str
)
->
str
:
"""Clean trailing sentinel tokens from the answer text.
Strips ``<turn|>``, ``<eos>``, and surrounding whitespace that the
model appends at the end of its response.
"""
text
=
text
.
strip
()
# Strip trailing <turn|> (Gemma4 turn-end marker)
if
text
.
endswith
(
_TURN_END_TAG
):
text
=
text
[:
-
len
(
_TURN_END_TAG
)].
rstrip
()
# Strip trailing <eos> if present
if
text
.
endswith
(
"<eos>"
):
text
=
text
[:
-
5
].
rstrip
()
return
text
vllm/tool_parsers/__init__.py
View file @
858bddce
...
...
@@ -154,6 +154,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"functiongemma_tool_parser"
,
"FunctionGemmaToolParser"
,
),
"gemma4"
:
(
"gemma4_tool_parser"
,
"Gemma4ToolParser"
,
),
}
...
...
vllm/tool_parsers/gemma4_tool_parser.py
0 → 100644
View file @
858bddce
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tool call parser for Google Gemma4 models.
Gemma4 uses a custom serialization format (not JSON) for tool calls::
<|tool_call>call:func_name{key:<|"|>value<|"|>,num:42}<tool_call|>
Strings are delimited by ``<|"|>`` (token 52), keys are unquoted, and
multiple tool calls are concatenated without separators.
Used when ``--enable-auto-tool-choice --tool-call-parser gemma4`` are set.
For offline inference tool call parsing (direct ``tokenizer.decode()`` output),
see ``vllm.tool_parsers.gemma4_utils.parse_tool_calls``.
"""
import
json
from
collections.abc
import
Sequence
import
regex
as
re
from
vllm.entrypoints.chat_utils
import
make_tool_call_id
from
vllm.entrypoints.openai.chat_completion.protocol
import
(
ChatCompletionRequest
,
)
from
vllm.entrypoints.openai.engine.protocol
import
(
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ExtractedToolCallInformation
,
FunctionCall
,
ToolCall
,
)
from
vllm.entrypoints.openai.responses.protocol
import
(
ResponsesRequest
,
)
from
vllm.logger
import
init_logger
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tool_parsers.abstract_tool_parser
import
Tool
,
ToolParser
from
vllm.tool_parsers.utils
import
find_common_prefix
logger
=
init_logger
(
__name__
)
# Gemma4 special tokens for tool calls
TOOL_CALL_START
=
"<|tool_call>"
TOOL_CALL_END
=
"<tool_call|>"
STRING_DELIM
=
'<|"|>'
# ---------------------------------------------------------------------------
# Gemma4 argument parser (used by both streaming and non-streaming paths)
# ---------------------------------------------------------------------------
def
_parse_gemma4_value
(
value_str
:
str
)
->
object
:
"""Parse a single Gemma4 value (after key:) into a Python object."""
value_str
=
value_str
.
strip
()
if
not
value_str
:
return
value_str
# Boolean
if
value_str
==
"true"
:
return
True
if
value_str
==
"false"
:
return
False
# Number (int or float)
try
:
if
"."
in
value_str
:
return
float
(
value_str
)
return
int
(
value_str
)
except
ValueError
:
pass
# Bare string (no <|"|> delimiters — shouldn't happen but be safe)
return
value_str
def
_parse_gemma4_args
(
args_str
:
str
,
*
,
partial
:
bool
=
False
)
->
dict
:
"""Parse Gemma4's custom key:value format into a Python dict.
Format examples::
location:<|"|>Tokyo<|"|>
location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>
count:42,flag:true
nested:{inner_key:<|"|>val<|"|>}
items:[<|"|>a<|"|>,<|"|>b<|"|>]
Args:
args_str: The raw Gemma4 argument string.
partial: When True (streaming), bare values at end of string are
omitted because they may be incomplete and type-unstable
(e.g. partial boolean parsed as bare string).
Returns a dict ready for ``json.dumps()``.
"""
if
not
args_str
or
not
args_str
.
strip
():
return
{}
result
:
dict
=
{}
i
=
0
n
=
len
(
args_str
)
while
i
<
n
:
# Skip whitespace and commas
while
i
<
n
and
args_str
[
i
]
in
(
" "
,
","
,
"
\n
"
,
"
\t
"
):
i
+=
1
if
i
>=
n
:
break
# Parse key (unquoted, ends at ':')
key_start
=
i
while
i
<
n
and
args_str
[
i
]
!=
":"
:
i
+=
1
if
i
>=
n
:
break
key
=
args_str
[
key_start
:
i
].
strip
()
i
+=
1
# skip ':'
# Parse value
if
i
>=
n
:
if
not
partial
:
result
[
key
]
=
""
break
# Skip whitespace after ':'
while
i
<
n
and
args_str
[
i
]
in
(
" "
,
"
\n
"
,
"
\t
"
):
i
+=
1
if
i
>=
n
:
if
not
partial
:
result
[
key
]
=
""
break
# String value: <|"|>...<|"|>
if
args_str
[
i
:].
startswith
(
STRING_DELIM
):
i
+=
len
(
STRING_DELIM
)
val_start
=
i
end_pos
=
args_str
.
find
(
STRING_DELIM
,
i
)
if
end_pos
==
-
1
:
# Unterminated string — take rest
result
[
key
]
=
args_str
[
val_start
:]
break
result
[
key
]
=
args_str
[
val_start
:
end_pos
]
i
=
end_pos
+
len
(
STRING_DELIM
)
# Nested object: {...}
elif
args_str
[
i
]
==
"{"
:
depth
=
1
obj_start
=
i
+
1
i
+=
1
while
i
<
n
and
depth
>
0
:
if
args_str
[
i
:].
startswith
(
STRING_DELIM
):
# Skip over string contents to avoid counting { inside strings
i
+=
len
(
STRING_DELIM
)
next_delim
=
args_str
.
find
(
STRING_DELIM
,
i
)
i
=
n
if
next_delim
==
-
1
else
next_delim
+
len
(
STRING_DELIM
)
continue
if
args_str
[
i
]
==
"{"
:
depth
+=
1
elif
args_str
[
i
]
==
"}"
:
depth
-=
1
i
+=
1
if
depth
>
0
:
# Incomplete nested object — use i (not i-1) to avoid
# dropping the last char, and recurse as partial.
result
[
key
]
=
_parse_gemma4_args
(
args_str
[
obj_start
:
i
],
partial
=
True
)
else
:
result
[
key
]
=
_parse_gemma4_args
(
args_str
[
obj_start
:
i
-
1
])
# Array: [...]
elif
args_str
[
i
]
==
"["
:
depth
=
1
arr_start
=
i
+
1
i
+=
1
while
i
<
n
and
depth
>
0
:
if
args_str
[
i
:].
startswith
(
STRING_DELIM
):
i
+=
len
(
STRING_DELIM
)
next_delim
=
args_str
.
find
(
STRING_DELIM
,
i
)
i
=
n
if
next_delim
==
-
1
else
next_delim
+
len
(
STRING_DELIM
)
continue
if
args_str
[
i
]
==
"["
:
depth
+=
1
elif
args_str
[
i
]
==
"]"
:
depth
-=
1
i
+=
1
if
depth
>
0
:
result
[
key
]
=
_parse_gemma4_array
(
args_str
[
arr_start
:
i
],
partial
=
True
)
else
:
result
[
key
]
=
_parse_gemma4_array
(
args_str
[
arr_start
:
i
-
1
])
# Bare value (number, boolean, etc.)
else
:
val_start
=
i
while
i
<
n
and
args_str
[
i
]
not
in
(
","
,
"}"
,
"]"
):
i
+=
1
if
partial
and
i
>=
n
:
# Value may be incomplete (e.g. partial boolean) —
# withhold to avoid type instability during streaming.
break
result
[
key
]
=
_parse_gemma4_value
(
args_str
[
val_start
:
i
])
return
result
def
_parse_gemma4_array
(
arr_str
:
str
,
*
,
partial
:
bool
=
False
)
->
list
:
"""Parse a Gemma4 array content string into a Python list."""
items
:
list
=
[]
i
=
0
n
=
len
(
arr_str
)
while
i
<
n
:
while
i
<
n
and
arr_str
[
i
]
in
(
" "
,
","
,
"
\n
"
,
"
\t
"
):
i
+=
1
if
i
>=
n
:
break
# String element
if
arr_str
[
i
:].
startswith
(
STRING_DELIM
):
i
+=
len
(
STRING_DELIM
)
end_pos
=
arr_str
.
find
(
STRING_DELIM
,
i
)
if
end_pos
==
-
1
:
items
.
append
(
arr_str
[
i
:])
break
items
.
append
(
arr_str
[
i
:
end_pos
])
i
=
end_pos
+
len
(
STRING_DELIM
)
# Nested object
elif
arr_str
[
i
]
==
"{"
:
depth
=
1
obj_start
=
i
+
1
i
+=
1
while
i
<
n
and
depth
>
0
:
if
arr_str
[
i
:].
startswith
(
STRING_DELIM
):
i
+=
len
(
STRING_DELIM
)
nd
=
arr_str
.
find
(
STRING_DELIM
,
i
)
i
=
nd
+
len
(
STRING_DELIM
)
if
nd
!=
-
1
else
n
continue
if
arr_str
[
i
]
==
"{"
:
depth
+=
1
elif
arr_str
[
i
]
==
"}"
:
depth
-=
1
i
+=
1
if
depth
>
0
:
items
.
append
(
_parse_gemma4_args
(
arr_str
[
obj_start
:
i
],
partial
=
True
))
else
:
items
.
append
(
_parse_gemma4_args
(
arr_str
[
obj_start
:
i
-
1
]))
# Nested array
elif
arr_str
[
i
]
==
"["
:
depth
=
1
sub_start
=
i
+
1
i
+=
1
while
i
<
n
and
depth
>
0
:
if
arr_str
[
i
]
==
"["
:
depth
+=
1
elif
arr_str
[
i
]
==
"]"
:
depth
-=
1
i
+=
1
if
depth
>
0
:
items
.
append
(
_parse_gemma4_array
(
arr_str
[
sub_start
:
i
],
partial
=
True
))
else
:
items
.
append
(
_parse_gemma4_array
(
arr_str
[
sub_start
:
i
-
1
]))
# Bare value
else
:
val_start
=
i
while
i
<
n
and
arr_str
[
i
]
not
in
(
","
,
"]"
):
i
+=
1
if
partial
and
i
>=
n
:
break
items
.
append
(
_parse_gemma4_value
(
arr_str
[
val_start
:
i
]))
return
items
# ---------------------------------------------------------------------------
# Parser
# ---------------------------------------------------------------------------
class
Gemma4ToolParser
(
ToolParser
):
"""
Tool call parser for Google Gemma4 models.
Handles the Gemma4 function call format::
<|tool_call>call:func_name{key:<|"|>value<|"|>}<tool_call|>
Used when ``--enable-auto-tool-choice --tool-call-parser gemma4``
are set.
Streaming strategy: **accumulate-then-parse-then-diff**
Instead of trying to convert Gemma4's custom format to JSON
token-by-token (which fails because Gemma4 uses bare keys, custom
delimiters, and structural braces that differ from JSON), this parser:
1. Accumulates the raw Gemma4 argument string during streaming
2. Parses it with ``_parse_gemma4_args()`` into a Python dict
3. Converts to JSON with ``json.dumps()``
4. Diffs against the previously-streamed JSON string
5. Emits only the new JSON fragment as the delta
This follows the same pattern used by FunctionGemma, Hermes, and Llama
tool parsers.
"""
def
__init__
(
self
,
tokenizer
:
TokenizerLike
,
tools
:
list
[
Tool
]
|
None
=
None
):
super
().
__init__
(
tokenizer
,
tools
)
if
not
self
.
model_tokenizer
:
raise
ValueError
(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
# Token strings
self
.
tool_call_start_token
=
TOOL_CALL_START
self
.
tool_call_end_token
=
TOOL_CALL_END
# Token IDs
self
.
tool_call_start_token_id
=
self
.
vocab
.
get
(
TOOL_CALL_START
)
self
.
tool_call_end_token_id
=
self
.
vocab
.
get
(
TOOL_CALL_END
)
if
self
.
tool_call_start_token_id
is
None
:
raise
RuntimeError
(
"Gemma4 ToolParser could not locate the tool call start "
f
"token '
{
TOOL_CALL_START
}
' in the tokenizer!"
)
# Regex for non-streaming: extract complete tool calls.
# Supports function names with letters, digits, underscores,
# hyphens, and dots (e.g. "get-weather", "module.func").
self
.
tool_call_regex
=
re
.
compile
(
r
"<\|tool_call>call:([\w\-\.]+)\{(.*?)\}<tool_call\|>"
,
re
.
DOTALL
,
)
# Streaming state — reset per-request via _reset_streaming_state()
self
.
_reset_streaming_state
()
# Delta buffer for handling multi-token special sequences
self
.
buffered_delta_text
=
""
def
_reset_streaming_state
(
self
)
->
None
:
"""Reset all streaming state for a new request."""
self
.
current_tool_id
=
-
1
self
.
current_tool_name_sent
=
False
self
.
prev_tool_call_arr
:
list
[
dict
]
=
[]
self
.
streamed_args_for_tool
:
list
[
str
]
=
[]
def
adjust_request
(
self
,
request
:
ChatCompletionRequest
|
ResponsesRequest
)
->
ChatCompletionRequest
|
ResponsesRequest
:
request
=
super
().
adjust_request
(
request
)
if
(
isinstance
(
request
,
ChatCompletionRequest
)
and
request
.
tools
and
request
.
tool_choice
!=
"none"
):
# Don't skip special tokens — <|tool_call> etc. are needed
request
.
skip_special_tokens
=
False
return
request
# ------------------------------------------------------------------
# Delta buffering for multi-token special sequences
# ------------------------------------------------------------------
def
_buffer_delta_text
(
self
,
delta_text
:
str
)
->
str
:
"""Buffer incoming delta text to handle multi-token special sequences.
Accumulates partial tokens that could be the start of
``<|tool_call>`` or ``<tool_call|>`` and only flushes them
when the complete sequence is recognized or the sequence breaks.
This prevents partial special tokens (e.g., ``<|tool``) from being
emitted prematurely as content text.
"""
combined
=
self
.
buffered_delta_text
+
delta_text
# Check if combined ends with a complete special token
if
combined
.
endswith
(
TOOL_CALL_START
)
or
combined
.
endswith
(
TOOL_CALL_END
):
self
.
buffered_delta_text
=
""
return
combined
# Check if combined ends with a partial prefix of a special token
for
tag
in
[
TOOL_CALL_START
,
TOOL_CALL_END
]:
for
i
in
range
(
1
,
len
(
tag
)):
if
combined
.
endswith
(
tag
[:
i
]):
self
.
buffered_delta_text
=
combined
[
-
i
:]
return
combined
[:
-
i
]
# No partial match — flush everything
self
.
buffered_delta_text
=
""
return
combined
# ------------------------------------------------------------------
# Non-streaming extraction
# ------------------------------------------------------------------
def
extract_tool_calls
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
,
)
->
ExtractedToolCallInformation
:
if
self
.
tool_call_start_token
not
in
model_output
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
try
:
matches
=
self
.
tool_call_regex
.
findall
(
model_output
)
if
not
matches
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
tool_calls
:
list
[
ToolCall
]
=
[]
for
func_name
,
args_str
in
matches
:
arguments
=
_parse_gemma4_args
(
args_str
)
tool_calls
.
append
(
ToolCall
(
type
=
"function"
,
function
=
FunctionCall
(
name
=
func_name
,
arguments
=
json
.
dumps
(
arguments
,
ensure_ascii
=
False
),
),
)
)
# Content = text before first tool call (if any)
content_end
=
model_output
.
find
(
self
.
tool_call_start_token
)
content
=
model_output
[:
content_end
].
strip
()
if
content_end
>
0
else
None
return
ExtractedToolCallInformation
(
tools_called
=
True
,
tool_calls
=
tool_calls
,
content
=
content
if
content
else
None
,
)
except
Exception
:
logger
.
exception
(
"Error extracting tool calls from Gemma4 response"
)
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
# ------------------------------------------------------------------
# Streaming extraction — accumulate-then-parse-then-diff
# ------------------------------------------------------------------
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
DeltaMessage
|
None
:
# Buffer delta text to handle multi-token special sequences
delta_text
=
self
.
_buffer_delta_text
(
delta_text
)
# Keep current_text from the upstream stream state. The buffered delta
# is only for emission, and must not be stitched back into the
# accumulated model text or normal content like "<div>" can be
# duplicated into "<<div>" when a tool call just ended.
# If no tool call token seen yet, emit as content
if
self
.
tool_call_start_token
not
in
current_text
:
if
delta_text
:
return
DeltaMessage
(
content
=
delta_text
)
return
None
try
:
return
self
.
_extract_streaming
(
previous_text
=
previous_text
,
current_text
=
current_text
,
delta_text
=
delta_text
,
)
except
Exception
:
logger
.
exception
(
"Error in Gemma4 streaming tool call extraction"
)
return
None
def
_extract_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
)
->
DeltaMessage
|
None
:
"""Tag-counting streaming parser.
Uses the proven approach from FunctionGemma/Hermes: count start/end
tags in previous vs current text to determine phase, then
accumulate-parse-diff for arguments.
Format: ``<|tool_call>call:name{args}<tool_call|>``
"""
start_count
=
current_text
.
count
(
self
.
tool_call_start_token
)
end_count
=
current_text
.
count
(
self
.
tool_call_end_token
)
prev_start_count
=
previous_text
.
count
(
self
.
tool_call_start_token
)
prev_end_count
=
previous_text
.
count
(
self
.
tool_call_end_token
)
# Case 1: Not inside any tool call — emit as content
if
(
start_count
==
end_count
and
prev_end_count
==
end_count
and
self
.
tool_call_end_token
not
in
delta_text
):
if
delta_text
:
return
DeltaMessage
(
content
=
delta_text
)
return
None
# Case 2: Starting a new tool call
if
start_count
>
prev_start_count
and
start_count
>
end_count
:
self
.
current_tool_id
+=
1
self
.
current_tool_name_sent
=
False
self
.
streamed_args_for_tool
.
append
(
""
)
self
.
prev_tool_call_arr
.
append
({})
logger
.
debug
(
"Starting new tool call %d"
,
self
.
current_tool_id
)
# Don't return yet — fall through to try parsing if there's
# content after <|tool_call> in this same delta
# (but usually it's just the token itself, so return None)
if
len
(
delta_text
)
<=
len
(
self
.
tool_call_start_token
):
return
None
# Case 3: Tool call just ended
if
end_count
>
prev_end_count
:
return
self
.
_handle_tool_call_end
(
current_text
)
# Case 4: In the middle of a tool call — parse partial content
if
start_count
>
end_count
:
return
self
.
_handle_tool_call_middle
(
current_text
)
# Default: generate text outside tool calls
if
delta_text
:
text
=
delta_text
.
replace
(
self
.
tool_call_start_token
,
""
)
text
=
text
.
replace
(
self
.
tool_call_end_token
,
""
)
if
text
:
return
DeltaMessage
(
content
=
text
)
return
None
def
_extract_partial_call
(
self
,
current_text
:
str
)
->
tuple
[
str
|
None
,
str
]:
"""Extract function name and raw argument string from partial text.
Returns (func_name, raw_args_str) or (None, "") if not parseable yet.
"""
# Get the text after the last <|tool_call> token
last_start
=
current_text
.
rfind
(
self
.
tool_call_start_token
)
if
last_start
==
-
1
:
return
None
,
""
partial_call
=
current_text
[
last_start
+
len
(
self
.
tool_call_start_token
)
:]
# Strip end token if present
if
self
.
tool_call_end_token
in
partial_call
:
partial_call
=
partial_call
.
split
(
self
.
tool_call_end_token
)[
0
]
# Expect "call:name{args...}" or "call:name{args...}"
if
not
partial_call
.
startswith
(
"call:"
):
return
None
,
""
func_part
=
partial_call
[
5
:]
# skip "call:"
if
"{"
not
in
func_part
:
# Still accumulating function name, not ready yet
return
None
,
""
func_name
,
_
,
args_part
=
func_part
.
partition
(
"{"
)
func_name
=
func_name
.
strip
()
# Strip trailing '}' if present (Gemma4 structural brace)
if
args_part
.
endswith
(
"}"
):
args_part
=
args_part
[:
-
1
]
return
func_name
,
args_part
def
_handle_tool_call_middle
(
self
,
current_text
:
str
)
->
DeltaMessage
|
None
:
"""Handle streaming when we're inside an active tool call.
Accumulates the raw Gemma4 arguments, parses them into JSON, and
diffs against the previously-streamed JSON to emit only the new
fragment.
"""
func_name
,
args_part
=
self
.
_extract_partial_call
(
current_text
)
if
func_name
is
None
:
return
None
# Step 1: Send function name (once)
if
not
self
.
current_tool_name_sent
and
func_name
:
self
.
current_tool_name_sent
=
True
self
.
prev_tool_call_arr
[
self
.
current_tool_id
]
=
{
"name"
:
func_name
,
"arguments"
:
{},
}
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
type
=
"function"
,
id
=
make_tool_call_id
(),
function
=
DeltaFunctionCall
(
name
=
func_name
,
arguments
=
""
,
).
model_dump
(
exclude_none
=
True
),
)
]
)
# Step 2: Parse and diff arguments
if
self
.
current_tool_name_sent
and
args_part
:
return
self
.
_emit_argument_diff
(
args_part
)
return
None
def
_handle_tool_call_end
(
self
,
current_text
:
str
)
->
DeltaMessage
|
None
:
"""Handle streaming when a tool call has just completed.
Performs a final parse of the complete tool call and flushes
any remaining un-streamed argument fragments.
"""
if
self
.
current_tool_id
<
0
or
self
.
current_tool_id
>=
len
(
self
.
prev_tool_call_arr
):
logger
.
debug
(
"Tool call end detected but no active tool call (current_tool_id=%d)"
,
self
.
current_tool_id
,
)
return
None
# Parse the complete tool call using regex for accuracy
all_matches
=
self
.
tool_call_regex
.
findall
(
current_text
)
if
self
.
current_tool_id
<
len
(
all_matches
):
_
,
args_str
=
all_matches
[
self
.
current_tool_id
]
final_args
=
_parse_gemma4_args
(
args_str
)
final_args_json
=
json
.
dumps
(
final_args
,
ensure_ascii
=
False
)
prev_streamed
=
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
if
len
(
final_args_json
)
>
len
(
prev_streamed
):
diff
=
final_args_json
[
len
(
prev_streamed
)
:]
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
=
final_args_json
self
.
prev_tool_call_arr
[
self
.
current_tool_id
][
"arguments"
]
=
final_args
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
diff
).
model_dump
(
exclude_none
=
True
),
)
]
)
return
None
def
_emit_argument_diff
(
self
,
raw_args_str
:
str
)
->
DeltaMessage
|
None
:
"""Parse raw Gemma4 arguments, convert to JSON, diff, and emit.
This is the core of the accumulate-then-parse-then-diff strategy:
1. Parse ``raw_args_str`` with ``_parse_gemma4_args()``
2. Convert to JSON string with ``json.dumps()``
3. Withhold trailing closing characters (``"}``) that may move
as more tokens arrive
4. Diff against previously streamed JSON and emit only new chars
**Why withholding is necessary:**
Gemma4's custom format produces *structurally incomplete* JSON
during streaming. For example, when ``<|"|>Paris`` arrives
without a closing delimiter, ``_parse_gemma4_args`` treats it
as a complete value and produces ``{"location": "Paris"}``. But
when ``, France<|"|>`` arrives next, the JSON becomes
``{"location": "Paris, France"}``. If we had sent the closing
``"}`` from the first parse, the concatenated client output
would be ``{"location": "Paris"}France"}``, which is garbage.
The solution: **never send trailing closing chars during
streaming**. They get flushed by ``_handle_tool_call_end()``
when the ``<tool_call|>`` end marker arrives.
Args:
raw_args_str: The raw Gemma4 argument text accumulated so far
(without the surrounding ``{`` ``}``).
Returns:
DeltaMessage with the argument diff, or None if no new content.
"""
try
:
current_args
=
_parse_gemma4_args
(
raw_args_str
,
partial
=
True
)
except
Exception
:
logger
.
debug
(
"Could not parse partial Gemma4 args yet: %s"
,
raw_args_str
[:
100
],
)
return
None
if
not
current_args
:
return
None
current_args_json
=
json
.
dumps
(
current_args
,
ensure_ascii
=
False
)
# Withhold trailing closing characters that may shift as more
# tokens arrive. Strip trailing '}', '"', ']' and partial
# STRING_DELIM fragments ('<', '|', '\\', '>') to get the
# "safe prefix".
safe_json
=
current_args_json
while
safe_json
and
safe_json
[
-
1
]
in
(
"}"
,
'"'
,
"]"
,
"<"
,
"|"
,
"
\\
"
,
">"
):
safe_json
=
safe_json
[:
-
1
]
prev_streamed
=
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
if
not
safe_json
or
safe_json
==
prev_streamed
:
return
None
# Use find_common_prefix to handle cases where the value changed
# structurally (e.g., a string grew).
if
prev_streamed
:
prefix
=
find_common_prefix
(
prev_streamed
,
safe_json
)
sent_len
=
len
(
prev_streamed
)
prefix_len
=
len
(
prefix
)
if
prefix_len
<
sent_len
:
# Structure changed — we sent too much. Truncate our
# tracking to the common prefix and wait for the final
# flush in _handle_tool_call_end.
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
=
prefix
return
None
# Stream the new stable portion
diff
=
safe_json
[
sent_len
:]
else
:
# First emission
diff
=
safe_json
if
diff
:
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
=
safe_json
self
.
prev_tool_call_arr
[
self
.
current_tool_id
][
"arguments"
]
=
current_args
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
diff
).
model_dump
(
exclude_none
=
True
),
)
]
)
return
None
vllm/transformers_utils/model_arch_config_convertor.py
View file @
858bddce
...
...
@@ -300,6 +300,28 @@ class ModelArchConfigConvertorBase:
return
model_arch_config
class
CohereAsrModelArchConfigConvertor
(
ModelArchConfigConvertorBase
):
def
get_total_num_attention_heads
(
self
)
->
int
:
return
self
.
hf_text_config
.
transf_decoder
[
"config_dict"
][
"num_attention_heads"
]
def
get_head_size
(
self
)
->
int
:
hidden_size
=
self
.
hf_text_config
.
transf_decoder
[
"config_dict"
][
"hidden_size"
]
num_attention_heads
=
self
.
hf_text_config
.
transf_decoder
[
"config_dict"
][
"num_attention_heads"
]
return
hidden_size
//
num_attention_heads
def
get_total_num_kv_heads
(
self
)
->
int
:
enc_num_kv_heads
=
self
.
hf_text_config
.
encoder
[
"n_heads"
]
dec_num_kv_heads
=
self
.
hf_text_config
.
transf_decoder
[
"config_dict"
][
"num_attention_heads"
]
assert
enc_num_kv_heads
==
dec_num_kv_heads
,
(
"Encoder and decoder must have the same number of kv heads"
)
return
enc_num_kv_heads
class
MambaModelArchConfigConvertor
(
ModelArchConfigConvertorBase
):
def
get_head_size
(
self
)
->
int
:
return
0
...
...
@@ -423,6 +445,16 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
return
getattr
(
self
.
hf_text_config
,
"num_nextn_predict_layers"
,
1
)
class
Gemma4ModelArchConfigConvertor
(
ModelArchConfigConvertorBase
):
def
get_head_size
(
self
)
->
int
:
# Gemma4 uses dual head dimensions: head_dim (sliding attention)
# and global_head_dim (full attention). Return the largest so
# that attention backends allocate buffers large enough for both.
head_dim
=
getattr
(
self
.
hf_text_config
,
"head_dim"
,
0
)
global_head_dim
=
getattr
(
self
.
hf_text_config
,
"global_head_dim"
,
0
)
return
max
(
head_dim
,
global_head_dim
)
or
super
().
get_head_size
()
# hf_config.model_type -> convertor class
MODEL_ARCH_CONFIG_CONVERTORS
=
{
"mamba"
:
MambaModelArchConfigConvertor
,
...
...
@@ -433,6 +465,8 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
"mpt"
:
MPTModelArchConfigConvertor
,
"dbrx"
:
DbrxModelArchConfigConvertor
,
"falcon"
:
FalconModelArchConfigConvertor
,
"gemma4"
:
Gemma4ModelArchConfigConvertor
,
"gemma4_text"
:
Gemma4ModelArchConfigConvertor
,
"RefinedWeb"
:
FalconModelArchConfigConvertor
,
"RefinedWebModel"
:
FalconModelArchConfigConvertor
,
"nemotron-nas"
:
NemotronNasModelArchConfigConvertor
,
...
...
vllm/v1/attention/ops/triton_unified_attention.py
View file @
858bddce
...
...
@@ -1040,6 +1040,8 @@ def unified_attention(
num_seqs
=
num_seqs
,
BLOCK_M
=
BLOCK_M
,
USE_FP8
=
output_scale
is
not
None
,
num_stages
=
1
)
else
:
kernel_unified_attention_3d
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment