Commit fc67613a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.19.1' into v0.19.0

parents 31aec25b b1388b1f
......@@ -1577,6 +1577,22 @@ class VllmConfig:
compile_range_end,
)
if compilation_config.pass_config.fuse_minimax_qk_norm:
from vllm.compilation.passes.fusion.minimax_qk_norm_fusion import (
MAX_TOKEN_NUM,
)
max_token_num = min(
MAX_TOKEN_NUM, self.scheduler_config.max_num_batched_tokens
)
if compile_range_end is not None and max_token_num < compile_range_end:
computed_compile_ranges_endpoints.append(max_token_num)
else:
logger.debug(
"Max num batched tokens below MiniMax QK norm fusion threshold, "
"MiniMax QK norm fusion enabled for all num_tokens."
)
if compilation_config.compile_ranges_endpoints is not None:
for x in compilation_config.compile_ranges_endpoints:
assert isinstance(x, int)
......
......@@ -170,7 +170,8 @@ class AnthropicServingMessages(OpenAIServingChat):
else:
cls._convert_message_content(msg, openai_msg, openai_messages)
openai_messages.append(openai_msg)
if not (msg.role == "user" and "content" not in openai_msg):
openai_messages.append(openai_msg)
@classmethod
def _convert_message_content(
......
......@@ -372,6 +372,7 @@ async def init_app_state(
enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
tool_parser=args.tool_call_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
default_chat_template_kwargs=args.default_chat_template_kwargs,
log_error_stack=args.log_error_stack,
)
......@@ -467,6 +468,7 @@ async def init_render_app_state(
enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
tool_parser=args.tool_call_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
default_chat_template_kwargs=args.default_chat_template_kwargs,
log_error_stack=args.log_error_stack,
)
......
......@@ -594,6 +594,7 @@ class OpenAIServingResponses(OpenAIServing):
default_template_kwargs=None,
tool_dicts=tool_dicts,
tool_parser=self.parser.tool_parser_cls if self.parser else None,
reasoning_parser=self.parser.reasoning_parser_cls if self.parser else None,
)
return messages, engine_inputs
......@@ -618,6 +619,7 @@ class OpenAIServingResponses(OpenAIServing):
default_template_kwargs=None,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
reasoning_parser=self.parser.reasoning_parser_cls if self.parser else None,
)
return engine_inputs
......
......@@ -44,6 +44,7 @@ from vllm.inputs import (
)
from vllm.logger import init_logger
from vllm.parser import ParserManager
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.renderers import BaseRenderer, merge_kwargs
from vllm.renderers.inputs.preprocess import (
extract_prompt_components,
......@@ -74,6 +75,7 @@ class OpenAIServingRender:
enable_auto_tools: bool = False,
exclude_tools_when_tool_choice_none: bool = False,
tool_parser: str | None = None,
reasoning_parser: str | None = None,
default_chat_template_kwargs: dict[str, Any] | None = None,
log_error_stack: bool = False,
) -> None:
......@@ -94,6 +96,11 @@ class OpenAIServingRender:
enable_auto_tools=enable_auto_tools,
model_name=model_config.model,
)
self.reasoning_parser: type[ReasoningParser] | None = (
ParserManager.get_reasoning_parser(
reasoning_parser_name=reasoning_parser,
)
)
self.default_chat_template_kwargs: dict[str, Any] = (
default_chat_template_kwargs or {}
)
......@@ -245,6 +252,7 @@ class OpenAIServingRender:
default_template_kwargs=self.default_chat_template_kwargs,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
reasoning_parser=self.reasoning_parser,
)
else:
# For GPT-OSS.
......@@ -498,6 +506,9 @@ class OpenAIServingRender:
default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: type[ToolParser] | None = None,
reasoning_parser: type[ReasoningParser] | None = None,
*,
skip_mm_cache: bool = False,
) -> tuple[list[ConversationMessage], list[EngineInput]]:
"""Copied from OpenAIServing._preprocess_chat."""
renderer = self.renderer
......@@ -531,6 +542,10 @@ class OpenAIServingRender:
},
)
if reasoning_parser is not None:
tokenizer = renderer.get_tokenizer()
request = reasoning_parser(tokenizer).adjust_request(request=request)
# tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
# is set, we want to prevent parsing a tool_call hallucinated by the LLM
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import array
import contextlib
import struct
import sys
import threading
import torch
try:
from cuda.bindings import runtime as cudart
except ImportError:
from cuda import cudart
_ALIGN = 1 << 21 # 2 MiB — CUDA IPC allocation alignment
# ---------------------------------------------------------------------------
# CUDA helpers
# ---------------------------------------------------------------------------
def _check(error):
"""Raise on CUDA runtime error."""
success = getattr(cudart.cudaError_t, "cudaSuccess", None) or cudart.cudaError_t(0)
if error != success:
raise RuntimeError(f"CUDA runtime error: {error}")
def _cuda_malloc(size: int):
aligned = ((size + _ALIGN - 1) >> 21) << 21
err, ptr = cudart.cudaMalloc(aligned)
_check(err)
return ptr, aligned
def _cuda_free(ptr: int):
if ptr:
_check(cudart.cudaFree(ptr)[0])
def _cuda_memset_zero(ptr: int, size: int):
_check(cudart.cudaMemset(ptr, 0, size)[0])
def _cuda_memcpy_d2d(dst: int, src: int, size: int):
_check(
cudart.cudaMemcpy(
dst, src, size, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice
)[0]
)
# ---------------------------------------------------------------------------
# IPC buffer
# ---------------------------------------------------------------------------
class IpcBuffer:
"""
Allocates CUDA device memory and exchanges IPC handles with all ranks
so that every rank holds a valid device pointer to every other rank's buffer.
"""
def __init__(self, rank: int, world_size: int, size: int, process_group=None):
self.rank = rank
self.world_size = world_size
self.peer_ptrs: list[int] = [0] * world_size
self.local_ptr: int = 0
self._alive = False
if size <= 0:
return
self.local_ptr, _ = _cuda_malloc(size)
_cuda_memset_zero(self.local_ptr, size)
self._alive = True
# --- exchange IPC handles via torch.distributed ---
err, local_handle = cudart.cudaIpcGetMemHandle(self.local_ptr)
_check(err)
all_handles: list[bytes | None] = [None] * world_size
torch.distributed.all_gather_object(
all_handles, bytes(local_handle.reserved), group=process_group
)
for r in range(world_size):
if r == rank:
self.peer_ptrs[r] = self.local_ptr
else:
handle = cudart.cudaIpcMemHandle_t()
handle.reserved = all_handles[r]
err, ptr = cudart.cudaIpcOpenMemHandle(
handle, cudart.cudaIpcMemLazyEnablePeerAccess
)
_check(err)
self.peer_ptrs[r] = ptr
def serialize(self) -> list[int]:
"""Return peer pointers as a list of int64 values (one per rank)."""
raw = b""
for ptr in self.peer_ptrs:
raw += struct.pack("P", ptr)
return array.array("Q", raw).tolist()
def cleanup(self):
if not self._alive:
return
self._alive = False
for r in range(self.world_size):
if self.peer_ptrs[r] == 0:
continue
if r == self.rank:
_cuda_free(self.peer_ptrs[r])
else:
with contextlib.suppress(RuntimeError):
_check(cudart.cudaIpcCloseMemHandle(self.peer_ptrs[r])[0])
self.peer_ptrs[r] = 0
self.local_ptr = 0
def __del__(self):
if not sys.is_finalizing():
self.cleanup()
# ---------------------------------------------------------------------------
# Lamport negative-zero initialization
# ---------------------------------------------------------------------------
def _lamport_fill_neg_zero(device_ptr: int, size_bytes: int):
"""
Fill device memory with IEEE-754 negative zero (-0.0f = 0x80000000).
This is the "slot empty" sentinel for the Lamport protocol: the kernel
spin-waits until a value is *not* negative zero.
"""
if size_bytes == 0 or device_ptr == 0:
return
n_floats = size_bytes // 4
# torch preserves -0.0 in IEEE-754
fill = torch.full((n_floats,), -0.0, dtype=torch.float32, device="cuda")
_cuda_memcpy_d2d(device_ptr, fill.data_ptr(), size_bytes)
del fill
# ---------------------------------------------------------------------------
# LamportWorkspace — the main class
# ---------------------------------------------------------------------------
class LamportWorkspace:
"""
Self-contained workspace for Lamport-based cross-GPU AllReduce.
Parameters
----------
rank : int
Local rank (0-based).
world_size : int
Total number of ranks in the TP group.
comm_size : int
Size in bytes of *one* Lamport buffer slot. The total IPC allocation
per rank is ``3 * comm_size`` (triple-buffering). Must be large enough
to hold the per-slot data written by the kernel. Use
``compute_comm_size_for_minimax()`` for a safe default.
process_group : optional
``torch.distributed`` process group for IPC handle exchange.
``None`` uses the default group.
"""
def __init__(self, rank: int, world_size: int, comm_size: int, process_group=None):
assert world_size >= 2, "Lamport workspace requires at least 2 ranks"
assert comm_size > 0, "comm_size must be positive"
self.rank = rank
self.world_size = world_size
self.comm_size = comm_size
# 1) Lamport triple-buffer (the only IPC memory the kernel reads/writes)
lamport_total = 3 * comm_size
self._lamport = IpcBuffer(rank, world_size, lamport_total, process_group)
_lamport_fill_neg_zero(self._lamport.local_ptr, lamport_total)
# 2) flag_buffer on device: int32[3] = {counter, unused, lamport_flag}
# counter — used for block-level sync inside the kernel
# unused — reserved (index 1)
# lamport_flag — triple-buffer rotation index (0 → 1 → 2 → 0 …)
self._flag_buf = torch.zeros(3, dtype=torch.int32, device="cuda")
# 3) layout_buffer on device: int64[2] = {clear_size, comm_size}
# clear_size — bytes to clear from *previous* slot (set by kernel)
# comm_size — size of one triple-buffer slot
self._layout_buf = torch.tensor(
[0, comm_size], dtype=torch.int64, device="cuda"
)
# 4) Assemble device-side void* pointer array
N = world_size
ptrs: list[int] = []
ptrs += [0] * N # [0 .. N-1] ipc_buffers (placeholder)
ptrs += [0] * N # [N .. 2N-1] ipc_barriers (placeholder)
ptrs += self._lamport.serialize() # [2N .. 3N-1] lamport peer ptrs
ptrs.append(self._flag_buf.data_ptr()) # [3N] flag_buffer
ptrs.append(self._layout_buf.data_ptr()) # [3N+1] layout_buffer
self._workspace = torch.tensor(ptrs, dtype=torch.int64, device="cuda")
@property
def workspace(self) -> torch.Tensor:
"""Device tensor (int64) that can be passed to the kernel
as ``void** workspace``."""
return self._workspace
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def compute_comm_size_for_minimax(
max_tokens: int,
world_size: int,
fused_qk: bool = True,
) -> int:
"""
Return a safe ``comm_size`` (in bytes) for MiniMaxReduceRMSKernel.
The kernel stores per-token variance scalars in the Lamport buffer:
- single-matrix path: ``world_size × max_tokens × 4`` bytes per slot
- fused Q+K path: ``world_size × 2 × ceil(max_tokens/4) × 16`` bytes per slot
The returned value is rounded up to 2 MiB alignment.
"""
if fused_qk:
groups = (max_tokens + 3) // 4
slot_bytes = world_size * 2 * groups * 16 # 16 = sizeof(float4)
else:
slot_bytes = world_size * max_tokens * 4 # 4 = sizeof(float)
return ((slot_bytes + _ALIGN - 1) >> 21) << 21
def cleanup(self):
if hasattr(self, "_lamport"):
self._lamport.cleanup()
def __del__(self):
if not sys.is_finalizing():
self.cleanup()
def __repr__(self):
return (
f"LamportWorkspace(rank={self.rank}, world_size={self.world_size}, "
f"comm_size={self.comm_size})"
)
# ---------------------------------------------------------------------------
# Cached convenience function (mirrors TRT-LLM's get_allreduce_workspace)
# ---------------------------------------------------------------------------
_cache_lock = threading.Lock()
_workspace_cache: dict = {}
def get_allreduce_workspace(
rank: int,
world_size: int,
comm_size: int | None = None,
max_tokens: int = 16384,
process_group=None,
) -> torch.Tensor:
"""
Return a cached workspace tensor for the given (rank, world_size) pair.
On first call the workspace is allocated and IPC handles are exchanged;
subsequent calls with the same arguments return the cached tensor.
Parameters
----------
rank, world_size : int
TP rank and TP size.
comm_size : int, optional
Explicit slot size in bytes. If ``None``, computed automatically
from ``max_tokens`` and ``world_size`` (fused Q+K path).
max_tokens : int
Maximum number of tokens per batch (used when ``comm_size is None``).
process_group : optional
``torch.distributed`` process group.
"""
if comm_size is None:
comm_size = LamportWorkspace.compute_comm_size_for_minimax(
max_tokens, world_size, fused_qk=True
)
pg_id = id(process_group) if process_group is not None else 0
key = (rank, world_size, comm_size, pg_id)
with _cache_lock:
if key not in _workspace_cache:
ws = LamportWorkspace(rank, world_size, comm_size, process_group)
_workspace_cache[key] = ws
return _workspace_cache[key].workspace
......@@ -209,12 +209,24 @@ class GGUFModelLoader(BaseModelLoader):
GGUF tensor name with suffix (e.g., 'mm.soft_emb_norm.weight')
or None if no mapping found
"""
# In transformers v5, multimodal models (e.g. Gemma3) wrap
# all sub-models under an outer 'model.' attribute, producing
# state_dict keys like 'model.language_model.layers.0...' and
# 'model.vision_tower.vision_model...'. Strip this outer
# prefix so the keys match what gguf-py expects.
if is_multimodal and hf_name.startswith("model."):
hf_name = hf_name[6:] # Remove outer 'model.'
# Strip 'language_model.' prefix for multimodal models - gguf-py
# tensor mappings expect parameter names without this prefix.
# Note: 'model.' prefix should be KEPT for text-only models as
# gguf-py expects it.
if hf_name.startswith("language_model."):
hf_name = hf_name[15:] # Remove 'language_model.'
# Re-add 'model.' prefix because gguf-py text tensor maps
# expect 'model.layers...' format.
if is_multimodal:
hf_name = "model." + hf_name
# Parse parameter name and suffix
if hf_name.endswith((".weight", ".bias")):
......
......@@ -19,6 +19,7 @@
"""Gemma 4 model implementation for vLLM."""
from collections.abc import Iterable
from dataclasses import replace
from itertools import islice
import regex as re
......@@ -32,6 +33,7 @@ from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.attention import Attention
......@@ -56,10 +58,18 @@ from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
from .interfaces import (
EagleModelMixin,
MixtureOfExperts,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
WeightsMapper,
extract_layer_index,
is_pp_missing_parameter,
make_layers,
......@@ -636,8 +646,206 @@ class Gemma4DecoderLayer(nn.Module):
return hidden_states, None
@support_torch_compile
class Gemma4Model(nn.Module):
def _run_decoder_layers(
decoder_layers: list[Gemma4DecoderLayer],
layer_idx_start: int,
positions: torch.Tensor,
hidden_states: torch.Tensor,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
"""Run a slice of decoder layers with PLE extraction."""
residual = None
for idx, layer in enumerate(decoder_layers):
layer_idx = idx + layer_idx_start
layer_per_input = (
per_layer_inputs[:, layer_idx, :] if per_layer_inputs is not None else None
)
hidden_states, residual = layer(
positions,
hidden_states,
residual,
per_layer_input=layer_per_input,
**kwargs,
)
return hidden_states
@support_torch_compile(
enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
)
class Gemma4SelfDecoderLayers(nn.Module):
"""Compiled wrapper: embedding + non-KV-shared layers (YOCO first half).
Owns the embedding and PLE modules so they are inside the compiled
graph. Gemma4Model delegates embedding methods here.
"""
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layers: list[Gemma4DecoderLayer],
layer_idx_start: int,
embed_tokens: VocabParallelEmbedding,
normalizer: torch.Tensor,
embed_tokens_per_layer: VocabParallelEmbedding | None,
embed_scale_per_layer: torch.Tensor | None,
per_layer_model_projection: ColumnParallelLinear | None,
per_layer_projection_norm: RMSNorm | None,
per_layer_input_scale: torch.Tensor | None,
per_layer_projection_scale: torch.Tensor | None,
):
super().__init__()
self.decoder_layers = decoder_layers
self.layer_idx_start = layer_idx_start
config = _get_text_config(vllm_config.model_config.hf_config)
self.config = config
self.hidden_size_per_layer_input = getattr(
config, "hidden_size_per_layer_input", 0
)
self.vocab_size_per_layer_input = getattr(
config, "vocab_size_per_layer_input", config.vocab_size
)
# Shared references to modules owned by Gemma4Model — must be
# inside this nn.Module so torch.compile captures them.
self.embed_tokens = embed_tokens
self.normalizer = normalizer
self.embed_tokens_per_layer = embed_tokens_per_layer
self.embed_scale_per_layer = embed_scale_per_layer
self.per_layer_model_projection = per_layer_model_projection
self.per_layer_projection_norm = per_layer_projection_norm
self.per_layer_input_scale = per_layer_input_scale
self.per_layer_projection_scale = per_layer_projection_scale
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.normalizer
def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None:
"""Get per-layer embeddings from embed_tokens_per_layer.
Returns:
Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input)
"""
if self.embed_tokens_per_layer is None:
return None
per_layer_inputs_mask = torch.logical_and(
input_ids >= 0,
input_ids < self.vocab_size_per_layer_input,
)
per_layer_inputs_tokens = torch.where(
per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
)
per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens)
per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer
return per_layer_embeds.reshape(
*input_ids.shape,
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
def project_per_layer_inputs(
self,
inputs_embeds: torch.Tensor,
per_layer_inputs: torch.Tensor | None,
) -> torch.Tensor | None:
"""Project inputs_embeds and combine with per_layer_inputs.
Steps:
1. Project inputs_embeds: hidden_size → total_ple_dim
2. Scale by hidden_size^{-0.5}
3. Reshape to (num_tokens, num_layers, per_layer_dim)
4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
"""
if self.per_layer_model_projection is None:
return None
per_layer_projection = self.per_layer_model_projection(inputs_embeds)
per_layer_projection = per_layer_projection * self.per_layer_projection_scale
per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1],
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
if per_layer_inputs is None:
return per_layer_projection
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if inputs_embeds is not None:
hidden_states = inputs_embeds
per_layer_inputs = self.project_per_layer_inputs(
hidden_states, per_layer_inputs
)
else:
hidden_states = self.embed_input_ids(input_ids)
per_layer_embeds = self.get_per_layer_inputs(input_ids)
per_layer_inputs = self.project_per_layer_inputs(
hidden_states, per_layer_embeds
)
hidden_states = _run_decoder_layers(
self.decoder_layers,
self.layer_idx_start,
positions,
hidden_states,
per_layer_inputs,
**kwargs,
)
return hidden_states, per_layer_inputs
@support_torch_compile(
enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill
)
class Gemma4CrossDecoderLayers(nn.Module):
"""Cross-decoder layers (YOCO second half, KV-shared)."""
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layers: list[Gemma4DecoderLayer],
layer_idx_start: int,
):
super().__init__()
self.decoder_layers = decoder_layers
self.layer_idx_start = layer_idx_start
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
return _run_decoder_layers(
self.decoder_layers,
self.layer_idx_start,
positions,
hidden_states,
per_layer_inputs,
**kwargs,
)
@support_torch_compile(
enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill
)
class Gemma4Model(nn.Module, EagleModelMixin):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = _get_text_config(vllm_config.model_config.hf_config)
......@@ -740,6 +948,75 @@ class Gemma4Model(nn.Module):
torch.tensor(config.hidden_size**0.5),
persistent=False,
)
# --- You Only Cache Once (YOCO) split for fast prefill ---
first_kv_shared_layer_idx = config.num_hidden_layers - getattr(
config, "num_kv_shared_layers", 0
)
from vllm.compilation.backends import set_model_tag
# Layers 0..(K-1) are self-decoder layers in YOCO
with set_model_tag("self_decoder"):
self.self_decoder = Gemma4SelfDecoderLayers(
vllm_config=vllm_config,
prefix=f"{prefix}.self_decoder",
decoder_layers=self.layers[:first_kv_shared_layer_idx],
layer_idx_start=0,
embed_tokens=self.embed_tokens,
normalizer=self.normalizer,
embed_tokens_per_layer=getattr(self, "embed_tokens_per_layer", None),
embed_scale_per_layer=getattr(self, "embed_scale_per_layer", None),
per_layer_model_projection=getattr(
self, "per_layer_model_projection", None
),
per_layer_projection_norm=getattr(
self, "per_layer_projection_norm", None
),
per_layer_input_scale=getattr(self, "per_layer_input_scale", None),
per_layer_projection_scale=getattr(
self, "per_layer_projection_scale", None
),
)
# Layers K..(N-1) are cross-decoder layers in YOCO
with set_model_tag("cross_decoder"):
self.cross_decoder = Gemma4CrossDecoderLayers(
vllm_config=vllm_config,
prefix=f"{prefix}.cross_decoder",
decoder_layers=self.layers[first_kv_shared_layer_idx:],
layer_idx_start=first_kv_shared_layer_idx,
)
self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill
if self.fast_prefill_enabled:
# Allocate static buffers for CUDAGraph
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
device = next(self.parameters()).device
self.positions = torch.zeros(
max_num_tokens, dtype=torch.int64, device=device
)
self.hidden_states = torch.zeros(
(max_num_tokens, config.hidden_size),
dtype=self.embed_tokens.weight.dtype,
device=device,
)
if (
self.hidden_size_per_layer_input
and self.hidden_size_per_layer_input > 0
):
self.per_layer_inputs = torch.zeros(
(
max_num_tokens,
config.num_hidden_layers,
self.hidden_size_per_layer_input,
),
dtype=self.embed_tokens.weight.dtype,
device=device,
)
else:
self.per_layer_inputs = None
# Custom factory that includes per_layer_inputs for PLE-enabled PP.
# per_layer_inputs has shape (batch, num_layers, per_layer_dim),
# which differs from the standard (batch, hidden_size) shape,
......@@ -776,47 +1053,22 @@ class Gemma4Model(nn.Module):
self.make_empty_intermediate_tensors = _make_empty_intermediate_tensors
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.normalizer
return self.self_decoder.embed_input_ids(input_ids)
def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor:
def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None:
"""Get per-layer embeddings from embed_tokens_per_layer.
Returns:
Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input)
"""
if self.embed_tokens_per_layer is None:
return None
# Handle out-of-vocab tokens for PLE (vocab_size_per_layer_input may
# be smaller than the main vocab_size).
per_layer_inputs_mask = torch.logical_and(
input_ids >= 0,
input_ids < self.vocab_size_per_layer_input,
)
per_layer_inputs_tokens = torch.where(
per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
)
# Get packed per-layer embeddings: (num_tokens, total_ple_dim)
per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens)
# Apply embed_scale (sqrt of per-layer hidden dim)
per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input)
per_layer_embeds = per_layer_embeds.reshape(
*input_ids.shape,
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
return per_layer_embeds
return self.self_decoder.get_per_layer_inputs(input_ids)
def project_per_layer_inputs(
self,
inputs_embeds: torch.Tensor,
per_layer_inputs: torch.Tensor | None,
) -> torch.Tensor:
) -> torch.Tensor | None:
"""Project inputs_embeds and combine with per_layer_inputs.
Steps:
......@@ -826,29 +1078,94 @@ class Gemma4Model(nn.Module):
4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
"""
if self.per_layer_model_projection is None:
return None
return self.self_decoder.project_per_layer_inputs(
inputs_embeds, per_layer_inputs
)
# Project from hidden_size to total_ple_dim
# Scaled projection: output = linear(input, weight) * scale
per_layer_projection = self.per_layer_model_projection(inputs_embeds)
per_layer_projection = per_layer_projection * self.per_layer_projection_scale
def fast_prefill_forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
logits_indices_padded, num_logits_indices = None, None
attn_metadata = get_forward_context().attn_metadata
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input)
per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1],
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
layer_attn_metadata = attn_metadata[
self.layers[-1].self_attn.attn.layer_name
]
if isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata):
logits_indices_padded = layer_attn_metadata.logits_indices_padded
num_logits_indices = layer_attn_metadata.num_logits_indices
batch_size = positions.size(0)
self.positions[:batch_size].copy_(positions)
self_decoder_hidden_states, per_layer_inputs = self.self_decoder(
input_ids=input_ids,
positions=self.positions[:batch_size],
inputs_embeds=inputs_embeds,
per_layer_inputs=per_layer_inputs,
**kwargs,
)
# Normalize
per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
if logits_indices_padded is None:
logits_indices_padded = torch.arange(
batch_size,
dtype=positions.dtype,
device=positions.device,
)
if per_layer_inputs is None:
return per_layer_projection
# NOTE: Keep .clone() until fix in
# https://github.com/vllm-project/vllm/pull/22282
hidden_states = self_decoder_hidden_states.clone()
# Combine: (projection + per_layer_inputs) * scale
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
num_padded = logits_indices_padded.size(0)
self.positions[:num_padded].copy_(positions[logits_indices_padded])
self.hidden_states[:num_padded].copy_(
self_decoder_hidden_states[logits_indices_padded]
)
if self.per_layer_inputs is not None and per_layer_inputs is not None:
self.per_layer_inputs[:num_padded].copy_(
per_layer_inputs[logits_indices_padded]
)
# Update batch_descriptor so the cross-decoder's piecewise
# CUDAGraphWrapper dispatches to the correct (reduced) batch size.
forward_context = get_forward_context()
orig_batch_desc = forward_context.batch_descriptor
if orig_batch_desc is not None:
forward_context.batch_descriptor = replace(
orig_batch_desc, num_tokens=num_padded
)
cross_per_layer = (
self.per_layer_inputs[:num_padded]
if self.per_layer_inputs is not None
else None
)
cross_hidden_states = self.cross_decoder(
self.positions[:num_padded],
self.hidden_states[:num_padded],
cross_per_layer,
**kwargs,
)
# Restore the original batch_descriptor
forward_context.batch_descriptor = orig_batch_desc
if num_logits_indices is not None:
assert num_logits_indices > 0
hidden_states[logits_indices_padded[:num_logits_indices]] = (
cross_hidden_states[:num_logits_indices]
)
else:
hidden_states = cross_hidden_states
return hidden_states
def forward(
self,
......@@ -858,7 +1175,19 @@ class Gemma4Model(nn.Module):
inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | IntermediateTensors:
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
if self.fast_prefill_enabled:
hidden_states = self.fast_prefill_forward(
input_ids,
positions,
inputs_embeds,
per_layer_inputs,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return hidden_states
# Normal (non-fast-prefill) path with PP support
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
......@@ -882,6 +1211,7 @@ class Gemma4Model(nn.Module):
residual = intermediate_tensors["residual"]
per_layer_inputs = intermediate_tensors.get("per_layer_inputs")
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for layer_idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
......@@ -900,6 +1230,9 @@ class Gemma4Model(nn.Module):
per_layer_input=layer_per_input,
**kwargs,
)
self._maybe_add_hidden_state(
aux_hidden_states, layer_idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{
......@@ -914,6 +1247,9 @@ class Gemma4Model(nn.Module):
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
......@@ -926,21 +1262,27 @@ class Gemma4Model(nn.Module):
("gate_up_proj", "up_proj", 1),
]
# MoE expert weight mapping: checkpoint 3D packed tensors are
# exploded in _weight_iterator to per-expert 2D weights like:
# MoE expert weight mapping: checkpoint can have either:
# 1. 3D packed tensors (exploded in _weight_iterator to per-expert 2D)
# 2. Already per-expert 2D weights (if quantized)
# Map to FusedMoE parameters:
# moe.experts.{id}.gate_proj → FusedMoE w1 (shard of w13)
# moe.experts.{id}.up_proj → FusedMoE w3 (shard of w13)
# moe.experts.{id}.down_proj → FusedMoE w2
# We build the mapping directly since Gemma4 uses bare param
# names (no .weight suffix) unlike standard MoE checkpoints.
#
# Use prefix matching to handle both weights and
# quantization scale parameters. The param_name is a prefix ending
# in underscore, and weight_name ends with a dot, so that:
# "experts.0.gate_proj.weight_scale" -> "experts.w13_weight_scale"
# "experts.0.gate_proj.weight" -> "experts.w13_weight"
num_experts = getattr(self.config, "num_experts", None) or 0
expert_params_mapping = [
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_weight"
"experts.w13_"
if proj_name in ["gate_proj", "up_proj"]
else "experts.w2_weight",
f"experts.{expert_id}.{proj_name}",
else "experts.w2_",
f"experts.{expert_id}.{proj_name}.",
expert_id,
shard_id,
)
......@@ -1000,9 +1342,21 @@ class Gemma4Model(nn.Module):
expert_id,
shard_id,
) in expert_params_mapping:
if weight_name not in name:
# Match both:
# - Bare weights: "experts.0.down_proj" (from 3D explosion)
# - With suffix: "experts.0.down_proj.weight_scale" (2D quantized)
# weight_name has trailing dot, so check with and without it
weight_name_base = weight_name.rstrip(".")
if weight_name in name:
# Has suffix (e.g., .weight_scale)
moe_name = name.replace(weight_name, param_name)
elif name.endswith(weight_name_base):
# Bare weight (no suffix)
moe_name = name.replace(
weight_name_base, param_name.rstrip("_") + "_weight"
)
else:
continue
moe_name = name.replace(weight_name, param_name)
if moe_name not in params_dict:
continue
if is_pp_missing_parameter(moe_name, self):
......@@ -1012,15 +1366,12 @@ class Gemma4Model(nn.Module):
# orientation for FusedMoE after _weight_iterator:
# gate/up: [I, H] → w1/w3 expects [I, H]
# down: [H, I] → w2 expects [H, I]
assert loaded_weight.dim() == 2, (
f"Expected 2D expert weight for {weight_name}, "
f"got shape {loaded_weight.shape}"
)
# Scales and other quantization params may be 1D or scalar.
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
weight_name + ".weight",
moe_name, # Pass mapped name (handles both weights and scales)
shard_id=shard_id,
expert_id=expert_id,
)
......@@ -1044,7 +1395,25 @@ class Gemma4Model(nn.Module):
return loaded_params
class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
class Gemma4ForCausalLM(
nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts, SupportsEagle3
):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# Gemma4ForConditionalGeneration already loads the text stack
# from `model.language_model.*`. We reuse that same checkpoint
# and adapter naming for the text-only Gemma4ForCausalLM path,
# so LoRA keys from the conditional wrapper map onto `model.*`.
"model.language_model.": "model.",
},
orig_to_new_substr={
# Gemma4ForConditionalGeneration names MoE adapter targets under
# `...moe.experts.*`, while the text-only model exposes them
# under `...moe.*`.
".moe.experts.gate_up_proj": ".moe.gate_up_proj",
".moe.experts.down_proj": ".moe.down_proj",
},
)
# Note: qkv_proj packing applies to non-k_eq_v layers (sliding
# attention and full attention without k_eq_v). k_eq_v layers use
# separate q_proj + k_proj without packing.
......@@ -1126,7 +1495,7 @@ class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | IntermediateTensors:
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
)
......@@ -1177,6 +1546,11 @@ class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
".moe.down_proj",
)
# Remap individual 2D expert weights:
# .experts.{id}.{proj} → .moe.experts.{id}.{proj}
# (This handles per-expert 2D quantized weights)
name = re.sub(r"\.experts\.(\d+)\.", r".moe.experts.\1.", name)
# MoE expert weights: checkpoint stores as 3D packed
# tensors. Explode into per-expert 2D weights for
# FusedMoE weight_loader.
......
......@@ -65,7 +65,12 @@ from vllm.multimodal.processing.processor import (
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .interfaces import (
MultiModalEmbeddings,
SupportsEagle3,
SupportsMultiModal,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
WeightsMapper,
......@@ -121,8 +126,12 @@ class Gemma4AudioInputs(TensorSchema):
"""
type: Literal["audio"] = "audio"
input_features_padded: Annotated[torch.Tensor, TensorShape("bn", "s", "f")]
input_features_mask: Annotated[torch.Tensor, TensorShape("bn", "s")]
input_features_padded: Annotated[
torch.Tensor, TensorShape("bn", "s", "f", dynamic_dims={"s"})
]
input_features_mask: Annotated[
torch.Tensor, TensorShape("bn", "s", dynamic_dims={"s"})
]
Gemma4ImageInputs = Gemma4ImagePixelInputs
......@@ -163,10 +172,15 @@ class Gemma4ProcessingInfo(BaseProcessingInfo):
Setting ``add_special_tokens=False`` here prevents the duplicate and
ensures both ``llm.generate()`` and the chat/completions API behave
correctly.
correctly for IT models. For PT models (without chat template), we
keep the default (True) to ensure BOS is added for raw prompts.
"""
tokenizer = self.ctx.get_tokenizer()
has_chat_template = getattr(tokenizer, "chat_template", None) is not None
params = super().get_default_tok_params()
params = params.with_kwargs(add_special_tokens=False)
if has_chat_template:
params = params.with_kwargs(add_special_tokens=False)
return params
def get_hf_processor(self, **kwargs: object) -> Gemma4Processor:
......@@ -503,6 +517,8 @@ class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]):
video_timestamps_per_video: list[list[float]] = []
video_frame_counts: list[int] = []
video_replacements: list[str] = []
for item in videos:
video_array, metadata = item
......@@ -555,10 +571,7 @@ class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]):
video_timestamps_per_video.append(timestamps)
video_frame_counts.append(len(frames))
# Build expanded replacement text and replace the
# <|video|> placeholder in the prompt.
# Use split(token, 1) to avoid collision — the
# replacement text itself contains <|video|> tokens.
# Build expanded replacement text for this video.
ts_strs = [f"{int(s // 60):02d}:{int(s % 60):02d}" for s in timestamps]
replacement = " ".join(
f"{t} {processor.boi_token}"
......@@ -566,9 +579,23 @@ class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]):
f"{processor.eoi_token}"
for t, n in zip(ts_strs, num_soft_per_frame)
)
parts = prompt.split(processor.video_token, 1)
if len(parts) == 2:
prompt = parts[0] + replacement + parts[1]
video_replacements.append(replacement)
# Replace all <|video|> placeholders at once. We split on
# video_token to get N+1 parts, then interleave with the
# N replacement strings. This avoids the iterative
# split-replace bug where replacement text (which itself
# contains <|video|> tokens) collides with later splits.
vt = processor.video_token
parts = prompt.split(vt, len(video_replacements))
# NOTE: len(parts) <= len(video_replacements) + 1
parts_with_repl: list[str] = []
for part, repl in zip(parts, video_replacements):
parts_with_repl.extend([part, repl])
parts_with_repl.extend(parts[len(video_replacements) :])
prompt = "".join(parts_with_repl)
video_outputs = {
"pixel_values_videos": torch.cat(all_video_pixel_values, dim=0),
......@@ -631,19 +658,23 @@ class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]):
)
if "input_features" in processed_outputs:
# Keep padded features for batched audio tower execution.
processed_outputs["input_features_padded"] = processed_outputs[
"input_features"
]
# Unpad per-item so each item's cache entry is self-contained.
# Unpad per-item so each item's cache entry is
# self-contained. The batched() field config in
# _get_mm_fields_config will re-pad all fields to the
# batch's max length at batch time, ensuring consistent
# padding regardless of cache history.
masks = processed_outputs["input_features_mask"]
unpadded_features = [
f[mask]
for f, mask in zip(
processed_outputs["input_features"],
processed_outputs["input_features_mask"],
masks,
)
]
unpadded_masks = [mask[mask] for mask in masks]
processed_outputs["input_features"] = unpadded_features
processed_outputs["input_features_padded"] = unpadded_features
processed_outputs["input_features_mask"] = unpadded_masks
# Merge video outputs into the final result
combined_outputs = dict(processed_outputs, **video_outputs)
......@@ -848,7 +879,12 @@ class Gemma4MultimodalEmbedder(nn.Module):
info=Gemma4ProcessingInfo,
dummy_inputs=Gemma4DummyInputsBuilder,
)
class Gemma4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
class Gemma4ForConditionalGeneration(
nn.Module,
SupportsMultiModal,
SupportsPP,
SupportsEagle3,
):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......
......@@ -113,7 +113,29 @@ class KimiK25ProcessingInfo(BaseProcessingInfo):
trust_remote_code=self.ctx.model_config.trust_remote_code,
)
self.media_token_id = media_token_id = hf_config.media_placeholder_token_id
# Resolve token ID from the tokenizer because transformers v5
# may remap token IDs vs config.json.
config_token_id = hf_config.media_placeholder_token_id
resolved_token_id = tokenizer.convert_tokens_to_ids("<|media_pad|>")
is_valid_resolved = isinstance(resolved_token_id, int) and (
tokenizer.unk_token_id is None
or resolved_token_id != tokenizer.unk_token_id
)
if is_valid_resolved and resolved_token_id != config_token_id:
logger.warning_once(
"Kimi-K2.5 config.media_placeholder_token_id (%d) disagrees "
"with tokenizer mapping for <|media_pad|> (%d). "
"Using tokenizer value.",
config_token_id,
resolved_token_id,
)
media_token_id = resolved_token_id
# Patch config so downstream code also sees the correct ID.
hf_config.media_placeholder_token_id = resolved_token_id
else:
media_token_id = config_token_id
self.media_token_id = media_token_id
self.media_token = tokenizer.decode(media_token_id)
self.image_processor = image_processor
......@@ -232,8 +254,7 @@ class KimiK25MultiModalProcessor(BaseMultiModalProcessor[KimiK25ProcessingInfo])
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
media_token_id = hf_config.media_placeholder_token_id
media_token_id = self.info.media_token_id
def get_replacement(item_idx: int):
media = mm_items.get_items("vision_chunk", (VisionChunkProcessorItems,))
......
......@@ -232,9 +232,7 @@ class MiniMaxM2Attention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = MiniMaxText01RMSNormTP.forward_qk(
self.q_norm, self.k_norm, q.contiguous(), k.contiguous()
)
q, k = MiniMaxText01RMSNormTP.forward_qk(self.q_norm, self.k_norm, q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
......
......@@ -32,9 +32,9 @@ from transformers.models.musicflamingo import (
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs import MultiModalDataDict
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
......
......@@ -16,13 +16,11 @@
# limitations under the License.
"""Wrapper around `transformers` models"""
from vllm.compilation.decorators import support_torch_compile
from vllm.model_executor.models.transformers.base import Base
from vllm.model_executor.models.transformers.causal import CausalMixin
from vllm.model_executor.models.transformers.legacy import LegacyMixin
from vllm.model_executor.models.transformers.moe import MoEMixin
from vllm.model_executor.models.transformers.multimodal import (
DYNAMIC_ARG_DIMS,
MultiModalDummyInputsBuilder,
MultiModalMixin,
MultiModalProcessingInfo,
......@@ -32,16 +30,13 @@ from vllm.model_executor.models.transformers.pooling import (
EmbeddingMixin,
SequenceClassificationMixin,
)
from vllm.model_executor.models.transformers.utils import can_enable_torch_compile
from vllm.multimodal import MULTIMODAL_REGISTRY
# Text only models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForCausalLM(CausalMixin, Base): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ...
......@@ -51,9 +46,6 @@ class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ...
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder,
)
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ...
......@@ -62,20 +54,15 @@ class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ...
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder,
)
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalMoEForCausalLM(
MoEMixin, MultiModalMixin, CausalMixin, Base
): ...
# Embedding models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersEmbeddingModel(EmbeddingMixin, LegacyMixin, Base): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ...
......@@ -84,20 +71,15 @@ class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ...
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder,
)
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalEmbeddingModel(EmbeddingMixin, MultiModalMixin, Base): ...
# Sequence classification models
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForSequenceClassification(
SequenceClassificationMixin, LegacyMixin, Base
): ...
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersMoEForSequenceClassification(
SequenceClassificationMixin, MoEMixin, Base
): ...
......@@ -108,9 +90,6 @@ class TransformersMoEForSequenceClassification(
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder,
)
@support_torch_compile(
dynamic_arg_dims=DYNAMIC_ARG_DIMS, enable_if=can_enable_torch_compile
)
class TransformersMultiModalForSequenceClassification(
SequenceClassificationMixin, MultiModalMixin, Base
): ...
......
......@@ -16,6 +16,7 @@
# limitations under the License.
"""Transformers modeling backend base class."""
import sys
from collections.abc import Callable, Iterable
from itertools import chain
from operator import attrgetter
......@@ -29,6 +30,7 @@ from torch import nn
from transformers import AutoModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.compilation.decorators import support_torch_compile
from vllm.config.utils import getattr_iter
from vllm.distributed import get_pp_group, get_tp_group
from vllm.distributed.utils import get_pp_indices
......@@ -47,6 +49,7 @@ from vllm.model_executor.models.interfaces import (
)
from vllm.model_executor.models.interfaces_base import VllmModel
from vllm.model_executor.models.transformers.utils import (
can_enable_torch_compile,
get_feature_request_tip,
init_on_device_without_buffers,
log_replacement,
......@@ -117,6 +120,7 @@ class Base(
self.config = vllm_config.model_config.hf_config
self.text_config = self.config.get_text_config()
self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.device_config = vllm_config.device_config
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
......@@ -146,7 +150,7 @@ class Base(
if self.quant_config:
quant_method_name = self.quant_config.get_name()
# Check for unsupported quantization methods.
if quant_method_name == "mxfp4":
if quant_method_name in ("mxfp4", "gpt_oss_mxfp4"):
raise NotImplementedError(
"Transformers modeling backend does "
"not support MXFP4 quantization yet."
......@@ -155,14 +159,16 @@ class Base(
if "gptq" in quant_method_name:
self.ignore_unexpected_suffixes.append(".bias")
# Patch config and init on "meta" to delay allocating GPU tensors
self._patch_config()
from_config_kwargs = dict(
config=self.config,
dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
self._decorate_for_torch_compile(**from_config_kwargs)
# Init on "meta" to delay allocating GPU tensors
with init_on_device_without_buffers("meta"):
self.model: PreTrainedModel = AutoModel.from_config(
self.config,
dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
self.model: PreTrainedModel = AutoModel.from_config(**from_config_kwargs)
# Create weight name to module qualname mapper
self._create_hf_to_vllm_mapper()
......@@ -218,6 +224,87 @@ class Base(
if sub_config.dtype != (dtype := self.config.dtype):
sub_config.dtype = dtype
def _get_decoder_cls(self, **kwargs: dict) -> type[PreTrainedModel]:
"""
Get the decoder class from the model.
Args:
kwargs: The kwargs to create the model.
Returns:
The decoder class.
"""
with torch.device("meta"):
model: PreTrainedModel = AutoModel.from_config(**kwargs)
decoder_cls = type(model.get_decoder())
logger.debug("Identified decoder class as: %s", decoder_cls)
del model
return decoder_cls
def _decorate_cls_for_torch_compile(
self,
cls: type[PreTrainedModel],
dynamic_arg_dims: dict[str, int] | None,
enable_if: Callable[["VllmConfig"], bool],
is_encoder: bool,
):
"""
Decorate `cls` to indicate to vLLM that it supports torch compile.
Args:
cls: The PreTrainedModel class to decorate.
dynamic_arg_dims: A mapping from argument name to the dynamic dimensions
of the argument. If None, default dynamic arg dims will be used. See
[`support_torch_compile`][vllm.compilation.decorators.support_torch_compile]
for more details.
enable_if: A function which takes in the vLLM config and returns whether
torch compile should be enabled for this class.
is_encoder: Whether the class being decorated is an encoder.
"""
logger.debug(
"Decorating `%s` as %s for torch compile with dynamic_arg_dims of %s",
cls.__name__,
"encoder" if is_encoder else "decoder",
dynamic_arg_dims,
)
@support_torch_compile(
dynamic_arg_dims=dynamic_arg_dims,
enable_if=enable_if,
is_encoder=is_encoder,
)
class SupportTorchCompileWrapper(cls): ...
# Preserve __module__ so transformers v5's source-file checks
# (e.g. _can_set_experts_implementation) read the original
# model's module instead of this file.
SupportTorchCompileWrapper.__module__ = cls.__module__
# Patch the class in its module
module = sys.modules[cls.__module__]
setattr(module, cls.__name__, SupportTorchCompileWrapper)
def _decorate_for_torch_compile(self, **kwargs: dict):
"""
Decorate the model's decoder class to indicate to vLLM that it supports torch
compile if `can_enable_torch_compile` is True.
Args:
kwargs: The kwargs to create the model, which are needed to get the decoder
class.
"""
self._decorate_cls_for_torch_compile(
cls=self._get_decoder_cls(**kwargs),
# Applied to a PreTrainedModel so the batch dimension will exist
dynamic_arg_dims=dict[str, int](
input_ids=1, # shape: [1, seq_len]
inputs_embeds=1, # shape: [1, seq_len, hidden_size]
position_ids=-1, # shape: [1, seq_len] or [3, 1, seq_len] for mrope
),
enable_if=can_enable_torch_compile,
is_encoder=False,
)
def _create_hf_to_vllm_mapper(self):
"""
Create a WeightsMapper to map checkpoint weight names to module qualnames.
......@@ -553,11 +640,6 @@ class Base(
input_ids = None
inputs_embeds = intermediate_tensors["hidden_states"]
if input_ids is not None:
input_ids = input_ids[None, ...]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[None, ...]
# If the model scales embeddings inside the input embedding layer we must
# ensure they are scaled here since VocabParallelEmbedding will not do it
if (
......@@ -568,22 +650,29 @@ class Base(
inputs_embeds = self.embed_input_ids(input_ids)
input_ids = None
if self.model_config.uses_mrope:
position_ids = positions[:, None]
else:
position_ids = positions[None, ...]
# Add batch dimension before entering Transformers model
if input_ids is not None and input_ids.ndim == 1:
# [seq_len] -> [1, seq_len]
input_ids = input_ids[None, ...]
if inputs_embeds is not None and inputs_embeds.ndim == 2:
# [seq_len, hidden_size] -> [1, seq_len, hidden_size]
inputs_embeds = inputs_embeds[None, ...]
if positions.ndim == 1:
# [seq_len] -> [1, seq_len]
positions = positions[None, ...]
outputs = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
use_cache=False,
position_ids=position_ids,
position_ids=positions,
attention_instances=self.attention_instances,
return_dict=False,
**self._output_aux_hidden_states_kwargs,
**kwargs,
)
# We must remove the batch dimension from these outputs
# Remove batch dimension after exiting Transformers model
hidden_states = outputs[0][0, ...]
if self._output_aux_hidden_states_kwargs:
aux_hidden_states = [x[0][0, ...] for x in outputs[1:]]
......
......@@ -20,7 +20,9 @@ from collections.abc import Mapping
from typing import TYPE_CHECKING
import torch
from transformers import AutoModel
from vllm.compilation.decorators import should_torch_compile_mm_encoder
from vllm.config.utils import getattr_iter
from vllm.inputs import MultiModalDataDict, MultiModalInput, mm_input
from vllm.logger import init_logger
......@@ -46,19 +48,11 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from transformers import BatchFeature
from transformers import BatchFeature, PreTrainedModel
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
DYNAMIC_ARG_DIMS = {
"input_ids": 0,
# set `positions` to last dim to support Qwen-mrope
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
}
logger = init_logger(__name__)
......@@ -274,6 +268,66 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
# Skip SupportsMRoPE.__init__ and call the next class in MRO
super(SupportsMRoPE, self).__init__(vllm_config=vllm_config, prefix=prefix)
def _get_encoder_cls(
self, modality: str = "image", **kwargs: dict
) -> type["PreTrainedModel"]:
"""
Get the encoder class from the model.
Args:
kwargs: The kwargs to create the model.
Returns:
The encoder class.
"""
with torch.device("meta"):
model: PreTrainedModel = AutoModel.from_config(**kwargs)
encoder_cls = type(model.get_encoder(modality=modality))
logger.debug("Identified encoder class as: %s", encoder_cls)
if type(model) is encoder_cls:
raise ValueError(
"Unable to infer vision encoder class from the model. "
"You must either: update the model so that "
"https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.get_encoder"
" can detect the vision encoder correctly, or remove "
"'compile_mm_encoder'."
)
del model
return encoder_cls
def _decorate_for_torch_compile(self, **kwargs: dict):
"""
Decorate the model's decoder and encoder classes to indicate to vLLM that they
support torch compile if `can_enable_torch_compile` and
`should_torch_compile_mm_encoder` are True respectively.
Args:
kwargs: The kwargs to create the model, which are needed to get the decoder
and encoder classes.
"""
super()._decorate_for_torch_compile(**kwargs)
# Decorate the vision encoder model class to support torch compile if needed
if self.compilation_config.compile_mm_encoder:
self.check_version("5.0.0", "multimodal encoder compilation support")
logger.warning_once(
"Multimodal encoder compilation with the Transformers modeling backend "
"is an experimental feature. It relies on:\n"
"- The vision encoder being torch compilable.\n"
"- All vision encoder tensor inputs must be type hinted as either "
"`torch.Tensor` or `torch.FloatTensor`.\n"
"- The 0-th dimension of all tensor inputs to the vision encoder being "
"the dynamic dimension (i.e., sequence length or number of patches).\n"
"Please report any issues you encounter to help us improve it."
)
self._decorate_cls_for_torch_compile(
cls=self._get_encoder_cls(**kwargs),
# TODO: properly infer dynamic_arg_dims based on the encoder's forward
# method signature. Currently we assume dim 0 for all tensor inputs.
dynamic_arg_dims=None,
enable_if=should_torch_compile_mm_encoder,
is_encoder=True,
)
def forward(
self,
input_ids: torch.Tensor | None,
......@@ -285,6 +339,10 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
# Gemma3 and PaliGemma needs `token_type_ids` to work correctly
# Other models will not have `token_type_ids` in kwargs
kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"}
# Positions shape handling for MRoPE models
if self.model_config.uses_mrope:
# [3, seq_len] -> [3, 1, seq_len]
positions = positions[:, None]
model_output = super().forward(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
)
......
......@@ -470,6 +470,15 @@ class DelegatingParser(Parser):
# No tool calls
return [], content
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
if self._reasoning_parser is not None:
request = self._reasoning_parser.adjust_request(request)
if self._tool_parser is not None:
request = self._tool_parser.adjust_request(request)
return request
def extract_reasoning_streaming(
self,
previous_text: str,
......
......@@ -6,7 +6,7 @@ import os
from abc import abstractmethod
from collections.abc import Callable, Iterable, Sequence
from functools import cached_property
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast
from vllm.entrypoints.mcp.tool_server import ToolServer
from vllm.logger import init_logger
......@@ -150,6 +150,12 @@ class ReasoningParser:
previously been parsed and extracted (see constructor)
"""
def adjust_request(
self, request: "ChatCompletionRequest | ResponsesRequest"
) -> "ChatCompletionRequest | ResponsesRequest":
"""Adjust request parameters; override in subclasses as needed."""
return request
def prepare_structured_tag(
self,
original_tag: str | None,
......@@ -298,7 +304,7 @@ class ReasoningParserManager:
if isinstance(name, str):
names = [name]
elif is_list_of(name, str):
names = name
names = cast(list[str], name)
else:
names = [class_name]
......
......@@ -52,6 +52,16 @@ class Gemma4ReasoningParser(BaseThinkingReasoningParser):
# skip_special_tokens=True).
self._reasoning_text: str = ""
self._prefix_stripped: bool = False
self.new_turn_token_id = self.vocab["<|turn>"]
self.tool_call_token_id = self.vocab["<|tool_call>"]
self.tool_response_token_id = self.vocab["<|tool_response>"]
def adjust_request(
self, request: "ChatCompletionRequest | ResponsesRequest"
) -> "ChatCompletionRequest | ResponsesRequest":
"""Disable special-token stripping to preserve boundary tokens."""
request.skip_special_tokens = False
return request
@property
def start_token(self) -> str:
......@@ -63,6 +73,29 @@ class Gemma4ReasoningParser(BaseThinkingReasoningParser):
"""The token that ends reasoning content."""
return "<channel|>"
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
start_token_id = self.start_token_id
end_token_id = self.end_token_id
new_turn_token_id = self.new_turn_token_id
tool_call_token_id = self.tool_call_token_id
tool_response_token_id = self.tool_response_token_id
# Search from the end of input_ids to find the last match.
for i in range(len(input_ids) - 1, -1, -1):
if input_ids[i] == start_token_id:
return False
if input_ids[i] == tool_call_token_id:
# We're generating a tool call, so reasoning must be ended.
return True
if input_ids[i] in (new_turn_token_id, tool_response_token_id):
# We found a new turn or tool response token so don't consider
# reasoning ended yet, since the model starts new reasoning
# after these tokens.
return False
if input_ids[i] == end_token_id:
return True
return False
# ------------------------------------------------------------------
# Non-streaming path
# ------------------------------------------------------------------
......@@ -159,11 +192,10 @@ class Gemma4ReasoningParser(BaseThinkingReasoningParser):
result.reasoning = stripped
return result
else:
# This entire delta was prefix — suppress it.
# Don't set _prefix_stripped yet; there may be more
# prefix chars to consume in the next delta.
if len(self._reasoning_text) >= prefix_len:
self._prefix_stripped = True
result.reasoning = ""
return result
return None
# Case 2: Accumulated text is a strict prefix of
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from dataclasses import dataclass, field
from functools import lru_cache
from pathlib import Path
......@@ -10,6 +11,7 @@ from typing_extensions import TypeVar, assert_never
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config
from vllm.transformers_utils.gguf_utils import (
check_gguf_file,
get_gguf_file_path_from_hf,
......@@ -31,6 +33,13 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
# Model types whose hub tokenizer_class is incorrect and should be overridden with
# TokenizersBackend (the generic fast tokenizer). Adding a model type here is always a
# temporary workaround and better long term solutions are:
# - Add model type to MODELS_WITH_INCORRECT_HUB_TOKENIZER_CLASS in transformers (better)
# - Fix tokenizer_class on the hub for the affected models (best)
_MODEL_TYPES_WITH_INCORRECT_TOKENIZER_CLASS: set[str] = {"step3_vl"}
_VLLM_TOKENIZERS = {
"deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"),
"grok2": ("grok2", "Grok2Tokenizer"),
......@@ -202,7 +211,31 @@ def get_tokenizer(
**kwargs,
)
if tokenizer_cls == TokenizerLike:
# Ensure that, if the config were to come from vllm.transformers_utils.config, it is
# registered with AutoConfig before the tokenizer is loaded. This is necessary since
# tokenizer_cls_.from_pretrained will call AutoConfig.from_pretrained internally.
# This may fail for paths that don't have a model config (e.g. LoRA adapters),
# which is fine — those don't need custom config registration.
config = None
with contextlib.suppress(ValueError, OSError):
config = get_config(
tokenizer_name,
trust_remote_code=trust_remote_code,
revision=revision,
)
# Some models have an incorrect tokenizer_class on the hub.
# For these model types, bypass AutoTokenizer and use TokenizersBackend directly.
model_type = getattr(config, "model_type", None) if config else None
if model_type in _MODEL_TYPES_WITH_INCORRECT_TOKENIZER_CLASS:
from transformers.tokenization_utils_tokenizers import TokenizersBackend
logger.debug(
"Overriding tokenizer_class to TokenizersBackend for model_type=%r",
model_type,
)
tokenizer_cls_ = TokenizersBackend
elif tokenizer_cls == TokenizerLike:
tokenizer_cls_ = TokenizerRegistry.load_tokenizer_cls(tokenizer_mode)
else:
tokenizer_cls_ = tokenizer_cls
......
......@@ -66,6 +66,10 @@ def _parse_gemma4_value(value_str: str) -> object:
if value_str == "false":
return False
# Null
if value_str.lower() in ("null", "none", "nil"):
return None
# Number (int or float)
try:
if "." in value_str:
......@@ -78,7 +82,7 @@ def _parse_gemma4_value(value_str: str) -> object:
return value_str
def _parse_gemma4_args(args_str: str) -> dict:
def _parse_gemma4_args(args_str: str, *, partial: bool = False) -> dict:
"""Parse Gemma4's custom key:value format into a Python dict.
Format examples::
......@@ -89,6 +93,12 @@ def _parse_gemma4_args(args_str: str) -> dict:
nested:{inner_key:<|"|>val<|"|>}
items:[<|"|>a<|"|>,<|"|>b<|"|>]
Args:
args_str: The raw Gemma4 argument string.
partial: When True (streaming), bare values at end of string are
omitted because they may be incomplete and type-unstable
(e.g. partial boolean parsed as bare string).
Returns a dict ready for ``json.dumps()``.
"""
if not args_str or not args_str.strip():
......@@ -116,14 +126,16 @@ def _parse_gemma4_args(args_str: str) -> dict:
# Parse value
if i >= n:
result[key] = ""
if not partial:
result[key] = ""
break
# Skip whitespace after ':'
while i < n and args_str[i] in (" ", "\n", "\t"):
i += 1
if i >= n:
result[key] = ""
if not partial:
result[key] = ""
break
# String value: <|"|>...<|"|>
......@@ -155,7 +167,12 @@ def _parse_gemma4_args(args_str: str) -> dict:
elif args_str[i] == "}":
depth -= 1
i += 1
result[key] = _parse_gemma4_args(args_str[obj_start : i - 1])
if depth > 0:
# Incomplete nested object — use i (not i-1) to avoid
# dropping the last char, and recurse as partial.
result[key] = _parse_gemma4_args(args_str[obj_start:i], partial=True)
else:
result[key] = _parse_gemma4_args(args_str[obj_start : i - 1])
# Array: [...]
elif args_str[i] == "[":
......@@ -173,20 +190,26 @@ def _parse_gemma4_args(args_str: str) -> dict:
elif args_str[i] == "]":
depth -= 1
i += 1
arr_content = args_str[arr_start : i - 1]
result[key] = _parse_gemma4_array(arr_content)
if depth > 0:
result[key] = _parse_gemma4_array(args_str[arr_start:i], partial=True)
else:
result[key] = _parse_gemma4_array(args_str[arr_start : i - 1])
# Bare value (number, boolean, etc.)
else:
val_start = i
while i < n and args_str[i] not in (",", "}", "]"):
i += 1
if partial and i >= n:
# Value may be incomplete (e.g. partial boolean) —
# withhold to avoid type instability during streaming.
break
result[key] = _parse_gemma4_value(args_str[val_start:i])
return result
def _parse_gemma4_array(arr_str: str) -> list:
def _parse_gemma4_array(arr_str: str, *, partial: bool = False) -> list:
"""Parse a Gemma4 array content string into a Python list."""
items: list = []
i = 0
......@@ -224,7 +247,10 @@ def _parse_gemma4_array(arr_str: str) -> list:
elif arr_str[i] == "}":
depth -= 1
i += 1
items.append(_parse_gemma4_args(arr_str[obj_start : i - 1]))
if depth > 0:
items.append(_parse_gemma4_args(arr_str[obj_start:i], partial=True))
else:
items.append(_parse_gemma4_args(arr_str[obj_start : i - 1]))
# Nested array
elif arr_str[i] == "[":
......@@ -237,13 +263,18 @@ def _parse_gemma4_array(arr_str: str) -> list:
elif arr_str[i] == "]":
depth -= 1
i += 1
items.append(_parse_gemma4_array(arr_str[sub_start : i - 1]))
if depth > 0:
items.append(_parse_gemma4_array(arr_str[sub_start:i], partial=True))
else:
items.append(_parse_gemma4_array(arr_str[sub_start : i - 1]))
# Bare value
else:
val_start = i
while i < n and arr_str[i] not in (",", "]"):
i += 1
if partial and i >= n:
break
items.append(_parse_gemma4_value(arr_str[val_start:i]))
return items
......@@ -436,8 +467,10 @@ class Gemma4ToolParser(ToolParser):
) -> DeltaMessage | None:
# Buffer delta text to handle multi-token special sequences
delta_text = self._buffer_delta_text(delta_text)
# Reconstruct current_text after buffering to stay in sync
current_text = previous_text + delta_text
# Keep current_text from the upstream stream state. The buffered delta
# is only for emission, and must not be stitched back into the
# accumulated model text or normal content like "<div>" can be
# duplicated into "<<div>" when a tool call just ended.
# If no tool call token seen yet, emit as content
if self.tool_call_start_token not in current_text:
......@@ -661,7 +694,7 @@ class Gemma4ToolParser(ToolParser):
DeltaMessage with the argument diff, or None if no new content.
"""
try:
current_args = _parse_gemma4_args(raw_args_str)
current_args = _parse_gemma4_args(raw_args_str, partial=True)
except Exception:
logger.debug(
"Could not parse partial Gemma4 args yet: %s",
......@@ -675,10 +708,11 @@ class Gemma4ToolParser(ToolParser):
current_args_json = json.dumps(current_args, ensure_ascii=False)
# Withhold trailing closing characters that may shift as more
# tokens arrive. Strip trailing '}', '"', and ']' sequences
# to get the "safe prefix".
# tokens arrive. Strip trailing '}', '"', ']' and partial
# STRING_DELIM fragments ('<', '|', '\\', '>') to get the
# "safe prefix".
safe_json = current_args_json
while safe_json and safe_json[-1] in ("}", '"', "]"):
while safe_json and safe_json[-1] in ("}", '"', "]", "<", "|", "\\", ">"):
safe_json = safe_json[:-1]
prev_streamed = self.streamed_args_for_tool[self.current_tool_id]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment