Commit fc67613a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.19.1' into v0.19.0

parents 31aec25b b1388b1f
...@@ -1577,6 +1577,22 @@ class VllmConfig: ...@@ -1577,6 +1577,22 @@ class VllmConfig:
compile_range_end, compile_range_end,
) )
if compilation_config.pass_config.fuse_minimax_qk_norm:
from vllm.compilation.passes.fusion.minimax_qk_norm_fusion import (
MAX_TOKEN_NUM,
)
max_token_num = min(
MAX_TOKEN_NUM, self.scheduler_config.max_num_batched_tokens
)
if compile_range_end is not None and max_token_num < compile_range_end:
computed_compile_ranges_endpoints.append(max_token_num)
else:
logger.debug(
"Max num batched tokens below MiniMax QK norm fusion threshold, "
"MiniMax QK norm fusion enabled for all num_tokens."
)
if compilation_config.compile_ranges_endpoints is not None: if compilation_config.compile_ranges_endpoints is not None:
for x in compilation_config.compile_ranges_endpoints: for x in compilation_config.compile_ranges_endpoints:
assert isinstance(x, int) assert isinstance(x, int)
......
...@@ -170,7 +170,8 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -170,7 +170,8 @@ class AnthropicServingMessages(OpenAIServingChat):
else: else:
cls._convert_message_content(msg, openai_msg, openai_messages) cls._convert_message_content(msg, openai_msg, openai_messages)
openai_messages.append(openai_msg) if not (msg.role == "user" and "content" not in openai_msg):
openai_messages.append(openai_msg)
@classmethod @classmethod
def _convert_message_content( def _convert_message_content(
......
...@@ -372,6 +372,7 @@ async def init_app_state( ...@@ -372,6 +372,7 @@ async def init_app_state(
enable_auto_tools=args.enable_auto_tool_choice, enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none, exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
tool_parser=args.tool_call_parser, tool_parser=args.tool_call_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
default_chat_template_kwargs=args.default_chat_template_kwargs, default_chat_template_kwargs=args.default_chat_template_kwargs,
log_error_stack=args.log_error_stack, log_error_stack=args.log_error_stack,
) )
...@@ -467,6 +468,7 @@ async def init_render_app_state( ...@@ -467,6 +468,7 @@ async def init_render_app_state(
enable_auto_tools=args.enable_auto_tool_choice, enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none, exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
tool_parser=args.tool_call_parser, tool_parser=args.tool_call_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
default_chat_template_kwargs=args.default_chat_template_kwargs, default_chat_template_kwargs=args.default_chat_template_kwargs,
log_error_stack=args.log_error_stack, log_error_stack=args.log_error_stack,
) )
......
...@@ -594,6 +594,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -594,6 +594,7 @@ class OpenAIServingResponses(OpenAIServing):
default_template_kwargs=None, default_template_kwargs=None,
tool_dicts=tool_dicts, tool_dicts=tool_dicts,
tool_parser=self.parser.tool_parser_cls if self.parser else None, tool_parser=self.parser.tool_parser_cls if self.parser else None,
reasoning_parser=self.parser.reasoning_parser_cls if self.parser else None,
) )
return messages, engine_inputs return messages, engine_inputs
...@@ -618,6 +619,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -618,6 +619,7 @@ class OpenAIServingResponses(OpenAIServing):
default_template_kwargs=None, default_template_kwargs=None,
tool_dicts=tool_dicts, tool_dicts=tool_dicts,
tool_parser=tool_parser, tool_parser=tool_parser,
reasoning_parser=self.parser.reasoning_parser_cls if self.parser else None,
) )
return engine_inputs return engine_inputs
......
...@@ -44,6 +44,7 @@ from vllm.inputs import ( ...@@ -44,6 +44,7 @@ from vllm.inputs import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.parser import ParserManager from vllm.parser import ParserManager
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.renderers import BaseRenderer, merge_kwargs from vllm.renderers import BaseRenderer, merge_kwargs
from vllm.renderers.inputs.preprocess import ( from vllm.renderers.inputs.preprocess import (
extract_prompt_components, extract_prompt_components,
...@@ -74,6 +75,7 @@ class OpenAIServingRender: ...@@ -74,6 +75,7 @@ class OpenAIServingRender:
enable_auto_tools: bool = False, enable_auto_tools: bool = False,
exclude_tools_when_tool_choice_none: bool = False, exclude_tools_when_tool_choice_none: bool = False,
tool_parser: str | None = None, tool_parser: str | None = None,
reasoning_parser: str | None = None,
default_chat_template_kwargs: dict[str, Any] | None = None, default_chat_template_kwargs: dict[str, Any] | None = None,
log_error_stack: bool = False, log_error_stack: bool = False,
) -> None: ) -> None:
...@@ -94,6 +96,11 @@ class OpenAIServingRender: ...@@ -94,6 +96,11 @@ class OpenAIServingRender:
enable_auto_tools=enable_auto_tools, enable_auto_tools=enable_auto_tools,
model_name=model_config.model, model_name=model_config.model,
) )
self.reasoning_parser: type[ReasoningParser] | None = (
ParserManager.get_reasoning_parser(
reasoning_parser_name=reasoning_parser,
)
)
self.default_chat_template_kwargs: dict[str, Any] = ( self.default_chat_template_kwargs: dict[str, Any] = (
default_chat_template_kwargs or {} default_chat_template_kwargs or {}
) )
...@@ -245,6 +252,7 @@ class OpenAIServingRender: ...@@ -245,6 +252,7 @@ class OpenAIServingRender:
default_template_kwargs=self.default_chat_template_kwargs, default_template_kwargs=self.default_chat_template_kwargs,
tool_dicts=tool_dicts, tool_dicts=tool_dicts,
tool_parser=tool_parser, tool_parser=tool_parser,
reasoning_parser=self.reasoning_parser,
) )
else: else:
# For GPT-OSS. # For GPT-OSS.
...@@ -498,6 +506,9 @@ class OpenAIServingRender: ...@@ -498,6 +506,9 @@ class OpenAIServingRender:
default_template_kwargs: dict[str, Any] | None, default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None, tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: type[ToolParser] | None = None, tool_parser: type[ToolParser] | None = None,
reasoning_parser: type[ReasoningParser] | None = None,
*,
skip_mm_cache: bool = False,
) -> tuple[list[ConversationMessage], list[EngineInput]]: ) -> tuple[list[ConversationMessage], list[EngineInput]]:
"""Copied from OpenAIServing._preprocess_chat.""" """Copied from OpenAIServing._preprocess_chat."""
renderer = self.renderer renderer = self.renderer
...@@ -531,6 +542,10 @@ class OpenAIServingRender: ...@@ -531,6 +542,10 @@ class OpenAIServingRender:
}, },
) )
if reasoning_parser is not None:
tokenizer = renderer.get_tokenizer()
request = reasoning_parser(tokenizer).adjust_request(request=request)
# tool parsing is done only if a tool_parser has been set and if # tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
# is set, we want to prevent parsing a tool_call hallucinated by the LLM # is set, we want to prevent parsing a tool_call hallucinated by the LLM
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import array
import contextlib
import struct
import sys
import threading
import torch
try:
from cuda.bindings import runtime as cudart
except ImportError:
from cuda import cudart
_ALIGN = 1 << 21 # 2 MiB — CUDA IPC allocation alignment
# ---------------------------------------------------------------------------
# CUDA helpers
# ---------------------------------------------------------------------------
def _check(error):
"""Raise on CUDA runtime error."""
success = getattr(cudart.cudaError_t, "cudaSuccess", None) or cudart.cudaError_t(0)
if error != success:
raise RuntimeError(f"CUDA runtime error: {error}")
def _cuda_malloc(size: int):
aligned = ((size + _ALIGN - 1) >> 21) << 21
err, ptr = cudart.cudaMalloc(aligned)
_check(err)
return ptr, aligned
def _cuda_free(ptr: int):
if ptr:
_check(cudart.cudaFree(ptr)[0])
def _cuda_memset_zero(ptr: int, size: int):
_check(cudart.cudaMemset(ptr, 0, size)[0])
def _cuda_memcpy_d2d(dst: int, src: int, size: int):
_check(
cudart.cudaMemcpy(
dst, src, size, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice
)[0]
)
# ---------------------------------------------------------------------------
# IPC buffer
# ---------------------------------------------------------------------------
class IpcBuffer:
"""
Allocates CUDA device memory and exchanges IPC handles with all ranks
so that every rank holds a valid device pointer to every other rank's buffer.
"""
def __init__(self, rank: int, world_size: int, size: int, process_group=None):
self.rank = rank
self.world_size = world_size
self.peer_ptrs: list[int] = [0] * world_size
self.local_ptr: int = 0
self._alive = False
if size <= 0:
return
self.local_ptr, _ = _cuda_malloc(size)
_cuda_memset_zero(self.local_ptr, size)
self._alive = True
# --- exchange IPC handles via torch.distributed ---
err, local_handle = cudart.cudaIpcGetMemHandle(self.local_ptr)
_check(err)
all_handles: list[bytes | None] = [None] * world_size
torch.distributed.all_gather_object(
all_handles, bytes(local_handle.reserved), group=process_group
)
for r in range(world_size):
if r == rank:
self.peer_ptrs[r] = self.local_ptr
else:
handle = cudart.cudaIpcMemHandle_t()
handle.reserved = all_handles[r]
err, ptr = cudart.cudaIpcOpenMemHandle(
handle, cudart.cudaIpcMemLazyEnablePeerAccess
)
_check(err)
self.peer_ptrs[r] = ptr
def serialize(self) -> list[int]:
"""Return peer pointers as a list of int64 values (one per rank)."""
raw = b""
for ptr in self.peer_ptrs:
raw += struct.pack("P", ptr)
return array.array("Q", raw).tolist()
def cleanup(self):
if not self._alive:
return
self._alive = False
for r in range(self.world_size):
if self.peer_ptrs[r] == 0:
continue
if r == self.rank:
_cuda_free(self.peer_ptrs[r])
else:
with contextlib.suppress(RuntimeError):
_check(cudart.cudaIpcCloseMemHandle(self.peer_ptrs[r])[0])
self.peer_ptrs[r] = 0
self.local_ptr = 0
def __del__(self):
if not sys.is_finalizing():
self.cleanup()
# ---------------------------------------------------------------------------
# Lamport negative-zero initialization
# ---------------------------------------------------------------------------
def _lamport_fill_neg_zero(device_ptr: int, size_bytes: int):
"""
Fill device memory with IEEE-754 negative zero (-0.0f = 0x80000000).
This is the "slot empty" sentinel for the Lamport protocol: the kernel
spin-waits until a value is *not* negative zero.
"""
if size_bytes == 0 or device_ptr == 0:
return
n_floats = size_bytes // 4
# torch preserves -0.0 in IEEE-754
fill = torch.full((n_floats,), -0.0, dtype=torch.float32, device="cuda")
_cuda_memcpy_d2d(device_ptr, fill.data_ptr(), size_bytes)
del fill
# ---------------------------------------------------------------------------
# LamportWorkspace — the main class
# ---------------------------------------------------------------------------
class LamportWorkspace:
"""
Self-contained workspace for Lamport-based cross-GPU AllReduce.
Parameters
----------
rank : int
Local rank (0-based).
world_size : int
Total number of ranks in the TP group.
comm_size : int
Size in bytes of *one* Lamport buffer slot. The total IPC allocation
per rank is ``3 * comm_size`` (triple-buffering). Must be large enough
to hold the per-slot data written by the kernel. Use
``compute_comm_size_for_minimax()`` for a safe default.
process_group : optional
``torch.distributed`` process group for IPC handle exchange.
``None`` uses the default group.
"""
def __init__(self, rank: int, world_size: int, comm_size: int, process_group=None):
assert world_size >= 2, "Lamport workspace requires at least 2 ranks"
assert comm_size > 0, "comm_size must be positive"
self.rank = rank
self.world_size = world_size
self.comm_size = comm_size
# 1) Lamport triple-buffer (the only IPC memory the kernel reads/writes)
lamport_total = 3 * comm_size
self._lamport = IpcBuffer(rank, world_size, lamport_total, process_group)
_lamport_fill_neg_zero(self._lamport.local_ptr, lamport_total)
# 2) flag_buffer on device: int32[3] = {counter, unused, lamport_flag}
# counter — used for block-level sync inside the kernel
# unused — reserved (index 1)
# lamport_flag — triple-buffer rotation index (0 → 1 → 2 → 0 …)
self._flag_buf = torch.zeros(3, dtype=torch.int32, device="cuda")
# 3) layout_buffer on device: int64[2] = {clear_size, comm_size}
# clear_size — bytes to clear from *previous* slot (set by kernel)
# comm_size — size of one triple-buffer slot
self._layout_buf = torch.tensor(
[0, comm_size], dtype=torch.int64, device="cuda"
)
# 4) Assemble device-side void* pointer array
N = world_size
ptrs: list[int] = []
ptrs += [0] * N # [0 .. N-1] ipc_buffers (placeholder)
ptrs += [0] * N # [N .. 2N-1] ipc_barriers (placeholder)
ptrs += self._lamport.serialize() # [2N .. 3N-1] lamport peer ptrs
ptrs.append(self._flag_buf.data_ptr()) # [3N] flag_buffer
ptrs.append(self._layout_buf.data_ptr()) # [3N+1] layout_buffer
self._workspace = torch.tensor(ptrs, dtype=torch.int64, device="cuda")
@property
def workspace(self) -> torch.Tensor:
"""Device tensor (int64) that can be passed to the kernel
as ``void** workspace``."""
return self._workspace
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def compute_comm_size_for_minimax(
max_tokens: int,
world_size: int,
fused_qk: bool = True,
) -> int:
"""
Return a safe ``comm_size`` (in bytes) for MiniMaxReduceRMSKernel.
The kernel stores per-token variance scalars in the Lamport buffer:
- single-matrix path: ``world_size × max_tokens × 4`` bytes per slot
- fused Q+K path: ``world_size × 2 × ceil(max_tokens/4) × 16`` bytes per slot
The returned value is rounded up to 2 MiB alignment.
"""
if fused_qk:
groups = (max_tokens + 3) // 4
slot_bytes = world_size * 2 * groups * 16 # 16 = sizeof(float4)
else:
slot_bytes = world_size * max_tokens * 4 # 4 = sizeof(float)
return ((slot_bytes + _ALIGN - 1) >> 21) << 21
def cleanup(self):
if hasattr(self, "_lamport"):
self._lamport.cleanup()
def __del__(self):
if not sys.is_finalizing():
self.cleanup()
def __repr__(self):
return (
f"LamportWorkspace(rank={self.rank}, world_size={self.world_size}, "
f"comm_size={self.comm_size})"
)
# ---------------------------------------------------------------------------
# Cached convenience function (mirrors TRT-LLM's get_allreduce_workspace)
# ---------------------------------------------------------------------------
_cache_lock = threading.Lock()
_workspace_cache: dict = {}
def get_allreduce_workspace(
rank: int,
world_size: int,
comm_size: int | None = None,
max_tokens: int = 16384,
process_group=None,
) -> torch.Tensor:
"""
Return a cached workspace tensor for the given (rank, world_size) pair.
On first call the workspace is allocated and IPC handles are exchanged;
subsequent calls with the same arguments return the cached tensor.
Parameters
----------
rank, world_size : int
TP rank and TP size.
comm_size : int, optional
Explicit slot size in bytes. If ``None``, computed automatically
from ``max_tokens`` and ``world_size`` (fused Q+K path).
max_tokens : int
Maximum number of tokens per batch (used when ``comm_size is None``).
process_group : optional
``torch.distributed`` process group.
"""
if comm_size is None:
comm_size = LamportWorkspace.compute_comm_size_for_minimax(
max_tokens, world_size, fused_qk=True
)
pg_id = id(process_group) if process_group is not None else 0
key = (rank, world_size, comm_size, pg_id)
with _cache_lock:
if key not in _workspace_cache:
ws = LamportWorkspace(rank, world_size, comm_size, process_group)
_workspace_cache[key] = ws
return _workspace_cache[key].workspace
...@@ -209,12 +209,24 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -209,12 +209,24 @@ class GGUFModelLoader(BaseModelLoader):
GGUF tensor name with suffix (e.g., 'mm.soft_emb_norm.weight') GGUF tensor name with suffix (e.g., 'mm.soft_emb_norm.weight')
or None if no mapping found or None if no mapping found
""" """
# In transformers v5, multimodal models (e.g. Gemma3) wrap
# all sub-models under an outer 'model.' attribute, producing
# state_dict keys like 'model.language_model.layers.0...' and
# 'model.vision_tower.vision_model...'. Strip this outer
# prefix so the keys match what gguf-py expects.
if is_multimodal and hf_name.startswith("model."):
hf_name = hf_name[6:] # Remove outer 'model.'
# Strip 'language_model.' prefix for multimodal models - gguf-py # Strip 'language_model.' prefix for multimodal models - gguf-py
# tensor mappings expect parameter names without this prefix. # tensor mappings expect parameter names without this prefix.
# Note: 'model.' prefix should be KEPT for text-only models as # Note: 'model.' prefix should be KEPT for text-only models as
# gguf-py expects it. # gguf-py expects it.
if hf_name.startswith("language_model."): if hf_name.startswith("language_model."):
hf_name = hf_name[15:] # Remove 'language_model.' hf_name = hf_name[15:] # Remove 'language_model.'
# Re-add 'model.' prefix because gguf-py text tensor maps
# expect 'model.layers...' format.
if is_multimodal:
hf_name = "model." + hf_name
# Parse parameter name and suffix # Parse parameter name and suffix
if hf_name.endswith((".weight", ".bias")): if hf_name.endswith((".weight", ".bias")):
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
"""Gemma 4 model implementation for vLLM.""" """Gemma 4 model implementation for vLLM."""
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import replace
from itertools import islice from itertools import islice
import regex as re import regex as re
...@@ -32,6 +33,7 @@ from vllm.distributed import ( ...@@ -32,6 +33,7 @@ from vllm.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
...@@ -56,10 +58,18 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -56,10 +58,18 @@ from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .interfaces import (
EagleModelMixin,
MixtureOfExperts,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
WeightsMapper,
extract_layer_index, extract_layer_index,
is_pp_missing_parameter, is_pp_missing_parameter,
make_layers, make_layers,
...@@ -636,8 +646,206 @@ class Gemma4DecoderLayer(nn.Module): ...@@ -636,8 +646,206 @@ class Gemma4DecoderLayer(nn.Module):
return hidden_states, None return hidden_states, None
@support_torch_compile def _run_decoder_layers(
class Gemma4Model(nn.Module): decoder_layers: list[Gemma4DecoderLayer],
layer_idx_start: int,
positions: torch.Tensor,
hidden_states: torch.Tensor,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
"""Run a slice of decoder layers with PLE extraction."""
residual = None
for idx, layer in enumerate(decoder_layers):
layer_idx = idx + layer_idx_start
layer_per_input = (
per_layer_inputs[:, layer_idx, :] if per_layer_inputs is not None else None
)
hidden_states, residual = layer(
positions,
hidden_states,
residual,
per_layer_input=layer_per_input,
**kwargs,
)
return hidden_states
@support_torch_compile(
enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
)
class Gemma4SelfDecoderLayers(nn.Module):
"""Compiled wrapper: embedding + non-KV-shared layers (YOCO first half).
Owns the embedding and PLE modules so they are inside the compiled
graph. Gemma4Model delegates embedding methods here.
"""
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layers: list[Gemma4DecoderLayer],
layer_idx_start: int,
embed_tokens: VocabParallelEmbedding,
normalizer: torch.Tensor,
embed_tokens_per_layer: VocabParallelEmbedding | None,
embed_scale_per_layer: torch.Tensor | None,
per_layer_model_projection: ColumnParallelLinear | None,
per_layer_projection_norm: RMSNorm | None,
per_layer_input_scale: torch.Tensor | None,
per_layer_projection_scale: torch.Tensor | None,
):
super().__init__()
self.decoder_layers = decoder_layers
self.layer_idx_start = layer_idx_start
config = _get_text_config(vllm_config.model_config.hf_config)
self.config = config
self.hidden_size_per_layer_input = getattr(
config, "hidden_size_per_layer_input", 0
)
self.vocab_size_per_layer_input = getattr(
config, "vocab_size_per_layer_input", config.vocab_size
)
# Shared references to modules owned by Gemma4Model — must be
# inside this nn.Module so torch.compile captures them.
self.embed_tokens = embed_tokens
self.normalizer = normalizer
self.embed_tokens_per_layer = embed_tokens_per_layer
self.embed_scale_per_layer = embed_scale_per_layer
self.per_layer_model_projection = per_layer_model_projection
self.per_layer_projection_norm = per_layer_projection_norm
self.per_layer_input_scale = per_layer_input_scale
self.per_layer_projection_scale = per_layer_projection_scale
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.normalizer
def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None:
"""Get per-layer embeddings from embed_tokens_per_layer.
Returns:
Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input)
"""
if self.embed_tokens_per_layer is None:
return None
per_layer_inputs_mask = torch.logical_and(
input_ids >= 0,
input_ids < self.vocab_size_per_layer_input,
)
per_layer_inputs_tokens = torch.where(
per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
)
per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens)
per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer
return per_layer_embeds.reshape(
*input_ids.shape,
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
def project_per_layer_inputs(
self,
inputs_embeds: torch.Tensor,
per_layer_inputs: torch.Tensor | None,
) -> torch.Tensor | None:
"""Project inputs_embeds and combine with per_layer_inputs.
Steps:
1. Project inputs_embeds: hidden_size → total_ple_dim
2. Scale by hidden_size^{-0.5}
3. Reshape to (num_tokens, num_layers, per_layer_dim)
4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
"""
if self.per_layer_model_projection is None:
return None
per_layer_projection = self.per_layer_model_projection(inputs_embeds)
per_layer_projection = per_layer_projection * self.per_layer_projection_scale
per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1],
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
if per_layer_inputs is None:
return per_layer_projection
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if inputs_embeds is not None:
hidden_states = inputs_embeds
per_layer_inputs = self.project_per_layer_inputs(
hidden_states, per_layer_inputs
)
else:
hidden_states = self.embed_input_ids(input_ids)
per_layer_embeds = self.get_per_layer_inputs(input_ids)
per_layer_inputs = self.project_per_layer_inputs(
hidden_states, per_layer_embeds
)
hidden_states = _run_decoder_layers(
self.decoder_layers,
self.layer_idx_start,
positions,
hidden_states,
per_layer_inputs,
**kwargs,
)
return hidden_states, per_layer_inputs
@support_torch_compile(
enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
)
class Gemma4CrossDecoderLayers(nn.Module):
"""Cross-decoder layers (YOCO second half, KV-shared)."""
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layers: list[Gemma4DecoderLayer],
layer_idx_start: int,
):
super().__init__()
self.decoder_layers = decoder_layers
self.layer_idx_start = layer_idx_start
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
return _run_decoder_layers(
self.decoder_layers,
self.layer_idx_start,
positions,
hidden_states,
per_layer_inputs,
**kwargs,
)
@support_torch_compile(
enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill
)
class Gemma4Model(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = _get_text_config(vllm_config.model_config.hf_config) config = _get_text_config(vllm_config.model_config.hf_config)
...@@ -740,6 +948,75 @@ class Gemma4Model(nn.Module): ...@@ -740,6 +948,75 @@ class Gemma4Model(nn.Module):
torch.tensor(config.hidden_size**0.5), torch.tensor(config.hidden_size**0.5),
persistent=False, persistent=False,
) )
# --- You Only Cache Once (YOCO) split for fast prefill ---
first_kv_shared_layer_idx = config.num_hidden_layers - getattr(
config, "num_kv_shared_layers", 0
)
from vllm.compilation.backends import set_model_tag
# Layers 0..(K-1) are self-decoder layers in YOCO
with set_model_tag("self_decoder"):
self.self_decoder = Gemma4SelfDecoderLayers(
vllm_config=vllm_config,
prefix=f"{prefix}.self_decoder",
decoder_layers=self.layers[:first_kv_shared_layer_idx],
layer_idx_start=0,
embed_tokens=self.embed_tokens,
normalizer=self.normalizer,
embed_tokens_per_layer=getattr(self, "embed_tokens_per_layer", None),
embed_scale_per_layer=getattr(self, "embed_scale_per_layer", None),
per_layer_model_projection=getattr(
self, "per_layer_model_projection", None
),
per_layer_projection_norm=getattr(
self, "per_layer_projection_norm", None
),
per_layer_input_scale=getattr(self, "per_layer_input_scale", None),
per_layer_projection_scale=getattr(
self, "per_layer_projection_scale", None
),
)
# Layers K..(N-1) are cross-decoder layers in YOCO
with set_model_tag("cross_decoder"):
self.cross_decoder = Gemma4CrossDecoderLayers(
vllm_config=vllm_config,
prefix=f"{prefix}.cross_decoder",
decoder_layers=self.layers[first_kv_shared_layer_idx:],
layer_idx_start=first_kv_shared_layer_idx,
)
self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill
if self.fast_prefill_enabled:
# Allocate static buffers for CUDAGraph
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
device = next(self.parameters()).device
self.positions = torch.zeros(
max_num_tokens, dtype=torch.int64, device=device
)
self.hidden_states = torch.zeros(
(max_num_tokens, config.hidden_size),
dtype=self.embed_tokens.weight.dtype,
device=device,
)
if (
self.hidden_size_per_layer_input
and self.hidden_size_per_layer_input > 0
):
self.per_layer_inputs = torch.zeros(
(
max_num_tokens,
config.num_hidden_layers,
self.hidden_size_per_layer_input,
),
dtype=self.embed_tokens.weight.dtype,
device=device,
)
else:
self.per_layer_inputs = None
# Custom factory that includes per_layer_inputs for PLE-enabled PP. # Custom factory that includes per_layer_inputs for PLE-enabled PP.
# per_layer_inputs has shape (batch, num_layers, per_layer_dim), # per_layer_inputs has shape (batch, num_layers, per_layer_dim),
# which differs from the standard (batch, hidden_size) shape, # which differs from the standard (batch, hidden_size) shape,
...@@ -776,47 +1053,22 @@ class Gemma4Model(nn.Module): ...@@ -776,47 +1053,22 @@ class Gemma4Model(nn.Module):
self.make_empty_intermediate_tensors = _make_empty_intermediate_tensors self.make_empty_intermediate_tensors = _make_empty_intermediate_tensors
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.normalizer return self.self_decoder.embed_input_ids(input_ids)
def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor: def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None:
"""Get per-layer embeddings from embed_tokens_per_layer. """Get per-layer embeddings from embed_tokens_per_layer.
Returns: Returns:
Per-layer embeddings (num_tokens, num_layers, Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input) hidden_size_per_layer_input)
""" """
if self.embed_tokens_per_layer is None: return self.self_decoder.get_per_layer_inputs(input_ids)
return None
# Handle out-of-vocab tokens for PLE (vocab_size_per_layer_input may
# be smaller than the main vocab_size).
per_layer_inputs_mask = torch.logical_and(
input_ids >= 0,
input_ids < self.vocab_size_per_layer_input,
)
per_layer_inputs_tokens = torch.where(
per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
)
# Get packed per-layer embeddings: (num_tokens, total_ple_dim)
per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens)
# Apply embed_scale (sqrt of per-layer hidden dim)
per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input)
per_layer_embeds = per_layer_embeds.reshape(
*input_ids.shape,
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
return per_layer_embeds
def project_per_layer_inputs( def project_per_layer_inputs(
self, self,
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
per_layer_inputs: torch.Tensor | None, per_layer_inputs: torch.Tensor | None,
) -> torch.Tensor: ) -> torch.Tensor | None:
"""Project inputs_embeds and combine with per_layer_inputs. """Project inputs_embeds and combine with per_layer_inputs.
Steps: Steps:
...@@ -826,29 +1078,94 @@ class Gemma4Model(nn.Module): ...@@ -826,29 +1078,94 @@ class Gemma4Model(nn.Module):
4. Normalize with per_layer_projection_norm 4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2) 5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
""" """
if self.per_layer_model_projection is None: return self.self_decoder.project_per_layer_inputs(
return None inputs_embeds, per_layer_inputs
)
# Project from hidden_size to total_ple_dim def fast_prefill_forward(
# Scaled projection: output = linear(input, weight) * scale self,
per_layer_projection = self.per_layer_model_projection(inputs_embeds) input_ids: torch.Tensor | None,
per_layer_projection = per_layer_projection * self.per_layer_projection_scale positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
logits_indices_padded, num_logits_indices = None, None
attn_metadata = get_forward_context().attn_metadata
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input) if attn_metadata is not None:
per_layer_projection = per_layer_projection.reshape( assert isinstance(attn_metadata, dict)
*inputs_embeds.shape[:-1], layer_attn_metadata = attn_metadata[
self.config.num_hidden_layers, self.layers[-1].self_attn.attn.layer_name
self.hidden_size_per_layer_input, ]
if isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata):
logits_indices_padded = layer_attn_metadata.logits_indices_padded
num_logits_indices = layer_attn_metadata.num_logits_indices
batch_size = positions.size(0)
self.positions[:batch_size].copy_(positions)
self_decoder_hidden_states, per_layer_inputs = self.self_decoder(
input_ids=input_ids,
positions=self.positions[:batch_size],
inputs_embeds=inputs_embeds,
per_layer_inputs=per_layer_inputs,
**kwargs,
) )
# Normalize if logits_indices_padded is None:
per_layer_projection = self.per_layer_projection_norm(per_layer_projection) logits_indices_padded = torch.arange(
batch_size,
dtype=positions.dtype,
device=positions.device,
)
if per_layer_inputs is None: # NOTE: Keep .clone() until fix in
return per_layer_projection # https://github.com/vllm-project/vllm/pull/22282
hidden_states = self_decoder_hidden_states.clone()
# Combine: (projection + per_layer_inputs) * scale num_padded = logits_indices_padded.size(0)
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale self.positions[:num_padded].copy_(positions[logits_indices_padded])
self.hidden_states[:num_padded].copy_(
self_decoder_hidden_states[logits_indices_padded]
)
if self.per_layer_inputs is not None and per_layer_inputs is not None:
self.per_layer_inputs[:num_padded].copy_(
per_layer_inputs[logits_indices_padded]
)
# Update batch_descriptor so the cross-decoder's piecewise
# CUDAGraphWrapper dispatches to the correct (reduced) batch size.
forward_context = get_forward_context()
orig_batch_desc = forward_context.batch_descriptor
if orig_batch_desc is not None:
forward_context.batch_descriptor = replace(
orig_batch_desc, num_tokens=num_padded
)
cross_per_layer = (
self.per_layer_inputs[:num_padded]
if self.per_layer_inputs is not None
else None
)
cross_hidden_states = self.cross_decoder(
self.positions[:num_padded],
self.hidden_states[:num_padded],
cross_per_layer,
**kwargs,
)
# Restore the original batch_descriptor
forward_context.batch_descriptor = orig_batch_desc
if num_logits_indices is not None:
assert num_logits_indices > 0
hidden_states[logits_indices_padded[:num_logits_indices]] = (
cross_hidden_states[:num_logits_indices]
)
else:
hidden_states = cross_hidden_states
return hidden_states
def forward( def forward(
self, self,
...@@ -858,7 +1175,19 @@ class Gemma4Model(nn.Module): ...@@ -858,7 +1175,19 @@ class Gemma4Model(nn.Module):
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None, per_layer_inputs: torch.Tensor | None = None,
**kwargs, **kwargs,
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
if self.fast_prefill_enabled:
hidden_states = self.fast_prefill_forward(
input_ids,
positions,
inputs_embeds,
per_layer_inputs,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return hidden_states
# Normal (non-fast-prefill) path with PP support
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -882,6 +1211,7 @@ class Gemma4Model(nn.Module): ...@@ -882,6 +1211,7 @@ class Gemma4Model(nn.Module):
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
per_layer_inputs = intermediate_tensors.get("per_layer_inputs") per_layer_inputs = intermediate_tensors.get("per_layer_inputs")
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for layer_idx, layer in enumerate( for layer_idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer) islice(self.layers, self.start_layer, self.end_layer)
): ):
...@@ -900,6 +1230,9 @@ class Gemma4Model(nn.Module): ...@@ -900,6 +1230,9 @@ class Gemma4Model(nn.Module):
per_layer_input=layer_per_input, per_layer_input=layer_per_input,
**kwargs, **kwargs,
) )
self._maybe_add_hidden_state(
aux_hidden_states, layer_idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors( return IntermediateTensors(
{ {
...@@ -914,6 +1247,9 @@ class Gemma4Model(nn.Module): ...@@ -914,6 +1247,9 @@ class Gemma4Model(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
else: else:
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
...@@ -926,21 +1262,27 @@ class Gemma4Model(nn.Module): ...@@ -926,21 +1262,27 @@ class Gemma4Model(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
# MoE expert weight mapping: checkpoint 3D packed tensors are # MoE expert weight mapping: checkpoint can have either:
# exploded in _weight_iterator to per-expert 2D weights like: # 1. 3D packed tensors (exploded in _weight_iterator to per-expert 2D)
# 2. Already per-expert 2D weights (if quantized)
# Map to FusedMoE parameters:
# moe.experts.{id}.gate_proj → FusedMoE w1 (shard of w13) # moe.experts.{id}.gate_proj → FusedMoE w1 (shard of w13)
# moe.experts.{id}.up_proj → FusedMoE w3 (shard of w13) # moe.experts.{id}.up_proj → FusedMoE w3 (shard of w13)
# moe.experts.{id}.down_proj → FusedMoE w2 # moe.experts.{id}.down_proj → FusedMoE w2
# We build the mapping directly since Gemma4 uses bare param #
# names (no .weight suffix) unlike standard MoE checkpoints. # Use prefix matching to handle both weights and
# quantization scale parameters. The param_name is a prefix ending
# in underscore, and weight_name ends with a dot, so that:
# "experts.0.gate_proj.weight_scale" -> "experts.w13_weight_scale"
# "experts.0.gate_proj.weight" -> "experts.w13_weight"
num_experts = getattr(self.config, "num_experts", None) or 0 num_experts = getattr(self.config, "num_experts", None) or 0
expert_params_mapping = [ expert_params_mapping = [
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
( (
"experts.w13_weight" "experts.w13_"
if proj_name in ["gate_proj", "up_proj"] if proj_name in ["gate_proj", "up_proj"]
else "experts.w2_weight", else "experts.w2_",
f"experts.{expert_id}.{proj_name}", f"experts.{expert_id}.{proj_name}.",
expert_id, expert_id,
shard_id, shard_id,
) )
...@@ -1000,9 +1342,21 @@ class Gemma4Model(nn.Module): ...@@ -1000,9 +1342,21 @@ class Gemma4Model(nn.Module):
expert_id, expert_id,
shard_id, shard_id,
) in expert_params_mapping: ) in expert_params_mapping:
if weight_name not in name: # Match both:
# - Bare weights: "experts.0.down_proj" (from 3D explosion)
# - With suffix: "experts.0.down_proj.weight_scale" (2D quantized)
# weight_name has trailing dot, so check with and without it
weight_name_base = weight_name.rstrip(".")
if weight_name in name:
# Has suffix (e.g., .weight_scale)
moe_name = name.replace(weight_name, param_name)
elif name.endswith(weight_name_base):
# Bare weight (no suffix)
moe_name = name.replace(
weight_name_base, param_name.rstrip("_") + "_weight"
)
else:
continue continue
moe_name = name.replace(weight_name, param_name)
if moe_name not in params_dict: if moe_name not in params_dict:
continue continue
if is_pp_missing_parameter(moe_name, self): if is_pp_missing_parameter(moe_name, self):
...@@ -1012,15 +1366,12 @@ class Gemma4Model(nn.Module): ...@@ -1012,15 +1366,12 @@ class Gemma4Model(nn.Module):
# orientation for FusedMoE after _weight_iterator: # orientation for FusedMoE after _weight_iterator:
# gate/up: [I, H] → w1/w3 expects [I, H] # gate/up: [I, H] → w1/w3 expects [I, H]
# down: [H, I] → w2 expects [H, I] # down: [H, I] → w2 expects [H, I]
assert loaded_weight.dim() == 2, ( # Scales and other quantization params may be 1D or scalar.
f"Expected 2D expert weight for {weight_name}, "
f"got shape {loaded_weight.shape}"
)
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader( weight_loader(
param, param,
loaded_weight, loaded_weight,
weight_name + ".weight", moe_name, # Pass mapped name (handles both weights and scales)
shard_id=shard_id, shard_id=shard_id,
expert_id=expert_id, expert_id=expert_id,
) )
...@@ -1044,7 +1395,25 @@ class Gemma4Model(nn.Module): ...@@ -1044,7 +1395,25 @@ class Gemma4Model(nn.Module):
return loaded_params return loaded_params
class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): class Gemma4ForCausalLM(
nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts, SupportsEagle3
):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# Gemma4ForConditionalGeneration already loads the text stack
# from `model.language_model.*`. We reuse that same checkpoint
# and adapter naming for the text-only Gemma4ForCausalLM path,
# so LoRA keys from the conditional wrapper map onto `model.*`.
"model.language_model.": "model.",
},
orig_to_new_substr={
# Gemma4ForConditionalGeneration names MoE adapter targets under
# `...moe.experts.*`, while the text-only model exposes them
# under `...moe.*`.
".moe.experts.gate_up_proj": ".moe.gate_up_proj",
".moe.experts.down_proj": ".moe.down_proj",
},
)
# Note: qkv_proj packing applies to non-k_eq_v layers (sliding # Note: qkv_proj packing applies to non-k_eq_v layers (sliding
# attention and full attention without k_eq_v). k_eq_v layers use # attention and full attention without k_eq_v). k_eq_v layers use
# separate q_proj + k_proj without packing. # separate q_proj + k_proj without packing.
...@@ -1126,7 +1495,7 @@ class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): ...@@ -1126,7 +1495,7 @@ class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
**kwargs, **kwargs,
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
hidden_states = self.model( hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
) )
...@@ -1177,6 +1546,11 @@ class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): ...@@ -1177,6 +1546,11 @@ class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
".moe.down_proj", ".moe.down_proj",
) )
# Remap individual 2D expert weights:
# .experts.{id}.{proj} → .moe.experts.{id}.{proj}
# (This handles per-expert 2D quantized weights)
name = re.sub(r"\.experts\.(\d+)\.", r".moe.experts.\1.", name)
# MoE expert weights: checkpoint stores as 3D packed # MoE expert weights: checkpoint stores as 3D packed
# tensors. Explode into per-expert 2D weights for # tensors. Explode into per-expert 2D weights for
# FusedMoE weight_loader. # FusedMoE weight_loader.
......
...@@ -65,7 +65,12 @@ from vllm.multimodal.processing.processor import ( ...@@ -65,7 +65,12 @@ from vllm.multimodal.processing.processor import (
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import (
MultiModalEmbeddings,
SupportsEagle3,
SupportsMultiModal,
SupportsPP,
)
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
WeightsMapper, WeightsMapper,
...@@ -121,8 +126,12 @@ class Gemma4AudioInputs(TensorSchema): ...@@ -121,8 +126,12 @@ class Gemma4AudioInputs(TensorSchema):
""" """
type: Literal["audio"] = "audio" type: Literal["audio"] = "audio"
input_features_padded: Annotated[torch.Tensor, TensorShape("bn", "s", "f")] input_features_padded: Annotated[
input_features_mask: Annotated[torch.Tensor, TensorShape("bn", "s")] torch.Tensor, TensorShape("bn", "s", "f", dynamic_dims={"s"})
]
input_features_mask: Annotated[
torch.Tensor, TensorShape("bn", "s", dynamic_dims={"s"})
]
Gemma4ImageInputs = Gemma4ImagePixelInputs Gemma4ImageInputs = Gemma4ImagePixelInputs
...@@ -163,10 +172,15 @@ class Gemma4ProcessingInfo(BaseProcessingInfo): ...@@ -163,10 +172,15 @@ class Gemma4ProcessingInfo(BaseProcessingInfo):
Setting ``add_special_tokens=False`` here prevents the duplicate and Setting ``add_special_tokens=False`` here prevents the duplicate and
ensures both ``llm.generate()`` and the chat/completions API behave ensures both ``llm.generate()`` and the chat/completions API behave
correctly. correctly for IT models. For PT models (without chat template), we
keep the default (True) to ensure BOS is added for raw prompts.
""" """
tokenizer = self.ctx.get_tokenizer()
has_chat_template = getattr(tokenizer, "chat_template", None) is not None
params = super().get_default_tok_params() params = super().get_default_tok_params()
params = params.with_kwargs(add_special_tokens=False) if has_chat_template:
params = params.with_kwargs(add_special_tokens=False)
return params return params
def get_hf_processor(self, **kwargs: object) -> Gemma4Processor: def get_hf_processor(self, **kwargs: object) -> Gemma4Processor:
...@@ -503,6 +517,8 @@ class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]): ...@@ -503,6 +517,8 @@ class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]):
video_timestamps_per_video: list[list[float]] = [] video_timestamps_per_video: list[list[float]] = []
video_frame_counts: list[int] = [] video_frame_counts: list[int] = []
video_replacements: list[str] = []
for item in videos: for item in videos:
video_array, metadata = item video_array, metadata = item
...@@ -555,10 +571,7 @@ class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]): ...@@ -555,10 +571,7 @@ class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]):
video_timestamps_per_video.append(timestamps) video_timestamps_per_video.append(timestamps)
video_frame_counts.append(len(frames)) video_frame_counts.append(len(frames))
# Build expanded replacement text and replace the # Build expanded replacement text for this video.
# <|video|> placeholder in the prompt.
# Use split(token, 1) to avoid collision — the
# replacement text itself contains <|video|> tokens.
ts_strs = [f"{int(s // 60):02d}:{int(s % 60):02d}" for s in timestamps] ts_strs = [f"{int(s // 60):02d}:{int(s % 60):02d}" for s in timestamps]
replacement = " ".join( replacement = " ".join(
f"{t} {processor.boi_token}" f"{t} {processor.boi_token}"
...@@ -566,9 +579,23 @@ class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]): ...@@ -566,9 +579,23 @@ class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]):
f"{processor.eoi_token}" f"{processor.eoi_token}"
for t, n in zip(ts_strs, num_soft_per_frame) for t, n in zip(ts_strs, num_soft_per_frame)
) )
parts = prompt.split(processor.video_token, 1) video_replacements.append(replacement)
if len(parts) == 2:
prompt = parts[0] + replacement + parts[1] # Replace all <|video|> placeholders at once. We split on
# video_token to get N+1 parts, then interleave with the
# N replacement strings. This avoids the iterative
# split-replace bug where replacement text (which itself
# contains <|video|> tokens) collides with later splits.
vt = processor.video_token
parts = prompt.split(vt, len(video_replacements))
# NOTE: len(parts) <= len(video_replacements) + 1
parts_with_repl: list[str] = []
for part, repl in zip(parts, video_replacements):
parts_with_repl.extend([part, repl])
parts_with_repl.extend(parts[len(video_replacements) :])
prompt = "".join(parts_with_repl)
video_outputs = { video_outputs = {
"pixel_values_videos": torch.cat(all_video_pixel_values, dim=0), "pixel_values_videos": torch.cat(all_video_pixel_values, dim=0),
...@@ -631,19 +658,23 @@ class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]): ...@@ -631,19 +658,23 @@ class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]):
) )
if "input_features" in processed_outputs: if "input_features" in processed_outputs:
# Keep padded features for batched audio tower execution. # Unpad per-item so each item's cache entry is
processed_outputs["input_features_padded"] = processed_outputs[ # self-contained. The batched() field config in
"input_features" # _get_mm_fields_config will re-pad all fields to the
] # batch's max length at batch time, ensuring consistent
# Unpad per-item so each item's cache entry is self-contained. # padding regardless of cache history.
masks = processed_outputs["input_features_mask"]
unpadded_features = [ unpadded_features = [
f[mask] f[mask]
for f, mask in zip( for f, mask in zip(
processed_outputs["input_features"], processed_outputs["input_features"],
processed_outputs["input_features_mask"], masks,
) )
] ]
unpadded_masks = [mask[mask] for mask in masks]
processed_outputs["input_features"] = unpadded_features processed_outputs["input_features"] = unpadded_features
processed_outputs["input_features_padded"] = unpadded_features
processed_outputs["input_features_mask"] = unpadded_masks
# Merge video outputs into the final result # Merge video outputs into the final result
combined_outputs = dict(processed_outputs, **video_outputs) combined_outputs = dict(processed_outputs, **video_outputs)
...@@ -848,7 +879,12 @@ class Gemma4MultimodalEmbedder(nn.Module): ...@@ -848,7 +879,12 @@ class Gemma4MultimodalEmbedder(nn.Module):
info=Gemma4ProcessingInfo, info=Gemma4ProcessingInfo,
dummy_inputs=Gemma4DummyInputsBuilder, dummy_inputs=Gemma4DummyInputsBuilder,
) )
class Gemma4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class Gemma4ForConditionalGeneration(
nn.Module,
SupportsMultiModal,
SupportsPP,
SupportsEagle3,
):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
......
...@@ -113,7 +113,29 @@ class KimiK25ProcessingInfo(BaseProcessingInfo): ...@@ -113,7 +113,29 @@ class KimiK25ProcessingInfo(BaseProcessingInfo):
trust_remote_code=self.ctx.model_config.trust_remote_code, trust_remote_code=self.ctx.model_config.trust_remote_code,
) )
self.media_token_id = media_token_id = hf_config.media_placeholder_token_id # Resolve token ID from the tokenizer because transformers v5
# may remap token IDs vs config.json.
config_token_id = hf_config.media_placeholder_token_id
resolved_token_id = tokenizer.convert_tokens_to_ids("<|media_pad|>")
is_valid_resolved = isinstance(resolved_token_id, int) and (
tokenizer.unk_token_id is None
or resolved_token_id != tokenizer.unk_token_id
)
if is_valid_resolved and resolved_token_id != config_token_id:
logger.warning_once(
"Kimi-K2.5 config.media_placeholder_token_id (%d) disagrees "
"with tokenizer mapping for <|media_pad|> (%d). "
"Using tokenizer value.",
config_token_id,
resolved_token_id,
)
media_token_id = resolved_token_id
# Patch config so downstream code also sees the correct ID.
hf_config.media_placeholder_token_id = resolved_token_id
else:
media_token_id = config_token_id
self.media_token_id = media_token_id
self.media_token = tokenizer.decode(media_token_id) self.media_token = tokenizer.decode(media_token_id)
self.image_processor = image_processor self.image_processor = image_processor
...@@ -232,8 +254,7 @@ class KimiK25MultiModalProcessor(BaseMultiModalProcessor[KimiK25ProcessingInfo]) ...@@ -232,8 +254,7 @@ class KimiK25MultiModalProcessor(BaseMultiModalProcessor[KimiK25ProcessingInfo])
hf_processor_mm_kwargs: Mapping[str, Any], hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargsItems, out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config() media_token_id = self.info.media_token_id
media_token_id = hf_config.media_placeholder_token_id
def get_replacement(item_idx: int): def get_replacement(item_idx: int):
media = mm_items.get_items("vision_chunk", (VisionChunkProcessorItems,)) media = mm_items.get_items("vision_chunk", (VisionChunkProcessorItems,))
......
...@@ -232,9 +232,7 @@ class MiniMaxM2Attention(nn.Module): ...@@ -232,9 +232,7 @@ class MiniMaxM2Attention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = MiniMaxText01RMSNormTP.forward_qk( q, k = MiniMaxText01RMSNormTP.forward_qk(self.q_norm, self.k_norm, q, k)
self.q_norm, self.k_norm, q.contiguous(), k.contiguous()
)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
......
...@@ -32,9 +32,9 @@ from transformers.models.musicflamingo import ( ...@@ -32,9 +32,9 @@ from transformers.models.musicflamingo import (
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs import MultiModalDataDict
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalKwargsItems, MultiModalKwargsItems,
) )
......
...@@ -16,13 +16,11 @@ ...@@ -16,13 +16,11 @@
# limitations under the License. # limitations under the License.
"""Wrapper around `transformers` models""" """Wrapper around `transformers` models"""
from vllm.compilation.decorators import support_torch_compile
from vllm.model_executor.models.transformers.base import Base from vllm.model_executor.models.transformers.base import Base
from vllm.model_executor.models.transformers.causal import CausalMixin from vllm.model_executor.models.transformers.causal import CausalMixin
from vllm.model_executor.models.transformers.legacy import LegacyMixin from vllm.model_executor.models.transformers.legacy import LegacyMixin
from vllm.model_executor.models.transformers.moe import MoEMixin from vllm.model_executor.models.transformers.moe import MoEMixin
from vllm.model_executor.models.transformers.multimodal import ( from vllm.model_executor.models.transformers.multimodal import (
DYNAMIC_ARG_DIMS,
MultiModalDummyInputsBuilder, MultiModalDummyInputsBuilder,
MultiModalMixin, MultiModalMixin,
MultiModalProcessingInfo, MultiModalProcessingInfo,
...@@ -32,16 +30,13 @@ from vllm.model_executor.models.transformers.pooling import ( ...@@ -32,16 +30,13 @@ from vllm.model_executor.models.transformers.pooling import (
EmbeddingMixin, EmbeddingMixin,
SequenceClassificationMixin, SequenceClassificationMixin,
) )
from vllm.model_executor.models.transformers.utils import can_enable_torch_compile
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
# Text only models # Text only models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForCausalLM(CausalMixin, Base): ... class TransformersForCausalLM(CausalMixin, Base): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ... class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ...
...@@ -51,9 +46,6 @@ class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ... ...@@ -51,9 +46,6 @@ class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ...
info=MultiModalProcessingInfo, info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder, dummy_inputs=MultiModalDummyInputsBuilder,
) )
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ... class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ...
...@@ -62,20 +54,15 @@ class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ... ...@@ -62,20 +54,15 @@ class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ...
info=MultiModalProcessingInfo, info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder, dummy_inputs=MultiModalDummyInputsBuilder,
) )
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalMoEForCausalLM( class TransformersMultiModalMoEForCausalLM(
MoEMixin, MultiModalMixin, CausalMixin, Base MoEMixin, MultiModalMixin, CausalMixin, Base
): ... ): ...
# Embedding models # Embedding models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersEmbeddingModel(EmbeddingMixin, LegacyMixin, Base): ... class TransformersEmbeddingModel(EmbeddingMixin, LegacyMixin, Base): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ... class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ...
...@@ -84,20 +71,15 @@ class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ... ...@@ -84,20 +71,15 @@ class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ...
info=MultiModalProcessingInfo, info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder, dummy_inputs=MultiModalDummyInputsBuilder,
) )
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalEmbeddingModel(EmbeddingMixin, MultiModalMixin, Base): ... class TransformersMultiModalEmbeddingModel(EmbeddingMixin, MultiModalMixin, Base): ...
# Sequence classification models # Sequence classification models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForSequenceClassification( class TransformersForSequenceClassification(
SequenceClassificationMixin, LegacyMixin, Base SequenceClassificationMixin, LegacyMixin, Base
): ... ): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForSequenceClassification( class TransformersMoEForSequenceClassification(
SequenceClassificationMixin, MoEMixin, Base SequenceClassificationMixin, MoEMixin, Base
): ... ): ...
...@@ -108,9 +90,6 @@ class TransformersMoEForSequenceClassification( ...@@ -108,9 +90,6 @@ class TransformersMoEForSequenceClassification(
info=MultiModalProcessingInfo, info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder, dummy_inputs=MultiModalDummyInputsBuilder,
) )
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalForSequenceClassification( class TransformersMultiModalForSequenceClassification(
SequenceClassificationMixin, MultiModalMixin, Base SequenceClassificationMixin, MultiModalMixin, Base
): ... ): ...
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# limitations under the License. # limitations under the License.
"""Transformers modeling backend base class.""" """Transformers modeling backend base class."""
import sys
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from itertools import chain from itertools import chain
from operator import attrgetter from operator import attrgetter
...@@ -29,6 +30,7 @@ from torch import nn ...@@ -29,6 +30,7 @@ from torch import nn
from transformers import AutoModel from transformers import AutoModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.compilation.decorators import support_torch_compile
from vllm.config.utils import getattr_iter from vllm.config.utils import getattr_iter
from vllm.distributed import get_pp_group, get_tp_group from vllm.distributed import get_pp_group, get_tp_group
from vllm.distributed.utils import get_pp_indices from vllm.distributed.utils import get_pp_indices
...@@ -47,6 +49,7 @@ from vllm.model_executor.models.interfaces import ( ...@@ -47,6 +49,7 @@ from vllm.model_executor.models.interfaces import (
) )
from vllm.model_executor.models.interfaces_base import VllmModel from vllm.model_executor.models.interfaces_base import VllmModel
from vllm.model_executor.models.transformers.utils import ( from vllm.model_executor.models.transformers.utils import (
can_enable_torch_compile,
get_feature_request_tip, get_feature_request_tip,
init_on_device_without_buffers, init_on_device_without_buffers,
log_replacement, log_replacement,
...@@ -117,6 +120,7 @@ class Base( ...@@ -117,6 +120,7 @@ class Base(
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
self.text_config = self.config.get_text_config() self.text_config = self.config.get_text_config()
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.device_config = vllm_config.device_config self.device_config = vllm_config.device_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
...@@ -146,7 +150,7 @@ class Base( ...@@ -146,7 +150,7 @@ class Base(
if self.quant_config: if self.quant_config:
quant_method_name = self.quant_config.get_name() quant_method_name = self.quant_config.get_name()
# Check for unsupported quantization methods. # Check for unsupported quantization methods.
if quant_method_name == "mxfp4": if quant_method_name in ("mxfp4", "gpt_oss_mxfp4"):
raise NotImplementedError( raise NotImplementedError(
"Transformers modeling backend does " "Transformers modeling backend does "
"not support MXFP4 quantization yet." "not support MXFP4 quantization yet."
...@@ -155,14 +159,16 @@ class Base( ...@@ -155,14 +159,16 @@ class Base(
if "gptq" in quant_method_name: if "gptq" in quant_method_name:
self.ignore_unexpected_suffixes.append(".bias") self.ignore_unexpected_suffixes.append(".bias")
# Patch config and init on "meta" to delay allocating GPU tensors
self._patch_config() self._patch_config()
from_config_kwargs = dict(
config=self.config,
dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
self._decorate_for_torch_compile(**from_config_kwargs)
# Init on "meta" to delay allocating GPU tensors
with init_on_device_without_buffers("meta"): with init_on_device_without_buffers("meta"):
self.model: PreTrainedModel = AutoModel.from_config( self.model: PreTrainedModel = AutoModel.from_config(**from_config_kwargs)
self.config,
dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
# Create weight name to module qualname mapper # Create weight name to module qualname mapper
self._create_hf_to_vllm_mapper() self._create_hf_to_vllm_mapper()
...@@ -218,6 +224,87 @@ class Base( ...@@ -218,6 +224,87 @@ class Base(
if sub_config.dtype != (dtype := self.config.dtype): if sub_config.dtype != (dtype := self.config.dtype):
sub_config.dtype = dtype sub_config.dtype = dtype
def _get_decoder_cls(self, **kwargs: dict) -> type[PreTrainedModel]:
"""
Get the decoder class from the model.
Args:
kwargs: The kwargs to create the model.
Returns:
The decoder class.
"""
with torch.device("meta"):
model: PreTrainedModel = AutoModel.from_config(**kwargs)
decoder_cls = type(model.get_decoder())
logger.debug("Identified decoder class as: %s", decoder_cls)
del model
return decoder_cls
def _decorate_cls_for_torch_compile(
self,
cls: type[PreTrainedModel],
dynamic_arg_dims: dict[str, int] | None,
enable_if: Callable[["VllmConfig"], bool],
is_encoder: bool,
):
"""
Decorate `cls` to indicate to vLLM that it supports torch compile.
Args:
cls: The PreTrainedModel class to decorate.
dynamic_arg_dims: A mapping from argument name to the dynamic dimensions
of the argument. If None, default dynamic arg dims will be used. See
[`support_torch_compile`][vllm.compilation.decorators.support_torch_compile]
for more details.
enable_if: A function which takes in the vLLM config and returns whether
torch compile should be enabled for this class.
is_encoder: Whether the class being decorated is an encoder.
"""
logger.debug(
"Decorating `%s` as %s for torch compile with dynamic_arg_dims of %s",
cls.__name__,
"encoder" if is_encoder else "decoder",
dynamic_arg_dims,
)
@support_torch_compile(
dynamic_arg_dims=dynamic_arg_dims,
enable_if=enable_if,
is_encoder=is_encoder,
)
class SupportTorchCompileWrapper(cls): ...
# Preserve __module__ so transformers v5's source-file checks
# (e.g. _can_set_experts_implementation) read the original
# model's module instead of this file.
SupportTorchCompileWrapper.__module__ = cls.__module__
# Patch the class in its module
module = sys.modules[cls.__module__]
setattr(module, cls.__name__, SupportTorchCompileWrapper)
def _decorate_for_torch_compile(self, **kwargs: dict):
"""
Decorate the model's decoder class to indicate to vLLM that it supports torch
compile if `can_enable_torch_compile` is True.
Args:
kwargs: The kwargs to create the model, which are needed to get the decoder
class.
"""
self._decorate_cls_for_torch_compile(
cls=self._get_decoder_cls(**kwargs),
# Applied to a PreTrainedModel so the batch dimension will exist
dynamic_arg_dims=dict[str, int](
input_ids=1, # shape: [1, seq_len]
inputs_embeds=1, # shape: [1, seq_len, hidden_size]
position_ids=-1, # shape: [1, seq_len] or [3, 1, seq_len] for mrope
),
enable_if=can_enable_torch_compile,
is_encoder=False,
)
def _create_hf_to_vllm_mapper(self): def _create_hf_to_vllm_mapper(self):
""" """
Create a WeightsMapper to map checkpoint weight names to module qualnames. Create a WeightsMapper to map checkpoint weight names to module qualnames.
...@@ -553,11 +640,6 @@ class Base( ...@@ -553,11 +640,6 @@ class Base(
input_ids = None input_ids = None
inputs_embeds = intermediate_tensors["hidden_states"] inputs_embeds = intermediate_tensors["hidden_states"]
if input_ids is not None:
input_ids = input_ids[None, ...]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[None, ...]
# If the model scales embeddings inside the input embedding layer we must # If the model scales embeddings inside the input embedding layer we must
# ensure they are scaled here since VocabParallelEmbedding will not do it # ensure they are scaled here since VocabParallelEmbedding will not do it
if ( if (
...@@ -568,22 +650,29 @@ class Base( ...@@ -568,22 +650,29 @@ class Base(
inputs_embeds = self.embed_input_ids(input_ids) inputs_embeds = self.embed_input_ids(input_ids)
input_ids = None input_ids = None
if self.model_config.uses_mrope: # Add batch dimension before entering Transformers model
position_ids = positions[:, None] if input_ids is not None and input_ids.ndim == 1:
else: # [seq_len] -> [1, seq_len]
position_ids = positions[None, ...] input_ids = input_ids[None, ...]
if inputs_embeds is not None and inputs_embeds.ndim == 2:
# [seq_len, hidden_size] -> [1, seq_len, hidden_size]
inputs_embeds = inputs_embeds[None, ...]
if positions.ndim == 1:
# [seq_len] -> [1, seq_len]
positions = positions[None, ...]
outputs = self.model( outputs = self.model(
input_ids=input_ids, input_ids=input_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=False, use_cache=False,
position_ids=position_ids, position_ids=positions,
attention_instances=self.attention_instances, attention_instances=self.attention_instances,
return_dict=False, return_dict=False,
**self._output_aux_hidden_states_kwargs, **self._output_aux_hidden_states_kwargs,
**kwargs, **kwargs,
) )
# We must remove the batch dimension from these outputs
# Remove batch dimension after exiting Transformers model
hidden_states = outputs[0][0, ...] hidden_states = outputs[0][0, ...]
if self._output_aux_hidden_states_kwargs: if self._output_aux_hidden_states_kwargs:
aux_hidden_states = [x[0][0, ...] for x in outputs[1:]] aux_hidden_states = [x[0][0, ...] for x in outputs[1:]]
......
...@@ -20,7 +20,9 @@ from collections.abc import Mapping ...@@ -20,7 +20,9 @@ from collections.abc import Mapping
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
from transformers import AutoModel
from vllm.compilation.decorators import should_torch_compile_mm_encoder
from vllm.config.utils import getattr_iter from vllm.config.utils import getattr_iter
from vllm.inputs import MultiModalDataDict, MultiModalInput, mm_input from vllm.inputs import MultiModalDataDict, MultiModalInput, mm_input
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -46,19 +48,11 @@ from vllm.platforms import current_platform ...@@ -46,19 +48,11 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import BatchFeature from transformers import BatchFeature, PreTrainedModel
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
DYNAMIC_ARG_DIMS = {
"input_ids": 0,
# set `positions` to last dim to support Qwen-mrope
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
}
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -274,6 +268,66 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): ...@@ -274,6 +268,66 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
# Skip SupportsMRoPE.__init__ and call the next class in MRO # Skip SupportsMRoPE.__init__ and call the next class in MRO
super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix) super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix)
def _get_encoder_cls(
self, modality: str = "image", **kwargs: dict
) -> type["PreTrainedModel"]:
"""
Get the encoder class from the model.
Args:
kwargs: The kwargs to create the model.
Returns:
The encoder class.
"""
with torch.device("meta"):
model: PreTrainedModel = AutoModel.from_config(**kwargs)
encoder_cls = type(model.get_encoder(modality=modality))
logger.debug("Identified encoder class as: %s", encoder_cls)
if type(model) is encoder_cls:
raise ValueError(
"Unable to infer vision encoder class from the model. "
"You must either: update the model so that "
"https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.get_encoder"
" can detect the vision encoder correctly, or remove "
"'compile_mm_encoder'."
)
del model
return encoder_cls
def _decorate_for_torch_compile(self, **kwargs: dict):
"""
Decorate the model's decoder and encoder classes to indicate to vLLM that they
support torch compile if `can_enable_torch_compile` and
`should_torch_compile_mm_encoder` are True respectively.
Args:
kwargs: The kwargs to create the model, which are needed to get the decoder
and encoder classes.
"""
super()._decorate_for_torch_compile(**kwargs)
# Decorate the vision encoder model class to support torch compile if needed
if self.compilation_config.compile_mm_encoder:
self.check_version("5.0.0", "multimodal encoder compilation support")
logger.warning_once(
"Multimodal encoder compilation with the Transformers modeling backend "
"is an experimental feature. It relies on:\n"
"- The vision encoder being torch compilable.\n"
"- All vision encoder tensor inputs must be type hinted as either "
"`torch.Tensor` or `torch.FloatTensor`.\n"
"- The 0-th dimension of all tensor inputs to the vision encoder being "
"the dynamic dimension (i.e., sequence length or number of patches).\n"
"Please report any issues you encounter to help us improve it."
)
self._decorate_cls_for_torch_compile(
cls=self._get_encoder_cls(**kwargs),
# TODO: properly infer dynamic_arg_dims based on the encoder's forward
# method signature. Currently we assume dim 0 for all tensor inputs.
dynamic_arg_dims=None,
enable_if=should_torch_compile_mm_encoder,
is_encoder=True,
)
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor | None,
...@@ -285,6 +339,10 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE): ...@@ -285,6 +339,10 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
# Gemma3 and PaliGemma needs `token_type_ids` to work correctly # Gemma3 and PaliGemma needs `token_type_ids` to work correctly
# Other models will not have `token_type_ids` in kwargs # Other models will not have `token_type_ids` in kwargs
kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"} kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"}
# Positions shape handling for MRoPE models
if self.model_config.uses_mrope:
# [3, seq_len] -> [3, 1, seq_len]
positions = positions[:, None]
model_output = super().forward( model_output = super().forward(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
) )
......
...@@ -470,6 +470,15 @@ class DelegatingParser(Parser): ...@@ -470,6 +470,15 @@ class DelegatingParser(Parser):
# No tool calls # No tool calls
return [], content return [], content
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
if self._reasoning_parser is not None:
request = self._reasoning_parser.adjust_request(request)
if self._tool_parser is not None:
request = self._tool_parser.adjust_request(request)
return request
def extract_reasoning_streaming( def extract_reasoning_streaming(
self, self,
previous_text: str, previous_text: str,
......
...@@ -6,7 +6,7 @@ import os ...@@ -6,7 +6,7 @@ import os
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Callable, Iterable, Sequence from collections.abc import Callable, Iterable, Sequence
from functools import cached_property from functools import cached_property
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, cast
from vllm.entrypoints.mcp.tool_server import ToolServer from vllm.entrypoints.mcp.tool_server import ToolServer
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -150,6 +150,12 @@ class ReasoningParser: ...@@ -150,6 +150,12 @@ class ReasoningParser:
previously been parsed and extracted (see constructor) previously been parsed and extracted (see constructor)
""" """
def adjust_request(
self, request: "ChatCompletionRequest | ResponsesRequest"
) -> "ChatCompletionRequest | ResponsesRequest":
"""Adjust request parameters; override in subclasses as needed."""
return request
def prepare_structured_tag( def prepare_structured_tag(
self, self,
original_tag: str | None, original_tag: str | None,
...@@ -298,7 +304,7 @@ class ReasoningParserManager: ...@@ -298,7 +304,7 @@ class ReasoningParserManager:
if isinstance(name, str): if isinstance(name, str):
names = [name] names = [name]
elif is_list_of(name, str): elif is_list_of(name, str):
names = name names = cast(list[str], name)
else: else:
names = [class_name] names = [class_name]
......
...@@ -52,6 +52,16 @@ class Gemma4ReasoningParser(BaseThinkingReasoningParser): ...@@ -52,6 +52,16 @@ class Gemma4ReasoningParser(BaseThinkingReasoningParser):
# skip_special_tokens=True). # skip_special_tokens=True).
self._reasoning_text: str = "" self._reasoning_text: str = ""
self._prefix_stripped: bool = False self._prefix_stripped: bool = False
self.new_turn_token_id = self.vocab["<|turn>"]
self.tool_call_token_id = self.vocab["<|tool_call>"]
self.tool_response_token_id = self.vocab["<|tool_response>"]
def adjust_request(
self, request: "ChatCompletionRequest | ResponsesRequest"
) -> "ChatCompletionRequest | ResponsesRequest":
"""Disable special-token stripping to preserve boundary tokens."""
request.skip_special_tokens = False
return request
@property @property
def start_token(self) -> str: def start_token(self) -> str:
...@@ -63,6 +73,29 @@ class Gemma4ReasoningParser(BaseThinkingReasoningParser): ...@@ -63,6 +73,29 @@ class Gemma4ReasoningParser(BaseThinkingReasoningParser):
"""The token that ends reasoning content.""" """The token that ends reasoning content."""
return "<channel|>" return "<channel|>"
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
start_token_id = self.start_token_id
end_token_id = self.end_token_id
new_turn_token_id = self.new_turn_token_id
tool_call_token_id = self.tool_call_token_id
tool_response_token_id = self.tool_response_token_id
# Search from the end of input_ids to find the last match.
for i in range(len(input_ids) - 1, -1, -1):
if input_ids[i] == start_token_id:
return False
if input_ids[i] == tool_call_token_id:
# We're generating a tool call, so reasoning must be ended.
return True
if input_ids[i] in (new_turn_token_id, tool_response_token_id):
# We found a new turn or tool response token so don't consider
# reasoning ended yet, since the model starts new reasoning
# after these tokens.
return False
if input_ids[i] == end_token_id:
return True
return False
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Non-streaming path # Non-streaming path
# ------------------------------------------------------------------ # ------------------------------------------------------------------
...@@ -159,11 +192,10 @@ class Gemma4ReasoningParser(BaseThinkingReasoningParser): ...@@ -159,11 +192,10 @@ class Gemma4ReasoningParser(BaseThinkingReasoningParser):
result.reasoning = stripped result.reasoning = stripped
return result return result
else: else:
# This entire delta was prefix — suppress it.
# Don't set _prefix_stripped yet; there may be more
# prefix chars to consume in the next delta.
if len(self._reasoning_text) >= prefix_len: if len(self._reasoning_text) >= prefix_len:
self._prefix_stripped = True self._prefix_stripped = True
result.reasoning = ""
return result
return None return None
# Case 2: Accumulated text is a strict prefix of # Case 2: Accumulated text is a strict prefix of
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
...@@ -10,6 +11,7 @@ from typing_extensions import TypeVar, assert_never ...@@ -10,6 +11,7 @@ from typing_extensions import TypeVar, assert_never
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config
from vllm.transformers_utils.gguf_utils import ( from vllm.transformers_utils.gguf_utils import (
check_gguf_file, check_gguf_file,
get_gguf_file_path_from_hf, get_gguf_file_path_from_hf,
...@@ -31,6 +33,13 @@ if TYPE_CHECKING: ...@@ -31,6 +33,13 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
# Model types whose hub tokenizer_class is incorrect and should be overridden with
# TokenizersBackend (the generic fast tokenizer). Adding a model type here is always a
# temporary workaround and better long term solutions are:
# - Add model type to MODELS_WITH_INCORRECT_HUB_TOKENIZER_CLASS in transformers (better)
# - Fix tokenizer_class on the hub for the affected models (best)
_MODEL_TYPES_WITH_INCORRECT_TOKENIZER_CLASS: set[str] = {"step3_vl"}
_VLLM_TOKENIZERS = { _VLLM_TOKENIZERS = {
"deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"), "deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"),
"grok2": ("grok2", "Grok2Tokenizer"), "grok2": ("grok2", "Grok2Tokenizer"),
...@@ -202,7 +211,31 @@ def get_tokenizer( ...@@ -202,7 +211,31 @@ def get_tokenizer(
**kwargs, **kwargs,
) )
if tokenizer_cls == TokenizerLike: # Ensure that, if the config were to come from vllm.transformers_utils.config, it is
# registered with AutoConfig before the tokenizer is loaded. This is necessary since
# tokenizer_cls_.from_pretrained will call AutoConfig.from_pretrained internally.
# This may fail for paths that don't have a model config (e.g. LoRA adapters),
# which is fine — those don't need custom config registration.
config = None
with contextlib.suppress(ValueError, OSError):
config = get_config(
tokenizer_name,
trust_remote_code=trust_remote_code,
revision=revision,
)
# Some models have an incorrect tokenizer_class on the hub.
# For these model types, bypass AutoTokenizer and use TokenizersBackend directly.
model_type = getattr(config, "model_type", None) if config else None
if model_type in _MODEL_TYPES_WITH_INCORRECT_TOKENIZER_CLASS:
from transformers.tokenization_utils_tokenizers import TokenizersBackend
logger.debug(
"Overriding tokenizer_class to TokenizersBackend for model_type=%r",
model_type,
)
tokenizer_cls_ = TokenizersBackend
elif tokenizer_cls == TokenizerLike:
tokenizer_cls_ = TokenizerRegistry.load_tokenizer_cls(tokenizer_mode) tokenizer_cls_ = TokenizerRegistry.load_tokenizer_cls(tokenizer_mode)
else: else:
tokenizer_cls_ = tokenizer_cls tokenizer_cls_ = tokenizer_cls
......
...@@ -66,6 +66,10 @@ def _parse_gemma4_value(value_str: str) -> object: ...@@ -66,6 +66,10 @@ def _parse_gemma4_value(value_str: str) -> object:
if value_str == "false": if value_str == "false":
return False return False
# Null
if value_str.lower() in ("null", "none", "nil"):
return None
# Number (int or float) # Number (int or float)
try: try:
if "." in value_str: if "." in value_str:
...@@ -78,7 +82,7 @@ def _parse_gemma4_value(value_str: str) -> object: ...@@ -78,7 +82,7 @@ def _parse_gemma4_value(value_str: str) -> object:
return value_str return value_str
def _parse_gemma4_args(args_str: str) -> dict: def _parse_gemma4_args(args_str: str, *, partial: bool = False) -> dict:
"""Parse Gemma4's custom key:value format into a Python dict. """Parse Gemma4's custom key:value format into a Python dict.
Format examples:: Format examples::
...@@ -89,6 +93,12 @@ def _parse_gemma4_args(args_str: str) -> dict: ...@@ -89,6 +93,12 @@ def _parse_gemma4_args(args_str: str) -> dict:
nested:{inner_key:<|"|>val<|"|>} nested:{inner_key:<|"|>val<|"|>}
items:[<|"|>a<|"|>,<|"|>b<|"|>] items:[<|"|>a<|"|>,<|"|>b<|"|>]
Args:
args_str: The raw Gemma4 argument string.
partial: When True (streaming), bare values at end of string are
omitted because they may be incomplete and type-unstable
(e.g. partial boolean parsed as bare string).
Returns a dict ready for ``json.dumps()``. Returns a dict ready for ``json.dumps()``.
""" """
if not args_str or not args_str.strip(): if not args_str or not args_str.strip():
...@@ -116,14 +126,16 @@ def _parse_gemma4_args(args_str: str) -> dict: ...@@ -116,14 +126,16 @@ def _parse_gemma4_args(args_str: str) -> dict:
# Parse value # Parse value
if i >= n: if i >= n:
result[key] = "" if not partial:
result[key] = ""
break break
# Skip whitespace after ':' # Skip whitespace after ':'
while i < n and args_str[i] in (" ", "\n", "\t"): while i < n and args_str[i] in (" ", "\n", "\t"):
i += 1 i += 1
if i >= n: if i >= n:
result[key] = "" if not partial:
result[key] = ""
break break
# String value: <|"|>...<|"|> # String value: <|"|>...<|"|>
...@@ -155,7 +167,12 @@ def _parse_gemma4_args(args_str: str) -> dict: ...@@ -155,7 +167,12 @@ def _parse_gemma4_args(args_str: str) -> dict:
elif args_str[i] == "}": elif args_str[i] == "}":
depth -= 1 depth -= 1
i += 1 i += 1
result[key] = _parse_gemma4_args(args_str[obj_start : i - 1]) if depth > 0:
# Incomplete nested object — use i (not i-1) to avoid
# dropping the last char, and recurse as partial.
result[key] = _parse_gemma4_args(args_str[obj_start:i], partial=True)
else:
result[key] = _parse_gemma4_args(args_str[obj_start : i - 1])
# Array: [...] # Array: [...]
elif args_str[i] == "[": elif args_str[i] == "[":
...@@ -173,20 +190,26 @@ def _parse_gemma4_args(args_str: str) -> dict: ...@@ -173,20 +190,26 @@ def _parse_gemma4_args(args_str: str) -> dict:
elif args_str[i] == "]": elif args_str[i] == "]":
depth -= 1 depth -= 1
i += 1 i += 1
arr_content = args_str[arr_start : i - 1] if depth > 0:
result[key] = _parse_gemma4_array(arr_content) result[key] = _parse_gemma4_array(args_str[arr_start:i], partial=True)
else:
result[key] = _parse_gemma4_array(args_str[arr_start : i - 1])
# Bare value (number, boolean, etc.) # Bare value (number, boolean, etc.)
else: else:
val_start = i val_start = i
while i < n and args_str[i] not in (",", "}", "]"): while i < n and args_str[i] not in (",", "}", "]"):
i += 1 i += 1
if partial and i >= n:
# Value may be incomplete (e.g. partial boolean) —
# withhold to avoid type instability during streaming.
break
result[key] = _parse_gemma4_value(args_str[val_start:i]) result[key] = _parse_gemma4_value(args_str[val_start:i])
return result return result
def _parse_gemma4_array(arr_str: str) -> list: def _parse_gemma4_array(arr_str: str, *, partial: bool = False) -> list:
"""Parse a Gemma4 array content string into a Python list.""" """Parse a Gemma4 array content string into a Python list."""
items: list = [] items: list = []
i = 0 i = 0
...@@ -224,7 +247,10 @@ def _parse_gemma4_array(arr_str: str) -> list: ...@@ -224,7 +247,10 @@ def _parse_gemma4_array(arr_str: str) -> list:
elif arr_str[i] == "}": elif arr_str[i] == "}":
depth -= 1 depth -= 1
i += 1 i += 1
items.append(_parse_gemma4_args(arr_str[obj_start : i - 1])) if depth > 0:
items.append(_parse_gemma4_args(arr_str[obj_start:i], partial=True))
else:
items.append(_parse_gemma4_args(arr_str[obj_start : i - 1]))
# Nested array # Nested array
elif arr_str[i] == "[": elif arr_str[i] == "[":
...@@ -237,13 +263,18 @@ def _parse_gemma4_array(arr_str: str) -> list: ...@@ -237,13 +263,18 @@ def _parse_gemma4_array(arr_str: str) -> list:
elif arr_str[i] == "]": elif arr_str[i] == "]":
depth -= 1 depth -= 1
i += 1 i += 1
items.append(_parse_gemma4_array(arr_str[sub_start : i - 1])) if depth > 0:
items.append(_parse_gemma4_array(arr_str[sub_start:i], partial=True))
else:
items.append(_parse_gemma4_array(arr_str[sub_start : i - 1]))
# Bare value # Bare value
else: else:
val_start = i val_start = i
while i < n and arr_str[i] not in (",", "]"): while i < n and arr_str[i] not in (",", "]"):
i += 1 i += 1
if partial and i >= n:
break
items.append(_parse_gemma4_value(arr_str[val_start:i])) items.append(_parse_gemma4_value(arr_str[val_start:i]))
return items return items
...@@ -436,8 +467,10 @@ class Gemma4ToolParser(ToolParser): ...@@ -436,8 +467,10 @@ class Gemma4ToolParser(ToolParser):
) -> DeltaMessage | None: ) -> DeltaMessage | None:
# Buffer delta text to handle multi-token special sequences # Buffer delta text to handle multi-token special sequences
delta_text = self._buffer_delta_text(delta_text) delta_text = self._buffer_delta_text(delta_text)
# Reconstruct current_text after buffering to stay in sync # Keep current_text from the upstream stream state. The buffered delta
current_text = previous_text + delta_text # is only for emission, and must not be stitched back into the
# accumulated model text or normal content like "<div>" can be
# duplicated into "<<div>" when a tool call just ended.
# If no tool call token seen yet, emit as content # If no tool call token seen yet, emit as content
if self.tool_call_start_token not in current_text: if self.tool_call_start_token not in current_text:
...@@ -661,7 +694,7 @@ class Gemma4ToolParser(ToolParser): ...@@ -661,7 +694,7 @@ class Gemma4ToolParser(ToolParser):
DeltaMessage with the argument diff, or None if no new content. DeltaMessage with the argument diff, or None if no new content.
""" """
try: try:
current_args = _parse_gemma4_args(raw_args_str) current_args = _parse_gemma4_args(raw_args_str, partial=True)
except Exception: except Exception:
logger.debug( logger.debug(
"Could not parse partial Gemma4 args yet: %s", "Could not parse partial Gemma4 args yet: %s",
...@@ -675,10 +708,11 @@ class Gemma4ToolParser(ToolParser): ...@@ -675,10 +708,11 @@ class Gemma4ToolParser(ToolParser):
current_args_json = json.dumps(current_args, ensure_ascii=False) current_args_json = json.dumps(current_args, ensure_ascii=False)
# Withhold trailing closing characters that may shift as more # Withhold trailing closing characters that may shift as more
# tokens arrive. Strip trailing '}', '"', and ']' sequences # tokens arrive. Strip trailing '}', '"', ']' and partial
# to get the "safe prefix". # STRING_DELIM fragments ('<', '|', '\\', '>') to get the
# "safe prefix".
safe_json = current_args_json safe_json = current_args_json
while safe_json and safe_json[-1] in ("}", '"', "]"): while safe_json and safe_json[-1] in ("}", '"', "]", "<", "|", "\\", ">"):
safe_json = safe_json[:-1] safe_json = safe_json[:-1]
prev_streamed = self.streamed_args_for_tool[self.current_tool_id] prev_streamed = self.streamed_args_for_tool[self.current_tool_id]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment