Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
858bddce
Commit
858bddce
authored
Apr 09, 2026
by
luopl
Browse files
feat:add gemma4
parent
40faaf0c
Changes
14
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
4233 additions
and
1 deletion
+4233
-1
vllm/model_executor/layers/rotary_embedding/__init__.py
vllm/model_executor/layers/rotary_embedding/__init__.py
+12
-0
vllm/model_executor/layers/rotary_embedding/gemma4_rope.py
vllm/model_executor/layers/rotary_embedding/gemma4_rope.py
+84
-0
vllm/model_executor/models/config.py
vllm/model_executor/models/config.py
+54
-0
vllm/model_executor/models/gemma4.py
vllm/model_executor/models/gemma4.py
+1581
-0
vllm/model_executor/models/gemma4_mm.py
vllm/model_executor/models/gemma4_mm.py
+1339
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+2
-0
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+8
-1
vllm/reasoning/__init__.py
vllm/reasoning/__init__.py
+4
-0
vllm/reasoning/gemma4_reasoning_parser.py
vllm/reasoning/gemma4_reasoning_parser.py
+225
-0
vllm/reasoning/gemma4_utils.py
vllm/reasoning/gemma4_utils.py
+130
-0
vllm/tool_parsers/__init__.py
vllm/tool_parsers/__init__.py
+4
-0
vllm/tool_parsers/gemma4_tool_parser.py
vllm/tool_parsers/gemma4_tool_parser.py
+754
-0
vllm/transformers_utils/model_arch_config_convertor.py
vllm/transformers_utils/model_arch_config_convertor.py
+34
-0
vllm/v1/attention/ops/triton_unified_attention.py
vllm/v1/attention/ops/triton_unified_attention.py
+2
-0
No files found.
vllm/model_executor/layers/rotary_embedding/__init__.py
View file @
858bddce
...
@@ -13,6 +13,7 @@ from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
...
@@ -13,6 +13,7 @@ from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
from
.dynamic_ntk_scaling_rope
import
DynamicNTKScalingRotaryEmbedding
from
.dynamic_ntk_scaling_rope
import
DynamicNTKScalingRotaryEmbedding
from
.fope
import
FourierRotaryEmbedding
from
.fope
import
FourierRotaryEmbedding
from
.linear_scaling_rope
import
LinearScalingRotaryEmbedding
from
.linear_scaling_rope
import
LinearScalingRotaryEmbedding
from
.gemma4_rope
import
Gemma4RotaryEmbedding
from
.llama3_rope
import
Llama3RotaryEmbedding
from
.llama3_rope
import
Llama3RotaryEmbedding
from
.llama4_vision_rope
import
Llama4VisionRotaryEmbedding
from
.llama4_vision_rope
import
Llama4VisionRotaryEmbedding
from
.mrope
import
MRotaryEmbedding
from
.mrope
import
MRotaryEmbedding
...
@@ -134,6 +135,17 @@ def get_rope(
...
@@ -134,6 +135,17 @@ def get_rope(
is_neox_style
,
is_neox_style
,
dtype
,
dtype
,
)
)
elif
scaling_type
==
"proportional"
:
# Proportional RoPE is used by Gemma4 for global (full) attention.
# Gemma4 uses a sparse/fractional RoPE with cross-mixing between halves.
rotary_emb
=
Gemma4RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
)
elif
scaling_type
==
"llama3"
:
elif
scaling_type
==
"llama3"
:
scaling_factor
=
rope_parameters
[
"factor"
]
scaling_factor
=
rope_parameters
[
"factor"
]
low_freq_factor
=
rope_parameters
[
"low_freq_factor"
]
low_freq_factor
=
rope_parameters
[
"low_freq_factor"
]
...
...
vllm/model_executor/layers/rotary_embedding/gemma4_rope.py
0 → 100644
View file @
858bddce
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Gemma4-specific Rotary Positional Embeddings (proportional scaling).
Gemma4 uses "proportional" RoPE which computes inv_freq frequencies scaled
by head_dim (not rotary_dim), and zero-pads for non-rotated dimensions when
partial_rotary_factor < 1. The actual rotation uses standard neox-style
rotate_half, matching HF transformers' apply_rotary_pos_emb.
"""
import
torch
from
.base
import
RotaryEmbedding
class
Gemma4RotaryEmbedding
(
RotaryEmbedding
):
"""Gemma4 proportional RoPE.
Extends RotaryEmbedding (which provides standard neox-style rotation
via ops.rotary_embedding CUDA kernel) but overrides the inv_freq
computation to match HF's _compute_proportional_rope_parameters:
- Frequency exponents use head_dim (not rotary_dim) as denominator
- Non-rotated dims are zero-padded (cos=1, sin=0 = identity rotation)
When partial_rotary_factor=1.0 (the default for some variants), ALL dims are
rotated and this is equivalent to standard RotaryEmbedding with
head_dim-scaled frequencies.
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
)
->
None
:
# Number of rotation angle pairs (from partial_rotary_factor)
self
.
rope_angles
=
rotary_dim
//
2
# Non-rotated angle pairs per half
self
.
nope_angles
=
(
head_size
//
2
)
-
self
.
rope_angles
# Important: set rotary_dim = head_size so the base class's
# forward_static applies rotation to ALL dims of the cos/sin cache.
# The non-rotated dims will have cos=1, sin=0 (identity) thanks
# to our _compute_inv_freq zero-padding.
super
().
__init__
(
head_size
,
head_size
,
# rotary_dim = head_size (full application)
max_position_embeddings
,
base
,
is_neox_style
,
dtype
,
)
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
"""Compute frequencies matching HF proportional RoPE.
Key difference from base: exponent denominator is head_size (not
rotary_dim), and non-rotated dims are zero-padded.
"""
# HF formula: base ** (arange(0, 2*rope_angles, 2) / head_dim)
freq_exponents
=
(
torch
.
arange
(
0
,
2
*
self
.
rope_angles
,
2
,
dtype
=
torch
.
float
)
/
self
.
head_size
)
inv_freq
=
1.0
/
(
base
**
freq_exponents
)
# Zero-pad for non-rotated dims (identity rotation: cos=1, sin=0)
if
self
.
nope_angles
>
0
:
inv_freq
=
torch
.
cat
(
[
inv_freq
,
torch
.
zeros
(
self
.
nope_angles
,
dtype
=
torch
.
float
),
]
)
return
inv_freq
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", rope_angles=
{
self
.
rope_angles
}
, nope_angles=
{
self
.
nope_angles
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
s
+=
f
", base=
{
self
.
base
}
, is_neox_style=
{
self
.
is_neox_style
}
"
return
s
vllm/model_executor/models/config.py
View file @
858bddce
...
@@ -56,6 +56,57 @@ class Gemma3TextModelConfig(VerifyAndUpdateConfig):
...
@@ -56,6 +56,57 @@ class Gemma3TextModelConfig(VerifyAndUpdateConfig):
hf_config
=
model_config
.
hf_config
hf_config
=
model_config
.
hf_config
hf_config
.
is_causal
=
not
hf_config
.
use_bidirectional_attention
hf_config
.
is_causal
=
not
hf_config
.
use_bidirectional_attention
class
Gemma4Config
(
VerifyAndUpdateConfig
):
@
staticmethod
def
verify_and_update_config
(
vllm_config
:
"VllmConfig"
)
->
None
:
"""Force unified attention backend for models with heterogeneous
head dimensions.
Some Gemma4 variants use different head dimensions for
sliding window (head_dim) vs full attention (global_head_dim) layers.
When global_head_dim > 256, FlashAttention rejects those layers
(head_size <= 256 kernel limit), causing vLLM to select a different
backend for each layer type. This mixed-backend execution produces
numerical divergence and output corruption.
The fix detects heterogeneous head dimensions from the model config
and forces TRITON_ATTN (which has no head_size ceiling) for all
layers when the user hasn't explicitly chosen a backend.
TODO: Heterogeneous head_sizes (head_dim != global_head_dim)
require NixlConnector changes to support per-layer KV transfer
with different head dimensions for prefill-decode disaggregation.
"""
hf_text_config
=
vllm_config
.
model_config
.
hf_text_config
head_dim
=
getattr
(
hf_text_config
,
"head_dim"
,
None
)
global_head_dim
=
getattr
(
hf_text_config
,
"global_head_dim"
,
None
)
# Only force Triton when head dimensions actually differ AND the
# larger one exceeds FlashAttention's kernel limit (head_size <= 256).
# This avoids unnecessary backend forcing on smaller models where
# the config carries global_head_dim but all layers can still use
# the same FA backend.
max_head_dim
=
max
(
head_dim
or
0
,
global_head_dim
or
0
)
if
(
head_dim
is
not
None
and
global_head_dim
is
not
None
and
head_dim
!=
global_head_dim
and
max_head_dim
>
256
and
vllm_config
.
attention_config
.
backend
is
None
):
from
vllm.v1.attention.backends.registry
import
(
AttentionBackendEnum
,
)
vllm_config
.
attention_config
.
backend
=
AttentionBackendEnum
.
TRITON_ATTN
logger
.
info
(
"Gemma4 model has heterogeneous head dimensions "
"(head_dim=%d, global_head_dim=%d). Forcing TRITON_ATTN "
"backend to prevent mixed-backend numerical divergence."
,
head_dim
,
global_head_dim
,
)
class
GptOssForCausalLMConfig
(
VerifyAndUpdateConfig
):
class
GptOssForCausalLMConfig
(
VerifyAndUpdateConfig
):
@
staticmethod
@
staticmethod
...
@@ -647,10 +698,13 @@ class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig):
...
@@ -647,10 +698,13 @@ class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig):
MODELS_CONFIG_MAP
:
dict
[
str
,
type
[
VerifyAndUpdateConfig
]]
=
{
MODELS_CONFIG_MAP
:
dict
[
str
,
type
[
VerifyAndUpdateConfig
]]
=
{
"ColBERTJinaRobertaModel"
:
JinaRobertaModelConfig
,
"ColBERTJinaRobertaModel"
:
JinaRobertaModelConfig
,
"ColQwen3_5"
:
Qwen3_5ForConditionalGenerationConfig
,
"DeepseekV32ForCausalLM"
:
DeepseekV32ForCausalLM
,
"DeepseekV32ForCausalLM"
:
DeepseekV32ForCausalLM
,
"Ernie4_5_VLMoeForConditionalGeneration"
:
Ernie4_5_VLMoeForConditionalGenerationConfig
,
# noqa: E501
"Ernie4_5_VLMoeForConditionalGeneration"
:
Ernie4_5_VLMoeForConditionalGenerationConfig
,
# noqa: E501
"FalconMambaForCausalLM"
:
MambaModelConfig
,
"FalconMambaForCausalLM"
:
MambaModelConfig
,
"Gemma3TextModel"
:
Gemma3TextModelConfig
,
"Gemma3TextModel"
:
Gemma3TextModelConfig
,
"Gemma4ForCausalLM"
:
Gemma4Config
,
"Gemma4ForConditionalGeneration"
:
Gemma4Config
,
"GptOssForCausalLM"
:
GptOssForCausalLMConfig
,
"GptOssForCausalLM"
:
GptOssForCausalLMConfig
,
"GteModel"
:
SnowflakeGteNewModelConfig
,
"GteModel"
:
SnowflakeGteNewModelConfig
,
"GteNewForSequenceClassification"
:
GteNewModelConfig
,
"GteNewForSequenceClassification"
:
GteNewModelConfig
,
...
...
vllm/model_executor/models/gemma4.py
0 → 100644
View file @
858bddce
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/gemma4_mm.py
0 → 100644
View file @
858bddce
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/registry.py
View file @
858bddce
...
@@ -111,6 +111,7 @@ _TEXT_GENERATION_MODELS = {
...
@@ -111,6 +111,7 @@ _TEXT_GENERATION_MODELS = {
"Gemma2ForCausalLM"
:
(
"gemma2"
,
"Gemma2ForCausalLM"
),
"Gemma2ForCausalLM"
:
(
"gemma2"
,
"Gemma2ForCausalLM"
),
"Gemma3ForCausalLM"
:
(
"gemma3"
,
"Gemma3ForCausalLM"
),
"Gemma3ForCausalLM"
:
(
"gemma3"
,
"Gemma3ForCausalLM"
),
"Gemma3nForCausalLM"
:
(
"gemma3n"
,
"Gemma3nForCausalLM"
),
"Gemma3nForCausalLM"
:
(
"gemma3n"
,
"Gemma3nForCausalLM"
),
"Gemma4ForCausalLM"
:
(
"gemma4"
,
"Gemma4ForCausalLM"
),
"Qwen3NextForCausalLM"
:
(
"qwen3_next"
,
"Qwen3NextForCausalLM"
),
"Qwen3NextForCausalLM"
:
(
"qwen3_next"
,
"Qwen3NextForCausalLM"
),
"GlmForCausalLM"
:
(
"glm"
,
"GlmForCausalLM"
),
"GlmForCausalLM"
:
(
"glm"
,
"GlmForCausalLM"
),
"Glm4ForCausalLM"
:
(
"glm4"
,
"Glm4ForCausalLM"
),
"Glm4ForCausalLM"
:
(
"glm4"
,
"Glm4ForCausalLM"
),
...
@@ -377,6 +378,7 @@ _MULTIMODAL_MODELS = {
...
@@ -377,6 +378,7 @@ _MULTIMODAL_MODELS = {
"gemma3n_mm"
,
"gemma3n_mm"
,
"Gemma3nForConditionalGeneration"
,
"Gemma3nForConditionalGeneration"
,
),
),
"Gemma4ForConditionalGeneration"
:
(
"gemma4_mm"
,
"Gemma4ForConditionalGeneration"
),
"GlmAsrForConditionalGeneration"
:
(
"glmasr"
,
"GlmAsrForConditionalGeneration"
),
"GlmAsrForConditionalGeneration"
:
(
"glmasr"
,
"GlmAsrForConditionalGeneration"
),
"GLM4VForCausalLM"
:
(
"glm4v"
,
"GLM4VForCausalLM"
),
"GLM4VForCausalLM"
:
(
"glm4v"
,
"GLM4VForCausalLM"
),
"Glm4vForConditionalGeneration"
:
(
"glm4_1v"
,
"Glm4vForConditionalGeneration"
),
"Glm4vForConditionalGeneration"
:
(
"glm4_1v"
,
"Glm4vForConditionalGeneration"
),
...
...
vllm/model_executor/models/utils.py
View file @
858bddce
...
@@ -233,8 +233,15 @@ class AutoWeightsLoader:
...
@@ -233,8 +233,15 @@ class AutoWeightsLoader:
):
):
"""
"""
Add tensor names that are not in the model params that may be in the
Add tensor names that are not in the model params that may be in the
safetensors, e.g., batch normalization stats.
safetensors, e.g., batch normalization stats
and registered buffers
.
"""
"""
# Add persistent registered buffers.
# Non-persistent buffers are excluded, matching PyTorch state_dict().
non_persistent
=
getattr
(
module
,
"_non_persistent_buffers_set"
,
set
())
for
buf_name
,
buf
in
module
.
named_buffers
(
recurse
=
False
):
if
buf_name
not
in
child_params
and
buf_name
not
in
non_persistent
:
child_params
[
buf_name
]
=
buf
if
isinstance
(
if
isinstance
(
module
,
module
,
(
(
...
...
vllm/reasoning/__init__.py
View file @
858bddce
...
@@ -32,6 +32,10 @@ _REASONING_PARSERS_TO_REGISTER = {
...
@@ -32,6 +32,10 @@ _REASONING_PARSERS_TO_REGISTER = {
"ernie45_reasoning_parser"
,
"ernie45_reasoning_parser"
,
"Ernie45ReasoningParser"
,
"Ernie45ReasoningParser"
,
),
),
"gemma4"
:
(
"gemma4_reasoning_parser"
,
"Gemma4ReasoningParser"
,
),
"glm45"
:
(
"glm45"
:
(
"deepseek_v3_reasoning_parser"
,
"deepseek_v3_reasoning_parser"
,
"DeepSeekV3ReasoningWithThinkingParser"
,
"DeepSeekV3ReasoningWithThinkingParser"
,
...
...
vllm/reasoning/gemma4_reasoning_parser.py
0 → 100644
View file @
858bddce
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Sequence
from
typing
import
TYPE_CHECKING
from
vllm.entrypoints.openai.engine.protocol
import
DeltaMessage
from
vllm.reasoning.basic_parsers
import
BaseThinkingReasoningParser
from
vllm.tokenizers
import
TokenizerLike
if
TYPE_CHECKING
:
from
vllm.entrypoints.openai.chat_completion.protocol
import
(
ChatCompletionRequest
,
)
from
vllm.entrypoints.openai.responses.protocol
import
ResponsesRequest
# Role label that Gemma4 emits at the start of the thinking channel.
# The model generates: <|channel>thought\n...reasoning...<channel|>
# This prefix must be stripped to expose only the actual reasoning content.
_THOUGHT_PREFIX
=
"thought
\n
"
class
Gemma4ReasoningParser
(
BaseThinkingReasoningParser
):
"""
Reasoning parser for Google Gemma4 thinking models.
Gemma4 uses <|channel>...<channel|> tokens to delimit reasoning/thinking
content within its output. Thinking mode is activated by passing
``enable_thinking=True`` in the chat template kwargs, which injects a
system turn containing <|think|> (token 98) to trigger chain-of-thought
reasoning.
Output pattern when thinking is enabled::
<|channel>thought
...chain of thought reasoning...<channel|>
Final answer text here.
The ``thought
\\
n`` role label inside the channel delimiters is a
structural artefact (analogous to ``user
\\
n`` in ``<|turn>user
\\
n...``).
This parser strips it so that downstream consumers see only the
actual reasoning text, consistent with the offline parser
(``vllm.reasoning.gemma4_utils._strip_thought_label``).
"""
def
__init__
(
self
,
tokenizer
:
TokenizerLike
,
*
args
,
**
kwargs
):
super
().
__init__
(
tokenizer
,
*
args
,
**
kwargs
)
# Instance state for streaming prefix stripping.
# Tracks only the reasoning text received from the base parser,
# independent of current_text (which may contain pre-reasoning
# content and lacks special token text due to
# skip_special_tokens=True).
self
.
_reasoning_text
:
str
=
""
self
.
_prefix_stripped
:
bool
=
False
self
.
new_turn_token_id
=
self
.
vocab
[
"<|turn>"
]
self
.
tool_call_token_id
=
self
.
vocab
[
"<|tool_call>"
]
self
.
tool_response_token_id
=
self
.
vocab
[
"<|tool_response>"
]
def
adjust_request
(
self
,
request
:
"ChatCompletionRequest | ResponsesRequest"
)
->
"ChatCompletionRequest | ResponsesRequest"
:
"""Disable special-token stripping to preserve boundary tokens."""
request
.
skip_special_tokens
=
False
return
request
@
property
def
start_token
(
self
)
->
str
:
"""The token that starts reasoning content."""
return
"<|channel>"
@
property
def
end_token
(
self
)
->
str
:
"""The token that ends reasoning content."""
return
"<channel|>"
def
is_reasoning_end
(
self
,
input_ids
:
Sequence
[
int
])
->
bool
:
start_token_id
=
self
.
start_token_id
end_token_id
=
self
.
end_token_id
new_turn_token_id
=
self
.
new_turn_token_id
tool_call_token_id
=
self
.
tool_call_token_id
tool_response_token_id
=
self
.
tool_response_token_id
# Search from the end of input_ids to find the last match.
for
i
in
range
(
len
(
input_ids
)
-
1
,
-
1
,
-
1
):
if
input_ids
[
i
]
==
start_token_id
:
return
False
if
input_ids
[
i
]
==
tool_call_token_id
:
# We're generating a tool call, so reasoning must be ended.
return
True
if
input_ids
[
i
]
in
(
new_turn_token_id
,
tool_response_token_id
):
# We found a new turn or tool response token so don't consider
# reasoning ended yet, since the model starts new reasoning
# after these tokens.
return
False
if
input_ids
[
i
]
==
end_token_id
:
return
True
return
False
# ------------------------------------------------------------------
# Non-streaming path
# ------------------------------------------------------------------
def
extract_reasoning
(
self
,
model_output
:
str
,
request
:
"ChatCompletionRequest | ResponsesRequest"
,
)
->
tuple
[
str
|
None
,
str
|
None
]:
"""Extract reasoning, stripping the ``thought
\\
n`` role label."""
if
self
.
start_token
not
in
model_output
and
self
.
end_token
not
in
model_output
:
# Default to content history if no tags are present
# (or if they were stripped)
return
None
,
model_output
reasoning
,
content
=
super
().
extract_reasoning
(
model_output
,
request
)
if
reasoning
is
not
None
:
reasoning
=
_strip_thought_label
(
reasoning
)
return
reasoning
,
content
# ------------------------------------------------------------------
# Streaming path
# ------------------------------------------------------------------
def
extract_reasoning_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
)
->
DeltaMessage
|
None
:
"""Extract streaming reasoning, stripping ``thought
\\
n`` from the
first reasoning delta(s).
The ``thought
\\
n`` prefix may arrive as a single delta or split
across multiple deltas (e.g. ``"thought"`` then ``"
\\
n"``). We
buffer early reasoning tokens until we can determine whether the
prefix is present, then emit the buffered content minus the
prefix.
Unlike the previous implementation which reconstructed accumulated
reasoning from ``current_text``, this uses instance state
(``_reasoning_text``) to track only the reasoning content returned
by the base parser. This is necessary because
``skip_special_tokens=True`` (the vLLM default) causes the
``<|channel>`` delimiter to be invisible in ``current_text``,
making it impossible to separate pre-reasoning content from
reasoning content via string matching.
"""
result
=
super
().
extract_reasoning_streaming
(
previous_text
,
current_text
,
delta_text
,
previous_token_ids
,
current_token_ids
,
delta_token_ids
,
)
if
result
is
None
:
return
None
if
result
.
reasoning
is
None
:
return
result
# Accumulate ONLY the reasoning text from base parser results.
# This is immune to pre-reasoning content pollution.
self
.
_reasoning_text
+=
result
.
reasoning
# Once the prefix has been handled, all subsequent reasoning
# deltas pass through unchanged.
if
self
.
_prefix_stripped
:
return
result
# ---- Prefix stripping logic ----
# Case 1: We've accumulated enough to confirm the prefix is
# present. Strip it and pass through the remainder.
if
self
.
_reasoning_text
.
startswith
(
_THOUGHT_PREFIX
):
prefix_len
=
len
(
_THOUGHT_PREFIX
)
# How much reasoning was accumulated before this delta?
prev_reasoning_len
=
len
(
self
.
_reasoning_text
)
-
len
(
result
.
reasoning
)
if
prev_reasoning_len
>=
prefix_len
:
# Prefix was already consumed by prior deltas; this
# delta is entirely real content — pass through.
self
.
_prefix_stripped
=
True
return
result
else
:
# Part or all of the prefix is in this delta.
chars_of_prefix_in_delta
=
prefix_len
-
prev_reasoning_len
stripped
=
result
.
reasoning
[
chars_of_prefix_in_delta
:]
if
stripped
:
self
.
_prefix_stripped
=
True
result
.
reasoning
=
stripped
return
result
else
:
if
len
(
self
.
_reasoning_text
)
>=
prefix_len
:
self
.
_prefix_stripped
=
True
result
.
reasoning
=
""
return
result
return
None
# Case 2: Accumulated text is a strict prefix of
# _THOUGHT_PREFIX (e.g. we've only seen "thou" so far).
# Buffer by suppressing — we can't yet tell if this will
# become the full prefix or diverge.
if
_THOUGHT_PREFIX
.
startswith
(
self
.
_reasoning_text
):
return
None
# Case 3: Accumulated text doesn't match the thought prefix
# at all. This means prior deltas were buffered (suppressed
# by Case 2) but the text diverged. Re-emit the full
# accumulated text to avoid data loss.
self
.
_prefix_stripped
=
True
result
.
reasoning
=
self
.
_reasoning_text
return
result
def
_strip_thought_label
(
text
:
str
)
->
str
:
"""Remove the ``thought
\\
n`` role label from the beginning of text.
Mirrors ``vllm.reasoning.gemma4_utils._strip_thought_label`` from the
offline parser.
"""
if
text
.
startswith
(
_THOUGHT_PREFIX
):
return
text
[
len
(
_THOUGHT_PREFIX
)
:]
return
text
vllm/reasoning/gemma4_utils.py
0 → 100644
View file @
858bddce
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
"""Gemma4 thinking/reasoning output parsing utilities for offline inference.
Standalone functions that parse decoded model text to extract structured
thinking content from Gemma4 models. These are pure-Python utilities with
zero heavy dependencies — they work on raw decoded strings from any
inference backend (vLLM, HuggingFace, TGI, etc.).
For the OpenAI-compatible API reasoning parser (streaming +
non-streaming), see ``vllm.reasoning.gemma4_reasoning_parser``.
For tool call parsing, see ``vllm.tool_parsers.gemma4_utils``.
Usage with vLLM offline inference::
from vllm import LLM, SamplingParams
from vllm.reasoning.gemma4_utils import parse_thinking_output
llm = LLM(model="google/gemma-4-it")
outputs = llm.generate(prompt, SamplingParams(...))
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
# Extract thinking / answer (works with or without enable_thinking)
result = parse_thinking_output(text)
print(result["thinking"]) # chain-of-thought or None
print(result["answer"]) # final answer
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
do not need a transformers dependency for output parsing.
"""
# ---- Thinking Mode Utility ----
# Thinking delimiter tokens as they appear in decoded text.
# Gemma4 uses <|channel> (start) and <channel|> (end) as thinking delimiters.
_THINKING_START_TAG
=
"<|channel>"
_THINKING_END_TAG
=
"<channel|>"
# Sentinel tokens that may appear in decoded output.
_TURN_END_TAG
=
"<turn|>"
def
parse_thinking_output
(
text
:
str
)
->
dict
[
str
,
str
|
None
]:
"""Parse decoded Gemma4 model output.
Use this on **all** Gemma4 output regardless of whether thinking mode
was enabled. It handles three cases:
1. **Thinking enabled, tags present** — splits on ``<|channel>``/
``<channel|>`` to separate chain-of-thought from the answer and
strips the ``thought
\\
n`` role label.
2. **Thinking disabled, spurious label** — strips the bare
``thought
\\
n`` prefix that some Gemma4 models emit even
without thinking mode.
3. **Clean output** — returns the text unchanged.
The answer text is always cleaned of trailing sentinel tokens
(``<turn|>``, ``<eos>``, etc.).
Args:
text: Decoded model output text (from ``tokenizer.decode(...)``).
Returns:
A dict with keys:
- ``"thinking"``: The chain-of-thought text, or ``None`` if no
thinking delimiters were found.
- ``"answer"``: The final answer text.
Example::
>>> from vllm.reasoning.gemma4_utils import parse_thinking_output
>>> output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
>>> result = parse_thinking_output(output_text)
>>> print(result["thinking"]) # chain-of-thought reasoning or None
>>> print(result["answer"]) # final answer
"""
if
_THINKING_END_TAG
in
text
:
parts
=
text
.
split
(
_THINKING_END_TAG
,
1
)
thinking_block
=
parts
[
0
]
answer
=
_clean_answer
(
parts
[
1
])
# Extract thinking content: strip the start tag if present
if
_THINKING_START_TAG
in
thinking_block
:
thinking
=
thinking_block
.
split
(
_THINKING_START_TAG
,
1
)[
1
]
else
:
thinking
=
thinking_block
# Strip the "thought\n" channel role label the model emits inside
# <|channel>thought\n...<channel|> (analogous to "user\n" in
# <|turn>user\n...<turn|>).
thinking
=
_strip_thought_label
(
thinking
.
strip
())
thinking
=
thinking
.
strip
()
return
{
"thinking"
:
thinking
,
"answer"
:
answer
}
# No thinking delimiters found.
# Strip spurious "thought\n" role label that some Gemma4 models sometimes
# emit even without thinking mode enabled, then clean trailing tokens.
answer
=
_strip_thought_label
(
text
)
answer
=
_clean_answer
(
answer
)
return
{
"thinking"
:
None
,
"answer"
:
answer
}
def
_strip_thought_label
(
text
:
str
)
->
str
:
"""Strip the spurious ``thought
\\
n`` label from the start of text.
Only strips when ``thought`` appears as the very first word followed by
a newline — preserving the word ``thought`` in any other context.
"""
if
text
.
startswith
(
"thought
\n
"
):
return
text
[
len
(
"thought
\n
"
)
:]
return
text
def
_clean_answer
(
text
:
str
)
->
str
:
"""Clean trailing sentinel tokens from the answer text.
Strips ``<turn|>``, ``<eos>``, and surrounding whitespace that the
model appends at the end of its response.
"""
text
=
text
.
strip
()
# Strip trailing <turn|> (Gemma4 turn-end marker)
if
text
.
endswith
(
_TURN_END_TAG
):
text
=
text
[:
-
len
(
_TURN_END_TAG
)].
rstrip
()
# Strip trailing <eos> if present
if
text
.
endswith
(
"<eos>"
):
text
=
text
[:
-
5
].
rstrip
()
return
text
vllm/tool_parsers/__init__.py
View file @
858bddce
...
@@ -154,6 +154,10 @@ _TOOL_PARSERS_TO_REGISTER = {
...
@@ -154,6 +154,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"functiongemma_tool_parser"
,
"functiongemma_tool_parser"
,
"FunctionGemmaToolParser"
,
"FunctionGemmaToolParser"
,
),
),
"gemma4"
:
(
"gemma4_tool_parser"
,
"Gemma4ToolParser"
,
),
}
}
...
...
vllm/tool_parsers/gemma4_tool_parser.py
0 → 100644
View file @
858bddce
This diff is collapsed.
Click to expand it.
vllm/transformers_utils/model_arch_config_convertor.py
View file @
858bddce
...
@@ -300,6 +300,28 @@ class ModelArchConfigConvertorBase:
...
@@ -300,6 +300,28 @@ class ModelArchConfigConvertorBase:
return
model_arch_config
return
model_arch_config
class
CohereAsrModelArchConfigConvertor
(
ModelArchConfigConvertorBase
):
def
get_total_num_attention_heads
(
self
)
->
int
:
return
self
.
hf_text_config
.
transf_decoder
[
"config_dict"
][
"num_attention_heads"
]
def
get_head_size
(
self
)
->
int
:
hidden_size
=
self
.
hf_text_config
.
transf_decoder
[
"config_dict"
][
"hidden_size"
]
num_attention_heads
=
self
.
hf_text_config
.
transf_decoder
[
"config_dict"
][
"num_attention_heads"
]
return
hidden_size
//
num_attention_heads
def
get_total_num_kv_heads
(
self
)
->
int
:
enc_num_kv_heads
=
self
.
hf_text_config
.
encoder
[
"n_heads"
]
dec_num_kv_heads
=
self
.
hf_text_config
.
transf_decoder
[
"config_dict"
][
"num_attention_heads"
]
assert
enc_num_kv_heads
==
dec_num_kv_heads
,
(
"Encoder and decoder must have the same number of kv heads"
)
return
enc_num_kv_heads
class
MambaModelArchConfigConvertor
(
ModelArchConfigConvertorBase
):
class
MambaModelArchConfigConvertor
(
ModelArchConfigConvertorBase
):
def
get_head_size
(
self
)
->
int
:
def
get_head_size
(
self
)
->
int
:
return
0
return
0
...
@@ -423,6 +445,16 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
...
@@ -423,6 +445,16 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
return
getattr
(
self
.
hf_text_config
,
"num_nextn_predict_layers"
,
1
)
return
getattr
(
self
.
hf_text_config
,
"num_nextn_predict_layers"
,
1
)
class
Gemma4ModelArchConfigConvertor
(
ModelArchConfigConvertorBase
):
def
get_head_size
(
self
)
->
int
:
# Gemma4 uses dual head dimensions: head_dim (sliding attention)
# and global_head_dim (full attention). Return the largest so
# that attention backends allocate buffers large enough for both.
head_dim
=
getattr
(
self
.
hf_text_config
,
"head_dim"
,
0
)
global_head_dim
=
getattr
(
self
.
hf_text_config
,
"global_head_dim"
,
0
)
return
max
(
head_dim
,
global_head_dim
)
or
super
().
get_head_size
()
# hf_config.model_type -> convertor class
# hf_config.model_type -> convertor class
MODEL_ARCH_CONFIG_CONVERTORS
=
{
MODEL_ARCH_CONFIG_CONVERTORS
=
{
"mamba"
:
MambaModelArchConfigConvertor
,
"mamba"
:
MambaModelArchConfigConvertor
,
...
@@ -433,6 +465,8 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
...
@@ -433,6 +465,8 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
"mpt"
:
MPTModelArchConfigConvertor
,
"mpt"
:
MPTModelArchConfigConvertor
,
"dbrx"
:
DbrxModelArchConfigConvertor
,
"dbrx"
:
DbrxModelArchConfigConvertor
,
"falcon"
:
FalconModelArchConfigConvertor
,
"falcon"
:
FalconModelArchConfigConvertor
,
"gemma4"
:
Gemma4ModelArchConfigConvertor
,
"gemma4_text"
:
Gemma4ModelArchConfigConvertor
,
"RefinedWeb"
:
FalconModelArchConfigConvertor
,
"RefinedWeb"
:
FalconModelArchConfigConvertor
,
"RefinedWebModel"
:
FalconModelArchConfigConvertor
,
"RefinedWebModel"
:
FalconModelArchConfigConvertor
,
"nemotron-nas"
:
NemotronNasModelArchConfigConvertor
,
"nemotron-nas"
:
NemotronNasModelArchConfigConvertor
,
...
...
vllm/v1/attention/ops/triton_unified_attention.py
View file @
858bddce
...
@@ -1040,6 +1040,8 @@ def unified_attention(
...
@@ -1040,6 +1040,8 @@ def unified_attention(
num_seqs
=
num_seqs
,
num_seqs
=
num_seqs
,
BLOCK_M
=
BLOCK_M
,
BLOCK_M
=
BLOCK_M
,
USE_FP8
=
output_scale
is
not
None
,
USE_FP8
=
output_scale
is
not
None
,
num_stages
=
1
)
)
else
:
else
:
kernel_unified_attention_3d
[
kernel_unified_attention_3d
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment