Unverified Commit 08ed2b96 authored by Luciano Martins's avatar Luciano Martins Committed by GitHub
Browse files

feat(models): implement Google Gemma 4 architecture support (MoE, Multimodal,...


feat(models): implement Google Gemma 4 architecture support (MoE, Multimodal, Reasoning, Tool-Use) (#38826)
Signed-off-by: default avatarLuciano Martins <lucianommartins@users.noreply.github.com>
Signed-off-by: default avatarLuciano Martins <lucianomartins@google.com>
Co-authored-by: default avatarLuciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
parent ecd5443d
...@@ -394,6 +394,22 @@ VLM_TEST_SETTINGS = { ...@@ -394,6 +394,22 @@ VLM_TEST_SETTINGS = {
vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}}, vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
patch_hf_runner=model_utils.gemma3_patch_hf_runner, patch_hf_runner=model_utils.gemma3_patch_hf_runner,
), ),
"gemma4": VLMTestInfo(
models=["google/gemma-4-E2B-it"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501
single_image_prompts=IMAGE_ASSETS.prompts(
{
"stop_sign": "What's the content in the center of the image?",
"cherry_blossom": "What is the season?",
}
),
multi_image_prompt="Describe the two images in detail.",
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
vllm_runner_kwargs={"limit_mm_per_prompt": {"image": 4}},
),
"granite_vision": VLMTestInfo( "granite_vision": VLMTestInfo(
models=["ibm-granite/granite-vision-3.3-2b"], models=["ibm-granite/granite-vision-3.3-2b"],
test_type=(VLMTestType.IMAGE), test_type=(VLMTestType.IMAGE),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import ImageTestAssets
from ...utils import build_model_context
# TODO: to be updated to "google/gemma-4-e2b-it" once the models are available
GEMMA4_MODEL_ID = "google/gemma-4-E2B-it"
@pytest.mark.parametrize("model_id", [GEMMA4_MODEL_ID])
def test_limit_mm_per_prompt(
image_assets: ImageTestAssets,
model_id: str,
):
"""Test that limit_mm_per_prompt accurately restricts multiple images."""
# We only allow 1 image
ctx = build_model_context(
model_id,
mm_processor_kwargs={},
limit_mm_per_prompt={"image": 1},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
# Provide 2 images in the prompt
prompt = "<image><image>"
# image_assets usually has multiple images
images = [asset.pil_image for asset in image_assets][:2]
if len(images) < 2:
images = [images[0], images[0]]
mm_data = {"image": images}
# Expect ValueError when exceeding limit
with pytest.raises(ValueError, match="At most 1 image"):
processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={},
)
...@@ -277,6 +277,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -277,6 +277,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"google/gemma-2-9b", extras={"tiny": "google/gemma-2-2b-it"} "google/gemma-2-9b", extras={"tiny": "google/gemma-2-2b-it"}
), ),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
"Gemma4ForCausalLM": _HfExamplesInfo(
"google/gemma-4-E2B-it",
min_transformers_version="5.0.0",
),
"Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it"), "Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it"),
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"), "GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
"Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"), "Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"),
...@@ -813,6 +817,10 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -813,6 +817,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
), ),
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
"Gemma4ForConditionalGeneration": _HfExamplesInfo(
"google/gemma-4-E2B-it",
min_transformers_version="5.5.0",
),
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it"), "Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it"),
"GlmAsrForConditionalGeneration": _HfExamplesInfo( "GlmAsrForConditionalGeneration": _HfExamplesInfo(
"zai-org/GLM-ASR-Nano-2512", "zai-org/GLM-ASR-Nano-2512",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
# Using mistral tokenizer as a generic mock since the actual model is not on HF
from vllm.tokenizers.registry import get_tokenizer
parser_name = "gemma4"
@pytest.fixture(scope="module")
def generic_tokenizer():
return get_tokenizer("google/gemma-4-E2B-it")
INVALID_SIMPLE_NONSTREAMING = {
"output": "This is a reasoning section<channel|>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
INVALID_SIMPLE_STREAMING = {
"output": "This is a reasoning section<channel|>This is the rest",
"reasoning": None,
"content": "This is a reasoning sectionThis is the rest",
"is_reasoning_end": True,
}
INVALID_COMPLETE_NONSTREAMING = {
"output": "This is a reasoning section<channel|>",
"reasoning": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
}
INVALID_COMPLETE_STREAMING = {
"output": "This is a reasoning section<channel|>",
"reasoning": None,
"content": "This is a reasoning section",
"is_reasoning_end": True,
}
NO_CONTENT = {
"output": "<|channel>This is reasoning",
"reasoning": "This is reasoning",
"content": None,
"is_reasoning_end": False,
}
NO_REASONING = {
"output": "This is content",
"reasoning": None,
"content": "This is content",
"is_reasoning_end": False,
}
REASONING_WITH_CHANNEL = {
"output": "<|channel>This is a reasoning section<channel|>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
COMPLETE_REASONING_WITH_CHANNEL = {
"output": "<|channel>This is a reasoning section<channel|>",
"reasoning": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
}
MULTIPLE_LINES_WITH_CHANNEL = {
"output": "<|channel>This\nThat<channel|>This is the rest\nThat",
"reasoning": "This\nThat",
"content": "This is the rest\nThat",
"is_reasoning_end": True,
}
CHANNEL_NO_END = {
"output": "<|channel>This is a reasoning section",
"reasoning": "This is a reasoning section",
"content": None,
"is_reasoning_end": False,
}
EMPTY = {
"output": "",
"reasoning": None,
"content": "",
"is_reasoning_end": False,
}
NEW_LINE_NONSTREAMING = {
"output": (
"Before\n<|channel>This is a reasoning section<channel|>\nThis is the rest"
),
"reasoning": "This is a reasoning section",
"content": "\nThis is the rest",
"is_reasoning_end": True,
}
NEW_LINE_STREAMING = {
"output": (
"Before\n<|channel>This is a reasoning section<channel|>\nThis is the rest"
),
"reasoning": "This is a reasoning section",
"content": "Before\n\nThis is the rest",
"is_reasoning_end": True,
}
TEST_CASES = [
pytest.param(False, INVALID_SIMPLE_NONSTREAMING, id="invalid_simple"),
pytest.param(True, INVALID_SIMPLE_STREAMING, id="invalid_simple_streaming"),
pytest.param(False, INVALID_COMPLETE_NONSTREAMING, id="invalid_complete"),
pytest.param(True, INVALID_COMPLETE_STREAMING, id="invalid_complete_streaming"),
pytest.param(False, NO_CONTENT, id="no_content"),
pytest.param(False, NO_REASONING, id="no_reasoning"),
pytest.param(False, REASONING_WITH_CHANNEL, id="reasoning"),
pytest.param(True, REASONING_WITH_CHANNEL, id="reasoning_streaming"),
pytest.param(False, COMPLETE_REASONING_WITH_CHANNEL, id="complete_reasoning"),
pytest.param(
True, COMPLETE_REASONING_WITH_CHANNEL, id="complete_reasoning_streaming"
),
pytest.param(False, MULTIPLE_LINES_WITH_CHANNEL, id="multiple_lines"),
pytest.param(True, MULTIPLE_LINES_WITH_CHANNEL, id="multiple_lines_streaming"),
pytest.param(False, CHANNEL_NO_END, id="no_end"),
pytest.param(True, CHANNEL_NO_END, id="no_end_streaming"),
pytest.param(False, EMPTY, id="empty"),
pytest.param(False, NEW_LINE_NONSTREAMING, id="new_line"),
pytest.param(True, NEW_LINE_STREAMING, id="new_line_streaming"),
]
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_gemma4_reasoning(
streaming: bool,
param_dict: dict,
generic_tokenizer,
):
output = param_dict["output"]
# Resolve token IDs dynamically from the real tokenizer
vocab = generic_tokenizer.get_vocab()
start_token_id = vocab["<|channel>"]
end_token_id = vocab["<channel|>"]
index_start = output.find("<|channel>")
len_start = len("<|channel>")
index_end = output.find("<channel|>")
len_end = len("<channel|>")
output_tokens = []
def _encode(text: str) -> list[int]:
if not text:
return []
# Handle both raw transformers and vLLM wrappers
enc = getattr(generic_tokenizer, "tokenizer", generic_tokenizer)
try:
return enc.encode(text, add_special_tokens=False)
except TypeError:
return enc.encode(text)
if index_start != -1:
output_before = output[:index_start]
output_tokens += _encode(output_before)
output_tokens += [start_token_id]
if index_end != -1:
output_middle = output[index_start + len_start : index_end]
output_after = output[index_end + len_end :]
output_tokens += _encode(output_middle)
output_tokens += [end_token_id]
output_tokens += _encode(output_after)
else:
output_middle = output[index_start + len_start :]
output_tokens += _encode(output_middle)
elif index_end != -1:
output_before = output[:index_end]
output_after = output[index_end + len_end :]
output_tokens += _encode(output_before)
output_tokens += [end_token_id]
output_tokens += _encode(output_after)
else:
output_tokens += _encode(output)
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
generic_tokenizer
)
# We use the generic run_reasoning_extraction from utils
# Use decode per token to get standard spaces instead of
# SentencePiece space characters
output_token_strings = [generic_tokenizer.decode([t]) for t in output_tokens]
reasoning, content = run_reasoning_extraction(
parser, output_token_strings, streaming=streaming
)
assert reasoning == param_dict["reasoning"]
assert content == param_dict["content"]
# Test is_reasoning_end
is_reasoning_end = parser.is_reasoning_end(output_tokens)
assert is_reasoning_end == param_dict["is_reasoning_end"]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from typing import Any
from unittest.mock import MagicMock
import pytest
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.tool_parsers.gemma4_tool_parser import (
TOOL_CALL_END,
TOOL_CALL_START,
Gemma4ToolParser,
_parse_gemma4_args,
_parse_gemma4_array,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_tokenizer():
tokenizer = MagicMock()
tokenizer.encode.return_value = [1, 2, 3]
# Include the tool call start token in the vocab for the parser
tokenizer.get_vocab.return_value = {TOOL_CALL_START: 48, TOOL_CALL_END: 49}
return tokenizer
@pytest.fixture
def parser(mock_tokenizer):
return Gemma4ToolParser(mock_tokenizer)
@pytest.fixture
def mock_request():
request = MagicMock(spec=ChatCompletionRequest)
request.tools = []
request.tool_choice = "auto"
return request
# ---------------------------------------------------------------------------
# Unit tests for _parse_gemma4_args (shared parser logic)
# ---------------------------------------------------------------------------
class TestParseGemma4Args:
def test_empty_string(self):
assert _parse_gemma4_args("") == {}
def test_whitespace_only(self):
assert _parse_gemma4_args(" ") == {}
def test_single_string_value(self):
result = _parse_gemma4_args('location:<|"|>Paris<|"|>')
assert result == {"location": "Paris"}
def test_string_value_with_comma(self):
result = _parse_gemma4_args('location:<|"|>Paris, France<|"|>')
assert result == {"location": "Paris, France"}
def test_multiple_string_values(self):
result = _parse_gemma4_args(
'location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>'
)
assert result == {"location": "San Francisco", "unit": "celsius"}
def test_integer_value(self):
result = _parse_gemma4_args("count:42")
assert result == {"count": 42}
def test_float_value(self):
result = _parse_gemma4_args("score:3.14")
assert result == {"score": 3.14}
def test_boolean_true(self):
result = _parse_gemma4_args("flag:true")
assert result == {"flag": True}
def test_boolean_false(self):
result = _parse_gemma4_args("flag:false")
assert result == {"flag": False}
def test_mixed_types(self):
result = _parse_gemma4_args(
'name:<|"|>test<|"|>,count:42,active:true,score:3.14'
)
assert result == {
"name": "test",
"count": 42,
"active": True,
"score": 3.14,
}
def test_nested_object(self):
result = _parse_gemma4_args('nested:{inner:<|"|>value<|"|>}')
assert result == {"nested": {"inner": "value"}}
def test_array_of_strings(self):
result = _parse_gemma4_args('items:[<|"|>a<|"|>,<|"|>b<|"|>]')
assert result == {"items": ["a", "b"]}
def test_unterminated_string(self):
"""Unterminated strings should take everything after the delimiter."""
result = _parse_gemma4_args('key:<|"|>unterminated')
assert result == {"key": "unterminated"}
def test_empty_value(self):
"""Key with no value after colon."""
result = _parse_gemma4_args("key:")
assert result == {"key": ""}
class TestParseGemma4Array:
def test_string_array(self):
result = _parse_gemma4_array('<|"|>a<|"|>,<|"|>b<|"|>')
assert result == ["a", "b"]
def test_empty_array(self):
result = _parse_gemma4_array("")
assert result == []
def test_bare_values(self):
result = _parse_gemma4_array("42,true,3.14")
assert result == [42, True, 3.14]
# ---------------------------------------------------------------------------
# Non-streaming extraction tests
# ---------------------------------------------------------------------------
class TestExtractToolCalls:
def test_no_tool_calls(self, parser, mock_request):
model_output = "Hello, how can I help you today?"
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is False
assert result.tool_calls == []
assert result.content == model_output
def test_single_tool_call(self, parser, mock_request):
model_output = (
'<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>'
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_weather"
args = json.loads(result.tool_calls[0].function.arguments)
assert args == {"location": "London"}
def test_multiple_arguments(self, parser, mock_request):
model_output = (
"<|tool_call>call:get_weather{"
'location:<|"|>San Francisco<|"|>,'
'unit:<|"|>celsius<|"|>}'
"<tool_call|>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_weather"
args = json.loads(result.tool_calls[0].function.arguments)
assert args == {"location": "San Francisco", "unit": "celsius"}
def test_text_before_tool_call(self, parser, mock_request):
model_output = (
"Let me check the weather for you. "
'<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}'
"<tool_call|>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert result.content == "Let me check the weather for you."
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_weather"
def test_multiple_tool_calls(self, parser, mock_request):
model_output = (
'<|tool_call>call:get_weather{location:<|"|>London<|"|>}'
"<tool_call|>"
'<|tool_call>call:get_time{location:<|"|>London<|"|>}'
"<tool_call|>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 2
assert result.tool_calls[0].function.name == "get_weather"
assert result.tool_calls[1].function.name == "get_time"
def test_nested_arguments(self, parser, mock_request):
model_output = (
"<|tool_call>call:complex_function{"
'nested:{inner:<|"|>value<|"|>},'
'list:[<|"|>a<|"|>,<|"|>b<|"|>]}'
"<tool_call|>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "complex_function"
args = json.loads(result.tool_calls[0].function.arguments)
assert args == {"nested": {"inner": "value"}, "list": ["a", "b"]}
def test_tool_call_with_number_and_boolean(self, parser, mock_request):
model_output = (
"<|tool_call>call:set_status{"
"is_active:true,"
"count:42,"
"score:3.14}"
"<tool_call|>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "set_status"
args = json.loads(result.tool_calls[0].function.arguments)
assert args == {"is_active": True, "count": 42, "score": 3.14}
def test_incomplete_tool_call(self, parser, mock_request):
model_output = '<|tool_call>call:get_weather{location:<|"|>London'
result = parser.extract_tool_calls(model_output, mock_request)
# Incomplete — no <tool_call|> end marker, regex won't match
assert result.tools_called is False
assert result.content == model_output
def test_hyphenated_function_name(self, parser, mock_request):
"""Ensure function names with hyphens are parsed correctly."""
model_output = (
'<|tool_call>call:get-weather{location:<|"|>London<|"|>}<tool_call|>'
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert result.tool_calls[0].function.name == "get-weather"
def test_dotted_function_name(self, parser, mock_request):
"""Ensure function names with dots are parsed correctly."""
model_output = (
'<|tool_call>call:weather.get{location:<|"|>London<|"|>}<tool_call|>'
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert result.tool_calls[0].function.name == "weather.get"
def test_no_arguments(self, parser, mock_request):
"""Tool calls with empty arguments."""
model_output = "<|tool_call>call:get_status{}<tool_call|>"
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert result.tool_calls[0].function.name == "get_status"
args = json.loads(result.tool_calls[0].function.arguments)
assert args == {}
# ---------------------------------------------------------------------------
# Streaming extraction tests
# ---------------------------------------------------------------------------
class TestStreamingExtraction:
"""Tests for the streaming tool call extraction.
These simulate the token-by-token streaming that vLLM performs,
feeding incremental text to extract_tool_calls_streaming() and
verifying that the accumulated argument deltas form valid JSON.
"""
def _simulate_streaming(
self, parser: Gemma4ToolParser, mock_request: Any, chunks: list[str]
) -> list[tuple[Any, str]]:
"""Feed chunks through the streaming parser and collect results.
Returns a list of (delta_message, accumulated_text) tuples.
"""
results: list[tuple[Any, str]] = []
previous_text: str = ""
previous_token_ids: list[int] = []
for chunk in chunks:
current_text = previous_text + chunk
# Use token ID 48 for tool_call start, 49 for end, 0 otherwise
delta_token_ids: list[int] = []
if TOOL_CALL_START in chunk:
delta_token_ids.append(48)
elif TOOL_CALL_END in chunk:
delta_token_ids.append(49)
else:
delta_token_ids.append(0)
current_token_ids = previous_token_ids + delta_token_ids
delta = parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=chunk,
previous_token_ids=tuple(previous_token_ids),
current_token_ids=tuple(current_token_ids),
delta_token_ids=tuple(delta_token_ids),
request=mock_request,
)
results.append((delta, current_text))
previous_text = current_text
previous_token_ids = list(current_token_ids)
return results
def _collect_arguments(self, results):
"""Collect all argument deltas from streaming results into one string."""
args_text = ""
for delta, _ in results:
if delta and delta.tool_calls:
for tc in delta.tool_calls:
func = tc.function if isinstance(tc.function, dict) else tc.function
if isinstance(func, dict):
arg = func.get("arguments", "")
else:
arg = getattr(func, "arguments", "") or ""
if arg:
args_text += arg
return args_text
def _collect_function_name(self, results):
"""Extract the function name from streaming results."""
for delta, _ in results:
if delta and delta.tool_calls:
for tc in delta.tool_calls:
func = tc.function if isinstance(tc.function, dict) else tc.function
if isinstance(func, dict):
name = func.get("name")
else:
name = getattr(func, "name", None)
if name:
return name
return None
def test_basic_streaming_single_tool(self, parser, mock_request):
"""Simulate the exact streaming scenario from the bug report.
Model generates:
<|tool_call>call:get_weather{location:<|"|>Paris, France<|"|>}<tool_call|>
Expected: arguments should be valid JSON {"location": "Paris, France"}
"""
chunks = [
"<|tool_call>",
"call:get_weather{",
'location:<|"|>Paris',
", France",
'<|"|>}',
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
# Verify function name
name = self._collect_function_name(results)
assert name == "get_weather", f"Expected 'get_weather', got '{name}'"
# Verify arguments form valid JSON
args_text = self._collect_arguments(results)
assert args_text, "No arguments were streamed"
parsed_args = json.loads(args_text)
assert parsed_args == {"location": "Paris, France"}
def test_streaming_multi_arg(self, parser, mock_request):
"""Streaming with multiple arguments."""
chunks = [
"<|tool_call>",
"call:get_weather{",
'location:<|"|>Tokyo<|"|>,',
'unit:<|"|>celsius<|"|>}',
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
name = self._collect_function_name(results)
assert name == "get_weather"
args_text = self._collect_arguments(results)
assert args_text
parsed_args = json.loads(args_text)
assert parsed_args == {"location": "Tokyo", "unit": "celsius"}
def test_streaming_no_extra_brace(self, parser, mock_request):
"""Verify the closing } is NOT leaked into arguments (Bug #2)."""
chunks = [
"<|tool_call>",
"call:get_weather{",
'location:<|"|>London<|"|>}',
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
args_text = self._collect_arguments(results)
assert args_text
# The args text must be valid JSON (no extra })
parsed = json.loads(args_text)
assert parsed == {"location": "London"}
# Specifically assert no double-brace
assert args_text.count("}") <= 1, (
f"Arguments contain extra closing brace: {args_text!r}"
)
def test_streaming_no_unquoted_keys(self, parser, mock_request):
"""Verify keys are properly quoted in JSON (Bug #1)."""
chunks = [
"<|tool_call>",
"call:get_weather{",
'location:<|"|>Paris<|"|>}',
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
args_text = self._collect_arguments(results)
# Must start with { and contain quoted key
assert args_text.lstrip().startswith("{"), (
f"Arguments don't start with '{{': {args_text!r}"
)
assert '"location"' in args_text, (
f"Key 'location' not properly quoted: {args_text!r}"
)
def test_streaming_name_no_call_prefix(self, parser, mock_request):
"""Verify function name has no 'call:' prefix."""
chunks = [
"<|tool_call>",
"call:get_weather{",
'location:<|"|>Paris<|"|>}',
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
name = self._collect_function_name(results)
assert name == "get_weather"
assert not name.startswith("call:"), f"Name has 'call:' prefix: {name!r}"
def test_streaming_text_before_tool_call(self, parser, mock_request):
"""Text before tool call should be emitted as content."""
chunks = [
"Let me check ",
"the weather. ",
"<|tool_call>",
"call:get_weather{",
'location:<|"|>London<|"|>}',
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
# First chunks should be content
content_parts = []
for delta, _ in results:
if delta and delta.content:
content_parts.append(delta.content)
assert "".join(content_parts).strip().startswith("Let me check")
def test_streaming_numeric_args(self, parser, mock_request):
"""Streaming with numeric and boolean argument values."""
chunks = [
"<|tool_call>",
"call:set_config{",
"count:42,",
"active:true}",
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
args_text = self._collect_arguments(results)
if args_text:
parsed_args = json.loads(args_text)
assert parsed_args["count"] == 42
assert parsed_args["active"] is True
def test_streaming_empty_args(self, parser, mock_request):
"""Tool call with no arguments."""
chunks = [
"<|tool_call>",
"call:get_status{}",
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
name = self._collect_function_name(results)
assert name == "get_status"
...@@ -12,6 +12,7 @@ from .dual_chunk_rope import DualChunkRotaryEmbedding ...@@ -12,6 +12,7 @@ from .dual_chunk_rope import DualChunkRotaryEmbedding
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding
from .fope import FourierRotaryEmbedding from .fope import FourierRotaryEmbedding
from .gemma4_rope import Gemma4RotaryEmbedding
from .linear_scaling_rope import LinearScalingRotaryEmbedding from .linear_scaling_rope import LinearScalingRotaryEmbedding
from .llama3_rope import Llama3RotaryEmbedding from .llama3_rope import Llama3RotaryEmbedding
from .llama4_vision_rope import Llama4VisionRotaryEmbedding from .llama4_vision_rope import Llama4VisionRotaryEmbedding
...@@ -134,6 +135,17 @@ def get_rope( ...@@ -134,6 +135,17 @@ def get_rope(
is_neox_style, is_neox_style,
dtype, dtype,
) )
elif scaling_type == "proportional":
# Proportional RoPE is used by Gemma4 for global (full) attention.
# Gemma4 uses a sparse/fractional RoPE with cross-mixing between halves.
rotary_emb = Gemma4RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
elif scaling_type == "llama3": elif scaling_type == "llama3":
scaling_factor = rope_parameters["factor"] scaling_factor = rope_parameters["factor"]
low_freq_factor = rope_parameters["low_freq_factor"] low_freq_factor = rope_parameters["low_freq_factor"]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Gemma4-specific Rotary Positional Embeddings (proportional scaling).
Gemma4 uses "proportional" RoPE which computes inv_freq frequencies scaled
by head_dim (not rotary_dim), and zero-pads for non-rotated dimensions when
partial_rotary_factor < 1. The actual rotation uses standard neox-style
rotate_half, matching HF transformers' apply_rotary_pos_emb.
"""
import torch
from .base import RotaryEmbedding
class Gemma4RotaryEmbedding(RotaryEmbedding):
"""Gemma4 proportional RoPE.
Extends RotaryEmbedding (which provides standard neox-style rotation
via ops.rotary_embedding CUDA kernel) but overrides the inv_freq
computation to match HF's _compute_proportional_rope_parameters:
- Frequency exponents use head_dim (not rotary_dim) as denominator
- Non-rotated dims are zero-padded (cos=1, sin=0 = identity rotation)
When partial_rotary_factor=1.0 (the default for some variants), ALL dims are
rotated and this is equivalent to standard RotaryEmbedding with
head_dim-scaled frequencies.
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
# Number of rotation angle pairs (from partial_rotary_factor)
self.rope_angles = rotary_dim // 2
# Non-rotated angle pairs per half
self.nope_angles = (head_size // 2) - self.rope_angles
# Important: set rotary_dim = head_size so the base class's
# forward_static applies rotation to ALL dims of the cos/sin cache.
# The non-rotated dims will have cos=1, sin=0 (identity) thanks
# to our _compute_inv_freq zero-padding.
super().__init__(
head_size,
head_size, # rotary_dim = head_size (full application)
max_position_embeddings,
base,
is_neox_style,
dtype,
)
def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute frequencies matching HF proportional RoPE.
Key difference from base: exponent denominator is head_size (not
rotary_dim), and non-rotated dims are zero-padded.
"""
# HF formula: base ** (arange(0, 2*rope_angles, 2) / head_dim)
freq_exponents = (
torch.arange(0, 2 * self.rope_angles, 2, dtype=torch.float) / self.head_size
)
inv_freq = 1.0 / (base**freq_exponents)
# Zero-pad for non-rotated dims (identity rotation: cos=1, sin=0)
if self.nope_angles > 0:
inv_freq = torch.cat(
[
inv_freq,
torch.zeros(self.nope_angles, dtype=torch.float),
]
)
return inv_freq
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", rope_angles={self.rope_angles}, nope_angles={self.nope_angles}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
return s
...@@ -9,6 +9,7 @@ from vllm.utils.math_utils import round_up ...@@ -9,6 +9,7 @@ from vllm.utils.math_utils import round_up
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -52,6 +53,58 @@ class Gemma3TextModelConfig(VerifyAndUpdateConfig): ...@@ -52,6 +53,58 @@ class Gemma3TextModelConfig(VerifyAndUpdateConfig):
hf_config.is_causal = not hf_config.use_bidirectional_attention hf_config.is_causal = not hf_config.use_bidirectional_attention
class Gemma4Config(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
"""Force unified attention backend for models with heterogeneous
head dimensions.
Some Gemma4 variants use different head dimensions for
sliding window (head_dim) vs full attention (global_head_dim) layers.
When global_head_dim > 256, FlashAttention rejects those layers
(head_size <= 256 kernel limit), causing vLLM to select a different
backend for each layer type. This mixed-backend execution produces
numerical divergence and output corruption.
The fix detects heterogeneous head dimensions from the model config
and forces TRITON_ATTN (which has no head_size ceiling) for all
layers when the user hasn't explicitly chosen a backend.
TODO: Heterogeneous head_sizes (head_dim != global_head_dim)
require NixlConnector changes to support per-layer KV transfer
with different head dimensions for prefill-decode disaggregation.
"""
hf_text_config = vllm_config.model_config.hf_text_config
head_dim = getattr(hf_text_config, "head_dim", None)
global_head_dim = getattr(hf_text_config, "global_head_dim", None)
# Only force Triton when head dimensions actually differ AND the
# larger one exceeds FlashAttention's kernel limit (head_size <= 256).
# This avoids unnecessary backend forcing on smaller models where
# the config carries global_head_dim but all layers can still use
# the same FA backend.
max_head_dim = max(head_dim or 0, global_head_dim or 0)
if (
head_dim is not None
and global_head_dim is not None
and head_dim != global_head_dim
and max_head_dim > 256
and vllm_config.attention_config.backend is None
):
from vllm.v1.attention.backends.registry import (
AttentionBackendEnum,
)
vllm_config.attention_config.backend = AttentionBackendEnum.TRITON_ATTN
logger.info(
"Gemma4 model has heterogeneous head dimensions "
"(head_dim=%d, global_head_dim=%d). Forcing TRITON_ATTN "
"backend to prevent mixed-backend numerical divergence.",
head_dim,
global_head_dim,
)
class GptOssForCausalLMConfig(VerifyAndUpdateConfig): class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod @staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None: def verify_and_update_config(vllm_config: "VllmConfig") -> None:
...@@ -533,6 +586,8 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { ...@@ -533,6 +586,8 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501 "Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501
"FalconMambaForCausalLM": MambaModelConfig, "FalconMambaForCausalLM": MambaModelConfig,
"Gemma3TextModel": Gemma3TextModelConfig, "Gemma3TextModel": Gemma3TextModelConfig,
"Gemma4ForCausalLM": Gemma4Config,
"Gemma4ForConditionalGeneration": Gemma4Config,
"GptOssForCausalLM": GptOssForCausalLMConfig, "GptOssForCausalLM": GptOssForCausalLMConfig,
"GteModel": SnowflakeGteNewModelConfig, "GteModel": SnowflakeGteNewModelConfig,
"GteNewForSequenceClassification": GteNewModelConfig, "GteNewForSequenceClassification": GteNewModelConfig,
......
This diff is collapsed.
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
"""Gemma4 output parsing utilities for offline inference.
Standalone functions that parse decoded model text to extract structured
thinking content and tool calls from Gemma4 models. These are pure-Python
utilities with zero heavy dependencies — they work on raw decoded strings
from any inference backend (vLLM, HuggingFace, TGI, etc.).
Usage with vLLM offline inference::
from vllm import LLM, SamplingParams
from vllm.model_executor.models.gemma4_utils import (
parse_output,
parse_tool_calls,
)
llm = LLM(model="google/gemma-4-it")
outputs = llm.generate(prompt, SamplingParams(...))
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
# Extract thinking / answer (works with or without enable_thinking)
result = parse_output(text)
print(result["thinking"]) # chain-of-thought or None
print(result["answer"]) # final answer
# Extract tool calls
tool_calls = parse_tool_calls(text)
for tc in tool_calls:
print(f"{tc['name']}({tc['arguments']})")
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
do not need a transformers dependency for output parsing.
"""
import json
import regex as re
# ---- Thinking Mode Utility ----
# Thinking delimiter tokens as they appear in decoded text.
# Gemma4 uses <|channel> (start) and <channel|> (end) as thinking delimiters.
_THINKING_START_TAG = "<|channel>"
_THINKING_END_TAG = "<channel|>"
# Sentinel tokens that may appear in decoded output.
_TURN_END_TAG = "<turn|>"
def parse_thinking_output(text: str) -> dict[str, str | None]:
"""Parse decoded Gemma4 model output.
Use this on **all** Gemma4 output regardless of whether thinking mode
was enabled. It handles three cases:
1. **Thinking enabled, tags present** — splits on ``<|channel>``/
``<channel|>`` to separate chain-of-thought from the answer and
strips the ``thought\\n`` role label.
2. **Thinking disabled, spurious label** — strips the bare
``thought\\n`` prefix that some Gemma4 models emit even
without thinking mode.
3. **Clean output** — returns the text unchanged.
The answer text is always cleaned of trailing sentinel tokens
(``<turn|>``, ``<eos>``, etc.).
Args:
text: Decoded model output text (from ``tokenizer.decode(...)``).
Returns:
A dict with keys:
- ``"thinking"``: The chain-of-thought text, or ``None`` if no
thinking delimiters were found.
- ``"answer"``: The final answer text.
Example::
>>> from vllm.model_executor.models.gemma4_utils import parse_thinking_output
>>> output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
>>> result = parse_thinking_output(output_text)
>>> print(result["thinking"]) # chain-of-thought reasoning or None
>>> print(result["answer"]) # final answer
"""
if _THINKING_END_TAG in text:
parts = text.split(_THINKING_END_TAG, 1)
thinking_block = parts[0]
answer = _clean_answer(parts[1])
# Extract thinking content: strip the start tag if present
if _THINKING_START_TAG in thinking_block:
thinking = thinking_block.split(_THINKING_START_TAG, 1)[1]
else:
thinking = thinking_block
# Strip the "thought\n" channel role label the model emits inside
# <|channel>thought\n...<channel|> (analogous to "user\n" in
# <|turn>user\n...<turn|>).
thinking = _strip_thought_label(thinking.strip())
thinking = thinking.strip()
return {"thinking": thinking, "answer": answer}
# No thinking delimiters found.
# Strip spurious "thought\n" role label that some Gemma4 models sometimes
# emit even without thinking mode enabled, then clean trailing tokens.
answer = _strip_thought_label(text)
answer = _clean_answer(answer)
return {"thinking": None, "answer": answer}
def _strip_thought_label(text: str) -> str:
"""Strip the spurious ``thought\\n`` label from the start of text.
Only strips when ``thought`` appears as the very first word followed by
a newline — preserving the word ``thought`` in any other context.
"""
if text.startswith("thought\n"):
return text[len("thought\n") :]
return text
def _clean_answer(text: str) -> str:
"""Clean trailing sentinel tokens from the answer text.
Strips ``<turn|>``, ``<eos>``, and surrounding whitespace that the
model appends at the end of its response.
"""
text = text.strip()
# Strip trailing <turn|> (Gemma4 turn-end marker)
if text.endswith(_TURN_END_TAG):
text = text[: -len(_TURN_END_TAG)].rstrip()
# Strip trailing <eos> if present
if text.endswith("<eos>"):
text = text[:-5].rstrip()
return text
# ---- Tool Call Parsing Utility ----
#
# NOTE: For the OpenAI-compatible API server tool parser (streaming +
# non-streaming), see vllm/tool_parsers/gemma4_tool_parser.py.
# This module provides offline inference utilities for direct user import.
# Tool call delimiter tokens as they appear in decoded text.
# Standard format: <|tool_call>call:name{args}<tool_call|>
_TOOL_CALL_START_TAG = "<|tool_call>"
_TOOL_CALL_END_TAG = "<tool_call|>"
_TOOL_RESPONSE_START_TAG = "<|tool_response>"
# Gemma4 escape token as it appears in decoded text.
_ESCAPE_TOKEN = '<|"|>'
def _parse_tool_arguments(args_str: str) -> dict[str, str]:
"""Parse tool call arguments from the Gemma4 compact format.
Handles the ``key:<|"|>value<|"|>`` format used by Gemma4, with fallback
to heuristic key-value extraction. Also tolerates the slightly different
``key: "value"`` format (space + plain quotes) that some chat templates
produce.
Args:
args_str: Raw argument string from inside ``call:name{...}``.
Returns:
Dictionary of argument name → value.
"""
if not args_str or not args_str.strip():
return {}
# Replace Gemma4 escape tokens with standard quotes.
cleaned = args_str.replace(_ESCAPE_TOKEN, '"')
# Try JSON parsing first (handles nested values, arrays, etc.).
try:
parsed = json.loads("{" + cleaned + "}")
# Ensure all values are strings for consistency.
return {k: str(v) if not isinstance(v, str) else v for k, v in parsed.items()}
except (json.JSONDecodeError, ValueError):
pass
# Fallback: extract key:"value" pairs (allow optional space after colon).
arguments = {}
for key, value in re.findall(r'(\w+):\s*"([^"]*)"', cleaned):
arguments[key] = value
if not arguments:
# Last resort: extract key:value pairs (unquoted).
for key, value in re.findall(r"(\w+):\s*([^,}]+)", args_str):
arguments[key] = value.strip().strip('"').replace(_ESCAPE_TOKEN, "")
return arguments
def parse_tool_calls(text: str, *, strict: bool = False) -> list[dict]:
"""Parse tool calls from decoded Gemma4 model output.
Uses a tiered parsing strategy to handle known output variations in
Gemma4 models, which may emit
non-standard tool call formats.
Parsing tiers:
1. **Standard**: ``<|tool_call>call:name{args}<tool_call|>``
(special token IDs 48/49 in decoded text)
2. **Fallback** (when ``strict=False``): bare ``call:name{args}``
patterns, including ``<call>name{args}`` (fragmented tokens from
multimodal inputs)
Args:
text: Decoded model output text (from ``tokenizer.decode(...,
skip_special_tokens=False)``).
strict: If ``True``, only match the standard ``<|tool_call>`` format.
If ``False`` (default), also try fallback patterns for
known Gemma4 output variations.
Returns:
A list of dicts, each with keys:
- ``"name"``: The tool function name (e.g. ``"get_weather"``).
- ``"arguments"``: A dict of argument name → value.
Example::
>>> from vllm.model_executor.models.gemma4_utils import (
... parse_tool_calls
... )
>>> output = tokenizer.decode(outputs[0], skip_special_tokens=False)
>>> tool_calls = parse_tool_calls(output)
>>> for tc in tool_calls:
... print(f"Call: {tc['name']}({tc['arguments']})")
"""
results = []
# Tier 1: Standard format with special tokens.
# <|tool_call>call:name{args}<tool_call|>
# Note: Some Gemma4 models emit <turn|> instead of <tool_call|>.
standard_pattern = r"<\|tool_call\>call:(\w+)\{(.*?)\}(?:<tool_call\|>|<turn\|>)"
for match in re.finditer(standard_pattern, text, re.DOTALL):
name, args_str = match.group(1), match.group(2)
results.append(
{
"name": name,
"arguments": _parse_tool_arguments(args_str),
}
)
if results or strict:
return results
# Tier 2: Fallback for known Gemma4 output variations.
# Matches: <call>name{args}, call:name{args}, or bare call:name{args}<eos>
fallback_pattern = r"(?:<call>|(?:^|\s)call:)(\w+)\{(.*?)\}"
for match in re.finditer(fallback_pattern, text, re.DOTALL):
name, args_str = match.group(1), match.group(2)
results.append(
{
"name": name,
"arguments": _parse_tool_arguments(args_str),
}
)
return results
def has_tool_response_tag(text: str) -> bool:
"""Check if model output properly ends with a tool response tag.
Some Gemma4 models sometimes emit ``<eos>`` instead of
``<|tool_response>`` after a tool call. This helper detects
whether the model used the proper termination, so callers can
decide whether to inject ``<|tool_response>`` into the next prompt.
Args:
text: Decoded model output text.
Returns:
``True`` if the output ends with ``<|tool_response>``
(proper behavior), ``False`` otherwise.
Example::
>>> from vllm.model_executor.models.gemma4_utils import (
... has_tool_response_tag
... )
>>> if not has_tool_response_tag(model_output):
... # Model used <eos> instead — inject <|tool_response> manually
... next_prompt = "<|tool_response>" + tool_result
"""
stripped = text.rstrip()
return stripped.endswith(_TOOL_RESPONSE_START_TAG)
...@@ -111,6 +111,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -111,6 +111,7 @@ _TEXT_GENERATION_MODELS = {
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
"Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"), "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
"Gemma4ForCausalLM": ("gemma4", "Gemma4ForCausalLM"),
"Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"), "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"),
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
...@@ -383,6 +384,7 @@ _MULTIMODAL_MODELS = { ...@@ -383,6 +384,7 @@ _MULTIMODAL_MODELS = {
"gemma3n_mm", "gemma3n_mm",
"Gemma3nForConditionalGeneration", "Gemma3nForConditionalGeneration",
), ),
"Gemma4ForConditionalGeneration": ("gemma4_mm", "Gemma4ForConditionalGeneration"),
"GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"), "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
......
...@@ -233,8 +233,15 @@ class AutoWeightsLoader: ...@@ -233,8 +233,15 @@ class AutoWeightsLoader:
): ):
""" """
Add tensor names that are not in the model params that may be in the Add tensor names that are not in the model params that may be in the
safetensors, e.g., batch normalization stats. safetensors, e.g., batch normalization stats and registered buffers.
""" """
# Add persistent registered buffers.
# Non-persistent buffers are excluded, matching PyTorch state_dict().
non_persistent = getattr(module, "_non_persistent_buffers_set", set())
for buf_name, buf in module.named_buffers(recurse=False):
if buf_name not in child_params and buf_name not in non_persistent:
child_params[buf_name] = buf
if isinstance( if isinstance(
module, module,
( (
......
...@@ -32,6 +32,10 @@ _REASONING_PARSERS_TO_REGISTER = { ...@@ -32,6 +32,10 @@ _REASONING_PARSERS_TO_REGISTER = {
"ernie45_reasoning_parser", "ernie45_reasoning_parser",
"Ernie45ReasoningParser", "Ernie45ReasoningParser",
), ),
"gemma4": (
"gemma4_reasoning_parser",
"Gemma4ReasoningParser",
),
"glm45": ( "glm45": (
"deepseek_v3_reasoning_parser", "deepseek_v3_reasoning_parser",
"DeepSeekV3ReasoningWithThinkingParser", "DeepSeekV3ReasoningWithThinkingParser",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.tokenizers import TokenizerLike
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
# Role label that Gemma4 emits at the start of the thinking channel.
# The model generates: <|channel>thought\n...reasoning...<channel|>
# This prefix must be stripped to expose only the actual reasoning content.
_THOUGHT_PREFIX = "thought\n"
class Gemma4ReasoningParser(BaseThinkingReasoningParser):
"""
Reasoning parser for Google Gemma4 thinking models.
Gemma4 uses <|channel>...<channel|> tokens to delimit reasoning/thinking
content within its output. Thinking mode is activated by passing
``enable_thinking=True`` in the chat template kwargs, which injects a
system turn containing <|think|> (token 98) to trigger chain-of-thought
reasoning.
Output pattern when thinking is enabled::
<|channel>thought
...chain of thought reasoning...<channel|>
Final answer text here.
The ``thought\\n`` role label inside the channel delimiters is a
structural artefact (analogous to ``user\\n`` in ``<|turn>user\\n...``).
This parser strips it so that downstream consumers see only the
actual reasoning text, consistent with the offline parser
(``vllm.reasoning.gemma4_utils._strip_thought_label``).
"""
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
# Instance state for streaming prefix stripping.
# Tracks only the reasoning text received from the base parser,
# independent of current_text (which may contain pre-reasoning
# content and lacks special token text due to
# skip_special_tokens=True).
self._reasoning_text: str = ""
self._prefix_stripped: bool = False
@property
def start_token(self) -> str:
"""The token that starts reasoning content."""
return "<|channel>"
@property
def end_token(self) -> str:
"""The token that ends reasoning content."""
return "<channel|>"
# ------------------------------------------------------------------
# Non-streaming path
# ------------------------------------------------------------------
def extract_reasoning(
self,
model_output: str,
request: "ChatCompletionRequest | ResponsesRequest",
) -> tuple[str | None, str | None]:
"""Extract reasoning, stripping the ``thought\\n`` role label."""
if self.start_token not in model_output and self.end_token not in model_output:
# Default to content history if no tags are present
# (or if they were stripped)
return None, model_output
reasoning, content = super().extract_reasoning(model_output, request)
if reasoning is not None:
reasoning = _strip_thought_label(reasoning)
return reasoning, content
# ------------------------------------------------------------------
# Streaming path
# ------------------------------------------------------------------
def extract_reasoning_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
"""Extract streaming reasoning, stripping ``thought\\n`` from the
first reasoning delta(s).
The ``thought\\n`` prefix may arrive as a single delta or split
across multiple deltas (e.g. ``"thought"`` then ``"\\n"``). We
buffer early reasoning tokens until we can determine whether the
prefix is present, then emit the buffered content minus the
prefix.
Unlike the previous implementation which reconstructed accumulated
reasoning from ``current_text``, this uses instance state
(``_reasoning_text``) to track only the reasoning content returned
by the base parser. This is necessary because
``skip_special_tokens=True`` (the vLLM default) causes the
``<|channel>`` delimiter to be invisible in ``current_text``,
making it impossible to separate pre-reasoning content from
reasoning content via string matching.
"""
result = super().extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
)
if result is None:
return None
if result.reasoning is None:
return result
# Accumulate ONLY the reasoning text from base parser results.
# This is immune to pre-reasoning content pollution.
self._reasoning_text += result.reasoning
# Once the prefix has been handled, all subsequent reasoning
# deltas pass through unchanged.
if self._prefix_stripped:
return result
# ---- Prefix stripping logic ----
# Case 1: We've accumulated enough to confirm the prefix is
# present. Strip it and pass through the remainder.
if self._reasoning_text.startswith(_THOUGHT_PREFIX):
prefix_len = len(_THOUGHT_PREFIX)
# How much reasoning was accumulated before this delta?
prev_reasoning_len = len(self._reasoning_text) - len(result.reasoning)
if prev_reasoning_len >= prefix_len:
# Prefix was already consumed by prior deltas; this
# delta is entirely real content — pass through.
self._prefix_stripped = True
return result
else:
# Part or all of the prefix is in this delta.
chars_of_prefix_in_delta = prefix_len - prev_reasoning_len
stripped = result.reasoning[chars_of_prefix_in_delta:]
if stripped:
self._prefix_stripped = True
result.reasoning = stripped
return result
else:
# This entire delta was prefix — suppress it.
# Don't set _prefix_stripped yet; there may be more
# prefix chars to consume in the next delta.
if len(self._reasoning_text) >= prefix_len:
self._prefix_stripped = True
return None
# Case 2: Accumulated text is a strict prefix of
# _THOUGHT_PREFIX (e.g. we've only seen "thou" so far).
# Buffer by suppressing — we can't yet tell if this will
# become the full prefix or diverge.
if _THOUGHT_PREFIX.startswith(self._reasoning_text):
return None
# Case 3: Accumulated text doesn't match the thought prefix
# at all. This means prior deltas were buffered (suppressed
# by Case 2) but the text diverged. Re-emit the full
# accumulated text to avoid data loss.
self._prefix_stripped = True
result.reasoning = self._reasoning_text
return result
def _strip_thought_label(text: str) -> str:
"""Remove the ``thought\\n`` role label from the beginning of text.
Mirrors ``vllm.reasoning.gemma4_utils._strip_thought_label`` from the
offline parser.
"""
if text.startswith(_THOUGHT_PREFIX):
return text[len(_THOUGHT_PREFIX) :]
return text
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
"""Gemma4 thinking/reasoning output parsing utilities for offline inference.
Standalone functions that parse decoded model text to extract structured
thinking content from Gemma4 models. These are pure-Python utilities with
zero heavy dependencies — they work on raw decoded strings from any
inference backend (vLLM, HuggingFace, TGI, etc.).
For the OpenAI-compatible API reasoning parser (streaming +
non-streaming), see ``vllm.reasoning.gemma4_reasoning_parser``.
For tool call parsing, see ``vllm.tool_parsers.gemma4_utils``.
Usage with vLLM offline inference::
from vllm import LLM, SamplingParams
from vllm.reasoning.gemma4_utils import parse_thinking_output
llm = LLM(model="google/gemma-4-it")
outputs = llm.generate(prompt, SamplingParams(...))
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
# Extract thinking / answer (works with or without enable_thinking)
result = parse_thinking_output(text)
print(result["thinking"]) # chain-of-thought or None
print(result["answer"]) # final answer
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
do not need a transformers dependency for output parsing.
"""
# ---- Thinking Mode Utility ----
# Thinking delimiter tokens as they appear in decoded text.
# Gemma4 uses <|channel> (start) and <channel|> (end) as thinking delimiters.
_THINKING_START_TAG = "<|channel>"
_THINKING_END_TAG = "<channel|>"
# Sentinel tokens that may appear in decoded output.
_TURN_END_TAG = "<turn|>"
def parse_thinking_output(text: str) -> dict[str, str | None]:
"""Parse decoded Gemma4 model output.
Use this on **all** Gemma4 output regardless of whether thinking mode
was enabled. It handles three cases:
1. **Thinking enabled, tags present** — splits on ``<|channel>``/
``<channel|>`` to separate chain-of-thought from the answer and
strips the ``thought\\n`` role label.
2. **Thinking disabled, spurious label** — strips the bare
``thought\\n`` prefix that some Gemma4 models emit even
without thinking mode.
3. **Clean output** — returns the text unchanged.
The answer text is always cleaned of trailing sentinel tokens
(``<turn|>``, ``<eos>``, etc.).
Args:
text: Decoded model output text (from ``tokenizer.decode(...)``).
Returns:
A dict with keys:
- ``"thinking"``: The chain-of-thought text, or ``None`` if no
thinking delimiters were found.
- ``"answer"``: The final answer text.
Example::
>>> from vllm.reasoning.gemma4_utils import parse_thinking_output
>>> output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
>>> result = parse_thinking_output(output_text)
>>> print(result["thinking"]) # chain-of-thought reasoning or None
>>> print(result["answer"]) # final answer
"""
if _THINKING_END_TAG in text:
parts = text.split(_THINKING_END_TAG, 1)
thinking_block = parts[0]
answer = _clean_answer(parts[1])
# Extract thinking content: strip the start tag if present
if _THINKING_START_TAG in thinking_block:
thinking = thinking_block.split(_THINKING_START_TAG, 1)[1]
else:
thinking = thinking_block
# Strip the "thought\n" channel role label the model emits inside
# <|channel>thought\n...<channel|> (analogous to "user\n" in
# <|turn>user\n...<turn|>).
thinking = _strip_thought_label(thinking.strip())
thinking = thinking.strip()
return {"thinking": thinking, "answer": answer}
# No thinking delimiters found.
# Strip spurious "thought\n" role label that some Gemma4 models sometimes
# emit even without thinking mode enabled, then clean trailing tokens.
answer = _strip_thought_label(text)
answer = _clean_answer(answer)
return {"thinking": None, "answer": answer}
def _strip_thought_label(text: str) -> str:
"""Strip the spurious ``thought\\n`` label from the start of text.
Only strips when ``thought`` appears as the very first word followed by
a newline — preserving the word ``thought`` in any other context.
"""
if text.startswith("thought\n"):
return text[len("thought\n") :]
return text
def _clean_answer(text: str) -> str:
"""Clean trailing sentinel tokens from the answer text.
Strips ``<turn|>``, ``<eos>``, and surrounding whitespace that the
model appends at the end of its response.
"""
text = text.strip()
# Strip trailing <turn|> (Gemma4 turn-end marker)
if text.endswith(_TURN_END_TAG):
text = text[: -len(_TURN_END_TAG)].rstrip()
# Strip trailing <eos> if present
if text.endswith("<eos>"):
text = text[:-5].rstrip()
return text
...@@ -154,6 +154,10 @@ _TOOL_PARSERS_TO_REGISTER = { ...@@ -154,6 +154,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"functiongemma_tool_parser", "functiongemma_tool_parser",
"FunctionGemmaToolParser", "FunctionGemmaToolParser",
), ),
"gemma4": (
"gemma4_tool_parser",
"Gemma4ToolParser",
),
} }
......
This diff is collapsed.
This diff is collapsed.
...@@ -448,6 +448,16 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): ...@@ -448,6 +448,16 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
return getattr(self.hf_text_config, "num_nextn_predict_layers", 1) return getattr(self.hf_text_config, "num_nextn_predict_layers", 1)
class Gemma4ModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_head_size(self) -> int:
# Gemma4 uses dual head dimensions: head_dim (sliding attention)
# and global_head_dim (full attention). Return the largest so
# that attention backends allocate buffers large enough for both.
head_dim = getattr(self.hf_text_config, "head_dim", 0)
global_head_dim = getattr(self.hf_text_config, "global_head_dim", 0)
return max(head_dim, global_head_dim) or super().get_head_size()
# hf_config.model_type -> convertor class # hf_config.model_type -> convertor class
MODEL_ARCH_CONFIG_CONVERTORS = { MODEL_ARCH_CONFIG_CONVERTORS = {
"cohere_asr": CohereAsrModelArchConfigConvertor, "cohere_asr": CohereAsrModelArchConfigConvertor,
...@@ -471,4 +481,6 @@ MODEL_ARCH_CONFIG_CONVERTORS = { ...@@ -471,4 +481,6 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
"ernie_mtp": ErnieMTPModelArchConfigConvertor, "ernie_mtp": ErnieMTPModelArchConfigConvertor,
"pangu_ultra_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor, "pangu_ultra_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor,
"longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor, "longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor,
"gemma4": Gemma4ModelArchConfigConvertor,
"gemma4_text": Gemma4ModelArchConfigConvertor,
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment