Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
858bddce
Commit
858bddce
authored
Apr 09, 2026
by
luopl
Browse files
feat:add gemma4
parent
40faaf0c
Changes
14
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
4233 additions
and
1 deletion
+4233
-1
vllm/model_executor/layers/rotary_embedding/__init__.py
vllm/model_executor/layers/rotary_embedding/__init__.py
+12
-0
vllm/model_executor/layers/rotary_embedding/gemma4_rope.py
vllm/model_executor/layers/rotary_embedding/gemma4_rope.py
+84
-0
vllm/model_executor/models/config.py
vllm/model_executor/models/config.py
+54
-0
vllm/model_executor/models/gemma4.py
vllm/model_executor/models/gemma4.py
+1581
-0
vllm/model_executor/models/gemma4_mm.py
vllm/model_executor/models/gemma4_mm.py
+1339
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+2
-0
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+8
-1
vllm/reasoning/__init__.py
vllm/reasoning/__init__.py
+4
-0
vllm/reasoning/gemma4_reasoning_parser.py
vllm/reasoning/gemma4_reasoning_parser.py
+225
-0
vllm/reasoning/gemma4_utils.py
vllm/reasoning/gemma4_utils.py
+130
-0
vllm/tool_parsers/__init__.py
vllm/tool_parsers/__init__.py
+4
-0
vllm/tool_parsers/gemma4_tool_parser.py
vllm/tool_parsers/gemma4_tool_parser.py
+754
-0
vllm/transformers_utils/model_arch_config_convertor.py
vllm/transformers_utils/model_arch_config_convertor.py
+34
-0
vllm/v1/attention/ops/triton_unified_attention.py
vllm/v1/attention/ops/triton_unified_attention.py
+2
-0
No files found.
vllm/model_executor/layers/rotary_embedding/__init__.py
View file @
858bddce
...
...
@@ -13,6 +13,7 @@ from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
from
.dynamic_ntk_scaling_rope
import
DynamicNTKScalingRotaryEmbedding
from
.fope
import
FourierRotaryEmbedding
from
.linear_scaling_rope
import
LinearScalingRotaryEmbedding
from
.gemma4_rope
import
Gemma4RotaryEmbedding
from
.llama3_rope
import
Llama3RotaryEmbedding
from
.llama4_vision_rope
import
Llama4VisionRotaryEmbedding
from
.mrope
import
MRotaryEmbedding
...
...
@@ -134,6 +135,17 @@ def get_rope(
is_neox_style
,
dtype
,
)
elif
scaling_type
==
"proportional"
:
# Proportional RoPE is used by Gemma4 for global (full) attention.
# Gemma4 uses a sparse/fractional RoPE with cross-mixing between halves.
rotary_emb
=
Gemma4RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
)
elif
scaling_type
==
"llama3"
:
scaling_factor
=
rope_parameters
[
"factor"
]
low_freq_factor
=
rope_parameters
[
"low_freq_factor"
]
...
...
vllm/model_executor/layers/rotary_embedding/gemma4_rope.py
0 → 100644
View file @
858bddce
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Gemma4-specific Rotary Positional Embeddings (proportional scaling).
Gemma4 uses "proportional" RoPE which computes inv_freq frequencies scaled
by head_dim (not rotary_dim), and zero-pads for non-rotated dimensions when
partial_rotary_factor < 1. The actual rotation uses standard neox-style
rotate_half, matching HF transformers' apply_rotary_pos_emb.
"""
import
torch
from
.base
import
RotaryEmbedding
class
Gemma4RotaryEmbedding
(
RotaryEmbedding
):
"""Gemma4 proportional RoPE.
Extends RotaryEmbedding (which provides standard neox-style rotation
via ops.rotary_embedding CUDA kernel) but overrides the inv_freq
computation to match HF's _compute_proportional_rope_parameters:
- Frequency exponents use head_dim (not rotary_dim) as denominator
- Non-rotated dims are zero-padded (cos=1, sin=0 = identity rotation)
When partial_rotary_factor=1.0 (the default for some variants), ALL dims are
rotated and this is equivalent to standard RotaryEmbedding with
head_dim-scaled frequencies.
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
float
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
)
->
None
:
# Number of rotation angle pairs (from partial_rotary_factor)
self
.
rope_angles
=
rotary_dim
//
2
# Non-rotated angle pairs per half
self
.
nope_angles
=
(
head_size
//
2
)
-
self
.
rope_angles
# Important: set rotary_dim = head_size so the base class's
# forward_static applies rotation to ALL dims of the cos/sin cache.
# The non-rotated dims will have cos=1, sin=0 (identity) thanks
# to our _compute_inv_freq zero-padding.
super
().
__init__
(
head_size
,
head_size
,
# rotary_dim = head_size (full application)
max_position_embeddings
,
base
,
is_neox_style
,
dtype
,
)
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
"""Compute frequencies matching HF proportional RoPE.
Key difference from base: exponent denominator is head_size (not
rotary_dim), and non-rotated dims are zero-padded.
"""
# HF formula: base ** (arange(0, 2*rope_angles, 2) / head_dim)
freq_exponents
=
(
torch
.
arange
(
0
,
2
*
self
.
rope_angles
,
2
,
dtype
=
torch
.
float
)
/
self
.
head_size
)
inv_freq
=
1.0
/
(
base
**
freq_exponents
)
# Zero-pad for non-rotated dims (identity rotation: cos=1, sin=0)
if
self
.
nope_angles
>
0
:
inv_freq
=
torch
.
cat
(
[
inv_freq
,
torch
.
zeros
(
self
.
nope_angles
,
dtype
=
torch
.
float
),
]
)
return
inv_freq
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", rope_angles=
{
self
.
rope_angles
}
, nope_angles=
{
self
.
nope_angles
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
s
+=
f
", base=
{
self
.
base
}
, is_neox_style=
{
self
.
is_neox_style
}
"
return
s
vllm/model_executor/models/config.py
View file @
858bddce
...
...
@@ -56,6 +56,57 @@ class Gemma3TextModelConfig(VerifyAndUpdateConfig):
hf_config
=
model_config
.
hf_config
hf_config
.
is_causal
=
not
hf_config
.
use_bidirectional_attention
class
Gemma4Config
(
VerifyAndUpdateConfig
):
@
staticmethod
def
verify_and_update_config
(
vllm_config
:
"VllmConfig"
)
->
None
:
"""Force unified attention backend for models with heterogeneous
head dimensions.
Some Gemma4 variants use different head dimensions for
sliding window (head_dim) vs full attention (global_head_dim) layers.
When global_head_dim > 256, FlashAttention rejects those layers
(head_size <= 256 kernel limit), causing vLLM to select a different
backend for each layer type. This mixed-backend execution produces
numerical divergence and output corruption.
The fix detects heterogeneous head dimensions from the model config
and forces TRITON_ATTN (which has no head_size ceiling) for all
layers when the user hasn't explicitly chosen a backend.
TODO: Heterogeneous head_sizes (head_dim != global_head_dim)
require NixlConnector changes to support per-layer KV transfer
with different head dimensions for prefill-decode disaggregation.
"""
hf_text_config
=
vllm_config
.
model_config
.
hf_text_config
head_dim
=
getattr
(
hf_text_config
,
"head_dim"
,
None
)
global_head_dim
=
getattr
(
hf_text_config
,
"global_head_dim"
,
None
)
# Only force Triton when head dimensions actually differ AND the
# larger one exceeds FlashAttention's kernel limit (head_size <= 256).
# This avoids unnecessary backend forcing on smaller models where
# the config carries global_head_dim but all layers can still use
# the same FA backend.
max_head_dim
=
max
(
head_dim
or
0
,
global_head_dim
or
0
)
if
(
head_dim
is
not
None
and
global_head_dim
is
not
None
and
head_dim
!=
global_head_dim
and
max_head_dim
>
256
and
vllm_config
.
attention_config
.
backend
is
None
):
from
vllm.v1.attention.backends.registry
import
(
AttentionBackendEnum
,
)
vllm_config
.
attention_config
.
backend
=
AttentionBackendEnum
.
TRITON_ATTN
logger
.
info
(
"Gemma4 model has heterogeneous head dimensions "
"(head_dim=%d, global_head_dim=%d). Forcing TRITON_ATTN "
"backend to prevent mixed-backend numerical divergence."
,
head_dim
,
global_head_dim
,
)
class
GptOssForCausalLMConfig
(
VerifyAndUpdateConfig
):
@
staticmethod
...
...
@@ -647,10 +698,13 @@ class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig):
MODELS_CONFIG_MAP
:
dict
[
str
,
type
[
VerifyAndUpdateConfig
]]
=
{
"ColBERTJinaRobertaModel"
:
JinaRobertaModelConfig
,
"ColQwen3_5"
:
Qwen3_5ForConditionalGenerationConfig
,
"DeepseekV32ForCausalLM"
:
DeepseekV32ForCausalLM
,
"Ernie4_5_VLMoeForConditionalGeneration"
:
Ernie4_5_VLMoeForConditionalGenerationConfig
,
# noqa: E501
"FalconMambaForCausalLM"
:
MambaModelConfig
,
"Gemma3TextModel"
:
Gemma3TextModelConfig
,
"Gemma4ForCausalLM"
:
Gemma4Config
,
"Gemma4ForConditionalGeneration"
:
Gemma4Config
,
"GptOssForCausalLM"
:
GptOssForCausalLMConfig
,
"GteModel"
:
SnowflakeGteNewModelConfig
,
"GteNewForSequenceClassification"
:
GteNewModelConfig
,
...
...
vllm/model_executor/models/gemma4.py
0 → 100644
View file @
858bddce
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/gemma4_mm.py
0 → 100644
View file @
858bddce
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/registry.py
View file @
858bddce
...
...
@@ -111,6 +111,7 @@ _TEXT_GENERATION_MODELS = {
"Gemma2ForCausalLM"
:
(
"gemma2"
,
"Gemma2ForCausalLM"
),
"Gemma3ForCausalLM"
:
(
"gemma3"
,
"Gemma3ForCausalLM"
),
"Gemma3nForCausalLM"
:
(
"gemma3n"
,
"Gemma3nForCausalLM"
),
"Gemma4ForCausalLM"
:
(
"gemma4"
,
"Gemma4ForCausalLM"
),
"Qwen3NextForCausalLM"
:
(
"qwen3_next"
,
"Qwen3NextForCausalLM"
),
"GlmForCausalLM"
:
(
"glm"
,
"GlmForCausalLM"
),
"Glm4ForCausalLM"
:
(
"glm4"
,
"Glm4ForCausalLM"
),
...
...
@@ -377,6 +378,7 @@ _MULTIMODAL_MODELS = {
"gemma3n_mm"
,
"Gemma3nForConditionalGeneration"
,
),
"Gemma4ForConditionalGeneration"
:
(
"gemma4_mm"
,
"Gemma4ForConditionalGeneration"
),
"GlmAsrForConditionalGeneration"
:
(
"glmasr"
,
"GlmAsrForConditionalGeneration"
),
"GLM4VForCausalLM"
:
(
"glm4v"
,
"GLM4VForCausalLM"
),
"Glm4vForConditionalGeneration"
:
(
"glm4_1v"
,
"Glm4vForConditionalGeneration"
),
...
...
vllm/model_executor/models/utils.py
View file @
858bddce
...
...
@@ -233,8 +233,15 @@ class AutoWeightsLoader:
):
"""
Add tensor names that are not in the model params that may be in the
safetensors, e.g., batch normalization stats.
safetensors, e.g., batch normalization stats
and registered buffers
.
"""
# Add persistent registered buffers.
# Non-persistent buffers are excluded, matching PyTorch state_dict().
non_persistent
=
getattr
(
module
,
"_non_persistent_buffers_set"
,
set
())
for
buf_name
,
buf
in
module
.
named_buffers
(
recurse
=
False
):
if
buf_name
not
in
child_params
and
buf_name
not
in
non_persistent
:
child_params
[
buf_name
]
=
buf
if
isinstance
(
module
,
(
...
...
vllm/reasoning/__init__.py
View file @
858bddce
...
...
@@ -32,6 +32,10 @@ _REASONING_PARSERS_TO_REGISTER = {
"ernie45_reasoning_parser"
,
"Ernie45ReasoningParser"
,
),
"gemma4"
:
(
"gemma4_reasoning_parser"
,
"Gemma4ReasoningParser"
,
),
"glm45"
:
(
"deepseek_v3_reasoning_parser"
,
"DeepSeekV3ReasoningWithThinkingParser"
,
...
...
vllm/reasoning/gemma4_reasoning_parser.py
0 → 100644
View file @
858bddce
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Sequence
from
typing
import
TYPE_CHECKING
from
vllm.entrypoints.openai.engine.protocol
import
DeltaMessage
from
vllm.reasoning.basic_parsers
import
BaseThinkingReasoningParser
from
vllm.tokenizers
import
TokenizerLike
if
TYPE_CHECKING
:
from
vllm.entrypoints.openai.chat_completion.protocol
import
(
ChatCompletionRequest
,
)
from
vllm.entrypoints.openai.responses.protocol
import
ResponsesRequest
# Role label that Gemma4 emits at the start of the thinking channel.
# The model generates: <|channel>thought\n...reasoning...<channel|>
# This prefix must be stripped to expose only the actual reasoning content.
_THOUGHT_PREFIX
=
"thought
\n
"
class
Gemma4ReasoningParser
(
BaseThinkingReasoningParser
):
"""
Reasoning parser for Google Gemma4 thinking models.
Gemma4 uses <|channel>...<channel|> tokens to delimit reasoning/thinking
content within its output. Thinking mode is activated by passing
``enable_thinking=True`` in the chat template kwargs, which injects a
system turn containing <|think|> (token 98) to trigger chain-of-thought
reasoning.
Output pattern when thinking is enabled::
<|channel>thought
...chain of thought reasoning...<channel|>
Final answer text here.
The ``thought
\\
n`` role label inside the channel delimiters is a
structural artefact (analogous to ``user
\\
n`` in ``<|turn>user
\\
n...``).
This parser strips it so that downstream consumers see only the
actual reasoning text, consistent with the offline parser
(``vllm.reasoning.gemma4_utils._strip_thought_label``).
"""
def
__init__
(
self
,
tokenizer
:
TokenizerLike
,
*
args
,
**
kwargs
):
super
().
__init__
(
tokenizer
,
*
args
,
**
kwargs
)
# Instance state for streaming prefix stripping.
# Tracks only the reasoning text received from the base parser,
# independent of current_text (which may contain pre-reasoning
# content and lacks special token text due to
# skip_special_tokens=True).
self
.
_reasoning_text
:
str
=
""
self
.
_prefix_stripped
:
bool
=
False
self
.
new_turn_token_id
=
self
.
vocab
[
"<|turn>"
]
self
.
tool_call_token_id
=
self
.
vocab
[
"<|tool_call>"
]
self
.
tool_response_token_id
=
self
.
vocab
[
"<|tool_response>"
]
def
adjust_request
(
self
,
request
:
"ChatCompletionRequest | ResponsesRequest"
)
->
"ChatCompletionRequest | ResponsesRequest"
:
"""Disable special-token stripping to preserve boundary tokens."""
request
.
skip_special_tokens
=
False
return
request
@
property
def
start_token
(
self
)
->
str
:
"""The token that starts reasoning content."""
return
"<|channel>"
@
property
def
end_token
(
self
)
->
str
:
"""The token that ends reasoning content."""
return
"<channel|>"
def
is_reasoning_end
(
self
,
input_ids
:
Sequence
[
int
])
->
bool
:
start_token_id
=
self
.
start_token_id
end_token_id
=
self
.
end_token_id
new_turn_token_id
=
self
.
new_turn_token_id
tool_call_token_id
=
self
.
tool_call_token_id
tool_response_token_id
=
self
.
tool_response_token_id
# Search from the end of input_ids to find the last match.
for
i
in
range
(
len
(
input_ids
)
-
1
,
-
1
,
-
1
):
if
input_ids
[
i
]
==
start_token_id
:
return
False
if
input_ids
[
i
]
==
tool_call_token_id
:
# We're generating a tool call, so reasoning must be ended.
return
True
if
input_ids
[
i
]
in
(
new_turn_token_id
,
tool_response_token_id
):
# We found a new turn or tool response token so don't consider
# reasoning ended yet, since the model starts new reasoning
# after these tokens.
return
False
if
input_ids
[
i
]
==
end_token_id
:
return
True
return
False
# ------------------------------------------------------------------
# Non-streaming path
# ------------------------------------------------------------------
def
extract_reasoning
(
self
,
model_output
:
str
,
request
:
"ChatCompletionRequest | ResponsesRequest"
,
)
->
tuple
[
str
|
None
,
str
|
None
]:
"""Extract reasoning, stripping the ``thought
\\
n`` role label."""
if
self
.
start_token
not
in
model_output
and
self
.
end_token
not
in
model_output
:
# Default to content history if no tags are present
# (or if they were stripped)
return
None
,
model_output
reasoning
,
content
=
super
().
extract_reasoning
(
model_output
,
request
)
if
reasoning
is
not
None
:
reasoning
=
_strip_thought_label
(
reasoning
)
return
reasoning
,
content
# ------------------------------------------------------------------
# Streaming path
# ------------------------------------------------------------------
def
extract_reasoning_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
)
->
DeltaMessage
|
None
:
"""Extract streaming reasoning, stripping ``thought
\\
n`` from the
first reasoning delta(s).
The ``thought
\\
n`` prefix may arrive as a single delta or split
across multiple deltas (e.g. ``"thought"`` then ``"
\\
n"``). We
buffer early reasoning tokens until we can determine whether the
prefix is present, then emit the buffered content minus the
prefix.
Unlike the previous implementation which reconstructed accumulated
reasoning from ``current_text``, this uses instance state
(``_reasoning_text``) to track only the reasoning content returned
by the base parser. This is necessary because
``skip_special_tokens=True`` (the vLLM default) causes the
``<|channel>`` delimiter to be invisible in ``current_text``,
making it impossible to separate pre-reasoning content from
reasoning content via string matching.
"""
result
=
super
().
extract_reasoning_streaming
(
previous_text
,
current_text
,
delta_text
,
previous_token_ids
,
current_token_ids
,
delta_token_ids
,
)
if
result
is
None
:
return
None
if
result
.
reasoning
is
None
:
return
result
# Accumulate ONLY the reasoning text from base parser results.
# This is immune to pre-reasoning content pollution.
self
.
_reasoning_text
+=
result
.
reasoning
# Once the prefix has been handled, all subsequent reasoning
# deltas pass through unchanged.
if
self
.
_prefix_stripped
:
return
result
# ---- Prefix stripping logic ----
# Case 1: We've accumulated enough to confirm the prefix is
# present. Strip it and pass through the remainder.
if
self
.
_reasoning_text
.
startswith
(
_THOUGHT_PREFIX
):
prefix_len
=
len
(
_THOUGHT_PREFIX
)
# How much reasoning was accumulated before this delta?
prev_reasoning_len
=
len
(
self
.
_reasoning_text
)
-
len
(
result
.
reasoning
)
if
prev_reasoning_len
>=
prefix_len
:
# Prefix was already consumed by prior deltas; this
# delta is entirely real content — pass through.
self
.
_prefix_stripped
=
True
return
result
else
:
# Part or all of the prefix is in this delta.
chars_of_prefix_in_delta
=
prefix_len
-
prev_reasoning_len
stripped
=
result
.
reasoning
[
chars_of_prefix_in_delta
:]
if
stripped
:
self
.
_prefix_stripped
=
True
result
.
reasoning
=
stripped
return
result
else
:
if
len
(
self
.
_reasoning_text
)
>=
prefix_len
:
self
.
_prefix_stripped
=
True
result
.
reasoning
=
""
return
result
return
None
# Case 2: Accumulated text is a strict prefix of
# _THOUGHT_PREFIX (e.g. we've only seen "thou" so far).
# Buffer by suppressing — we can't yet tell if this will
# become the full prefix or diverge.
if
_THOUGHT_PREFIX
.
startswith
(
self
.
_reasoning_text
):
return
None
# Case 3: Accumulated text doesn't match the thought prefix
# at all. This means prior deltas were buffered (suppressed
# by Case 2) but the text diverged. Re-emit the full
# accumulated text to avoid data loss.
self
.
_prefix_stripped
=
True
result
.
reasoning
=
self
.
_reasoning_text
return
result
def
_strip_thought_label
(
text
:
str
)
->
str
:
"""Remove the ``thought
\\
n`` role label from the beginning of text.
Mirrors ``vllm.reasoning.gemma4_utils._strip_thought_label`` from the
offline parser.
"""
if
text
.
startswith
(
_THOUGHT_PREFIX
):
return
text
[
len
(
_THOUGHT_PREFIX
)
:]
return
text
vllm/reasoning/gemma4_utils.py
0 → 100644
View file @
858bddce
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
"""Gemma4 thinking/reasoning output parsing utilities for offline inference.
Standalone functions that parse decoded model text to extract structured
thinking content from Gemma4 models. These are pure-Python utilities with
zero heavy dependencies — they work on raw decoded strings from any
inference backend (vLLM, HuggingFace, TGI, etc.).
For the OpenAI-compatible API reasoning parser (streaming +
non-streaming), see ``vllm.reasoning.gemma4_reasoning_parser``.
For tool call parsing, see ``vllm.tool_parsers.gemma4_utils``.
Usage with vLLM offline inference::
from vllm import LLM, SamplingParams
from vllm.reasoning.gemma4_utils import parse_thinking_output
llm = LLM(model="google/gemma-4-it")
outputs = llm.generate(prompt, SamplingParams(...))
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
# Extract thinking / answer (works with or without enable_thinking)
result = parse_thinking_output(text)
print(result["thinking"]) # chain-of-thought or None
print(result["answer"]) # final answer
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
do not need a transformers dependency for output parsing.
"""
# ---- Thinking Mode Utility ----
# Thinking delimiter tokens as they appear in decoded text.
# Gemma4 uses <|channel> (start) and <channel|> (end) as thinking delimiters.
_THINKING_START_TAG
=
"<|channel>"
_THINKING_END_TAG
=
"<channel|>"
# Sentinel tokens that may appear in decoded output.
_TURN_END_TAG
=
"<turn|>"
def
parse_thinking_output
(
text
:
str
)
->
dict
[
str
,
str
|
None
]:
"""Parse decoded Gemma4 model output.
Use this on **all** Gemma4 output regardless of whether thinking mode
was enabled. It handles three cases:
1. **Thinking enabled, tags present** — splits on ``<|channel>``/
``<channel|>`` to separate chain-of-thought from the answer and
strips the ``thought
\\
n`` role label.
2. **Thinking disabled, spurious label** — strips the bare
``thought
\\
n`` prefix that some Gemma4 models emit even
without thinking mode.
3. **Clean output** — returns the text unchanged.
The answer text is always cleaned of trailing sentinel tokens
(``<turn|>``, ``<eos>``, etc.).
Args:
text: Decoded model output text (from ``tokenizer.decode(...)``).
Returns:
A dict with keys:
- ``"thinking"``: The chain-of-thought text, or ``None`` if no
thinking delimiters were found.
- ``"answer"``: The final answer text.
Example::
>>> from vllm.reasoning.gemma4_utils import parse_thinking_output
>>> output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
>>> result = parse_thinking_output(output_text)
>>> print(result["thinking"]) # chain-of-thought reasoning or None
>>> print(result["answer"]) # final answer
"""
if
_THINKING_END_TAG
in
text
:
parts
=
text
.
split
(
_THINKING_END_TAG
,
1
)
thinking_block
=
parts
[
0
]
answer
=
_clean_answer
(
parts
[
1
])
# Extract thinking content: strip the start tag if present
if
_THINKING_START_TAG
in
thinking_block
:
thinking
=
thinking_block
.
split
(
_THINKING_START_TAG
,
1
)[
1
]
else
:
thinking
=
thinking_block
# Strip the "thought\n" channel role label the model emits inside
# <|channel>thought\n...<channel|> (analogous to "user\n" in
# <|turn>user\n...<turn|>).
thinking
=
_strip_thought_label
(
thinking
.
strip
())
thinking
=
thinking
.
strip
()
return
{
"thinking"
:
thinking
,
"answer"
:
answer
}
# No thinking delimiters found.
# Strip spurious "thought\n" role label that some Gemma4 models sometimes
# emit even without thinking mode enabled, then clean trailing tokens.
answer
=
_strip_thought_label
(
text
)
answer
=
_clean_answer
(
answer
)
return
{
"thinking"
:
None
,
"answer"
:
answer
}
def
_strip_thought_label
(
text
:
str
)
->
str
:
"""Strip the spurious ``thought
\\
n`` label from the start of text.
Only strips when ``thought`` appears as the very first word followed by
a newline — preserving the word ``thought`` in any other context.
"""
if
text
.
startswith
(
"thought
\n
"
):
return
text
[
len
(
"thought
\n
"
)
:]
return
text
def
_clean_answer
(
text
:
str
)
->
str
:
"""Clean trailing sentinel tokens from the answer text.
Strips ``<turn|>``, ``<eos>``, and surrounding whitespace that the
model appends at the end of its response.
"""
text
=
text
.
strip
()
# Strip trailing <turn|> (Gemma4 turn-end marker)
if
text
.
endswith
(
_TURN_END_TAG
):
text
=
text
[:
-
len
(
_TURN_END_TAG
)].
rstrip
()
# Strip trailing <eos> if present
if
text
.
endswith
(
"<eos>"
):
text
=
text
[:
-
5
].
rstrip
()
return
text
vllm/tool_parsers/__init__.py
View file @
858bddce
...
...
@@ -154,6 +154,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"functiongemma_tool_parser"
,
"FunctionGemmaToolParser"
,
),
"gemma4"
:
(
"gemma4_tool_parser"
,
"Gemma4ToolParser"
,
),
}
...
...
vllm/tool_parsers/gemma4_tool_parser.py
0 → 100644
View file @
858bddce
This diff is collapsed.
Click to expand it.
vllm/transformers_utils/model_arch_config_convertor.py
View file @
858bddce
...
...
@@ -300,6 +300,28 @@ class ModelArchConfigConvertorBase:
return
model_arch_config
class
CohereAsrModelArchConfigConvertor
(
ModelArchConfigConvertorBase
):
def
get_total_num_attention_heads
(
self
)
->
int
:
return
self
.
hf_text_config
.
transf_decoder
[
"config_dict"
][
"num_attention_heads"
]
def
get_head_size
(
self
)
->
int
:
hidden_size
=
self
.
hf_text_config
.
transf_decoder
[
"config_dict"
][
"hidden_size"
]
num_attention_heads
=
self
.
hf_text_config
.
transf_decoder
[
"config_dict"
][
"num_attention_heads"
]
return
hidden_size
//
num_attention_heads
def
get_total_num_kv_heads
(
self
)
->
int
:
enc_num_kv_heads
=
self
.
hf_text_config
.
encoder
[
"n_heads"
]
dec_num_kv_heads
=
self
.
hf_text_config
.
transf_decoder
[
"config_dict"
][
"num_attention_heads"
]
assert
enc_num_kv_heads
==
dec_num_kv_heads
,
(
"Encoder and decoder must have the same number of kv heads"
)
return
enc_num_kv_heads
class
MambaModelArchConfigConvertor
(
ModelArchConfigConvertorBase
):
def
get_head_size
(
self
)
->
int
:
return
0
...
...
@@ -423,6 +445,16 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
return
getattr
(
self
.
hf_text_config
,
"num_nextn_predict_layers"
,
1
)
class
Gemma4ModelArchConfigConvertor
(
ModelArchConfigConvertorBase
):
def
get_head_size
(
self
)
->
int
:
# Gemma4 uses dual head dimensions: head_dim (sliding attention)
# and global_head_dim (full attention). Return the largest so
# that attention backends allocate buffers large enough for both.
head_dim
=
getattr
(
self
.
hf_text_config
,
"head_dim"
,
0
)
global_head_dim
=
getattr
(
self
.
hf_text_config
,
"global_head_dim"
,
0
)
return
max
(
head_dim
,
global_head_dim
)
or
super
().
get_head_size
()
# hf_config.model_type -> convertor class
MODEL_ARCH_CONFIG_CONVERTORS
=
{
"mamba"
:
MambaModelArchConfigConvertor
,
...
...
@@ -433,6 +465,8 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
"mpt"
:
MPTModelArchConfigConvertor
,
"dbrx"
:
DbrxModelArchConfigConvertor
,
"falcon"
:
FalconModelArchConfigConvertor
,
"gemma4"
:
Gemma4ModelArchConfigConvertor
,
"gemma4_text"
:
Gemma4ModelArchConfigConvertor
,
"RefinedWeb"
:
FalconModelArchConfigConvertor
,
"RefinedWebModel"
:
FalconModelArchConfigConvertor
,
"nemotron-nas"
:
NemotronNasModelArchConfigConvertor
,
...
...
vllm/v1/attention/ops/triton_unified_attention.py
View file @
858bddce
...
...
@@ -1040,6 +1040,8 @@ def unified_attention(
num_seqs
=
num_seqs
,
BLOCK_M
=
BLOCK_M
,
USE_FP8
=
output_scale
is
not
None
,
num_stages
=
1
)
else
:
kernel_unified_attention_3d
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment