Commit 8adcf8c4 authored by Luciano Martins's avatar Luciano Martins Committed by khluu
Browse files

feat(models): implement Google Gemma 4 architecture support (MoE, Multimodal,...


feat(models): implement Google Gemma 4 architecture support (MoE, Multimodal, Reasoning, Tool-Use) (#38826)
Signed-off-by: default avatarLuciano Martins <lucianommartins@users.noreply.github.com>
Signed-off-by: default avatarLuciano Martins <lucianomartins@google.com>
Co-authored-by: default avatarLuciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
(cherry picked from commit 08ed2b96)
parent cfad6a50
......@@ -394,6 +394,22 @@ VLM_TEST_SETTINGS = {
vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
patch_hf_runner=model_utils.gemma3_patch_hf_runner,
),
"gemma4": VLMTestInfo(
models=["google/gemma-4-E2B-it"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501
single_image_prompts=IMAGE_ASSETS.prompts(
{
"stop_sign": "What's the content in the center of the image?",
"cherry_blossom": "What is the season?",
}
),
multi_image_prompt="Describe the two images in detail.",
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
vllm_runner_kwargs={"limit_mm_per_prompt": {"image": 4}},
),
"granite_vision": VLMTestInfo(
models=["ibm-granite/granite-vision-3.3-2b"],
test_type=(VLMTestType.IMAGE),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import ImageTestAssets
from ...utils import build_model_context
# TODO: to be updated to "google/gemma-4-e2b-it" once the models are available
GEMMA4_MODEL_ID = "google/gemma-4-E2B-it"
@pytest.mark.parametrize("model_id", [GEMMA4_MODEL_ID])
def test_limit_mm_per_prompt(
image_assets: ImageTestAssets,
model_id: str,
):
"""Test that limit_mm_per_prompt accurately restricts multiple images."""
# We only allow 1 image
ctx = build_model_context(
model_id,
mm_processor_kwargs={},
limit_mm_per_prompt={"image": 1},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
# Provide 2 images in the prompt
prompt = "<image><image>"
# image_assets usually has multiple images
images = [asset.pil_image for asset in image_assets][:2]
if len(images) < 2:
images = [images[0], images[0]]
mm_data = {"image": images}
# Expect ValueError when exceeding limit
with pytest.raises(ValueError, match="At most 1 image"):
processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={},
)
......@@ -277,6 +277,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"google/gemma-2-9b", extras={"tiny": "google/gemma-2-2b-it"}
),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
"Gemma4ForCausalLM": _HfExamplesInfo(
"google/gemma-4-E2B-it",
min_transformers_version="5.0.0",
),
"Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it"),
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
"Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"),
......@@ -805,6 +809,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
),
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
"Gemma4ForConditionalGeneration": _HfExamplesInfo(
"google/gemma-4-E2B-it",
min_transformers_version="5.5.0",
),
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it"),
"GlmAsrForConditionalGeneration": _HfExamplesInfo(
"zai-org/GLM-ASR-Nano-2512",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
# Using mistral tokenizer as a generic mock since the actual model is not on HF
from vllm.tokenizers.registry import get_tokenizer
parser_name = "gemma4"
@pytest.fixture(scope="module")
def generic_tokenizer():
return get_tokenizer("google/gemma-4-E2B-it")
INVALID_SIMPLE_NONSTREAMING = {
"output": "This is a reasoning section<channel|>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
INVALID_SIMPLE_STREAMING = {
"output": "This is a reasoning section<channel|>This is the rest",
"reasoning": None,
"content": "This is a reasoning sectionThis is the rest",
"is_reasoning_end": True,
}
INVALID_COMPLETE_NONSTREAMING = {
"output": "This is a reasoning section<channel|>",
"reasoning": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
}
INVALID_COMPLETE_STREAMING = {
"output": "This is a reasoning section<channel|>",
"reasoning": None,
"content": "This is a reasoning section",
"is_reasoning_end": True,
}
NO_CONTENT = {
"output": "<|channel>This is reasoning",
"reasoning": "This is reasoning",
"content": None,
"is_reasoning_end": False,
}
NO_REASONING = {
"output": "This is content",
"reasoning": None,
"content": "This is content",
"is_reasoning_end": False,
}
REASONING_WITH_CHANNEL = {
"output": "<|channel>This is a reasoning section<channel|>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
COMPLETE_REASONING_WITH_CHANNEL = {
"output": "<|channel>This is a reasoning section<channel|>",
"reasoning": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
}
MULTIPLE_LINES_WITH_CHANNEL = {
"output": "<|channel>This\nThat<channel|>This is the rest\nThat",
"reasoning": "This\nThat",
"content": "This is the rest\nThat",
"is_reasoning_end": True,
}
CHANNEL_NO_END = {
"output": "<|channel>This is a reasoning section",
"reasoning": "This is a reasoning section",
"content": None,
"is_reasoning_end": False,
}
EMPTY = {
"output": "",
"reasoning": None,
"content": "",
"is_reasoning_end": False,
}
NEW_LINE_NONSTREAMING = {
"output": (
"Before\n<|channel>This is a reasoning section<channel|>\nThis is the rest"
),
"reasoning": "This is a reasoning section",
"content": "\nThis is the rest",
"is_reasoning_end": True,
}
NEW_LINE_STREAMING = {
"output": (
"Before\n<|channel>This is a reasoning section<channel|>\nThis is the rest"
),
"reasoning": "This is a reasoning section",
"content": "Before\n\nThis is the rest",
"is_reasoning_end": True,
}
TEST_CASES = [
pytest.param(False, INVALID_SIMPLE_NONSTREAMING, id="invalid_simple"),
pytest.param(True, INVALID_SIMPLE_STREAMING, id="invalid_simple_streaming"),
pytest.param(False, INVALID_COMPLETE_NONSTREAMING, id="invalid_complete"),
pytest.param(True, INVALID_COMPLETE_STREAMING, id="invalid_complete_streaming"),
pytest.param(False, NO_CONTENT, id="no_content"),
pytest.param(False, NO_REASONING, id="no_reasoning"),
pytest.param(False, REASONING_WITH_CHANNEL, id="reasoning"),
pytest.param(True, REASONING_WITH_CHANNEL, id="reasoning_streaming"),
pytest.param(False, COMPLETE_REASONING_WITH_CHANNEL, id="complete_reasoning"),
pytest.param(
True, COMPLETE_REASONING_WITH_CHANNEL, id="complete_reasoning_streaming"
),
pytest.param(False, MULTIPLE_LINES_WITH_CHANNEL, id="multiple_lines"),
pytest.param(True, MULTIPLE_LINES_WITH_CHANNEL, id="multiple_lines_streaming"),
pytest.param(False, CHANNEL_NO_END, id="no_end"),
pytest.param(True, CHANNEL_NO_END, id="no_end_streaming"),
pytest.param(False, EMPTY, id="empty"),
pytest.param(False, NEW_LINE_NONSTREAMING, id="new_line"),
pytest.param(True, NEW_LINE_STREAMING, id="new_line_streaming"),
]
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_gemma4_reasoning(
streaming: bool,
param_dict: dict,
generic_tokenizer,
):
output = param_dict["output"]
# Resolve token IDs dynamically from the real tokenizer
vocab = generic_tokenizer.get_vocab()
start_token_id = vocab["<|channel>"]
end_token_id = vocab["<channel|>"]
index_start = output.find("<|channel>")
len_start = len("<|channel>")
index_end = output.find("<channel|>")
len_end = len("<channel|>")
output_tokens = []
def _encode(text: str) -> list[int]:
if not text:
return []
# Handle both raw transformers and vLLM wrappers
enc = getattr(generic_tokenizer, "tokenizer", generic_tokenizer)
try:
return enc.encode(text, add_special_tokens=False)
except TypeError:
return enc.encode(text)
if index_start != -1:
output_before = output[:index_start]
output_tokens += _encode(output_before)
output_tokens += [start_token_id]
if index_end != -1:
output_middle = output[index_start + len_start : index_end]
output_after = output[index_end + len_end :]
output_tokens += _encode(output_middle)
output_tokens += [end_token_id]
output_tokens += _encode(output_after)
else:
output_middle = output[index_start + len_start :]
output_tokens += _encode(output_middle)
elif index_end != -1:
output_before = output[:index_end]
output_after = output[index_end + len_end :]
output_tokens += _encode(output_before)
output_tokens += [end_token_id]
output_tokens += _encode(output_after)
else:
output_tokens += _encode(output)
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
generic_tokenizer
)
# We use the generic run_reasoning_extraction from utils
# Use decode per token to get standard spaces instead of
# SentencePiece space characters
output_token_strings = [generic_tokenizer.decode([t]) for t in output_tokens]
reasoning, content = run_reasoning_extraction(
parser, output_token_strings, streaming=streaming
)
assert reasoning == param_dict["reasoning"]
assert content == param_dict["content"]
# Test is_reasoning_end
is_reasoning_end = parser.is_reasoning_end(output_tokens)
assert is_reasoning_end == param_dict["is_reasoning_end"]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from typing import Any
from unittest.mock import MagicMock
import pytest
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.tool_parsers.gemma4_tool_parser import (
TOOL_CALL_END,
TOOL_CALL_START,
Gemma4ToolParser,
_parse_gemma4_args,
_parse_gemma4_array,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_tokenizer():
tokenizer = MagicMock()
tokenizer.encode.return_value = [1, 2, 3]
# Include the tool call start token in the vocab for the parser
tokenizer.get_vocab.return_value = {TOOL_CALL_START: 48, TOOL_CALL_END: 49}
return tokenizer
@pytest.fixture
def parser(mock_tokenizer):
return Gemma4ToolParser(mock_tokenizer)
@pytest.fixture
def mock_request():
request = MagicMock(spec=ChatCompletionRequest)
request.tools = []
request.tool_choice = "auto"
return request
# ---------------------------------------------------------------------------
# Unit tests for _parse_gemma4_args (shared parser logic)
# ---------------------------------------------------------------------------
class TestParseGemma4Args:
def test_empty_string(self):
assert _parse_gemma4_args("") == {}
def test_whitespace_only(self):
assert _parse_gemma4_args(" ") == {}
def test_single_string_value(self):
result = _parse_gemma4_args('location:<|"|>Paris<|"|>')
assert result == {"location": "Paris"}
def test_string_value_with_comma(self):
result = _parse_gemma4_args('location:<|"|>Paris, France<|"|>')
assert result == {"location": "Paris, France"}
def test_multiple_string_values(self):
result = _parse_gemma4_args(
'location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>'
)
assert result == {"location": "San Francisco", "unit": "celsius"}
def test_integer_value(self):
result = _parse_gemma4_args("count:42")
assert result == {"count": 42}
def test_float_value(self):
result = _parse_gemma4_args("score:3.14")
assert result == {"score": 3.14}
def test_boolean_true(self):
result = _parse_gemma4_args("flag:true")
assert result == {"flag": True}
def test_boolean_false(self):
result = _parse_gemma4_args("flag:false")
assert result == {"flag": False}
def test_mixed_types(self):
result = _parse_gemma4_args(
'name:<|"|>test<|"|>,count:42,active:true,score:3.14'
)
assert result == {
"name": "test",
"count": 42,
"active": True,
"score": 3.14,
}
def test_nested_object(self):
result = _parse_gemma4_args('nested:{inner:<|"|>value<|"|>}')
assert result == {"nested": {"inner": "value"}}
def test_array_of_strings(self):
result = _parse_gemma4_args('items:[<|"|>a<|"|>,<|"|>b<|"|>]')
assert result == {"items": ["a", "b"]}
def test_unterminated_string(self):
"""Unterminated strings should take everything after the delimiter."""
result = _parse_gemma4_args('key:<|"|>unterminated')
assert result == {"key": "unterminated"}
def test_empty_value(self):
"""Key with no value after colon."""
result = _parse_gemma4_args("key:")
assert result == {"key": ""}
class TestParseGemma4Array:
def test_string_array(self):
result = _parse_gemma4_array('<|"|>a<|"|>,<|"|>b<|"|>')
assert result == ["a", "b"]
def test_empty_array(self):
result = _parse_gemma4_array("")
assert result == []
def test_bare_values(self):
result = _parse_gemma4_array("42,true,3.14")
assert result == [42, True, 3.14]
# ---------------------------------------------------------------------------
# Non-streaming extraction tests
# ---------------------------------------------------------------------------
class TestExtractToolCalls:
def test_no_tool_calls(self, parser, mock_request):
model_output = "Hello, how can I help you today?"
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is False
assert result.tool_calls == []
assert result.content == model_output
def test_single_tool_call(self, parser, mock_request):
model_output = (
'<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>'
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_weather"
args = json.loads(result.tool_calls[0].function.arguments)
assert args == {"location": "London"}
def test_multiple_arguments(self, parser, mock_request):
model_output = (
"<|tool_call>call:get_weather{"
'location:<|"|>San Francisco<|"|>,'
'unit:<|"|>celsius<|"|>}'
"<tool_call|>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_weather"
args = json.loads(result.tool_calls[0].function.arguments)
assert args == {"location": "San Francisco", "unit": "celsius"}
def test_text_before_tool_call(self, parser, mock_request):
model_output = (
"Let me check the weather for you. "
'<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}'
"<tool_call|>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert result.content == "Let me check the weather for you."
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_weather"
def test_multiple_tool_calls(self, parser, mock_request):
model_output = (
'<|tool_call>call:get_weather{location:<|"|>London<|"|>}'
"<tool_call|>"
'<|tool_call>call:get_time{location:<|"|>London<|"|>}'
"<tool_call|>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 2
assert result.tool_calls[0].function.name == "get_weather"
assert result.tool_calls[1].function.name == "get_time"
def test_nested_arguments(self, parser, mock_request):
model_output = (
"<|tool_call>call:complex_function{"
'nested:{inner:<|"|>value<|"|>},'
'list:[<|"|>a<|"|>,<|"|>b<|"|>]}'
"<tool_call|>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "complex_function"
args = json.loads(result.tool_calls[0].function.arguments)
assert args == {"nested": {"inner": "value"}, "list": ["a", "b"]}
def test_tool_call_with_number_and_boolean(self, parser, mock_request):
model_output = (
"<|tool_call>call:set_status{"
"is_active:true,"
"count:42,"
"score:3.14}"
"<tool_call|>"
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "set_status"
args = json.loads(result.tool_calls[0].function.arguments)
assert args == {"is_active": True, "count": 42, "score": 3.14}
def test_incomplete_tool_call(self, parser, mock_request):
model_output = '<|tool_call>call:get_weather{location:<|"|>London'
result = parser.extract_tool_calls(model_output, mock_request)
# Incomplete — no <tool_call|> end marker, regex won't match
assert result.tools_called is False
assert result.content == model_output
def test_hyphenated_function_name(self, parser, mock_request):
"""Ensure function names with hyphens are parsed correctly."""
model_output = (
'<|tool_call>call:get-weather{location:<|"|>London<|"|>}<tool_call|>'
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert result.tool_calls[0].function.name == "get-weather"
def test_dotted_function_name(self, parser, mock_request):
"""Ensure function names with dots are parsed correctly."""
model_output = (
'<|tool_call>call:weather.get{location:<|"|>London<|"|>}<tool_call|>'
)
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert result.tool_calls[0].function.name == "weather.get"
def test_no_arguments(self, parser, mock_request):
"""Tool calls with empty arguments."""
model_output = "<|tool_call>call:get_status{}<tool_call|>"
result = parser.extract_tool_calls(model_output, mock_request)
assert result.tools_called is True
assert result.tool_calls[0].function.name == "get_status"
args = json.loads(result.tool_calls[0].function.arguments)
assert args == {}
# ---------------------------------------------------------------------------
# Streaming extraction tests
# ---------------------------------------------------------------------------
class TestStreamingExtraction:
"""Tests for the streaming tool call extraction.
These simulate the token-by-token streaming that vLLM performs,
feeding incremental text to extract_tool_calls_streaming() and
verifying that the accumulated argument deltas form valid JSON.
"""
def _simulate_streaming(
self, parser: Gemma4ToolParser, mock_request: Any, chunks: list[str]
) -> list[tuple[Any, str]]:
"""Feed chunks through the streaming parser and collect results.
Returns a list of (delta_message, accumulated_text) tuples.
"""
results: list[tuple[Any, str]] = []
previous_text: str = ""
previous_token_ids: list[int] = []
for chunk in chunks:
current_text = previous_text + chunk
# Use token ID 48 for tool_call start, 49 for end, 0 otherwise
delta_token_ids: list[int] = []
if TOOL_CALL_START in chunk:
delta_token_ids.append(48)
elif TOOL_CALL_END in chunk:
delta_token_ids.append(49)
else:
delta_token_ids.append(0)
current_token_ids = previous_token_ids + delta_token_ids
delta = parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=chunk,
previous_token_ids=tuple(previous_token_ids),
current_token_ids=tuple(current_token_ids),
delta_token_ids=tuple(delta_token_ids),
request=mock_request,
)
results.append((delta, current_text))
previous_text = current_text
previous_token_ids = list(current_token_ids)
return results
def _collect_arguments(self, results):
"""Collect all argument deltas from streaming results into one string."""
args_text = ""
for delta, _ in results:
if delta and delta.tool_calls:
for tc in delta.tool_calls:
func = tc.function if isinstance(tc.function, dict) else tc.function
if isinstance(func, dict):
arg = func.get("arguments", "")
else:
arg = getattr(func, "arguments", "") or ""
if arg:
args_text += arg
return args_text
def _collect_function_name(self, results):
"""Extract the function name from streaming results."""
for delta, _ in results:
if delta and delta.tool_calls:
for tc in delta.tool_calls:
func = tc.function if isinstance(tc.function, dict) else tc.function
if isinstance(func, dict):
name = func.get("name")
else:
name = getattr(func, "name", None)
if name:
return name
return None
def test_basic_streaming_single_tool(self, parser, mock_request):
"""Simulate the exact streaming scenario from the bug report.
Model generates:
<|tool_call>call:get_weather{location:<|"|>Paris, France<|"|>}<tool_call|>
Expected: arguments should be valid JSON {"location": "Paris, France"}
"""
chunks = [
"<|tool_call>",
"call:get_weather{",
'location:<|"|>Paris',
", France",
'<|"|>}',
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
# Verify function name
name = self._collect_function_name(results)
assert name == "get_weather", f"Expected 'get_weather', got '{name}'"
# Verify arguments form valid JSON
args_text = self._collect_arguments(results)
assert args_text, "No arguments were streamed"
parsed_args = json.loads(args_text)
assert parsed_args == {"location": "Paris, France"}
def test_streaming_multi_arg(self, parser, mock_request):
"""Streaming with multiple arguments."""
chunks = [
"<|tool_call>",
"call:get_weather{",
'location:<|"|>Tokyo<|"|>,',
'unit:<|"|>celsius<|"|>}',
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
name = self._collect_function_name(results)
assert name == "get_weather"
args_text = self._collect_arguments(results)
assert args_text
parsed_args = json.loads(args_text)
assert parsed_args == {"location": "Tokyo", "unit": "celsius"}
def test_streaming_no_extra_brace(self, parser, mock_request):
"""Verify the closing } is NOT leaked into arguments (Bug #2)."""
chunks = [
"<|tool_call>",
"call:get_weather{",
'location:<|"|>London<|"|>}',
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
args_text = self._collect_arguments(results)
assert args_text
# The args text must be valid JSON (no extra })
parsed = json.loads(args_text)
assert parsed == {"location": "London"}
# Specifically assert no double-brace
assert args_text.count("}") <= 1, (
f"Arguments contain extra closing brace: {args_text!r}"
)
def test_streaming_no_unquoted_keys(self, parser, mock_request):
"""Verify keys are properly quoted in JSON (Bug #1)."""
chunks = [
"<|tool_call>",
"call:get_weather{",
'location:<|"|>Paris<|"|>}',
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
args_text = self._collect_arguments(results)
# Must start with { and contain quoted key
assert args_text.lstrip().startswith("{"), (
f"Arguments don't start with '{{': {args_text!r}"
)
assert '"location"' in args_text, (
f"Key 'location' not properly quoted: {args_text!r}"
)
def test_streaming_name_no_call_prefix(self, parser, mock_request):
"""Verify function name has no 'call:' prefix."""
chunks = [
"<|tool_call>",
"call:get_weather{",
'location:<|"|>Paris<|"|>}',
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
name = self._collect_function_name(results)
assert name == "get_weather"
assert not name.startswith("call:"), f"Name has 'call:' prefix: {name!r}"
def test_streaming_text_before_tool_call(self, parser, mock_request):
"""Text before tool call should be emitted as content."""
chunks = [
"Let me check ",
"the weather. ",
"<|tool_call>",
"call:get_weather{",
'location:<|"|>London<|"|>}',
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
# First chunks should be content
content_parts = []
for delta, _ in results:
if delta and delta.content:
content_parts.append(delta.content)
assert "".join(content_parts).strip().startswith("Let me check")
def test_streaming_numeric_args(self, parser, mock_request):
"""Streaming with numeric and boolean argument values."""
chunks = [
"<|tool_call>",
"call:set_config{",
"count:42,",
"active:true}",
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
args_text = self._collect_arguments(results)
if args_text:
parsed_args = json.loads(args_text)
assert parsed_args["count"] == 42
assert parsed_args["active"] is True
def test_streaming_empty_args(self, parser, mock_request):
"""Tool call with no arguments."""
chunks = [
"<|tool_call>",
"call:get_status{}",
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
name = self._collect_function_name(results)
assert name == "get_status"
......@@ -12,6 +12,7 @@ from .dual_chunk_rope import DualChunkRotaryEmbedding
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding
from .fope import FourierRotaryEmbedding
from .gemma4_rope import Gemma4RotaryEmbedding
from .linear_scaling_rope import LinearScalingRotaryEmbedding
from .llama3_rope import Llama3RotaryEmbedding
from .llama4_vision_rope import Llama4VisionRotaryEmbedding
......@@ -134,6 +135,17 @@ def get_rope(
is_neox_style,
dtype,
)
elif scaling_type == "proportional":
# Proportional RoPE is used by Gemma4 for global (full) attention.
# Gemma4 uses a sparse/fractional RoPE with cross-mixing between halves.
rotary_emb = Gemma4RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
elif scaling_type == "llama3":
scaling_factor = rope_parameters["factor"]
low_freq_factor = rope_parameters["low_freq_factor"]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Gemma4-specific Rotary Positional Embeddings (proportional scaling).
Gemma4 uses "proportional" RoPE which computes inv_freq frequencies scaled
by head_dim (not rotary_dim), and zero-pads for non-rotated dimensions when
partial_rotary_factor < 1. The actual rotation uses standard neox-style
rotate_half, matching HF transformers' apply_rotary_pos_emb.
"""
import torch
from .base import RotaryEmbedding
class Gemma4RotaryEmbedding(RotaryEmbedding):
"""Gemma4 proportional RoPE.
Extends RotaryEmbedding (which provides standard neox-style rotation
via ops.rotary_embedding CUDA kernel) but overrides the inv_freq
computation to match HF's _compute_proportional_rope_parameters:
- Frequency exponents use head_dim (not rotary_dim) as denominator
- Non-rotated dims are zero-padded (cos=1, sin=0 = identity rotation)
When partial_rotary_factor=1.0 (the default for some variants), ALL dims are
rotated and this is equivalent to standard RotaryEmbedding with
head_dim-scaled frequencies.
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
# Number of rotation angle pairs (from partial_rotary_factor)
self.rope_angles = rotary_dim // 2
# Non-rotated angle pairs per half
self.nope_angles = (head_size // 2) - self.rope_angles
# Important: set rotary_dim = head_size so the base class's
# forward_static applies rotation to ALL dims of the cos/sin cache.
# The non-rotated dims will have cos=1, sin=0 (identity) thanks
# to our _compute_inv_freq zero-padding.
super().__init__(
head_size,
head_size, # rotary_dim = head_size (full application)
max_position_embeddings,
base,
is_neox_style,
dtype,
)
def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute frequencies matching HF proportional RoPE.
Key difference from base: exponent denominator is head_size (not
rotary_dim), and non-rotated dims are zero-padded.
"""
# HF formula: base ** (arange(0, 2*rope_angles, 2) / head_dim)
freq_exponents = (
torch.arange(0, 2 * self.rope_angles, 2, dtype=torch.float) / self.head_size
)
inv_freq = 1.0 / (base**freq_exponents)
# Zero-pad for non-rotated dims (identity rotation: cos=1, sin=0)
if self.nope_angles > 0:
inv_freq = torch.cat(
[
inv_freq,
torch.zeros(self.nope_angles, dtype=torch.float),
]
)
return inv_freq
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", rope_angles={self.rope_angles}, nope_angles={self.nope_angles}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
return s
......@@ -14,6 +14,7 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentio
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
logger = init_logger(__name__)
......@@ -57,6 +58,58 @@ class Gemma3TextModelConfig(VerifyAndUpdateConfig):
hf_config.is_causal = not hf_config.use_bidirectional_attention
class Gemma4Config(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
"""Force unified attention backend for models with heterogeneous
head dimensions.
Some Gemma4 variants use different head dimensions for
sliding window (head_dim) vs full attention (global_head_dim) layers.
When global_head_dim > 256, FlashAttention rejects those layers
(head_size <= 256 kernel limit), causing vLLM to select a different
backend for each layer type. This mixed-backend execution produces
numerical divergence and output corruption.
The fix detects heterogeneous head dimensions from the model config
and forces TRITON_ATTN (which has no head_size ceiling) for all
layers when the user hasn't explicitly chosen a backend.
TODO: Heterogeneous head_sizes (head_dim != global_head_dim)
require NixlConnector changes to support per-layer KV transfer
with different head dimensions for prefill-decode disaggregation.
"""
hf_text_config = vllm_config.model_config.hf_text_config
head_dim = getattr(hf_text_config, "head_dim", None)
global_head_dim = getattr(hf_text_config, "global_head_dim", None)
# Only force Triton when head dimensions actually differ AND the
# larger one exceeds FlashAttention's kernel limit (head_size <= 256).
# This avoids unnecessary backend forcing on smaller models where
# the config carries global_head_dim but all layers can still use
# the same FA backend.
max_head_dim = max(head_dim or 0, global_head_dim or 0)
if (
head_dim is not None
and global_head_dim is not None
and head_dim != global_head_dim
and max_head_dim > 256
and vllm_config.attention_config.backend is None
):
from vllm.v1.attention.backends.registry import (
AttentionBackendEnum,
)
vllm_config.attention_config.backend = AttentionBackendEnum.TRITON_ATTN
logger.info(
"Gemma4 model has heterogeneous head dimensions "
"(head_dim=%d, global_head_dim=%d). Forcing TRITON_ATTN "
"backend to prevent mixed-backend numerical divergence.",
head_dim,
global_head_dim,
)
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
......@@ -668,6 +721,8 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501
"FalconMambaForCausalLM": MambaModelConfig,
"Gemma3TextModel": Gemma3TextModelConfig,
"Gemma4ForCausalLM": Gemma4Config,
"Gemma4ForConditionalGeneration": Gemma4Config,
"GptOssForCausalLM": GptOssForCausalLMConfig,
"GteModel": SnowflakeGteNewModelConfig,
"GteNewForSequenceClassification": GteNewModelConfig,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The vLLM team.
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gemma 4 model implementation for vLLM."""
from collections.abc import Iterable
from itertools import islice
import regex as re
import torch
from torch import nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
extract_layer_index,
is_pp_missing_parameter,
make_layers,
maybe_prefix,
)
logger = init_logger(__name__)
def _get_text_config(config):
"""Dereference text_config if config is a nested Gemma4Config.
Gemma4 checkpoints use architectures=["Gemma4ForConditionalGeneration"]
which yields a Gemma4Config with nested text_config. This function
transparently returns the text config regardless of nesting.
"""
if hasattr(config, "text_config"):
return config.text_config
return config
class Gemma4MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_activation: str,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_activation != "gelu_pytorch_tanh":
raise ValueError(
"Gemma4 uses `gelu_pytorch_tanh` as the hidden activation "
"function. Please set `hidden_act` and `hidden_activation` to "
"`gelu_pytorch_tanh`."
)
self.act_fn = GeluAndMul(approximate="tanh")
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Gemma4Router(nn.Module):
"""Router for Gemma4 MoE that preprocesses input before projection.
Applies RMSNorm (no learned weight), root_size scaling
(hidden_size^{-0.5}), then a learned per-dimension scale before
projecting to expert logits.
This preprocessing is applied ONLY to the router's input, not to
the expert MLPs' input.
"""
def __init__(
self,
config,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# RMSNorm without learned weight — pure normalization only
self.norm = RMSNorm(self.hidden_size, eps=config.rms_norm_eps, has_weight=False)
# Per-dimension learned scale, applied after norm + root_size
self.scale = nn.Parameter(torch.ones(self.hidden_size))
# Constant 1/sqrt(hidden_size) scaling factor
self.register_buffer(
"root_size",
torch.tensor(self.hidden_size**-0.5),
persistent=False,
)
# Project to expert logits; replicated across TP for consistent routing
# GateLinear supports bf16 W/A → fp32 output, which is important
# because the topk kernel often needs fp32 for stable routing.
self.proj = GateLinear(
self.hidden_size,
config.num_experts,
bias=False,
out_dtype=torch.float32,
prefix=f"{prefix}.proj",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Returns raw router logits [T, E]."""
x = self.norm(x)
x = x * self.root_size.to(x.dtype)
x = x * self.scale.to(x.dtype)
router_logits, _ = self.proj(x)
return router_logits
class Gemma4MoE(nn.Module):
"""Mixture of Experts for Gemma4 using vLLM's FusedMoE.
Wraps FusedMoE with custom routing. The router projection is
external (Gemma4Router) — this class only handles expert dispatch.
Gemma4 routing: softmax over ALL experts → top-k → renormalize.
per_expert_scale is folded into routing weights for mathematical
correctness with FusedMoE's fused kernel.
"""
def __init__(
self,
config,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.num_experts = config.num_experts
# Per-expert output scale folded into routing weights so that
# FusedMoE's fused kernel computes: Σ_e (expert_e * w_e * scale_e)
self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts))
# Gemma4 routing: softmax over ALL experts → top-k → renormalize.
# FusedMoE's built-in fused_topk scopes softmax differently, so
# a custom routing function is needed for numerical correctness.
per_expert_scale = self.per_expert_scale
def routing_function(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
_, topk_ids = torch.topk(gating_output, k=topk, dim=-1)
router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1)
indicator = torch.nn.functional.one_hot(
topk_ids, num_classes=gating_output.size(-1)
).sum(dim=-2)
gate_weights = indicator * router_probabilities
renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True)
renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0)
dispatch_weights = gate_weights / renorm_factor
topk_weights = dispatch_weights.gather(1, topk_ids)
# Fold per_expert_scale into routing weights
expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype)
topk_weights = topk_weights * expert_scales
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
# FusedMoE experts with custom Gemma4 routing
self.experts = FusedMoE(
num_experts=config.num_experts,
top_k=config.top_k_experts,
hidden_size=config.hidden_size,
intermediate_size=getattr(
config,
"moe_intermediate_size",
getattr(config, "expert_intermediate_size", None),
),
reduce_results=True,
renormalize=True,
quant_config=quant_config,
prefix=f"{prefix}.experts",
custom_routing_function=routing_function,
activation="gelu",
)
def forward(self, x: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor:
return self.experts(x, router_logits)
class Gemma4Attention(nn.Module):
def __init__(
self,
config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_position_embeddings: int,
use_k_eq_v: bool = False,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
attn_logits_soft_cap: float | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.hidden_size = hidden_size
self.use_k_eq_v = use_k_eq_v
tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
# Gemma4 uses scaling=1.0.
# Unlike Gemma2/3, query_pre_attn_scalar is NOT used here;
# Q/K norms with learnable weights handle scaling implicitly.
self.scaling = 1.0
# QKVParallelLinear handles GQA correctly for all layer types.
# k_eq_v layers load K weights into both K and V slots via
# _weight_iterator remapping — no structural difference needed.
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
# Q/K norms: output = norm(x) * weight (learnable per-head scale)
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
# V norm: no learnable scale (pure normalization only)
self.v_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, has_weight=False)
# Determine layer type and sliding window
layer_idx = extract_layer_index(prefix)
layer_type = config.layer_types[layer_idx]
self.is_sliding = layer_type == "sliding_attention"
sliding_window = config.sliding_window if self.is_sliding else None
# Initialize RoPE based on layer type.
# Gemma4 uses different RoPE parameters for sliding vs full attention.
if layer_type in config.rope_parameters:
# Per-layer-type rope config (dict format).
# rope_parameters already contains the correct
# partial_rotary_factor per layer type (1.0 for full
# attention, 1.0 for sliding). Do NOT override with
# global_partial_rotary_factor — that config key is
# not needed for Gemma4 — config uses per-layer rope_parameters.
rope_parameters = dict(config.rope_parameters[layer_type])
else:
# Legacy config format fallback.
rope_parameters = dict(config.rope_parameters.copy())
if self.is_sliding:
rope_parameters["rope_theta"] = getattr(
config, "rope_local_base_freq", 10000.0
)
# KV sharing: layers in the last `num_kv_shared_layers` share KV
# cache with earlier layers of the same type.
kv_sharing_target_layer_name = None
self.is_kv_shared_layer = False
num_kv_shared_layers = getattr(config, "num_kv_shared_layers", 0)
if num_kv_shared_layers > 0:
first_kv_shared_layer_idx = config.num_hidden_layers - num_kv_shared_layers
if layer_idx >= first_kv_shared_layer_idx:
self.is_kv_shared_layer = True
# Find the last non-shared layer of the same attention type
prev_layers = config.layer_types[:first_kv_shared_layer_idx]
current_layer_type = config.layer_types[layer_idx]
kv_shared_layer_index = (
len(prev_layers) - 1 - prev_layers[::-1].index(current_layer_type)
)
if kv_shared_layer_index >= 0:
if ".layers." in prefix:
param_name_before_layers = prefix.split(".layers.")[0]
else:
raise ValueError(
"Unexpected prefix format for Gemma4Attention: "
f"'{prefix}'. Expected to contain '.layers.'."
)
kv_sharing_target_layer_name = (
f"{param_name_before_layers}.layers."
f"{kv_shared_layer_index}.self_attn.attn"
)
self.rotary_emb = get_rope(
self.head_dim,
max_position=max_position_embeddings,
rope_parameters=rope_parameters,
is_neox_style=True,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap,
per_layer_sliding_window=sliding_window,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
prefix=f"{prefix}.attn",
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
**kwargs,
) -> torch.Tensor:
# Unified QKV path (works for both k_eq_v and standard layers).
# For k_eq_v, K weights are loaded into both K and V slots of
# qkv_proj, so V == K automatically.
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Q norm (always applied)
q = q.unflatten(-1, (self.num_heads, self.head_dim))
q = self.q_norm(q)
q = q.flatten(-2, -1)
if not self.is_kv_shared_layer:
# Non-shared: apply K norm + RoPE, V norm
k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
k = self.k_norm(k)
k = k.flatten(-2, -1)
q, k = self.rotary_emb(positions, q, k)
v = v.unflatten(-1, (self.num_kv_heads, self.head_dim))
v = self.v_norm(v)
v = v.flatten(-2, -1)
else:
# Shared: only apply RoPE to Q
q = self.rotary_emb(positions, q, k)[0]
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Gemma4DecoderLayer(nn.Module):
def __init__(
self,
config,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.hidden_size_per_layer_input = getattr(
config, "hidden_size_per_layer_input", 0
)
layer_idx = extract_layer_index(prefix)
self.layer_idx = layer_idx
# Gemma4 uses different head dimensions for sliding vs full attention
layer_type = config.layer_types[layer_idx]
self.is_full_attention = layer_type == "full_attention"
if self.is_full_attention:
head_dim = getattr(config, "global_head_dim", config.head_dim)
else:
head_dim = config.head_dim
# Determine if this full-attention layer uses k_eq_v
# (laptop variant: no v_proj, K reused as V on full attention layers)
use_k_eq_v = self.is_full_attention and getattr(
config, "attention_k_eq_v", False
)
# For k_eq_v full-attention layers, use num_global_key_value_heads
# as the KV head count when k_eq_v is enabled.
if use_k_eq_v:
num_kv_heads = getattr(
config, "num_global_key_value_heads", config.num_key_value_heads
)
else:
num_kv_heads = config.num_key_value_heads
self.self_attn = Gemma4Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
max_position_embeddings=config.max_position_embeddings,
use_k_eq_v=use_k_eq_v,
cache_config=cache_config,
quant_config=quant_config,
attn_logits_soft_cap=getattr(config, "attn_logit_softcapping", None),
prefix=f"{prefix}.self_attn",
)
# Compute per-layer intermediate_size from config.
# When use_double_wide_mlp is set, intermediate_size doubles for
# KV-shared layers (layers >= first_kv_shared_layer_idx).
first_kv_shared_layer_idx = config.num_hidden_layers - getattr(
config, "num_kv_shared_layers", 0
)
is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
use_double_wide_mlp = (
getattr(config, "use_double_wide_mlp", False) and is_kv_shared_layer
)
layer_intermediate_size = config.intermediate_size * (
2 if use_double_wide_mlp else 1
)
self.mlp = Gemma4MLP(
hidden_size=self.hidden_size,
intermediate_size=layer_intermediate_size,
hidden_activation=config.hidden_activation,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
# Layer norms: output = norm(x) * weight
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.pre_feedforward_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_feedforward_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
# MoE (Mixture of Experts) — router + expert block parallel to MLP
self.enable_moe_block = getattr(config, "enable_moe_block", False) or getattr(
config, "use_second_mlp_block", False
)
if self.enable_moe_block:
self.router = Gemma4Router(
config,
quant_config=quant_config,
prefix=f"{prefix}.router",
)
self.moe = Gemma4MoE(
config,
quant_config=quant_config,
prefix=f"{prefix}.moe",
)
self.post_feedforward_layernorm_1 = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_feedforward_layernorm_2 = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.pre_feedforward_layernorm_2 = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
else:
self.router = None
self.moe = None
self.post_feedforward_layernorm_1 = None
self.post_feedforward_layernorm_2 = None
self.pre_feedforward_layernorm_2 = None
# Per-Layer Embedding (PLE) components — present in each decoder layer
if (
self.hidden_size_per_layer_input is not None
and self.hidden_size_per_layer_input > 0
):
# Gate: projects hidden_states → per-layer dim for gating
self.per_layer_input_gate = ReplicatedLinear(
self.hidden_size,
self.hidden_size_per_layer_input,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.per_layer_input_gate",
return_bias=False,
)
# Projection: projects gated per-layer input back → hidden size
self.per_layer_projection = ReplicatedLinear(
self.hidden_size_per_layer_input,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.per_layer_projection",
return_bias=False,
)
# Post-PLE norm: output = norm(x) * weight
self.post_per_layer_input_norm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
else:
self.per_layer_input_gate = None
self.per_layer_projection = None
self.post_per_layer_input_norm = None
# Layer scalar (loaded from checkpoint) — applies to ALL text layers
self.register_buffer("layer_scalar", torch.ones(1))
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
per_layer_input: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
# Gemma4 residual pattern:
# 1. input_norm(x) → attn → post_attn_norm → ADD residual
# 2. pre_ff_norm → mlp → post_ff_norm → ADD residual
residual = hidden_states
hidden_states = self.input_layernorm(residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = hidden_states + residual
residual = hidden_states
# MLP runs unconditionally (same inputs for MoE and non-MoE)
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if self.enable_moe_block:
hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states)
# Router and MoE experts see the residual (pre-MLP state),
# matching the HF transformers forward path
router_logits = self.router(residual)
hidden_states_2 = self.pre_feedforward_layernorm_2(residual)
hidden_states_2 = self.moe(hidden_states_2, router_logits)
hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2)
# Combine MLP and MoE outputs
hidden_states = hidden_states_1 + hidden_states_2
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = hidden_states + residual
# Apply PLE (Per-Layer Embedding) if configured
if per_layer_input is not None and self.per_layer_input_gate is not None:
gate = self.per_layer_input_gate(hidden_states)
gate = torch.nn.functional.gelu(gate, approximate="tanh")
gated_per_layer = gate * per_layer_input
per_layer_contribution = self.per_layer_projection(gated_per_layer)
per_layer_contribution = self.post_per_layer_input_norm(
per_layer_contribution
)
hidden_states = hidden_states + per_layer_contribution
# Apply layer scalar for full-attention layers
# Apply per-layer scalar (all text layers)
hidden_states = hidden_states * self.layer_scalar
return hidden_states, None
@support_torch_compile
class Gemma4Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = _get_text_config(vllm_config.model_config.hf_config)
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
# PLE config values (default to 0 if not present — disables PLE)
self.hidden_size_per_layer_input = getattr(
config, "hidden_size_per_layer_input", 0
)
self.vocab_size_per_layer_input = getattr(
config, "vocab_size_per_layer_input", config.vocab_size
)
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
# Per-Layer Embedding (PLE) components
if (
self.hidden_size_per_layer_input is not None
and self.hidden_size_per_layer_input > 0
):
total_ple_dim = self.hidden_size_per_layer_input * config.num_hidden_layers
self.embed_tokens_per_layer = VocabParallelEmbedding(
self.vocab_size_per_layer_input,
total_ple_dim,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens_per_layer",
)
# Scaled embedding factor (from config, not hardcoded)
# Register as buffer so it moves to GPU with the model
# and interacts correctly with torch.compile AOT caching.
self.register_buffer(
"embed_scale_per_layer",
torch.tensor(self.hidden_size_per_layer_input**0.5),
persistent=False,
)
# Projection: hidden_size → total_ple_dim
# ColumnParallelLinear with gather_output=True
self.per_layer_model_projection = ColumnParallelLinear(
config.hidden_size,
total_ple_dim,
bias=False,
gather_output=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.per_layer_model_projection",
)
# PLE projection norm: output = norm(x) * weight
self.per_layer_projection_norm = RMSNorm(
self.hidden_size_per_layer_input,
eps=config.rms_norm_eps,
)
# Scale factor for combining projection + per_layer_inputs
# Register as buffer so it moves to GPU with the model
# and interacts correctly with torch.compile AOT caching.
self.register_buffer(
"per_layer_input_scale",
torch.rsqrt(torch.tensor(2.0)),
persistent=False,
)
# Scaled projection: multiply output by hidden_size**-0.5.
# Register as buffer for GPU placement and torch.compile.
self.register_buffer(
"per_layer_projection_scale",
torch.tensor(config.hidden_size**-0.5),
persistent=False,
)
else:
self.embed_tokens_per_layer = None
self.embed_scale_per_layer = None
self.per_layer_model_projection = None
self.per_layer_projection_norm = None
self.per_layer_input_scale = None
self.per_layer_projection_scale = None
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Gemma4DecoderLayer(
config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
# Final norm: output = norm(x) * weight
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Embedding scale = sqrt(hidden_size)
# Downcast to model dtype (bfloat16 etc.) for numerical parity
self.register_buffer(
"normalizer",
torch.tensor(config.hidden_size**0.5),
persistent=False,
)
# Custom factory that includes per_layer_inputs for PLE-enabled PP.
# per_layer_inputs has shape (batch, num_layers, per_layer_dim),
# which differs from the standard (batch, hidden_size) shape,
# so we can't use the default factory.
ple_dim = self.hidden_size_per_layer_input
num_layers = config.num_hidden_layers
hidden_size = config.hidden_size
def _make_empty_intermediate_tensors(
batch_size: int,
dtype: torch.dtype,
device: torch.device,
) -> IntermediateTensors:
tensors: dict[str, torch.Tensor] = {
"hidden_states": torch.zeros(
(batch_size, hidden_size),
dtype=dtype,
device=device,
),
"residual": torch.zeros(
(batch_size, hidden_size),
dtype=dtype,
device=device,
),
}
if ple_dim and ple_dim > 0:
tensors["per_layer_inputs"] = torch.zeros(
(batch_size, num_layers, ple_dim),
dtype=dtype,
device=device,
)
return IntermediateTensors(tensors)
self.make_empty_intermediate_tensors = _make_empty_intermediate_tensors
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.normalizer
def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Get per-layer embeddings from embed_tokens_per_layer.
Returns:
Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input)
"""
if self.embed_tokens_per_layer is None:
return None
# Handle out-of-vocab tokens for PLE (vocab_size_per_layer_input may
# be smaller than the main vocab_size).
per_layer_inputs_mask = torch.logical_and(
input_ids >= 0,
input_ids < self.vocab_size_per_layer_input,
)
per_layer_inputs_tokens = torch.where(
per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
)
# Get packed per-layer embeddings: (num_tokens, total_ple_dim)
per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens)
# Apply embed_scale (sqrt of per-layer hidden dim)
per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input)
per_layer_embeds = per_layer_embeds.reshape(
*input_ids.shape,
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
return per_layer_embeds
def project_per_layer_inputs(
self,
inputs_embeds: torch.Tensor,
per_layer_inputs: torch.Tensor | None,
) -> torch.Tensor:
"""Project inputs_embeds and combine with per_layer_inputs.
Steps:
1. Project inputs_embeds: hidden_size → total_ple_dim
2. Scale by hidden_size^{-0.5}
3. Reshape to (num_tokens, num_layers, per_layer_dim)
4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
"""
if self.per_layer_model_projection is None:
return None
# Project from hidden_size to total_ple_dim
# Scaled projection: output = linear(input, weight) * scale
per_layer_projection = self.per_layer_model_projection(inputs_embeds)
per_layer_projection = per_layer_projection * self.per_layer_projection_scale
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input)
per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1],
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
# Normalize
per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
if per_layer_inputs is None:
return per_layer_projection
# Combine: (projection + per_layer_inputs) * scale
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
# When called from the multimodal wrapper, raw PLE
# embeddings are pre-computed and passed explicitly.
# Project them through per_layer_model_projection.
per_layer_inputs = self.project_per_layer_inputs(
hidden_states, per_layer_inputs
)
else:
hidden_states = self.embed_input_ids(input_ids)
# Compute per-layer inputs for PLE
per_layer_embeds = self.get_per_layer_inputs(input_ids)
per_layer_inputs = self.project_per_layer_inputs(
hidden_states, per_layer_embeds
)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
per_layer_inputs = intermediate_tensors.get("per_layer_inputs")
for layer_idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
# Extract the per-layer embedding for this specific layer
if per_layer_inputs is not None:
actual_layer_idx = self.start_layer + layer_idx
layer_per_input = per_layer_inputs[
:, actual_layer_idx, :
] # (num_tokens, per_layer_dim)
else:
layer_per_input = None
hidden_states, residual = layer(
positions,
hidden_states,
residual,
per_layer_input=layer_per_input,
**kwargs,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{
"hidden_states": hidden_states,
"residual": residual,
"per_layer_inputs": per_layer_inputs,
}
)
# Gemma4 incorporates residual into hidden_states directly
# Apply norm without residual fusion when possible.
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# MoE expert weight mapping: checkpoint 3D packed tensors are
# exploded in _weight_iterator to per-expert 2D weights like:
# moe.experts.{id}.gate_proj → FusedMoE w1 (shard of w13)
# moe.experts.{id}.up_proj → FusedMoE w3 (shard of w13)
# moe.experts.{id}.down_proj → FusedMoE w2
# We build the mapping directly since Gemma4 uses bare param
# names (no .weight suffix) unlike standard MoE checkpoints.
num_experts = getattr(self.config, "num_experts", None) or 0
expert_params_mapping = [
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_weight"
if proj_name in ["gate_proj", "up_proj"]
else "experts.w2_weight",
f"experts.{expert_id}.{proj_name}",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id, proj_name in [
("w1", "gate_proj"),
("w2", "down_proj"),
("w3", "up_proj"),
]
]
params_dict = dict(self.named_parameters())
# Include buffers (e.g. layer_scalar) so they can be loaded too
params_dict.update(dict(self.named_buffers()))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is not None and remapped_name in params_dict:
param = params_dict[remapped_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(remapped_name)
continue
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
stacked_name = name.replace(shard_name, param_name)
# k_eq_v layers use separate q_proj/k_proj instead of
# packed qkv_proj. If the stacked param doesn't exist,
# skip this mapping and fall through to direct load.
if stacked_name not in params_dict:
continue
if is_pp_missing_parameter(stacked_name, self):
continue
param = params_dict[stacked_name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(stacked_name)
break
else:
for (
param_name,
weight_name,
expert_id,
shard_id,
) in expert_params_mapping:
if weight_name not in name:
continue
moe_name = name.replace(weight_name, param_name)
if moe_name not in params_dict:
continue
if is_pp_missing_parameter(moe_name, self):
continue
param = params_dict[moe_name]
# Expert weights are already in the correct
# orientation for FusedMoE after _weight_iterator:
# gate/up: [I, H] → w1/w3 expects [I, H]
# down: [H, I] → w2 expects [H, I]
assert loaded_weight.dim() == 2, (
f"Expected 2D expert weight for {weight_name}, "
f"got shape {loaded_weight.shape}"
)
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
weight_name + ".weight",
shard_id=shard_id,
expert_id=expert_id,
)
loaded_params.add(moe_name)
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
# Note: qkv_proj packing applies to non-k_eq_v layers (sliding
# attention and full attention without k_eq_v). k_eq_v layers use
# separate q_proj + k_proj without packing.
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = _get_text_config(vllm_config.model_config.hf_config)
quant_config = vllm_config.quant_config
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = Gemma4Model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
self.logits_processor = LogitsProcessor(
config.vocab_size,
soft_cap=getattr(config, "final_logit_softcapping", None),
)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
# --- MixtureOfExperts protocol ---
self.expert_weights: list[list[torch.Tensor]] = []
self.moe_layers: list[nn.Module] = []
example_moe: Gemma4MoE | None = None
for layer in self.model.layers:
if hasattr(layer, "moe") and isinstance(layer.moe, Gemma4MoE):
example_moe = layer.moe
self.moe_layers.append(layer.moe.experts)
self.num_moe_layers = len(self.moe_layers)
if example_moe is not None:
self.num_logical_experts = example_moe.num_experts
self.num_physical_experts = example_moe.num_experts
self.num_local_physical_experts = example_moe.num_experts
self.num_routed_experts = example_moe.num_experts
else:
self.num_logical_experts = 0
self.num_physical_experts = 0
self.num_local_physical_experts = 0
self.num_routed_experts = 0
self.num_expert_groups = 1
self.num_shared_experts = 0
self.num_redundant_experts = 0
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.logits_processor(self.lm_head, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# Checkpoint weight names use "language_model." prefix (from the
# Gemma4ForConditionalGeneration wrapper). Strip it to map to our
# model tree which is just "model.*".
def _weight_iterator():
use_k_eq_v = getattr(self.config, "attention_k_eq_v", False)
# Build set of k_eq_v layer indices (full_attention layers
# when attention_k_eq_v is enabled). These layers have k_proj
# but no v_proj in checkpoint — we duplicate k_proj as v_proj.
k_eq_v_layer_indices: set[int] = set()
if use_k_eq_v:
for idx, lt in enumerate(self.config.layer_types):
if lt == "full_attention":
k_eq_v_layer_indices.add(idx)
for name, weight in weights:
# Remap "language_model." → "" to match our model tree.
# Checkpoint: model.language_model.layers.X.*
# Our model: model.layers.X.*
name = name.replace("language_model.", "")
# Remap new HF checkpoint naming to internal vLLM
# naming: HF moved per_expert_scale to router and
# renamed moe → experts in the MoE block.
name = name.replace(
".router.per_expert_scale",
".moe.per_expert_scale",
)
if ".experts.gate_up_proj" in name:
name = name.replace(
".experts.gate_up_proj",
".moe.gate_up_proj",
)
elif ".experts.down_proj" in name:
name = name.replace(
".experts.down_proj",
".moe.down_proj",
)
# MoE expert weights: checkpoint stores as 3D packed
# tensors. Explode into per-expert 2D weights for
# FusedMoE weight_loader.
#
# Checkpoint format:
# moe.gate_up_proj: [E, 2*I, H] (fused gate + up)
# moe.down_proj: [E, H, I]
#
# FusedMoE expects per-expert:
# w1 (gate): [I, H] — first half of gate_up
# w3 (up): [I, H] — second half of gate_up
# w2 (down): [H, I] — as-is from checkpoint
#
# No transpose needed: checkpoint orientation already
# matches FusedMoE's expected layout.
if "moe.gate_up_proj" in name and weight.dim() == 3:
num_experts = weight.size(0)
intermediate_size = weight.size(1) // 2
for expert_id in range(num_experts):
gate_weight = weight[expert_id, :intermediate_size, :]
up_weight = weight[expert_id, intermediate_size:, :]
base = name.replace("moe.", f"moe.experts.{expert_id}.")
yield base.replace("gate_up_proj", "gate_proj"), gate_weight
yield base.replace("gate_up_proj", "up_proj"), up_weight
continue
if "moe.down_proj" in name and weight.dim() == 3:
num_experts = weight.size(0)
for expert_id in range(num_experts):
expert_name = name.replace("moe.", f"moe.experts.{expert_id}.")
yield expert_name, weight[expert_id]
continue
# k_eq_v layers: checkpoint has k_proj but no v_proj.
# QKVParallelLinear expects both, so duplicate k_proj
# as v_proj so V gets identical weights to K.
# ONLY for full_attention layers — sliding layers have
# their own real v_proj weights.
if "self_attn.k_proj" in name and k_eq_v_layer_indices:
m = re.search(r"layers\.(\d+)\.", name)
if m and int(m.group(1)) in k_eq_v_layer_indices:
yield name, weight
yield name.replace("k_proj", "v_proj"), weight.clone()
continue
yield name, weight
# Skip multimodal weights — handled by the multimodal wrapper.
# Also skip lm_head when weights are tied.
skip = [
"audio_tower.",
"vision_tower.",
"embed_audio.",
"embed_vision.",
]
if self.config.tie_word_embeddings:
skip.append("lm_head.")
loader = AutoWeightsLoader(self, skip_substrs=skip)
return loader.load_weights(_weight_iterator())
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Gemma 4 multimodal model (image + audio + video support).
Adds vision tower, audio tower, and multimodal embedders on top of the
text-only Gemma4ForCausalLM. The vision/audio encoders are loaded via
AutoModel.from_config and run in eager mode while the language model uses
the vLLM-optimized path.
Video support: Gemma4 does **not** have a native video tower. Videos are
decomposed into timestamped image frames (up to 32 frames at 70 soft tokens
each) and fed through the same vision tower as regular images. The
processor inserts ``mm:ss`` timestamps between frames so the model can
reason about temporal order.
"""
import math
import sys
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal
import numpy as np
import torch
from PIL import Image as PILImage
from torch import nn
from transformers import AutoModel, BatchFeature
from transformers.models.gemma4 import (
Gemma4Config,
Gemma4Processor,
Gemma4VisionConfig,
)
from transformers.models.gemma4.configuration_gemma4 import (
Gemma4AudioConfig,
Gemma4TextConfig,
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.inputs import MultiModalDataDict
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.gemma4 import Gemma4ForCausalLM
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalFieldConfig,
MultiModalKwargsItems,
VideoItem,
)
from vllm.multimodal.parse import (
AudioProcessorItems,
ImageProcessorItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (
AutoWeightsLoader,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
)
logger = init_logger(__name__)
# Video constants — match transformers Gemma4VideoProcessor defaults.
_VIDEO_MAX_SOFT_TOKENS = 70 # soft tokens per video frame (vs 280 for images)
_VIDEO_MAX_FRAMES = 32 # max sampled frames per video
# ---------------------------------------------------------------------------
# Input schema
# ---------------------------------------------------------------------------
class Gemma4ImagePixelInputs(TensorSchema):
"""
Pre-patchified image inputs from the Gemma4 image processor.
Dimensions:
- bn: Batch size * number of images
- np: Number of patches (max_patches = max_soft_tokens * pooling_kernel_size²)
- pp: Patch pixels (patch_size² * 3)
The HF Gemma4ImageProcessor outputs pixel_values as
(batch, max_patches, patch_pixels) — already patchified with
zero-padding for patches beyond the real image content.
pixel_position_ids provides (x, y) coordinates per patch,
with (-1, -1) for padding patches.
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[
torch.Tensor,
TensorShape("bn", "np", "pp"),
]
pixel_position_ids: Annotated[
torch.Tensor,
TensorShape("bn", "np", 2),
]
class Gemma4AudioInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of audios
- s: Sequence length (MEL spectrogram frames)
- f: Number of features (MEL bins)
"""
type: Literal["audio"] = "audio"
input_features_padded: Annotated[torch.Tensor, TensorShape("bn", "s", "f")]
input_features_mask: Annotated[torch.Tensor, TensorShape("bn", "s")]
Gemma4ImageInputs = Gemma4ImagePixelInputs
class Gemma4VideoInputs(TensorSchema):
"""Video frame inputs — same tensor format as image inputs.
Gemma4 has no separate video tower; video frames are processed
through the vision tower at lower resolution (max_soft_tokens=70).
"""
type: Literal["pixel_values_videos"] = "pixel_values_videos"
pixel_values_videos: Annotated[
torch.Tensor,
TensorShape("bn", "np", "pp"),
]
pixel_position_ids_videos: Annotated[
torch.Tensor,
TensorShape("bn", "np", 2),
]
# ---------------------------------------------------------------------------
# Processing info
# ---------------------------------------------------------------------------
class Gemma4ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Gemma4Config)
def get_default_tok_params(self):
"""Gemma4's chat template already embeds a literal ``<bos>`` token in
the rendered text. If ``add_special_tokens=True`` (the base-class
default), the tokenizer prepends *another* BOS, producing a
``[2, 2, ...]`` double-BOS sequence that the model was not trained on.
Setting ``add_special_tokens=False`` here prevents the duplicate and
ensures both ``llm.generate()`` and the chat/completions API behave
correctly.
"""
params = super().get_default_tok_params()
params = params.with_kwargs(add_special_tokens=False)
return params
def get_hf_processor(self, **kwargs: object) -> Gemma4Processor:
return self.ctx.get_hf_processor(
Gemma4Processor,
**kwargs,
)
def validate_num_items(self, modality: str, num_items: int) -> None:
if (
modality == "audio"
and num_items > 0
and self.get_hf_config().audio_config is None
):
model = self.ctx.model_config.model
raise ValueError(
f"Audio input was provided but the model "
f"'{model}' does not have an audio tower. "
f"Audio inference is only supported for Gemma4 "
f"models that include an audio_config "
f"(i.e., models that include an audio_config)."
)
super().validate_num_items(modality, num_items)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
limits: dict[str, int | None] = {"image": None}
if self.get_hf_config().audio_config is not None:
limits["audio"] = None
limits["video"] = None
return limits
def get_mm_max_tokens_per_item(
self, seq_len: int, mm_counts: Mapping[str, int]
) -> Mapping[str, int] | None:
config = self.get_hf_config()
# Upper bound: the pooler outputs default_output_length slots
# per image (280). After padding is stripped the actual count
# is ≤ this value, but vLLM needs the max for memory planning.
tokens_per_image = config.vision_config.default_output_length
tokens: dict[str, int] = {"image": tokens_per_image}
if config.audio_config is not None:
# Audio max tokens from the processor's audio_seq_length.
processor = self.get_hf_processor()
tokens["audio"] = processor.audio_seq_length
# Video: each frame ≤ 70 soft tokens + boi + eoi + ~6 ts tokens.
tokens["video"] = _VIDEO_MAX_FRAMES * (_VIDEO_MAX_SOFT_TOKENS + 2 + 6)
return tokens
def get_data_parser(self) -> MultiModalDataParser:
config = self.get_hf_config()
kwargs: dict[str, Any] = {"video_needs_metadata": True}
if getattr(config, "audio_config", None) is not None:
processor = self.get_hf_processor()
kwargs["target_sr"] = processor.feature_extractor.sampling_rate
return MultiModalDataParser(**kwargs)
def _compute_num_soft_tokens(
self,
image_width: int,
image_height: int,
max_soft_tokens: int | None = None,
) -> int:
"""Compute the number of soft tokens the vision tower produces
for an image of the given dimensions, after padding is stripped.
Args:
max_soft_tokens: Override for the vision config's
``default_output_length``. When *None*, the value from
the model config is used.
"""
vision_cfg = self.get_hf_config().vision_config
patch_size = vision_cfg.patch_size
pooling_kernel_size = vision_cfg.pooling_kernel_size
if max_soft_tokens is None:
max_soft_tokens = vision_cfg.default_output_length
unit = patch_size * pooling_kernel_size
max_patches = max_soft_tokens * pooling_kernel_size**2
num_patches_orig = (image_height / patch_size) * (image_width / patch_size)
scale = math.sqrt(max_patches / num_patches_orig)
target_h = max(unit, int(math.floor(image_height * scale / unit)) * unit)
target_w = max(unit, int(math.floor(image_width * scale / unit)) * unit)
num_patches = (target_h // patch_size) * (target_w // patch_size)
return num_patches // (pooling_kernel_size**2)
def get_image_repl(
self,
*,
image_width: int,
image_height: int,
processor: Gemma4Processor | None,
max_soft_tokens: int | None = None,
) -> PromptUpdateDetails[list[int]]:
"""Return the dynamic image token sequence for this image.
Computes the exact number of soft tokens the vision tower will
produce after stripping padding.
Args:
max_soft_tokens: Override for the default token budget.
When *None*, falls back to the model config value.
"""
if processor is None:
processor = self.get_hf_processor()
num_soft = self._compute_num_soft_tokens(
image_width,
image_height,
max_soft_tokens=max_soft_tokens,
)
config = self.get_hf_config()
token_ids = (
[config.boi_token_id]
+ [processor.image_token_id] * num_soft
+ [config.eoi_token_id]
)
return PromptUpdateDetails.select_token_id(token_ids, processor.image_token_id)
def get_audio_repl(
self,
*,
audio_len: int,
processor: Gemma4Processor | None,
) -> PromptUpdateDetails[list[int]]:
"""Return the dynamic audio token sequence for this audio.
Computes the number of soft tokens from the audio waveform
length using ``ceil(duration_ms / audio_ms_per_token)``.
"""
if processor is None:
processor = self.get_hf_processor()
sampling_rate = processor.feature_extractor.sampling_rate
num_tokens = processor._compute_audio_num_tokens(
torch.zeros(audio_len), sampling_rate
)
config = self.get_hf_config()
token_ids = (
[config.boa_token_id]
+ [processor.audio_token_id] * num_tokens
+ [config.eoa_token_id]
)
return PromptUpdateDetails.select_token_id(token_ids, processor.audio_token_id)
def get_video_repl(
self,
*,
timestamps: list[float],
num_soft_tokens_per_frame: list[int],
processor: Gemma4Processor,
) -> PromptUpdateDetails[list[int]]:
"""Build the full token replacement for one video.
Produces the same interleaved sequence as the HF Gemma4Processor:
mm:ss <boi><|video|>*N<eoi> mm:ss <boi><|video|>*N<eoi> ...
"""
tokenizer = self.ctx.get_tokenizer()
config = self.get_hf_config()
boi_token_id = config.boi_token_id
eoi_token_id = config.eoi_token_id
video_token_id = processor.video_token_id
all_token_ids: list[int] = []
for i, (ts, n_tokens) in enumerate(zip(timestamps, num_soft_tokens_per_frame)):
# mm:ss timestamp — matches transformers: int-truncated,
# zero-padded.
minutes = int(ts // 60)
seconds = int(ts % 60)
ts_str = f"{minutes:02d}:{seconds:02d}"
prefix = f" {ts_str} " if i > 0 else f"{ts_str} "
ts_token_ids = tokenizer.encode(prefix, add_special_tokens=False)
all_token_ids.extend(ts_token_ids)
all_token_ids.append(boi_token_id)
all_token_ids.extend([video_token_id] * n_tokens)
all_token_ids.append(eoi_token_id)
return PromptUpdateDetails.select_token_id(all_token_ids, video_token_id)
# ---------------------------------------------------------------------------
# Dummy inputs builder
# ---------------------------------------------------------------------------
class Gemma4DummyInputsBuilder(BaseDummyInputsBuilder[Gemma4ProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
num_audios = mm_counts.get("audio", 0)
num_videos = mm_counts.get("video", 0)
processor = self.info.get_hf_processor()
# Use image_token (<|image|>) with tab prefix — this is what the
# Gemma4 chat template inserts per image (\t<|image|>).
# _get_prompt_updates targets image_token and expands it to the
# full_image_sequence.
text = ("\t" + processor.image_token) * num_images
if num_audios > 0 and processor.audio_token:
text += processor.audio_token * num_audios
if num_videos > 0:
text += processor.video_token * num_videos
return text
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_audios = mm_counts.get("audio", 0)
num_videos = mm_counts.get("video", 0)
processor = self.info.get_hf_processor()
image_processor = processor.image_processor
# Use processor's configured image size for dummies.
# Gemma4ImageProcessor sets size=None (it uses patch_size /
# max_soft_tokens instead of the standard size dict), so we
# guard against None with `or {}`.
size = getattr(image_processor, "size", None) or {}
img_width = size.get("width", 224)
img_height = size.get("height", 224)
image_overrides = mm_options.get("image") if mm_options else None
audio_overrides = mm_options.get("audio") if mm_options else None
video_overrides = mm_options.get("video") if mm_options else None
data: MultiModalDataDict = {
"image": self._get_dummy_images(
width=img_width,
height=img_height,
num_images=num_images,
overrides=image_overrides,
),
}
if num_audios > 0:
audio_len = processor.feature_extractor.fft_length
data["audio"] = self._get_dummy_audios(
length=audio_len,
num_audios=num_audios,
overrides=audio_overrides,
)
if num_videos > 0:
data["video"] = self._get_dummy_videos(
width=img_width,
height=img_height,
num_frames=_VIDEO_MAX_FRAMES,
num_videos=num_videos,
overrides=video_overrides,
)
return data
def _get_dummy_videos(
self,
*,
width: int,
height: int,
num_frames: int,
num_videos: int,
overrides: VideoDummyOptions | None = None,
) -> list[VideoItem]:
num_frames = max(num_frames, 2)
videos = super()._get_dummy_videos(
width=width,
height=height,
num_frames=num_frames,
num_videos=num_videos,
overrides=overrides,
)
videos = [v.copy() for v in videos]
video_items: list[VideoItem] = []
for video in videos:
video_num_frames = video.shape[0]
video_metadata = {
"fps": 2.0,
"duration": video_num_frames / 2.0,
"total_num_frames": video_num_frames,
"frames_indices": list(range(video_num_frames)),
"video_backend": "opencv",
"do_sample_frames": False,
}
video_items.append((video, video_metadata))
return video_items
# ---------------------------------------------------------------------------
# Multimodal processor
# ---------------------------------------------------------------------------
class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
# Validate max_soft_tokens early and exit cleanly on bad values.
_SUPPORTED_SOFT_TOKENS = (70, 140, 280, 560, 1120)
merged_kwargs = self.info.ctx.get_merged_mm_kwargs(mm_kwargs)
val = merged_kwargs.get("max_soft_tokens")
if val is None:
val = merged_kwargs.get("images_kwargs", {}).get("max_soft_tokens")
if val is not None and val not in _SUPPORTED_SOFT_TOKENS:
logger.error(
"Unsupported max_soft_tokens value: %d. Valid values are %s. Exiting.",
val,
_SUPPORTED_SOFT_TOKENS,
)
sys.exit(1)
mm_data = dict(mm_data)
# ---- VIDEO HANDLING ----
# Gemma4 decomposes video into timestamped image frames.
# Each frame is processed with max_soft_tokens=70 through the
# same vision tower, matching transformers processing_gemma4.py.
video_outputs: dict[str, Any] = {}
if videos := mm_data.pop("videos", []):
processor = self.info.get_hf_processor()
all_video_pixel_values: list[torch.Tensor] = []
all_video_position_ids: list[torch.Tensor] = []
video_num_soft_tokens_per_video: list[list[int]] = []
video_timestamps_per_video: list[list[float]] = []
video_frame_counts: list[int] = []
for item in videos:
video_array, metadata = item
# Convert frames to PIL images
if isinstance(video_array, np.ndarray):
frames = [
PILImage.fromarray(video_array[i])
for i in range(video_array.shape[0])
]
else:
frames = list(video_array)
# Compute timestamps from metadata (same as transformers)
fps = metadata.get("fps") or 24
frame_indices = metadata.get("frames_indices", list(range(len(frames))))
timestamps = [idx / fps for idx in frame_indices]
# Process frames as images with max_soft_tokens=70
video_mm_kwargs = dict(mm_kwargs)
video_mm_kwargs["max_soft_tokens"] = _VIDEO_MAX_SOFT_TOKENS
dummy_prompt = ("\t" + processor.image_token) * len(frames)
frame_outputs = super()._call_hf_processor(
prompt=dummy_prompt,
mm_data={"images": frames},
mm_kwargs=video_mm_kwargs,
tok_kwargs=tok_kwargs,
)
# Remap HF key name
if "image_position_ids" in frame_outputs:
frame_outputs["pixel_position_ids"] = frame_outputs.pop(
"image_position_ids"
)
all_video_pixel_values.append(frame_outputs["pixel_values"])
all_video_position_ids.append(frame_outputs["pixel_position_ids"])
# Compute soft tokens per frame
num_soft_per_frame = []
for img in frames:
w, h = img.size
n = self.info._compute_num_soft_tokens(
w, h, max_soft_tokens=_VIDEO_MAX_SOFT_TOKENS
)
num_soft_per_frame.append(n)
video_num_soft_tokens_per_video.append(num_soft_per_frame)
video_timestamps_per_video.append(timestamps)
video_frame_counts.append(len(frames))
# Build expanded replacement text and replace the
# <|video|> placeholder in the prompt.
# Use split(token, 1) to avoid collision — the
# replacement text itself contains <|video|> tokens.
ts_strs = [f"{int(s // 60):02d}:{int(s % 60):02d}" for s in timestamps]
replacement = " ".join(
f"{t} {processor.boi_token}"
f"{processor.video_token * n}"
f"{processor.eoi_token}"
for t, n in zip(ts_strs, num_soft_per_frame)
)
parts = prompt.split(processor.video_token, 1)
if len(parts) == 2:
prompt = parts[0] + replacement + parts[1]
video_outputs = {
"pixel_values_videos": torch.cat(all_video_pixel_values, dim=0),
"pixel_position_ids_videos": torch.cat(all_video_position_ids, dim=0),
"video_frame_counts": torch.tensor(video_frame_counts),
"video_num_soft_tokens": video_num_soft_tokens_per_video,
"video_timestamps": video_timestamps_per_video,
}
# The processor accepts 'audio' not 'audios'.
if "audios" in mm_data:
mm_data["audio"] = mm_data.pop("audios")
# Warn if any audio waveform exceeds the model's max duration.
if "audio" in mm_data:
processor = self.info.get_hf_processor()
sr = processor.feature_extractor.sampling_rate
max_tokens = processor.audio_seq_length
ms_per_tok = processor.audio_ms_per_token
max_duration_s = max_tokens * ms_per_tok / 1000.0
audios = mm_data["audio"]
if not isinstance(audios, (list, tuple)):
audios = [audios]
for i, waveform in enumerate(audios):
duration_s = len(waveform) / sr
if duration_s > max_duration_s:
logger.warning(
"Audio duration exceeds max: %f > %f seconds",
duration_s,
max_duration_s,
)
# vLLM's call_hf_processor (context.py) re-merges
# mm_processor_kwargs from the model config on every call via:
# config_kwargs | incoming_kwargs (right side wins)
#
# If we strip max_soft_tokens from incoming, the re-merge puts
# back the config's global default (e.g. 280), ignoring any
# per-prompt override. Instead, we keep it in the kwargs with
# the validated per-prompt value so it wins during the merge.
#
# NOTE: This requires a corresponding type annotation on the
# HF side (Gemma4ProcessorKwargs.images_kwargs) so that
# _merge_kwargs routes max_soft_tokens into images_kwargs.
patched_mm_kwargs = dict(mm_kwargs)
if val is not None:
patched_mm_kwargs["max_soft_tokens"] = val
processed_outputs = super()._call_hf_processor(
prompt,
mm_data,
patched_mm_kwargs,
tok_kwargs,
)
# HF uses 'image_position_ids'; vLLM uses 'pixel_position_ids'.
# Remap here to keep a single translation point.
if "image_position_ids" in processed_outputs:
processed_outputs["pixel_position_ids"] = processed_outputs.pop(
"image_position_ids"
)
if "input_features" in processed_outputs:
# Keep padded features for batched audio tower execution.
processed_outputs["input_features_padded"] = processed_outputs[
"input_features"
]
# Unpad per-item so each item's cache entry is self-contained.
unpadded_features = [
f[mask]
for f, mask in zip(
processed_outputs["input_features"],
processed_outputs["input_features_mask"],
)
]
processed_outputs["input_features"] = unpadded_features
# Merge video outputs into the final result
combined_outputs = dict(processed_outputs, **video_outputs)
return BatchFeature(combined_outputs)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
fields = dict(
pixel_values=MultiModalFieldConfig.batched("image"),
pixel_position_ids=MultiModalFieldConfig.batched("image"),
input_features_padded=MultiModalFieldConfig.batched("audio"),
input_features_mask=MultiModalFieldConfig.batched("audio"),
)
# Video fields: frames stored flat, split per video by
# video_frame_counts.
video_frame_counts = hf_inputs.get("video_frame_counts")
if video_frame_counts is not None:
vfc = video_frame_counts
if not isinstance(vfc, torch.Tensor):
vfc = torch.tensor(vfc)
fields.update(
pixel_values_videos=(
MultiModalFieldConfig.flat_from_sizes("video", vfc)
),
pixel_position_ids_videos=(
MultiModalFieldConfig.flat_from_sizes("video", vfc)
),
video_frame_counts=MultiModalFieldConfig.batched(
"video",
),
video_num_soft_tokens=MultiModalFieldConfig.batched(
"video", keep_on_cpu=True
),
video_timestamps=MultiModalFieldConfig.batched(
"video", keep_on_cpu=True
),
)
return fields
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
prompt_updates = []
if "image" in mm_items:
# Target image_token (<|image|>) — the single placeholder the
# Gemma4 chat template inserts once per image in the prompt.
# vLLM tokenizes the prompt without token expansion, so only
# one image_token exists per image in the token stream.
# The replacement expands it to the full image sequence
# (boi + N×image_token + eoi, where N = max_soft_tokens).
image_token = hf_processor.image_token
def get_replacement_image(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
# Resolve the effective max_soft_tokens by merging
# per-prompt kwargs with the config-level defaults,
# consistent with how _call_hf_processor resolves it.
# Without this merge, a missing per-prompt override
# would fall back to vision_cfg.default_output_length
# instead of the config's mm_processor_kwargs default.
merged_kwargs = self.info.ctx.get_merged_mm_kwargs(
hf_processor_mm_kwargs,
)
max_soft_tokens = merged_kwargs.get("max_soft_tokens")
return self.info.get_image_repl(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
max_soft_tokens=max_soft_tokens,
)
prompt_updates.append(
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement_image,
)
)
if "video" in mm_items:
video_token = hf_processor.video_token
def get_replacement_video(item_idx: int):
out_item = out_mm_kwargs["video"][item_idx]
timestamps = out_item["video_timestamps"].data
num_soft = out_item["video_num_soft_tokens"].data
return self.info.get_video_repl(
timestamps=timestamps,
num_soft_tokens_per_frame=num_soft,
processor=hf_processor,
)
prompt_updates.append(
PromptReplacement(
modality="video",
target=video_token,
replacement=get_replacement_video,
)
)
if "audio" in mm_items:
audio_token = hf_processor.audio_token
def get_replacement_audio(item_idx: int):
audios = mm_items.get_items("audio", AudioProcessorItems)
audio_len = audios.get_audio_length(item_idx)
return self.info.get_audio_repl(
audio_len=audio_len,
processor=hf_processor,
)
prompt_updates.append(
PromptReplacement(
modality="audio",
target=audio_token,
replacement=get_replacement_audio,
)
)
return prompt_updates
# NOTE: Gemma3/Gemma3n override _apply_token_matches and
# _find_mm_placeholders to merge adjacent newline tokens that arise
# when full_image_sequence contains "\n\n" wrappers. Gemma4's
# full_image_sequence has NO newlines (just BOI + 280×image_token +
# EOI), so the base class implementations work correctly as-is.
# ---------------------------------------------------------------------------
# Multimodal embedder
# ---------------------------------------------------------------------------
class Gemma4MultimodalEmbedder(nn.Module):
"""Projects vision/audio soft tokens into LM embedding space.
Architecture:
inputs_embeds → embedding_projection → embedding_post_projection_norm
Unlike Gemma3n which has separate hard/soft embedding paths with
per-path normalization and a learned embedding table, Gemma4 uses a
simplified 2-layer design: a linear projection followed by RMSNorm
(without learnable scale). The checkpoint confirms this — only
``embedding_projection.weight`` exists; there is no embedding table
or pre-projection norm weights.
"""
def __init__(
self,
multimodal_config: Gemma4VisionConfig | Gemma4AudioConfig,
text_config: Gemma4TextConfig,
):
super().__init__()
self.eps = multimodal_config.rms_norm_eps
self.text_hidden_size = text_config.hidden_size
# Audio tower uses output_proj_dims (1536) rather than hidden_size
# (1024); vision uses hidden_size (768) directly.
embedding_dim = (
getattr(multimodal_config, "output_proj_dims", None)
or multimodal_config.hidden_size
)
self.embedding_projection = ReplicatedLinear(
embedding_dim,
self.text_hidden_size,
bias=False,
)
self.embedding_post_projection_norm = RMSNorm(
self.text_hidden_size,
eps=self.eps,
has_weight=False,
)
def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
"""Project soft tokens from a multimodal tower into LM space."""
embs_proj, _ = self.embedding_projection(inputs_embeds)
return self.embedding_post_projection_norm(embs_proj)
# ---------------------------------------------------------------------------
# Main model
# ---------------------------------------------------------------------------
@MULTIMODAL_REGISTRY.register_processor(
Gemma4MultiModalProcessor,
info=Gemma4ProcessingInfo,
dummy_inputs=Gemma4DummyInputsBuilder,
)
class Gemma4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# Maps checkpoint prefixes to vLLM module paths.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.embed_audio.": "embed_audio.",
"model.embed_vision.": "embed_vision.",
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.audio_tower.": "audio_tower.",
"lm_head.": "language_model.lm_head.",
"model": "language_model.model",
}
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
# ---- Vision tower (shared by image and video) ----
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.embed_vision = Gemma4MultimodalEmbedder(
config.vision_config, config.text_config
)
# ---- Audio tower (variants with audio_config) ----
if config.audio_config is not None:
with self._mark_tower_model(vllm_config, "audio"):
self.audio_tower = AutoModel.from_config(config=config.audio_config)
# AutoModel.from_config does NOT call post_init(),
# which is needed to initialize buffers that are absent
# from the checkpoint (e.g. inv_timescales for relative
# position embeddings, softcap, gradient_clipping).
self.audio_tower.post_init()
self.embed_audio = Gemma4MultimodalEmbedder(
config.audio_config, config.text_config
)
else:
self.audio_tower = None
self.embed_audio = None
# ---- Language model (vLLM optimised) ----
with self._mark_language_model(vllm_config):
self.language_model: Gemma4ForCausalLM = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Gemma4ForCausalLM"],
)
# Pre-allocate PLE buffer for CUDA graph compatibility.
# Some variants have hidden_size_per_layer_input=None (no PLE).
ple_dim = config.text_config.hidden_size_per_layer_input
if ple_dim is not None:
self.per_layer_embeddings = torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.num_hidden_layers,
ple_dim,
device=(self.language_model.model.embed_tokens.weight.device),
dtype=(self.language_model.model.embed_tokens.weight.dtype),
)
else:
self.per_layer_embeddings = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
# --- MixtureOfExperts delegation to language_model ---
self.expert_weights = self.language_model.expert_weights
self.moe_layers = self.language_model.moe_layers
self.num_moe_layers = self.language_model.num_moe_layers
self.num_logical_experts = self.language_model.num_logical_experts
self.num_physical_experts = self.language_model.num_physical_experts
self.num_local_physical_experts = self.language_model.num_local_physical_experts
self.num_routed_experts = self.language_model.num_routed_experts
self.num_expert_groups = self.language_model.num_expert_groups
self.num_shared_experts = self.language_model.num_shared_experts
self.num_redundant_experts = self.language_model.num_redundant_experts
# ------------------------------------------------------------------ #
# Input parsing
# ------------------------------------------------------------------ #
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Gemma4ImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
pixel_position_ids = kwargs.pop("pixel_position_ids", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma4 does not support image_embeds."
if pixel_values is None:
return None
return Gemma4ImagePixelInputs(
pixel_values=pixel_values,
pixel_position_ids=pixel_position_ids,
)
def _parse_and_validate_audio_input(
self, **kwargs: object
) -> Gemma4AudioInputs | None:
input_features_padded = kwargs.pop("input_features_padded", None)
if input_features_padded is None:
return None
input_features_mask = kwargs.pop("input_features_mask", None)
if input_features_mask is None:
return None
return Gemma4AudioInputs(
input_features_padded=input_features_padded,
input_features_mask=input_features_mask,
)
def _parse_and_validate_video_input(
self, **kwargs: object
) -> dict[str, torch.Tensor] | None:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
pixel_position_ids_videos = kwargs.pop("pixel_position_ids_videos", None)
video_frame_counts = kwargs.pop("video_frame_counts", None)
if pixel_values_videos is None:
return None
return {
"pixel_values_videos": pixel_values_videos,
"pixel_position_ids_videos": pixel_position_ids_videos,
"video_frame_counts": video_frame_counts,
}
def _parse_and_validate_multimodal_inputs(
self, **kwargs: object
) -> dict[str, Gemma4ImageInputs | Gemma4AudioInputs | Gemma4VideoInputs | None]:
mm_input_by_modality = {}
for input_key in list(kwargs):
if (
input_key in ("pixel_values", "image_embeds")
and "image" not in mm_input_by_modality
):
mm_input_by_modality["image"] = self._parse_and_validate_image_input(
**kwargs
)
if (
input_key == "pixel_values_videos"
and "video" not in mm_input_by_modality
):
mm_input_by_modality["video"] = self._parse_and_validate_video_input(
**kwargs
)
if (
input_key == "input_features_padded"
and "audio" not in mm_input_by_modality
):
mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(
**kwargs
)
return mm_input_by_modality
# ------------------------------------------------------------------ #
# Image processing
# ------------------------------------------------------------------ #
def _process_image_input(
self,
image_input: Gemma4ImageInputs,
) -> list[torch.Tensor]:
pixel_values = image_input["pixel_values"]
pixel_position_ids = image_input["pixel_position_ids"]
# The HF image processor now outputs pre-patchified data:
# pixel_values: (num_images, max_patches, patch_pixels)
# pixel_position_ids: (num_images, max_patches, 2)
# We call the vision tower's forward() directly, which handles
# patch embedding, encoding, pooling, padding removal, and
# optional standardization internally.
vt = self.vision_tower
pooling_k2 = self.config.vision_config.pooling_kernel_size**2
# TODO: Move this per-image loop into the input processor to
# reduce dynamism at the model runner / engine core. This
# requires spatially padding all images to uniform (H_max,
# W_max) in _call_hf_processor() so they arrive as a single
# stacked tensor, tracking padded regions via image_sizes
# metadata, and validating numerical equivalence with the
# current per-image path.
#
# Process each image individually through the vision tower.
# The vision tower's forward() strips padding and returns a
# flat tensor of valid tokens. We process per-image to get
# variable-length outputs matching the dynamic token count
# from get_image_repl.
per_image_features = []
for i in range(pixel_values.shape[0]):
pv = pixel_values[i].unsqueeze(0) # (1, max_patches, patch_pixels)
pp = pixel_position_ids[i].unsqueeze(0) # (1, max_patches, 2)
# Derive the pooler's output_length from the total patch
# count (including padding). The vision tower encoder
# processes ALL patches — padding patches get zero hidden
# states but still occupy sequence positions. The pooler's
# _avg_pool_by_positions requires:
# input_seq_len / output_length == k²
# where k == pooling_kernel_size. The image processor
# allocates max_patches = max_soft_tokens * k² total slots,
# so output_length = max_patches / k² == max_soft_tokens.
# Without this, the pooler falls back to
# config.image_seq_length (e.g. 280), which fails when a
# different max_soft_tokens was used at preprocessing time.
max_patches = pv.shape[1]
output_length = max_patches // pooling_k2
vt_output = vt(pv, pp, output_length=output_length)
# last_hidden_state: (num_valid_tokens, hidden_size)
# — already flat with padding stripped by the vision tower
per_image_features.append(vt_output.last_hidden_state)
# Project each image's features into LM embedding space.
# Per-image loop is required because images have variable
# token counts after padding removal.
# Cast to match the projection layer's dtype (model may be
# bf16 while the vision tower outputs fp32).
target_dtype = self.embed_vision.embedding_projection.weight.dtype
return [
self.embed_vision(inputs_embeds=img.unsqueeze(0).to(target_dtype)).squeeze(
0
)
for img in per_image_features
]
# ------------------------------------------------------------------ #
# Video processing (frames through vision tower)
# ------------------------------------------------------------------ #
def _process_video_input(
self,
video_input: dict[str, torch.Tensor],
) -> list[torch.Tensor]:
"""Process video frames through the vision tower.
Reuses the image processing pipeline — Gemma4 has no separate
video tower; video frames are just images at lower resolution
(max_soft_tokens=70).
Returns one concatenated embedding tensor per video (not per
frame), because vLLM treats one video as one multimodal item.
The flat_from_sizes field config groups all frames of a video
together, so embed_multimodal must return one tensor per video.
"""
pixel_values = video_input["pixel_values_videos"]
pixel_position_ids = video_input["pixel_position_ids_videos"]
frame_counts = video_input["video_frame_counts"]
vt = self.vision_tower
pooling_k2 = self.config.vision_config.pooling_kernel_size**2
target_dtype = self.embed_vision.embedding_projection.weight.dtype
# Split flat tensors into per-video chunks
if isinstance(frame_counts, torch.Tensor):
fc_list = frame_counts.tolist()
else:
fc_list = list(frame_counts)
pv_per_video = torch.split(pixel_values, fc_list, dim=0)
pp_per_video = torch.split(pixel_position_ids, fc_list, dim=0)
per_video_embeddings = []
for pv_chunk, pp_chunk in zip(pv_per_video, pp_per_video):
frame_embs = []
for i in range(pv_chunk.shape[0]):
pv = pv_chunk[i].unsqueeze(0)
pp = pp_chunk[i].unsqueeze(0)
max_patches = pv.shape[1]
output_length = max_patches // pooling_k2
vt_output = vt(pv, pp, output_length=output_length)
frame_emb = self.embed_vision(
inputs_embeds=(
vt_output.last_hidden_state.unsqueeze(0).to(target_dtype)
)
).squeeze(0)
frame_embs.append(frame_emb)
# Concatenate all frames of this video into one tensor.
per_video_embeddings.append(torch.cat(frame_embs, dim=0))
return per_video_embeddings
# ------------------------------------------------------------------ #
# Audio processing
# ------------------------------------------------------------------ #
def _process_audio_input(
self,
audio_input: Gemma4AudioInputs,
) -> list[torch.Tensor]:
input_features = audio_input["input_features_padded"].squeeze(1)
input_features_mask = audio_input["input_features_mask"].squeeze(1)
# Run audio tower — mask uses standard HF convention
# (True=valid, False=padding).
audio_outputs = self.audio_tower(input_features, input_features_mask)
if isinstance(audio_outputs, tuple):
audio_encodings, audio_mask = audio_outputs
else:
audio_encodings = audio_outputs.last_hidden_state
audio_mask = audio_outputs.attention_mask
# Project into LM embedding space.
audio_features = self.embed_audio(inputs_embeds=audio_encodings)
# Strip padding per-batch element: only keep real (non-padding)
# tokens. audio_mask is True for valid positions (HF convention).
per_audio = []
for enc, mask in zip(audio_features, audio_mask, strict=True):
per_audio.append(enc[mask]) # [num_real, hidden_size]
return per_audio
# ------------------------------------------------------------------ #
# MultiModalEmbeddings interface
# ------------------------------------------------------------------ #
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
multimodal_embeddings: list[torch.Tensor] = []
for modality, multimodal_input in mm_input_by_modality.items():
if multimodal_input is None:
continue
if modality == "image":
multimodal_embeddings.extend(
self._process_image_input(multimodal_input)
)
elif modality == "video":
multimodal_embeddings.extend(
self._process_video_input(multimodal_input)
)
elif modality == "audio":
multimodal_embeddings.extend(
self._process_audio_input(multimodal_input)
)
return multimodal_embeddings
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
) -> torch.Tensor:
# Cache per-layer embeddings (PLE) for the language model's
# forward pass. During profiling embed_input_ids is not called,
# so the pre-allocated zeros are used instead.
if self.per_layer_embeddings is not None:
# Mask multimodal tokens (image/audio) to 0 for PLE
# computation (using token_type_ids == 0 as text_mask).
# Replicate this: map image token positions to token 0.
if is_multimodal is not None:
is_multimodal = is_multimodal.to(input_ids.device)
ple_input_ids = torch.where(
is_multimodal, torch.zeros_like(input_ids), input_ids
)
else:
ple_input_ids = input_ids
per_layer_inputs = self.language_model.model.get_per_layer_inputs(
ple_input_ids
)
if per_layer_inputs is not None:
per_layer_inputs = per_layer_inputs.reshape(
-1,
self.config.text_config.num_hidden_layers,
self.config.text_config.hidden_size_per_layer_input,
)
self.per_layer_embeddings[: per_layer_inputs.shape[0]].copy_(
per_layer_inputs
)
if multimodal_embeddings is None or is_multimodal is None:
return super().embed_input_ids(input_ids)
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
# ------------------------------------------------------------------ #
# Forward
# ------------------------------------------------------------------ #
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
# Select the pre-cached PLEs for this batch (None when PLE
# is disabled for variants without PLE).
per_layer_inputs = (
self.per_layer_embeddings[: inputs_embeds.shape[0]]
if self.per_layer_embeddings is not None and inputs_embeds is not None
else None
)
hidden_states = self.language_model.model(
input_ids,
positions,
per_layer_inputs=per_layer_inputs,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**kwargs,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
# ------------------------------------------------------------------ #
# Weight loading
# ------------------------------------------------------------------ #
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# Some checkpoints have vestigial embed_vision.embedding and
# embed_audio.embedding weights from the Gemma3n architecture
# that are not used by Gemma4's MultimodalEmbedder (which only
# has embedding_projection + embedding_post_projection_norm).
ignore_prefixes = [
"embed_vision.embedding.",
"embed_audio.embedding.",
]
# Models without audio tower should skip
# audio weights entirely.
if self.audio_tower is None:
ignore_prefixes.extend(
[
"audio_tower.",
"embed_audio.",
]
)
loader = AutoWeightsLoader(
self,
ignore_unexpected_prefixes=ignore_prefixes,
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
# ------------------------------------------------------------------ #
# LoRA / multimodal mapping
# ------------------------------------------------------------------ #
def get_mm_mapping(self) -> MultiModelKeys:
"""Get the module prefix mapping for multimodal models."""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector=["embed_vision", "embed_audio"],
tower_model=["vision_tower", "audio_tower"],
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality == "image":
return "<image_soft_token>"
if modality == "audio":
return "<audio_soft_token>"
if modality == "video":
return "<|video|>"
raise ValueError(f"Unsupported modality: {modality}")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
"""Gemma4 output parsing utilities for offline inference.
Standalone functions that parse decoded model text to extract structured
thinking content and tool calls from Gemma4 models. These are pure-Python
utilities with zero heavy dependencies — they work on raw decoded strings
from any inference backend (vLLM, HuggingFace, TGI, etc.).
Usage with vLLM offline inference::
from vllm import LLM, SamplingParams
from vllm.model_executor.models.gemma4_utils import (
parse_output,
parse_tool_calls,
)
llm = LLM(model="google/gemma-4-it")
outputs = llm.generate(prompt, SamplingParams(...))
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
# Extract thinking / answer (works with or without enable_thinking)
result = parse_output(text)
print(result["thinking"]) # chain-of-thought or None
print(result["answer"]) # final answer
# Extract tool calls
tool_calls = parse_tool_calls(text)
for tc in tool_calls:
print(f"{tc['name']}({tc['arguments']})")
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
do not need a transformers dependency for output parsing.
"""
import json
import regex as re
# ---- Thinking Mode Utility ----
# Thinking delimiter tokens as they appear in decoded text.
# Gemma4 uses <|channel> (start) and <channel|> (end) as thinking delimiters.
_THINKING_START_TAG = "<|channel>"
_THINKING_END_TAG = "<channel|>"
# Sentinel tokens that may appear in decoded output.
_TURN_END_TAG = "<turn|>"
def parse_thinking_output(text: str) -> dict[str, str | None]:
"""Parse decoded Gemma4 model output.
Use this on **all** Gemma4 output regardless of whether thinking mode
was enabled. It handles three cases:
1. **Thinking enabled, tags present** — splits on ``<|channel>``/
``<channel|>`` to separate chain-of-thought from the answer and
strips the ``thought\\n`` role label.
2. **Thinking disabled, spurious label** — strips the bare
``thought\\n`` prefix that some Gemma4 models emit even
without thinking mode.
3. **Clean output** — returns the text unchanged.
The answer text is always cleaned of trailing sentinel tokens
(``<turn|>``, ``<eos>``, etc.).
Args:
text: Decoded model output text (from ``tokenizer.decode(...)``).
Returns:
A dict with keys:
- ``"thinking"``: The chain-of-thought text, or ``None`` if no
thinking delimiters were found.
- ``"answer"``: The final answer text.
Example::
>>> from vllm.model_executor.models.gemma4_utils import parse_thinking_output
>>> output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
>>> result = parse_thinking_output(output_text)
>>> print(result["thinking"]) # chain-of-thought reasoning or None
>>> print(result["answer"]) # final answer
"""
if _THINKING_END_TAG in text:
parts = text.split(_THINKING_END_TAG, 1)
thinking_block = parts[0]
answer = _clean_answer(parts[1])
# Extract thinking content: strip the start tag if present
if _THINKING_START_TAG in thinking_block:
thinking = thinking_block.split(_THINKING_START_TAG, 1)[1]
else:
thinking = thinking_block
# Strip the "thought\n" channel role label the model emits inside
# <|channel>thought\n...<channel|> (analogous to "user\n" in
# <|turn>user\n...<turn|>).
thinking = _strip_thought_label(thinking.strip())
thinking = thinking.strip()
return {"thinking": thinking, "answer": answer}
# No thinking delimiters found.
# Strip spurious "thought\n" role label that some Gemma4 models sometimes
# emit even without thinking mode enabled, then clean trailing tokens.
answer = _strip_thought_label(text)
answer = _clean_answer(answer)
return {"thinking": None, "answer": answer}
def _strip_thought_label(text: str) -> str:
"""Strip the spurious ``thought\\n`` label from the start of text.
Only strips when ``thought`` appears as the very first word followed by
a newline — preserving the word ``thought`` in any other context.
"""
if text.startswith("thought\n"):
return text[len("thought\n") :]
return text
def _clean_answer(text: str) -> str:
"""Clean trailing sentinel tokens from the answer text.
Strips ``<turn|>``, ``<eos>``, and surrounding whitespace that the
model appends at the end of its response.
"""
text = text.strip()
# Strip trailing <turn|> (Gemma4 turn-end marker)
if text.endswith(_TURN_END_TAG):
text = text[: -len(_TURN_END_TAG)].rstrip()
# Strip trailing <eos> if present
if text.endswith("<eos>"):
text = text[:-5].rstrip()
return text
# ---- Tool Call Parsing Utility ----
#
# NOTE: For the OpenAI-compatible API server tool parser (streaming +
# non-streaming), see vllm/tool_parsers/gemma4_tool_parser.py.
# This module provides offline inference utilities for direct user import.
# Tool call delimiter tokens as they appear in decoded text.
# Standard format: <|tool_call>call:name{args}<tool_call|>
_TOOL_CALL_START_TAG = "<|tool_call>"
_TOOL_CALL_END_TAG = "<tool_call|>"
_TOOL_RESPONSE_START_TAG = "<|tool_response>"
# Gemma4 escape token as it appears in decoded text.
_ESCAPE_TOKEN = '<|"|>'
def _parse_tool_arguments(args_str: str) -> dict[str, str]:
"""Parse tool call arguments from the Gemma4 compact format.
Handles the ``key:<|"|>value<|"|>`` format used by Gemma4, with fallback
to heuristic key-value extraction. Also tolerates the slightly different
``key: "value"`` format (space + plain quotes) that some chat templates
produce.
Args:
args_str: Raw argument string from inside ``call:name{...}``.
Returns:
Dictionary of argument name → value.
"""
if not args_str or not args_str.strip():
return {}
# Replace Gemma4 escape tokens with standard quotes.
cleaned = args_str.replace(_ESCAPE_TOKEN, '"')
# Try JSON parsing first (handles nested values, arrays, etc.).
try:
parsed = json.loads("{" + cleaned + "}")
# Ensure all values are strings for consistency.
return {k: str(v) if not isinstance(v, str) else v for k, v in parsed.items()}
except (json.JSONDecodeError, ValueError):
pass
# Fallback: extract key:"value" pairs (allow optional space after colon).
arguments = {}
for key, value in re.findall(r'(\w+):\s*"([^"]*)"', cleaned):
arguments[key] = value
if not arguments:
# Last resort: extract key:value pairs (unquoted).
for key, value in re.findall(r"(\w+):\s*([^,}]+)", args_str):
arguments[key] = value.strip().strip('"').replace(_ESCAPE_TOKEN, "")
return arguments
def parse_tool_calls(text: str, *, strict: bool = False) -> list[dict]:
"""Parse tool calls from decoded Gemma4 model output.
Uses a tiered parsing strategy to handle known output variations in
Gemma4 models, which may emit
non-standard tool call formats.
Parsing tiers:
1. **Standard**: ``<|tool_call>call:name{args}<tool_call|>``
(special token IDs 48/49 in decoded text)
2. **Fallback** (when ``strict=False``): bare ``call:name{args}``
patterns, including ``<call>name{args}`` (fragmented tokens from
multimodal inputs)
Args:
text: Decoded model output text (from ``tokenizer.decode(...,
skip_special_tokens=False)``).
strict: If ``True``, only match the standard ``<|tool_call>`` format.
If ``False`` (default), also try fallback patterns for
known Gemma4 output variations.
Returns:
A list of dicts, each with keys:
- ``"name"``: The tool function name (e.g. ``"get_weather"``).
- ``"arguments"``: A dict of argument name → value.
Example::
>>> from vllm.model_executor.models.gemma4_utils import (
... parse_tool_calls
... )
>>> output = tokenizer.decode(outputs[0], skip_special_tokens=False)
>>> tool_calls = parse_tool_calls(output)
>>> for tc in tool_calls:
... print(f"Call: {tc['name']}({tc['arguments']})")
"""
results = []
# Tier 1: Standard format with special tokens.
# <|tool_call>call:name{args}<tool_call|>
# Note: Some Gemma4 models emit <turn|> instead of <tool_call|>.
standard_pattern = r"<\|tool_call\>call:(\w+)\{(.*?)\}(?:<tool_call\|>|<turn\|>)"
for match in re.finditer(standard_pattern, text, re.DOTALL):
name, args_str = match.group(1), match.group(2)
results.append(
{
"name": name,
"arguments": _parse_tool_arguments(args_str),
}
)
if results or strict:
return results
# Tier 2: Fallback for known Gemma4 output variations.
# Matches: <call>name{args}, call:name{args}, or bare call:name{args}<eos>
fallback_pattern = r"(?:<call>|(?:^|\s)call:)(\w+)\{(.*?)\}"
for match in re.finditer(fallback_pattern, text, re.DOTALL):
name, args_str = match.group(1), match.group(2)
results.append(
{
"name": name,
"arguments": _parse_tool_arguments(args_str),
}
)
return results
def has_tool_response_tag(text: str) -> bool:
"""Check if model output properly ends with a tool response tag.
Some Gemma4 models sometimes emit ``<eos>`` instead of
``<|tool_response>`` after a tool call. This helper detects
whether the model used the proper termination, so callers can
decide whether to inject ``<|tool_response>`` into the next prompt.
Args:
text: Decoded model output text.
Returns:
``True`` if the output ends with ``<|tool_response>``
(proper behavior), ``False`` otherwise.
Example::
>>> from vllm.model_executor.models.gemma4_utils import (
... has_tool_response_tag
... )
>>> if not has_tool_response_tag(model_output):
... # Model used <eos> instead — inject <|tool_response> manually
... next_prompt = "<|tool_response>" + tool_result
"""
stripped = text.rstrip()
return stripped.endswith(_TOOL_RESPONSE_START_TAG)
......@@ -111,6 +111,7 @@ _TEXT_GENERATION_MODELS = {
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
"Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
"Gemma4ForCausalLM": ("gemma4", "Gemma4ForCausalLM"),
"Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
......@@ -381,6 +382,7 @@ _MULTIMODAL_MODELS = {
"gemma3n_mm",
"Gemma3nForConditionalGeneration",
),
"Gemma4ForConditionalGeneration": ("gemma4_mm", "Gemma4ForConditionalGeneration"),
"GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
......
......@@ -233,8 +233,15 @@ class AutoWeightsLoader:
):
"""
Add tensor names that are not in the model params that may be in the
safetensors, e.g., batch normalization stats.
safetensors, e.g., batch normalization stats and registered buffers.
"""
# Add persistent registered buffers.
# Non-persistent buffers are excluded, matching PyTorch state_dict().
non_persistent = getattr(module, "_non_persistent_buffers_set", set())
for buf_name, buf in module.named_buffers(recurse=False):
if buf_name not in child_params and buf_name not in non_persistent:
child_params[buf_name] = buf
if isinstance(
module,
(
......
......@@ -32,6 +32,10 @@ _REASONING_PARSERS_TO_REGISTER = {
"ernie45_reasoning_parser",
"Ernie45ReasoningParser",
),
"gemma4": (
"gemma4_reasoning_parser",
"Gemma4ReasoningParser",
),
"glm45": (
"deepseek_v3_reasoning_parser",
"DeepSeekV3ReasoningWithThinkingParser",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.tokenizers import TokenizerLike
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
# Role label that Gemma4 emits at the start of the thinking channel.
# The model generates: <|channel>thought\n...reasoning...<channel|>
# This prefix must be stripped to expose only the actual reasoning content.
_THOUGHT_PREFIX = "thought\n"
class Gemma4ReasoningParser(BaseThinkingReasoningParser):
"""
Reasoning parser for Google Gemma4 thinking models.
Gemma4 uses <|channel>...<channel|> tokens to delimit reasoning/thinking
content within its output. Thinking mode is activated by passing
``enable_thinking=True`` in the chat template kwargs, which injects a
system turn containing <|think|> (token 98) to trigger chain-of-thought
reasoning.
Output pattern when thinking is enabled::
<|channel>thought
...chain of thought reasoning...<channel|>
Final answer text here.
The ``thought\\n`` role label inside the channel delimiters is a
structural artefact (analogous to ``user\\n`` in ``<|turn>user\\n...``).
This parser strips it so that downstream consumers see only the
actual reasoning text, consistent with the offline parser
(``vllm.reasoning.gemma4_utils._strip_thought_label``).
"""
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
# Instance state for streaming prefix stripping.
# Tracks only the reasoning text received from the base parser,
# independent of current_text (which may contain pre-reasoning
# content and lacks special token text due to
# skip_special_tokens=True).
self._reasoning_text: str = ""
self._prefix_stripped: bool = False
@property
def start_token(self) -> str:
"""The token that starts reasoning content."""
return "<|channel>"
@property
def end_token(self) -> str:
"""The token that ends reasoning content."""
return "<channel|>"
# ------------------------------------------------------------------
# Non-streaming path
# ------------------------------------------------------------------
def extract_reasoning(
self,
model_output: str,
request: "ChatCompletionRequest | ResponsesRequest",
) -> tuple[str | None, str | None]:
"""Extract reasoning, stripping the ``thought\\n`` role label."""
if self.start_token not in model_output and self.end_token not in model_output:
# Default to content history if no tags are present
# (or if they were stripped)
return None, model_output
reasoning, content = super().extract_reasoning(model_output, request)
if reasoning is not None:
reasoning = _strip_thought_label(reasoning)
return reasoning, content
# ------------------------------------------------------------------
# Streaming path
# ------------------------------------------------------------------
def extract_reasoning_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
"""Extract streaming reasoning, stripping ``thought\\n`` from the
first reasoning delta(s).
The ``thought\\n`` prefix may arrive as a single delta or split
across multiple deltas (e.g. ``"thought"`` then ``"\\n"``). We
buffer early reasoning tokens until we can determine whether the
prefix is present, then emit the buffered content minus the
prefix.
Unlike the previous implementation which reconstructed accumulated
reasoning from ``current_text``, this uses instance state
(``_reasoning_text``) to track only the reasoning content returned
by the base parser. This is necessary because
``skip_special_tokens=True`` (the vLLM default) causes the
``<|channel>`` delimiter to be invisible in ``current_text``,
making it impossible to separate pre-reasoning content from
reasoning content via string matching.
"""
result = super().extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
)
if result is None:
return None
if result.reasoning is None:
return result
# Accumulate ONLY the reasoning text from base parser results.
# This is immune to pre-reasoning content pollution.
self._reasoning_text += result.reasoning
# Once the prefix has been handled, all subsequent reasoning
# deltas pass through unchanged.
if self._prefix_stripped:
return result
# ---- Prefix stripping logic ----
# Case 1: We've accumulated enough to confirm the prefix is
# present. Strip it and pass through the remainder.
if self._reasoning_text.startswith(_THOUGHT_PREFIX):
prefix_len = len(_THOUGHT_PREFIX)
# How much reasoning was accumulated before this delta?
prev_reasoning_len = len(self._reasoning_text) - len(result.reasoning)
if prev_reasoning_len >= prefix_len:
# Prefix was already consumed by prior deltas; this
# delta is entirely real content — pass through.
self._prefix_stripped = True
return result
else:
# Part or all of the prefix is in this delta.
chars_of_prefix_in_delta = prefix_len - prev_reasoning_len
stripped = result.reasoning[chars_of_prefix_in_delta:]
if stripped:
self._prefix_stripped = True
result.reasoning = stripped
return result
else:
# This entire delta was prefix — suppress it.
# Don't set _prefix_stripped yet; there may be more
# prefix chars to consume in the next delta.
if len(self._reasoning_text) >= prefix_len:
self._prefix_stripped = True
return None
# Case 2: Accumulated text is a strict prefix of
# _THOUGHT_PREFIX (e.g. we've only seen "thou" so far).
# Buffer by suppressing — we can't yet tell if this will
# become the full prefix or diverge.
if _THOUGHT_PREFIX.startswith(self._reasoning_text):
return None
# Case 3: Accumulated text doesn't match the thought prefix
# at all. This means prior deltas were buffered (suppressed
# by Case 2) but the text diverged. Re-emit the full
# accumulated text to avoid data loss.
self._prefix_stripped = True
result.reasoning = self._reasoning_text
return result
def _strip_thought_label(text: str) -> str:
"""Remove the ``thought\\n`` role label from the beginning of text.
Mirrors ``vllm.reasoning.gemma4_utils._strip_thought_label`` from the
offline parser.
"""
if text.startswith(_THOUGHT_PREFIX):
return text[len(_THOUGHT_PREFIX) :]
return text
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
"""Gemma4 thinking/reasoning output parsing utilities for offline inference.
Standalone functions that parse decoded model text to extract structured
thinking content from Gemma4 models. These are pure-Python utilities with
zero heavy dependencies — they work on raw decoded strings from any
inference backend (vLLM, HuggingFace, TGI, etc.).
For the OpenAI-compatible API reasoning parser (streaming +
non-streaming), see ``vllm.reasoning.gemma4_reasoning_parser``.
For tool call parsing, see ``vllm.tool_parsers.gemma4_utils``.
Usage with vLLM offline inference::
from vllm import LLM, SamplingParams
from vllm.reasoning.gemma4_utils import parse_thinking_output
llm = LLM(model="google/gemma-4-it")
outputs = llm.generate(prompt, SamplingParams(...))
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
# Extract thinking / answer (works with or without enable_thinking)
result = parse_thinking_output(text)
print(result["thinking"]) # chain-of-thought or None
print(result["answer"]) # final answer
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
do not need a transformers dependency for output parsing.
"""
# ---- Thinking Mode Utility ----
# Thinking delimiter tokens as they appear in decoded text.
# Gemma4 uses <|channel> (start) and <channel|> (end) as thinking delimiters.
_THINKING_START_TAG = "<|channel>"
_THINKING_END_TAG = "<channel|>"
# Sentinel tokens that may appear in decoded output.
_TURN_END_TAG = "<turn|>"
def parse_thinking_output(text: str) -> dict[str, str | None]:
"""Parse decoded Gemma4 model output.
Use this on **all** Gemma4 output regardless of whether thinking mode
was enabled. It handles three cases:
1. **Thinking enabled, tags present** — splits on ``<|channel>``/
``<channel|>`` to separate chain-of-thought from the answer and
strips the ``thought\\n`` role label.
2. **Thinking disabled, spurious label** — strips the bare
``thought\\n`` prefix that some Gemma4 models emit even
without thinking mode.
3. **Clean output** — returns the text unchanged.
The answer text is always cleaned of trailing sentinel tokens
(``<turn|>``, ``<eos>``, etc.).
Args:
text: Decoded model output text (from ``tokenizer.decode(...)``).
Returns:
A dict with keys:
- ``"thinking"``: The chain-of-thought text, or ``None`` if no
thinking delimiters were found.
- ``"answer"``: The final answer text.
Example::
>>> from vllm.reasoning.gemma4_utils import parse_thinking_output
>>> output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
>>> result = parse_thinking_output(output_text)
>>> print(result["thinking"]) # chain-of-thought reasoning or None
>>> print(result["answer"]) # final answer
"""
if _THINKING_END_TAG in text:
parts = text.split(_THINKING_END_TAG, 1)
thinking_block = parts[0]
answer = _clean_answer(parts[1])
# Extract thinking content: strip the start tag if present
if _THINKING_START_TAG in thinking_block:
thinking = thinking_block.split(_THINKING_START_TAG, 1)[1]
else:
thinking = thinking_block
# Strip the "thought\n" channel role label the model emits inside
# <|channel>thought\n...<channel|> (analogous to "user\n" in
# <|turn>user\n...<turn|>).
thinking = _strip_thought_label(thinking.strip())
thinking = thinking.strip()
return {"thinking": thinking, "answer": answer}
# No thinking delimiters found.
# Strip spurious "thought\n" role label that some Gemma4 models sometimes
# emit even without thinking mode enabled, then clean trailing tokens.
answer = _strip_thought_label(text)
answer = _clean_answer(answer)
return {"thinking": None, "answer": answer}
def _strip_thought_label(text: str) -> str:
"""Strip the spurious ``thought\\n`` label from the start of text.
Only strips when ``thought`` appears as the very first word followed by
a newline — preserving the word ``thought`` in any other context.
"""
if text.startswith("thought\n"):
return text[len("thought\n") :]
return text
def _clean_answer(text: str) -> str:
"""Clean trailing sentinel tokens from the answer text.
Strips ``<turn|>``, ``<eos>``, and surrounding whitespace that the
model appends at the end of its response.
"""
text = text.strip()
# Strip trailing <turn|> (Gemma4 turn-end marker)
if text.endswith(_TURN_END_TAG):
text = text[: -len(_TURN_END_TAG)].rstrip()
# Strip trailing <eos> if present
if text.endswith("<eos>"):
text = text[:-5].rstrip()
return text
......@@ -154,6 +154,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"functiongemma_tool_parser",
"FunctionGemmaToolParser",
),
"gemma4": (
"gemma4_tool_parser",
"Gemma4ToolParser",
),
}
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tool call parser for Google Gemma4 models.
Gemma4 uses a custom serialization format (not JSON) for tool calls::
<|tool_call>call:func_name{key:<|"|>value<|"|>,num:42}<tool_call|>
Strings are delimited by ``<|"|>`` (token 52), keys are unquoted, and
multiple tool calls are concatenated without separators.
Used when ``--enable-auto-tool-choice --tool-call-parser gemma4`` are set.
For offline inference tool call parsing (direct ``tokenizer.decode()`` output),
see ``vllm.tool_parsers.gemma4_utils.parse_tool_calls``.
"""
import json
from collections.abc import Sequence
import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.tool_parsers.utils import find_common_prefix
logger = init_logger(__name__)
# Gemma4 special tokens for tool calls
TOOL_CALL_START = "<|tool_call>"
TOOL_CALL_END = "<tool_call|>"
STRING_DELIM = '<|"|>'
# ---------------------------------------------------------------------------
# Gemma4 argument parser (used by both streaming and non-streaming paths)
# ---------------------------------------------------------------------------
def _parse_gemma4_value(value_str: str) -> object:
"""Parse a single Gemma4 value (after key:) into a Python object."""
value_str = value_str.strip()
if not value_str:
return value_str
# Boolean
if value_str == "true":
return True
if value_str == "false":
return False
# Number (int or float)
try:
if "." in value_str:
return float(value_str)
return int(value_str)
except ValueError:
pass
# Bare string (no <|"|> delimiters — shouldn't happen but be safe)
return value_str
def _parse_gemma4_args(args_str: str) -> dict:
"""Parse Gemma4's custom key:value format into a Python dict.
Format examples::
location:<|"|>Tokyo<|"|>
location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>
count:42,flag:true
nested:{inner_key:<|"|>val<|"|>}
items:[<|"|>a<|"|>,<|"|>b<|"|>]
Returns a dict ready for ``json.dumps()``.
"""
if not args_str or not args_str.strip():
return {}
result: dict = {}
i = 0
n = len(args_str)
while i < n:
# Skip whitespace and commas
while i < n and args_str[i] in (" ", ",", "\n", "\t"):
i += 1
if i >= n:
break
# Parse key (unquoted, ends at ':')
key_start = i
while i < n and args_str[i] != ":":
i += 1
if i >= n:
break
key = args_str[key_start:i].strip()
i += 1 # skip ':'
# Parse value
if i >= n:
result[key] = ""
break
# Skip whitespace after ':'
while i < n and args_str[i] in (" ", "\n", "\t"):
i += 1
if i >= n:
result[key] = ""
break
# String value: <|"|>...<|"|>
if args_str[i:].startswith(STRING_DELIM):
i += len(STRING_DELIM)
val_start = i
end_pos = args_str.find(STRING_DELIM, i)
if end_pos == -1:
# Unterminated string — take rest
result[key] = args_str[val_start:]
break
result[key] = args_str[val_start:end_pos]
i = end_pos + len(STRING_DELIM)
# Nested object: {...}
elif args_str[i] == "{":
depth = 1
obj_start = i + 1
i += 1
while i < n and depth > 0:
if args_str[i:].startswith(STRING_DELIM):
# Skip over string contents to avoid counting { inside strings
i += len(STRING_DELIM)
next_delim = args_str.find(STRING_DELIM, i)
i = n if next_delim == -1 else next_delim + len(STRING_DELIM)
continue
if args_str[i] == "{":
depth += 1
elif args_str[i] == "}":
depth -= 1
i += 1
result[key] = _parse_gemma4_args(args_str[obj_start : i - 1])
# Array: [...]
elif args_str[i] == "[":
depth = 1
arr_start = i + 1
i += 1
while i < n and depth > 0:
if args_str[i:].startswith(STRING_DELIM):
i += len(STRING_DELIM)
next_delim = args_str.find(STRING_DELIM, i)
i = n if next_delim == -1 else next_delim + len(STRING_DELIM)
continue
if args_str[i] == "[":
depth += 1
elif args_str[i] == "]":
depth -= 1
i += 1
arr_content = args_str[arr_start : i - 1]
result[key] = _parse_gemma4_array(arr_content)
# Bare value (number, boolean, etc.)
else:
val_start = i
while i < n and args_str[i] not in (",", "}", "]"):
i += 1
result[key] = _parse_gemma4_value(args_str[val_start:i])
return result
def _parse_gemma4_array(arr_str: str) -> list:
"""Parse a Gemma4 array content string into a Python list."""
items: list = []
i = 0
n = len(arr_str)
while i < n:
while i < n and arr_str[i] in (" ", ",", "\n", "\t"):
i += 1
if i >= n:
break
# String element
if arr_str[i:].startswith(STRING_DELIM):
i += len(STRING_DELIM)
end_pos = arr_str.find(STRING_DELIM, i)
if end_pos == -1:
items.append(arr_str[i:])
break
items.append(arr_str[i:end_pos])
i = end_pos + len(STRING_DELIM)
# Nested object
elif arr_str[i] == "{":
depth = 1
obj_start = i + 1
i += 1
while i < n and depth > 0:
if arr_str[i:].startswith(STRING_DELIM):
i += len(STRING_DELIM)
nd = arr_str.find(STRING_DELIM, i)
i = nd + len(STRING_DELIM) if nd != -1 else n
continue
if arr_str[i] == "{":
depth += 1
elif arr_str[i] == "}":
depth -= 1
i += 1
items.append(_parse_gemma4_args(arr_str[obj_start : i - 1]))
# Nested array
elif arr_str[i] == "[":
depth = 1
sub_start = i + 1
i += 1
while i < n and depth > 0:
if arr_str[i] == "[":
depth += 1
elif arr_str[i] == "]":
depth -= 1
i += 1
items.append(_parse_gemma4_array(arr_str[sub_start : i - 1]))
# Bare value
else:
val_start = i
while i < n and arr_str[i] not in (",", "]"):
i += 1
items.append(_parse_gemma4_value(arr_str[val_start:i]))
return items
# ---------------------------------------------------------------------------
# Parser
# ---------------------------------------------------------------------------
class Gemma4ToolParser(ToolParser):
"""
Tool call parser for Google Gemma4 models.
Handles the Gemma4 function call format::
<|tool_call>call:func_name{key:<|"|>value<|"|>}<tool_call|>
Used when ``--enable-auto-tool-choice --tool-call-parser gemma4``
are set.
Streaming strategy: **accumulate-then-parse-then-diff**
Instead of trying to convert Gemma4's custom format to JSON
token-by-token (which fails because Gemma4 uses bare keys, custom
delimiters, and structural braces that differ from JSON), this parser:
1. Accumulates the raw Gemma4 argument string during streaming
2. Parses it with ``_parse_gemma4_args()`` into a Python dict
3. Converts to JSON with ``json.dumps()``
4. Diffs against the previously-streamed JSON string
5. Emits only the new JSON fragment as the delta
This follows the same pattern used by FunctionGemma, Hermes, and Llama
tool parsers.
"""
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
# Token strings
self.tool_call_start_token = TOOL_CALL_START
self.tool_call_end_token = TOOL_CALL_END
# Token IDs
self.tool_call_start_token_id = self.vocab.get(TOOL_CALL_START)
self.tool_call_end_token_id = self.vocab.get(TOOL_CALL_END)
if self.tool_call_start_token_id is None:
raise RuntimeError(
"Gemma4 ToolParser could not locate the tool call start "
f"token '{TOOL_CALL_START}' in the tokenizer!"
)
# Regex for non-streaming: extract complete tool calls.
# Supports function names with letters, digits, underscores,
# hyphens, and dots (e.g. "get-weather", "module.func").
self.tool_call_regex = re.compile(
r"<\|tool_call>call:([\w\-\.]+)\{(.*?)\}<tool_call\|>",
re.DOTALL,
)
# Streaming state — reset per-request via _reset_streaming_state()
self._reset_streaming_state()
# Delta buffer for handling multi-token special sequences
self.buffered_delta_text = ""
def _reset_streaming_state(self) -> None:
"""Reset all streaming state for a new request."""
self.current_tool_id = -1
self.current_tool_name_sent = False
self.prev_tool_call_arr: list[dict] = []
self.streamed_args_for_tool: list[str] = []
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if (
isinstance(request, ChatCompletionRequest)
and request.tools
and request.tool_choice != "none"
):
# Don't skip special tokens — <|tool_call> etc. are needed
request.skip_special_tokens = False
return request
# ------------------------------------------------------------------
# Delta buffering for multi-token special sequences
# ------------------------------------------------------------------
def _buffer_delta_text(self, delta_text: str) -> str:
"""Buffer incoming delta text to handle multi-token special sequences.
Accumulates partial tokens that could be the start of
``<|tool_call>`` or ``<tool_call|>`` and only flushes them
when the complete sequence is recognized or the sequence breaks.
This prevents partial special tokens (e.g., ``<|tool``) from being
emitted prematurely as content text.
"""
combined = self.buffered_delta_text + delta_text
# Check if combined ends with a complete special token
if combined.endswith(TOOL_CALL_START) or combined.endswith(TOOL_CALL_END):
self.buffered_delta_text = ""
return combined
# Check if combined ends with a partial prefix of a special token
for tag in [TOOL_CALL_START, TOOL_CALL_END]:
for i in range(1, len(tag)):
if combined.endswith(tag[:i]):
self.buffered_delta_text = combined[-i:]
return combined[:-i]
# No partial match — flush everything
self.buffered_delta_text = ""
return combined
# ------------------------------------------------------------------
# Non-streaming extraction
# ------------------------------------------------------------------
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
matches = self.tool_call_regex.findall(model_output)
if not matches:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
tool_calls: list[ToolCall] = []
for func_name, args_str in matches:
arguments = _parse_gemma4_args(args_str)
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=func_name,
arguments=json.dumps(arguments, ensure_ascii=False),
),
)
)
# Content = text before first tool call (if any)
content_end = model_output.find(self.tool_call_start_token)
content = model_output[:content_end].strip() if content_end > 0 else None
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error extracting tool calls from Gemma4 response")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# ------------------------------------------------------------------
# Streaming extraction — accumulate-then-parse-then-diff
# ------------------------------------------------------------------
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
# Buffer delta text to handle multi-token special sequences
delta_text = self._buffer_delta_text(delta_text)
# Reconstruct current_text after buffering to stay in sync
current_text = previous_text + delta_text
# If no tool call token seen yet, emit as content
if self.tool_call_start_token not in current_text:
if delta_text:
return DeltaMessage(content=delta_text)
return None
try:
return self._extract_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
)
except Exception:
logger.exception("Error in Gemma4 streaming tool call extraction")
return None
def _extract_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
) -> DeltaMessage | None:
"""Tag-counting streaming parser.
Uses the proven approach from FunctionGemma/Hermes: count start/end
tags in previous vs current text to determine phase, then
accumulate-parse-diff for arguments.
Format: ``<|tool_call>call:name{args}<tool_call|>``
"""
start_count = current_text.count(self.tool_call_start_token)
end_count = current_text.count(self.tool_call_end_token)
prev_start_count = previous_text.count(self.tool_call_start_token)
prev_end_count = previous_text.count(self.tool_call_end_token)
# Case 1: Not inside any tool call — emit as content
if (
start_count == end_count
and prev_end_count == end_count
and self.tool_call_end_token not in delta_text
):
if delta_text:
return DeltaMessage(content=delta_text)
return None
# Case 2: Starting a new tool call
if start_count > prev_start_count and start_count > end_count:
self.current_tool_id += 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
self.prev_tool_call_arr.append({})
logger.debug("Starting new tool call %d", self.current_tool_id)
# Don't return yet — fall through to try parsing if there's
# content after <|tool_call> in this same delta
# (but usually it's just the token itself, so return None)
if len(delta_text) <= len(self.tool_call_start_token):
return None
# Case 3: Tool call just ended
if end_count > prev_end_count:
return self._handle_tool_call_end(current_text)
# Case 4: In the middle of a tool call — parse partial content
if start_count > end_count:
return self._handle_tool_call_middle(current_text)
# Default: generate text outside tool calls
if delta_text:
text = delta_text.replace(self.tool_call_start_token, "")
text = text.replace(self.tool_call_end_token, "")
if text:
return DeltaMessage(content=text)
return None
def _extract_partial_call(self, current_text: str) -> tuple[str | None, str]:
"""Extract function name and raw argument string from partial text.
Returns (func_name, raw_args_str) or (None, "") if not parseable yet.
"""
# Get the text after the last <|tool_call> token
last_start = current_text.rfind(self.tool_call_start_token)
if last_start == -1:
return None, ""
partial_call = current_text[last_start + len(self.tool_call_start_token) :]
# Strip end token if present
if self.tool_call_end_token in partial_call:
partial_call = partial_call.split(self.tool_call_end_token)[0]
# Expect "call:name{args...}" or "call:name{args...}"
if not partial_call.startswith("call:"):
return None, ""
func_part = partial_call[5:] # skip "call:"
if "{" not in func_part:
# Still accumulating function name, not ready yet
return None, ""
func_name, _, args_part = func_part.partition("{")
func_name = func_name.strip()
# Strip trailing '}' if present (Gemma4 structural brace)
if args_part.endswith("}"):
args_part = args_part[:-1]
return func_name, args_part
def _handle_tool_call_middle(self, current_text: str) -> DeltaMessage | None:
"""Handle streaming when we're inside an active tool call.
Accumulates the raw Gemma4 arguments, parses them into JSON, and
diffs against the previously-streamed JSON to emit only the new
fragment.
"""
func_name, args_part = self._extract_partial_call(current_text)
if func_name is None:
return None
# Step 1: Send function name (once)
if not self.current_tool_name_sent and func_name:
self.current_tool_name_sent = True
self.prev_tool_call_arr[self.current_tool_id] = {
"name": func_name,
"arguments": {},
}
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=func_name,
arguments="",
).model_dump(exclude_none=True),
)
]
)
# Step 2: Parse and diff arguments
if self.current_tool_name_sent and args_part:
return self._emit_argument_diff(args_part)
return None
def _handle_tool_call_end(self, current_text: str) -> DeltaMessage | None:
"""Handle streaming when a tool call has just completed.
Performs a final parse of the complete tool call and flushes
any remaining un-streamed argument fragments.
"""
if self.current_tool_id < 0 or self.current_tool_id >= len(
self.prev_tool_call_arr
):
logger.debug(
"Tool call end detected but no active tool call (current_tool_id=%d)",
self.current_tool_id,
)
return None
# Parse the complete tool call using regex for accuracy
all_matches = self.tool_call_regex.findall(current_text)
if self.current_tool_id < len(all_matches):
_, args_str = all_matches[self.current_tool_id]
final_args = _parse_gemma4_args(args_str)
final_args_json = json.dumps(final_args, ensure_ascii=False)
prev_streamed = self.streamed_args_for_tool[self.current_tool_id]
if len(final_args_json) > len(prev_streamed):
diff = final_args_json[len(prev_streamed) :]
self.streamed_args_for_tool[self.current_tool_id] = final_args_json
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = final_args
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=diff).model_dump(
exclude_none=True
),
)
]
)
return None
def _emit_argument_diff(self, raw_args_str: str) -> DeltaMessage | None:
"""Parse raw Gemma4 arguments, convert to JSON, diff, and emit.
This is the core of the accumulate-then-parse-then-diff strategy:
1. Parse ``raw_args_str`` with ``_parse_gemma4_args()``
2. Convert to JSON string with ``json.dumps()``
3. Withhold trailing closing characters (``"}``) that may move
as more tokens arrive
4. Diff against previously streamed JSON and emit only new chars
**Why withholding is necessary:**
Gemma4's custom format produces *structurally incomplete* JSON
during streaming. For example, when ``<|"|>Paris`` arrives
without a closing delimiter, ``_parse_gemma4_args`` treats it
as a complete value and produces ``{"location": "Paris"}``. But
when ``, France<|"|>`` arrives next, the JSON becomes
``{"location": "Paris, France"}``. If we had sent the closing
``"}`` from the first parse, the concatenated client output
would be ``{"location": "Paris"}France"}``, which is garbage.
The solution: **never send trailing closing chars during
streaming**. They get flushed by ``_handle_tool_call_end()``
when the ``<tool_call|>`` end marker arrives.
Args:
raw_args_str: The raw Gemma4 argument text accumulated so far
(without the surrounding ``{`` ``}``).
Returns:
DeltaMessage with the argument diff, or None if no new content.
"""
try:
current_args = _parse_gemma4_args(raw_args_str)
except Exception:
logger.debug(
"Could not parse partial Gemma4 args yet: %s",
raw_args_str[:100],
)
return None
if not current_args:
return None
current_args_json = json.dumps(current_args, ensure_ascii=False)
# Withhold trailing closing characters that may shift as more
# tokens arrive. Strip trailing '}', '"', and ']' sequences
# to get the "safe prefix".
safe_json = current_args_json
while safe_json and safe_json[-1] in ("}", '"', "]"):
safe_json = safe_json[:-1]
prev_streamed = self.streamed_args_for_tool[self.current_tool_id]
if not safe_json or safe_json == prev_streamed:
return None
# Use find_common_prefix to handle cases where the value changed
# structurally (e.g., a string grew).
if prev_streamed:
prefix = find_common_prefix(prev_streamed, safe_json)
sent_len = len(prev_streamed)
prefix_len = len(prefix)
if prefix_len < sent_len:
# Structure changed — we sent too much. Truncate our
# tracking to the common prefix and wait for the final
# flush in _handle_tool_call_end.
self.streamed_args_for_tool[self.current_tool_id] = prefix
return None
# Stream the new stable portion
diff = safe_json[sent_len:]
else:
# First emission
diff = safe_json
if diff:
self.streamed_args_for_tool[self.current_tool_id] = safe_json
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = current_args
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=diff).model_dump(
exclude_none=True
),
)
]
)
return None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
"""Gemma4 tool call parsing utilities for offline inference.
Standalone functions that parse decoded model text to extract tool calls
from Gemma4 models. These are pure-Python utilities with zero heavy
dependencies — they work on raw decoded strings from any inference
backend (vLLM, HuggingFace, TGI, etc.).
For the OpenAI-compatible API server tool parser (streaming +
non-streaming), see ``vllm.tool_parsers.gemma4_tool_parser``.
For thinking/reasoning output parsing, see
``vllm.reasoning.gemma4_utils``.
Usage with vLLM offline inference::
from vllm import LLM, SamplingParams
from vllm.tool_parsers.gemma4_utils import (
parse_tool_calls,
has_tool_response_tag,
)
llm = LLM(model="google/gemma-4-it")
outputs = llm.generate(prompt, SamplingParams(...))
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
# Extract tool calls
tool_calls = parse_tool_calls(text)
for tc in tool_calls:
print(f"{tc['name']}({tc['arguments']})")
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
do not need a transformers dependency for output parsing.
"""
import json
import regex as re
# Tool call delimiter tokens as they appear in decoded text.
# Standard format: <|tool_call>call:name{args}<tool_call|>
_TOOL_CALL_START_TAG = "<|tool_call>"
_TOOL_CALL_END_TAG = "<tool_call|>"
_TOOL_RESPONSE_START_TAG = "<|tool_response>"
# Gemma4 escape token as it appears in decoded text.
_ESCAPE_TOKEN = '<|"|>'
def _parse_tool_arguments(args_str: str) -> dict[str, str]:
"""Parse tool call arguments from the Gemma4 compact format.
Handles the ``key:<|"|>value<|"|>`` format used by Gemma4, with fallback
to heuristic key-value extraction. Also tolerates the slightly different
``key: "value"`` format (space + plain quotes) that some chat templates
produce.
Args:
args_str: Raw argument string from inside ``call:name{...}``.
Returns:
Dictionary of argument name → value.
"""
if not args_str or not args_str.strip():
return {}
# Replace Gemma4 escape tokens with standard quotes.
cleaned = args_str.replace(_ESCAPE_TOKEN, '"')
# Try JSON parsing first (handles nested values, arrays, etc.).
try:
parsed = json.loads("{" + cleaned + "}")
# Ensure all values are strings for consistency.
return {k: str(v) if not isinstance(v, str) else v for k, v in parsed.items()}
except (json.JSONDecodeError, ValueError):
pass
# Fallback: extract key:"value" pairs (allow optional space after colon).
arguments = {}
for key, value in re.findall(r'(\w+):\s*"([^"]*)"', cleaned):
arguments[key] = value
if not arguments:
# Last resort: extract key:value pairs (unquoted).
for key, value in re.findall(r"(\w+):\s*([^,}]+)", args_str):
arguments[key] = value.strip().strip('"').replace(_ESCAPE_TOKEN, "")
return arguments
def parse_tool_calls(text: str, *, strict: bool = False) -> list[dict]:
"""Parse tool calls from decoded Gemma4 model output.
Uses a tiered parsing strategy to handle known output variations in
Gemma4 models, which may emit
non-standard tool call formats.
Parsing tiers:
1. **Standard**: ``<|tool_call>call:name{args}<tool_call|>``
(special token IDs 48/49 in decoded text)
2. **Fallback** (when ``strict=False``): bare ``call:name{args}``
patterns, including ``<call>name{args}`` (fragmented tokens from
multimodal inputs)
Args:
text: Decoded model output text (from ``tokenizer.decode(...,
skip_special_tokens=False)``).
strict: If ``True``, only match the standard ``<|tool_call>`` format.
If ``False`` (default), also try fallback patterns for
known Gemma4 output variations.
Returns:
A list of dicts, each with keys:
- ``"name"``: The tool function name (e.g. ``"get_weather"``).
- ``"arguments"``: A dict of argument name → value.
Example::
>>> from vllm.tool_parsers.gemma4_utils import parse_tool_calls
>>> output = tokenizer.decode(outputs[0], skip_special_tokens=False)
>>> tool_calls = parse_tool_calls(output)
>>> for tc in tool_calls:
... print(f"Call: {tc['name']}({tc['arguments']})")
"""
results = []
# Tier 1: Standard format with special tokens.
# <|tool_call>call:name{args}<tool_call|>
# Note: Some Gemma4 models emit <turn|> instead of <tool_call|>.
standard_pattern = r"<\|tool_call\>call:(\w+)\{(.*?)\}(?:<tool_call\|>|<turn\|>)"
for match in re.finditer(standard_pattern, text, re.DOTALL):
name, args_str = match.group(1), match.group(2)
results.append(
{
"name": name,
"arguments": _parse_tool_arguments(args_str),
}
)
if results or strict:
return results
# Tier 2: Fallback for known Gemma4 output variations.
# Matches: <call>name{args}, call:name{args}, or bare call:name{args}<eos>
fallback_pattern = r"(?:<call>|(?:^|\s)call:)(\w+)\{(.*?)\}"
for match in re.finditer(fallback_pattern, text, re.DOTALL):
name, args_str = match.group(1), match.group(2)
results.append(
{
"name": name,
"arguments": _parse_tool_arguments(args_str),
}
)
return results
def has_tool_response_tag(text: str) -> bool:
"""Check if model output properly ends with a tool response tag.
Some Gemma4 models sometimes emit ``<eos>`` instead of
``<|tool_response>`` after a tool call. This helper detects
whether the model used the proper termination, so callers can
decide whether to inject ``<|tool_response>`` into the next prompt.
Args:
text: Decoded model output text.
Returns:
``True`` if the output ends with ``<|tool_response>``
(proper behavior), ``False`` otherwise.
Example::
>>> from vllm.tool_parsers.gemma4_utils import has_tool_response_tag
>>> if not has_tool_response_tag(model_output):
... # Model used <eos> instead — inject <|tool_response> manually
... next_prompt = "<|tool_response>" + tool_result
"""
stripped = text.rstrip()
return stripped.endswith(_TOOL_RESPONSE_START_TAG)
......@@ -448,6 +448,16 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
return getattr(self.hf_text_config, "num_nextn_predict_layers", 1)
class Gemma4ModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_head_size(self) -> int:
# Gemma4 uses dual head dimensions: head_dim (sliding attention)
# and global_head_dim (full attention). Return the largest so
# that attention backends allocate buffers large enough for both.
head_dim = getattr(self.hf_text_config, "head_dim", 0)
global_head_dim = getattr(self.hf_text_config, "global_head_dim", 0)
return max(head_dim, global_head_dim) or super().get_head_size()
# hf_config.model_type -> convertor class
MODEL_ARCH_CONFIG_CONVERTORS = {
"cohere_asr": CohereAsrModelArchConfigConvertor,
......@@ -471,4 +481,6 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
"ernie_mtp": ErnieMTPModelArchConfigConvertor,
"pangu_ultra_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor,
"longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor,
"gemma4": Gemma4ModelArchConfigConvertor,
"gemma4_text": Gemma4ModelArchConfigConvertor,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment