Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fc67613a
"tests/vscode:/vscode.git/clone" did not exist on "965525667b70dc23463d57295dce792eba1ac452"
Commit
fc67613a
authored
Apr 18, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.19.1' into v0.19.0
parents
31aec25b
b1388b1f
Changes
82
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1182 additions
and
163 deletions
+1182
-163
vllm/config/vllm.py
vllm/config/vllm.py
+16
-0
vllm/entrypoints/anthropic/serving.py
vllm/entrypoints/anthropic/serving.py
+2
-1
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+2
-0
vllm/entrypoints/openai/responses/serving.py
vllm/entrypoints/openai/responses/serving.py
+2
-0
vllm/entrypoints/serve/render/serving.py
vllm/entrypoints/serve/render/serving.py
+15
-0
vllm/model_executor/layers/mamba/lamport_workspace.py
vllm/model_executor/layers/mamba/lamport_workspace.py
+302
-0
vllm/model_executor/model_loader/gguf_loader.py
vllm/model_executor/model_loader/gguf_loader.py
+12
-0
vllm/model_executor/models/gemma4.py
vllm/model_executor/models/gemma4.py
+441
-67
vllm/model_executor/models/gemma4_mm.py
vllm/model_executor/models/gemma4_mm.py
+55
-19
vllm/model_executor/models/kimi_k25.py
vllm/model_executor/models/kimi_k25.py
+24
-3
vllm/model_executor/models/minimax_m2.py
vllm/model_executor/models/minimax_m2.py
+1
-3
vllm/model_executor/models/musicflamingo.py
vllm/model_executor/models/musicflamingo.py
+1
-1
vllm/model_executor/models/transformers/__init__.py
vllm/model_executor/models/transformers/__init__.py
+0
-21
vllm/model_executor/models/transformers/base.py
vllm/model_executor/models/transformers/base.py
+107
-18
vllm/model_executor/models/transformers/multimodal.py
vllm/model_executor/models/transformers/multimodal.py
+67
-9
vllm/parser/abstract_parser.py
vllm/parser/abstract_parser.py
+9
-0
vllm/reasoning/abs_reasoning_parsers.py
vllm/reasoning/abs_reasoning_parsers.py
+8
-2
vllm/reasoning/gemma4_reasoning_parser.py
vllm/reasoning/gemma4_reasoning_parser.py
+35
-3
vllm/tokenizers/registry.py
vllm/tokenizers/registry.py
+34
-1
vllm/tool_parsers/gemma4_tool_parser.py
vllm/tool_parsers/gemma4_tool_parser.py
+49
-15
No files found.
vllm/config/vllm.py
View file @
fc67613a
...
...
@@ -1577,6 +1577,22 @@ class VllmConfig:
compile_range_end
,
)
if
compilation_config
.
pass_config
.
fuse_minimax_qk_norm
:
from
vllm.compilation.passes.fusion.minimax_qk_norm_fusion
import
(
MAX_TOKEN_NUM
,
)
max_token_num
=
min
(
MAX_TOKEN_NUM
,
self
.
scheduler_config
.
max_num_batched_tokens
)
if
compile_range_end
is
not
None
and
max_token_num
<
compile_range_end
:
computed_compile_ranges_endpoints
.
append
(
max_token_num
)
else
:
logger
.
debug
(
"Max num batched tokens below MiniMax QK norm fusion threshold, "
"MiniMax QK norm fusion enabled for all num_tokens."
)
if
compilation_config
.
compile_ranges_endpoints
is
not
None
:
for
x
in
compilation_config
.
compile_ranges_endpoints
:
assert
isinstance
(
x
,
int
)
...
...
vllm/entrypoints/anthropic/serving.py
View file @
fc67613a
...
...
@@ -170,7 +170,8 @@ class AnthropicServingMessages(OpenAIServingChat):
else
:
cls
.
_convert_message_content
(
msg
,
openai_msg
,
openai_messages
)
openai_messages
.
append
(
openai_msg
)
if
not
(
msg
.
role
==
"user"
and
"content"
not
in
openai_msg
):
openai_messages
.
append
(
openai_msg
)
@
classmethod
def
_convert_message_content
(
...
...
vllm/entrypoints/openai/api_server.py
View file @
fc67613a
...
...
@@ -372,6 +372,7 @@ async def init_app_state(
enable_auto_tools
=
args
.
enable_auto_tool_choice
,
exclude_tools_when_tool_choice_none
=
args
.
exclude_tools_when_tool_choice_none
,
tool_parser
=
args
.
tool_call_parser
,
reasoning_parser
=
args
.
structured_outputs_config
.
reasoning_parser
,
default_chat_template_kwargs
=
args
.
default_chat_template_kwargs
,
log_error_stack
=
args
.
log_error_stack
,
)
...
...
@@ -467,6 +468,7 @@ async def init_render_app_state(
enable_auto_tools
=
args
.
enable_auto_tool_choice
,
exclude_tools_when_tool_choice_none
=
args
.
exclude_tools_when_tool_choice_none
,
tool_parser
=
args
.
tool_call_parser
,
reasoning_parser
=
args
.
structured_outputs_config
.
reasoning_parser
,
default_chat_template_kwargs
=
args
.
default_chat_template_kwargs
,
log_error_stack
=
args
.
log_error_stack
,
)
...
...
vllm/entrypoints/openai/responses/serving.py
View file @
fc67613a
...
...
@@ -594,6 +594,7 @@ class OpenAIServingResponses(OpenAIServing):
default_template_kwargs
=
None
,
tool_dicts
=
tool_dicts
,
tool_parser
=
self
.
parser
.
tool_parser_cls
if
self
.
parser
else
None
,
reasoning_parser
=
self
.
parser
.
reasoning_parser_cls
if
self
.
parser
else
None
,
)
return
messages
,
engine_inputs
...
...
@@ -618,6 +619,7 @@ class OpenAIServingResponses(OpenAIServing):
default_template_kwargs
=
None
,
tool_dicts
=
tool_dicts
,
tool_parser
=
tool_parser
,
reasoning_parser
=
self
.
parser
.
reasoning_parser_cls
if
self
.
parser
else
None
,
)
return
engine_inputs
...
...
vllm/entrypoints/serve/render/serving.py
View file @
fc67613a
...
...
@@ -44,6 +44,7 @@ from vllm.inputs import (
)
from
vllm.logger
import
init_logger
from
vllm.parser
import
ParserManager
from
vllm.reasoning.abs_reasoning_parsers
import
ReasoningParser
from
vllm.renderers
import
BaseRenderer
,
merge_kwargs
from
vllm.renderers.inputs.preprocess
import
(
extract_prompt_components
,
...
...
@@ -74,6 +75,7 @@ class OpenAIServingRender:
enable_auto_tools
:
bool
=
False
,
exclude_tools_when_tool_choice_none
:
bool
=
False
,
tool_parser
:
str
|
None
=
None
,
reasoning_parser
:
str
|
None
=
None
,
default_chat_template_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
log_error_stack
:
bool
=
False
,
)
->
None
:
...
...
@@ -94,6 +96,11 @@ class OpenAIServingRender:
enable_auto_tools
=
enable_auto_tools
,
model_name
=
model_config
.
model
,
)
self
.
reasoning_parser
:
type
[
ReasoningParser
]
|
None
=
(
ParserManager
.
get_reasoning_parser
(
reasoning_parser_name
=
reasoning_parser
,
)
)
self
.
default_chat_template_kwargs
:
dict
[
str
,
Any
]
=
(
default_chat_template_kwargs
or
{}
)
...
...
@@ -245,6 +252,7 @@ class OpenAIServingRender:
default_template_kwargs
=
self
.
default_chat_template_kwargs
,
tool_dicts
=
tool_dicts
,
tool_parser
=
tool_parser
,
reasoning_parser
=
self
.
reasoning_parser
,
)
else
:
# For GPT-OSS.
...
...
@@ -498,6 +506,9 @@ class OpenAIServingRender:
default_template_kwargs
:
dict
[
str
,
Any
]
|
None
,
tool_dicts
:
list
[
dict
[
str
,
Any
]]
|
None
=
None
,
tool_parser
:
type
[
ToolParser
]
|
None
=
None
,
reasoning_parser
:
type
[
ReasoningParser
]
|
None
=
None
,
*
,
skip_mm_cache
:
bool
=
False
,
)
->
tuple
[
list
[
ConversationMessage
],
list
[
EngineInput
]]:
"""Copied from OpenAIServing._preprocess_chat."""
renderer
=
self
.
renderer
...
...
@@ -531,6 +542,10 @@ class OpenAIServingRender:
},
)
if
reasoning_parser
is
not
None
:
tokenizer
=
renderer
.
get_tokenizer
()
request
=
reasoning_parser
(
tokenizer
).
adjust_request
(
request
=
request
)
# tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
# is set, we want to prevent parsing a tool_call hallucinated by the LLM
...
...
vllm/model_executor/layers/mamba/lamport_workspace.py
0 → 100644
View file @
fc67613a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
array
import
contextlib
import
struct
import
sys
import
threading
import
torch
try
:
from
cuda.bindings
import
runtime
as
cudart
except
ImportError
:
from
cuda
import
cudart
_ALIGN
=
1
<<
21
# 2 MiB — CUDA IPC allocation alignment
# ---------------------------------------------------------------------------
# CUDA helpers
# ---------------------------------------------------------------------------
def
_check
(
error
):
"""Raise on CUDA runtime error."""
success
=
getattr
(
cudart
.
cudaError_t
,
"cudaSuccess"
,
None
)
or
cudart
.
cudaError_t
(
0
)
if
error
!=
success
:
raise
RuntimeError
(
f
"CUDA runtime error:
{
error
}
"
)
def
_cuda_malloc
(
size
:
int
):
aligned
=
((
size
+
_ALIGN
-
1
)
>>
21
)
<<
21
err
,
ptr
=
cudart
.
cudaMalloc
(
aligned
)
_check
(
err
)
return
ptr
,
aligned
def
_cuda_free
(
ptr
:
int
):
if
ptr
:
_check
(
cudart
.
cudaFree
(
ptr
)[
0
])
def
_cuda_memset_zero
(
ptr
:
int
,
size
:
int
):
_check
(
cudart
.
cudaMemset
(
ptr
,
0
,
size
)[
0
])
def
_cuda_memcpy_d2d
(
dst
:
int
,
src
:
int
,
size
:
int
):
_check
(
cudart
.
cudaMemcpy
(
dst
,
src
,
size
,
cudart
.
cudaMemcpyKind
.
cudaMemcpyDeviceToDevice
)[
0
]
)
# ---------------------------------------------------------------------------
# IPC buffer
# ---------------------------------------------------------------------------
class
IpcBuffer
:
"""
Allocates CUDA device memory and exchanges IPC handles with all ranks
so that every rank holds a valid device pointer to every other rank's buffer.
"""
def
__init__
(
self
,
rank
:
int
,
world_size
:
int
,
size
:
int
,
process_group
=
None
):
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
peer_ptrs
:
list
[
int
]
=
[
0
]
*
world_size
self
.
local_ptr
:
int
=
0
self
.
_alive
=
False
if
size
<=
0
:
return
self
.
local_ptr
,
_
=
_cuda_malloc
(
size
)
_cuda_memset_zero
(
self
.
local_ptr
,
size
)
self
.
_alive
=
True
# --- exchange IPC handles via torch.distributed ---
err
,
local_handle
=
cudart
.
cudaIpcGetMemHandle
(
self
.
local_ptr
)
_check
(
err
)
all_handles
:
list
[
bytes
|
None
]
=
[
None
]
*
world_size
torch
.
distributed
.
all_gather_object
(
all_handles
,
bytes
(
local_handle
.
reserved
),
group
=
process_group
)
for
r
in
range
(
world_size
):
if
r
==
rank
:
self
.
peer_ptrs
[
r
]
=
self
.
local_ptr
else
:
handle
=
cudart
.
cudaIpcMemHandle_t
()
handle
.
reserved
=
all_handles
[
r
]
err
,
ptr
=
cudart
.
cudaIpcOpenMemHandle
(
handle
,
cudart
.
cudaIpcMemLazyEnablePeerAccess
)
_check
(
err
)
self
.
peer_ptrs
[
r
]
=
ptr
def
serialize
(
self
)
->
list
[
int
]:
"""Return peer pointers as a list of int64 values (one per rank)."""
raw
=
b
""
for
ptr
in
self
.
peer_ptrs
:
raw
+=
struct
.
pack
(
"P"
,
ptr
)
return
array
.
array
(
"Q"
,
raw
).
tolist
()
def
cleanup
(
self
):
if
not
self
.
_alive
:
return
self
.
_alive
=
False
for
r
in
range
(
self
.
world_size
):
if
self
.
peer_ptrs
[
r
]
==
0
:
continue
if
r
==
self
.
rank
:
_cuda_free
(
self
.
peer_ptrs
[
r
])
else
:
with
contextlib
.
suppress
(
RuntimeError
):
_check
(
cudart
.
cudaIpcCloseMemHandle
(
self
.
peer_ptrs
[
r
])[
0
])
self
.
peer_ptrs
[
r
]
=
0
self
.
local_ptr
=
0
def
__del__
(
self
):
if
not
sys
.
is_finalizing
():
self
.
cleanup
()
# ---------------------------------------------------------------------------
# Lamport negative-zero initialization
# ---------------------------------------------------------------------------
def
_lamport_fill_neg_zero
(
device_ptr
:
int
,
size_bytes
:
int
):
"""
Fill device memory with IEEE-754 negative zero (-0.0f = 0x80000000).
This is the "slot empty" sentinel for the Lamport protocol: the kernel
spin-waits until a value is *not* negative zero.
"""
if
size_bytes
==
0
or
device_ptr
==
0
:
return
n_floats
=
size_bytes
//
4
# torch preserves -0.0 in IEEE-754
fill
=
torch
.
full
((
n_floats
,),
-
0.0
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
_cuda_memcpy_d2d
(
device_ptr
,
fill
.
data_ptr
(),
size_bytes
)
del
fill
# ---------------------------------------------------------------------------
# LamportWorkspace — the main class
# ---------------------------------------------------------------------------
class
LamportWorkspace
:
"""
Self-contained workspace for Lamport-based cross-GPU AllReduce.
Parameters
----------
rank : int
Local rank (0-based).
world_size : int
Total number of ranks in the TP group.
comm_size : int
Size in bytes of *one* Lamport buffer slot. The total IPC allocation
per rank is ``3 * comm_size`` (triple-buffering). Must be large enough
to hold the per-slot data written by the kernel. Use
``compute_comm_size_for_minimax()`` for a safe default.
process_group : optional
``torch.distributed`` process group for IPC handle exchange.
``None`` uses the default group.
"""
def
__init__
(
self
,
rank
:
int
,
world_size
:
int
,
comm_size
:
int
,
process_group
=
None
):
assert
world_size
>=
2
,
"Lamport workspace requires at least 2 ranks"
assert
comm_size
>
0
,
"comm_size must be positive"
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
comm_size
=
comm_size
# 1) Lamport triple-buffer (the only IPC memory the kernel reads/writes)
lamport_total
=
3
*
comm_size
self
.
_lamport
=
IpcBuffer
(
rank
,
world_size
,
lamport_total
,
process_group
)
_lamport_fill_neg_zero
(
self
.
_lamport
.
local_ptr
,
lamport_total
)
# 2) flag_buffer on device: int32[3] = {counter, unused, lamport_flag}
# counter — used for block-level sync inside the kernel
# unused — reserved (index 1)
# lamport_flag — triple-buffer rotation index (0 → 1 → 2 → 0 …)
self
.
_flag_buf
=
torch
.
zeros
(
3
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# 3) layout_buffer on device: int64[2] = {clear_size, comm_size}
# clear_size — bytes to clear from *previous* slot (set by kernel)
# comm_size — size of one triple-buffer slot
self
.
_layout_buf
=
torch
.
tensor
(
[
0
,
comm_size
],
dtype
=
torch
.
int64
,
device
=
"cuda"
)
# 4) Assemble device-side void* pointer array
N
=
world_size
ptrs
:
list
[
int
]
=
[]
ptrs
+=
[
0
]
*
N
# [0 .. N-1] ipc_buffers (placeholder)
ptrs
+=
[
0
]
*
N
# [N .. 2N-1] ipc_barriers (placeholder)
ptrs
+=
self
.
_lamport
.
serialize
()
# [2N .. 3N-1] lamport peer ptrs
ptrs
.
append
(
self
.
_flag_buf
.
data_ptr
())
# [3N] flag_buffer
ptrs
.
append
(
self
.
_layout_buf
.
data_ptr
())
# [3N+1] layout_buffer
self
.
_workspace
=
torch
.
tensor
(
ptrs
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
@
property
def
workspace
(
self
)
->
torch
.
Tensor
:
"""Device tensor (int64) that can be passed to the kernel
as ``void** workspace``."""
return
self
.
_workspace
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@
staticmethod
def
compute_comm_size_for_minimax
(
max_tokens
:
int
,
world_size
:
int
,
fused_qk
:
bool
=
True
,
)
->
int
:
"""
Return a safe ``comm_size`` (in bytes) for MiniMaxReduceRMSKernel.
The kernel stores per-token variance scalars in the Lamport buffer:
- single-matrix path: ``world_size × max_tokens × 4`` bytes per slot
- fused Q+K path: ``world_size × 2 × ceil(max_tokens/4) × 16`` bytes per slot
The returned value is rounded up to 2 MiB alignment.
"""
if
fused_qk
:
groups
=
(
max_tokens
+
3
)
//
4
slot_bytes
=
world_size
*
2
*
groups
*
16
# 16 = sizeof(float4)
else
:
slot_bytes
=
world_size
*
max_tokens
*
4
# 4 = sizeof(float)
return
((
slot_bytes
+
_ALIGN
-
1
)
>>
21
)
<<
21
def
cleanup
(
self
):
if
hasattr
(
self
,
"_lamport"
):
self
.
_lamport
.
cleanup
()
def
__del__
(
self
):
if
not
sys
.
is_finalizing
():
self
.
cleanup
()
def
__repr__
(
self
):
return
(
f
"LamportWorkspace(rank=
{
self
.
rank
}
, world_size=
{
self
.
world_size
}
, "
f
"comm_size=
{
self
.
comm_size
}
)"
)
# ---------------------------------------------------------------------------
# Cached convenience function (mirrors TRT-LLM's get_allreduce_workspace)
# ---------------------------------------------------------------------------
_cache_lock
=
threading
.
Lock
()
_workspace_cache
:
dict
=
{}
def
get_allreduce_workspace
(
rank
:
int
,
world_size
:
int
,
comm_size
:
int
|
None
=
None
,
max_tokens
:
int
=
16384
,
process_group
=
None
,
)
->
torch
.
Tensor
:
"""
Return a cached workspace tensor for the given (rank, world_size) pair.
On first call the workspace is allocated and IPC handles are exchanged;
subsequent calls with the same arguments return the cached tensor.
Parameters
----------
rank, world_size : int
TP rank and TP size.
comm_size : int, optional
Explicit slot size in bytes. If ``None``, computed automatically
from ``max_tokens`` and ``world_size`` (fused Q+K path).
max_tokens : int
Maximum number of tokens per batch (used when ``comm_size is None``).
process_group : optional
``torch.distributed`` process group.
"""
if
comm_size
is
None
:
comm_size
=
LamportWorkspace
.
compute_comm_size_for_minimax
(
max_tokens
,
world_size
,
fused_qk
=
True
)
pg_id
=
id
(
process_group
)
if
process_group
is
not
None
else
0
key
=
(
rank
,
world_size
,
comm_size
,
pg_id
)
with
_cache_lock
:
if
key
not
in
_workspace_cache
:
ws
=
LamportWorkspace
(
rank
,
world_size
,
comm_size
,
process_group
)
_workspace_cache
[
key
]
=
ws
return
_workspace_cache
[
key
].
workspace
vllm/model_executor/model_loader/gguf_loader.py
View file @
fc67613a
...
...
@@ -209,12 +209,24 @@ class GGUFModelLoader(BaseModelLoader):
GGUF tensor name with suffix (e.g., 'mm.soft_emb_norm.weight')
or None if no mapping found
"""
# In transformers v5, multimodal models (e.g. Gemma3) wrap
# all sub-models under an outer 'model.' attribute, producing
# state_dict keys like 'model.language_model.layers.0...' and
# 'model.vision_tower.vision_model...'. Strip this outer
# prefix so the keys match what gguf-py expects.
if
is_multimodal
and
hf_name
.
startswith
(
"model."
):
hf_name
=
hf_name
[
6
:]
# Remove outer 'model.'
# Strip 'language_model.' prefix for multimodal models - gguf-py
# tensor mappings expect parameter names without this prefix.
# Note: 'model.' prefix should be KEPT for text-only models as
# gguf-py expects it.
if
hf_name
.
startswith
(
"language_model."
):
hf_name
=
hf_name
[
15
:]
# Remove 'language_model.'
# Re-add 'model.' prefix because gguf-py text tensor maps
# expect 'model.layers...' format.
if
is_multimodal
:
hf_name
=
"model."
+
hf_name
# Parse parameter name and suffix
if
hf_name
.
endswith
((
".weight"
,
".bias"
)):
...
...
vllm/model_executor/models/gemma4.py
View file @
fc67613a
...
...
@@ -19,6 +19,7 @@
"""Gemma 4 model implementation for vLLM."""
from
collections.abc
import
Iterable
from
dataclasses
import
replace
from
itertools
import
islice
import
regex
as
re
...
...
@@ -32,6 +33,7 @@ from vllm.distributed import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
GeluAndMul
from
vllm.model_executor.layers.attention
import
Attention
...
...
@@ -56,10 +58,18 @@ from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name
,
)
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
MixtureOfExperts
,
SupportsLoRA
,
SupportsPP
from
vllm.v1.attention.backends.utils
import
KVSharingFastPrefillMetadata
from
.interfaces
import
(
EagleModelMixin
,
MixtureOfExperts
,
SupportsEagle3
,
SupportsLoRA
,
SupportsPP
,
)
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
extract_layer_index
,
is_pp_missing_parameter
,
make_layers
,
...
...
@@ -636,8 +646,206 @@ class Gemma4DecoderLayer(nn.Module):
return
hidden_states
,
None
@
support_torch_compile
class
Gemma4Model
(
nn
.
Module
):
def
_run_decoder_layers
(
decoder_layers
:
list
[
Gemma4DecoderLayer
],
layer_idx_start
:
int
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""Run a slice of decoder layers with PLE extraction."""
residual
=
None
for
idx
,
layer
in
enumerate
(
decoder_layers
):
layer_idx
=
idx
+
layer_idx_start
layer_per_input
=
(
per_layer_inputs
[:,
layer_idx
,
:]
if
per_layer_inputs
is
not
None
else
None
)
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
,
per_layer_input
=
layer_per_input
,
**
kwargs
,
)
return
hidden_states
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma4SelfDecoderLayers
(
nn
.
Module
):
"""Compiled wrapper: embedding + non-KV-shared layers (YOCO first half).
Owns the embedding and PLE modules so they are inside the compiled
graph. Gemma4Model delegates embedding methods here.
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
decoder_layers
:
list
[
Gemma4DecoderLayer
],
layer_idx_start
:
int
,
embed_tokens
:
VocabParallelEmbedding
,
normalizer
:
torch
.
Tensor
,
embed_tokens_per_layer
:
VocabParallelEmbedding
|
None
,
embed_scale_per_layer
:
torch
.
Tensor
|
None
,
per_layer_model_projection
:
ColumnParallelLinear
|
None
,
per_layer_projection_norm
:
RMSNorm
|
None
,
per_layer_input_scale
:
torch
.
Tensor
|
None
,
per_layer_projection_scale
:
torch
.
Tensor
|
None
,
):
super
().
__init__
()
self
.
decoder_layers
=
decoder_layers
self
.
layer_idx_start
=
layer_idx_start
config
=
_get_text_config
(
vllm_config
.
model_config
.
hf_config
)
self
.
config
=
config
self
.
hidden_size_per_layer_input
=
getattr
(
config
,
"hidden_size_per_layer_input"
,
0
)
self
.
vocab_size_per_layer_input
=
getattr
(
config
,
"vocab_size_per_layer_input"
,
config
.
vocab_size
)
# Shared references to modules owned by Gemma4Model — must be
# inside this nn.Module so torch.compile captures them.
self
.
embed_tokens
=
embed_tokens
self
.
normalizer
=
normalizer
self
.
embed_tokens_per_layer
=
embed_tokens_per_layer
self
.
embed_scale_per_layer
=
embed_scale_per_layer
self
.
per_layer_model_projection
=
per_layer_model_projection
self
.
per_layer_projection_norm
=
per_layer_projection_norm
self
.
per_layer_input_scale
=
per_layer_input_scale
self
.
per_layer_projection_scale
=
per_layer_projection_scale
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
*
self
.
normalizer
def
get_per_layer_inputs
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
|
None
:
"""Get per-layer embeddings from embed_tokens_per_layer.
Returns:
Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input)
"""
if
self
.
embed_tokens_per_layer
is
None
:
return
None
per_layer_inputs_mask
=
torch
.
logical_and
(
input_ids
>=
0
,
input_ids
<
self
.
vocab_size_per_layer_input
,
)
per_layer_inputs_tokens
=
torch
.
where
(
per_layer_inputs_mask
,
input_ids
,
torch
.
zeros_like
(
input_ids
)
)
per_layer_embeds
=
self
.
embed_tokens_per_layer
(
per_layer_inputs_tokens
)
per_layer_embeds
=
per_layer_embeds
*
self
.
embed_scale_per_layer
return
per_layer_embeds
.
reshape
(
*
input_ids
.
shape
,
self
.
config
.
num_hidden_layers
,
self
.
hidden_size_per_layer_input
,
)
def
project_per_layer_inputs
(
self
,
inputs_embeds
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
None
:
"""Project inputs_embeds and combine with per_layer_inputs.
Steps:
1. Project inputs_embeds: hidden_size → total_ple_dim
2. Scale by hidden_size^{-0.5}
3. Reshape to (num_tokens, num_layers, per_layer_dim)
4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
"""
if
self
.
per_layer_model_projection
is
None
:
return
None
per_layer_projection
=
self
.
per_layer_model_projection
(
inputs_embeds
)
per_layer_projection
=
per_layer_projection
*
self
.
per_layer_projection_scale
per_layer_projection
=
per_layer_projection
.
reshape
(
*
inputs_embeds
.
shape
[:
-
1
],
self
.
config
.
num_hidden_layers
,
self
.
hidden_size_per_layer_input
,
)
per_layer_projection
=
self
.
per_layer_projection_norm
(
per_layer_projection
)
if
per_layer_inputs
is
None
:
return
per_layer_projection
return
(
per_layer_projection
+
per_layer_inputs
)
*
self
.
per_layer_input_scale
def
forward
(
self
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
per_layer_inputs
=
self
.
project_per_layer_inputs
(
hidden_states
,
per_layer_inputs
)
else
:
hidden_states
=
self
.
embed_input_ids
(
input_ids
)
per_layer_embeds
=
self
.
get_per_layer_inputs
(
input_ids
)
per_layer_inputs
=
self
.
project_per_layer_inputs
(
hidden_states
,
per_layer_embeds
)
hidden_states
=
_run_decoder_layers
(
self
.
decoder_layers
,
self
.
layer_idx_start
,
positions
,
hidden_states
,
per_layer_inputs
,
**
kwargs
,
)
return
hidden_states
,
per_layer_inputs
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma4CrossDecoderLayers
(
nn
.
Module
):
"""Cross-decoder layers (YOCO second half, KV-shared)."""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
decoder_layers
:
list
[
Gemma4DecoderLayer
],
layer_idx_start
:
int
,
):
super
().
__init__
()
self
.
decoder_layers
=
decoder_layers
self
.
layer_idx_start
=
layer_idx_start
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
return
_run_decoder_layers
(
self
.
decoder_layers
,
self
.
layer_idx_start
,
positions
,
hidden_states
,
per_layer_inputs
,
**
kwargs
,
)
@
support_torch_compile
(
enable_if
=
lambda
vllm_config
:
not
vllm_config
.
cache_config
.
kv_sharing_fast_prefill
)
class
Gemma4Model
(
nn
.
Module
,
EagleModelMixin
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
_get_text_config
(
vllm_config
.
model_config
.
hf_config
)
...
...
@@ -740,6 +948,75 @@ class Gemma4Model(nn.Module):
torch
.
tensor
(
config
.
hidden_size
**
0.5
),
persistent
=
False
,
)
# --- You Only Cache Once (YOCO) split for fast prefill ---
first_kv_shared_layer_idx
=
config
.
num_hidden_layers
-
getattr
(
config
,
"num_kv_shared_layers"
,
0
)
from
vllm.compilation.backends
import
set_model_tag
# Layers 0..(K-1) are self-decoder layers in YOCO
with
set_model_tag
(
"self_decoder"
):
self
.
self_decoder
=
Gemma4SelfDecoderLayers
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.self_decoder"
,
decoder_layers
=
self
.
layers
[:
first_kv_shared_layer_idx
],
layer_idx_start
=
0
,
embed_tokens
=
self
.
embed_tokens
,
normalizer
=
self
.
normalizer
,
embed_tokens_per_layer
=
getattr
(
self
,
"embed_tokens_per_layer"
,
None
),
embed_scale_per_layer
=
getattr
(
self
,
"embed_scale_per_layer"
,
None
),
per_layer_model_projection
=
getattr
(
self
,
"per_layer_model_projection"
,
None
),
per_layer_projection_norm
=
getattr
(
self
,
"per_layer_projection_norm"
,
None
),
per_layer_input_scale
=
getattr
(
self
,
"per_layer_input_scale"
,
None
),
per_layer_projection_scale
=
getattr
(
self
,
"per_layer_projection_scale"
,
None
),
)
# Layers K..(N-1) are cross-decoder layers in YOCO
with
set_model_tag
(
"cross_decoder"
):
self
.
cross_decoder
=
Gemma4CrossDecoderLayers
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.cross_decoder"
,
decoder_layers
=
self
.
layers
[
first_kv_shared_layer_idx
:],
layer_idx_start
=
first_kv_shared_layer_idx
,
)
self
.
fast_prefill_enabled
=
cache_config
.
kv_sharing_fast_prefill
if
self
.
fast_prefill_enabled
:
# Allocate static buffers for CUDAGraph
max_num_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
device
=
next
(
self
.
parameters
()).
device
self
.
positions
=
torch
.
zeros
(
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
self
.
hidden_states
=
torch
.
zeros
(
(
max_num_tokens
,
config
.
hidden_size
),
dtype
=
self
.
embed_tokens
.
weight
.
dtype
,
device
=
device
,
)
if
(
self
.
hidden_size_per_layer_input
and
self
.
hidden_size_per_layer_input
>
0
):
self
.
per_layer_inputs
=
torch
.
zeros
(
(
max_num_tokens
,
config
.
num_hidden_layers
,
self
.
hidden_size_per_layer_input
,
),
dtype
=
self
.
embed_tokens
.
weight
.
dtype
,
device
=
device
,
)
else
:
self
.
per_layer_inputs
=
None
# Custom factory that includes per_layer_inputs for PLE-enabled PP.
# per_layer_inputs has shape (batch, num_layers, per_layer_dim),
# which differs from the standard (batch, hidden_size) shape,
...
...
@@ -776,47 +1053,22 @@ class Gemma4Model(nn.Module):
self
.
make_empty_intermediate_tensors
=
_make_empty_intermediate_tensors
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
*
self
.
normalizer
return
self
.
self_decoder
.
embed_input_ids
(
input_ids
)
def
get_per_layer_inputs
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_per_layer_inputs
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
|
None
:
"""Get per-layer embeddings from embed_tokens_per_layer.
Returns:
Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input)
"""
if
self
.
embed_tokens_per_layer
is
None
:
return
None
# Handle out-of-vocab tokens for PLE (vocab_size_per_layer_input may
# be smaller than the main vocab_size).
per_layer_inputs_mask
=
torch
.
logical_and
(
input_ids
>=
0
,
input_ids
<
self
.
vocab_size_per_layer_input
,
)
per_layer_inputs_tokens
=
torch
.
where
(
per_layer_inputs_mask
,
input_ids
,
torch
.
zeros_like
(
input_ids
)
)
# Get packed per-layer embeddings: (num_tokens, total_ple_dim)
per_layer_embeds
=
self
.
embed_tokens_per_layer
(
per_layer_inputs_tokens
)
# Apply embed_scale (sqrt of per-layer hidden dim)
per_layer_embeds
=
per_layer_embeds
*
self
.
embed_scale_per_layer
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input)
per_layer_embeds
=
per_layer_embeds
.
reshape
(
*
input_ids
.
shape
,
self
.
config
.
num_hidden_layers
,
self
.
hidden_size_per_layer_input
,
)
return
per_layer_embeds
return
self
.
self_decoder
.
get_per_layer_inputs
(
input_ids
)
def
project_per_layer_inputs
(
self
,
inputs_embeds
:
torch
.
Tensor
,
per_layer_inputs
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
|
None
:
"""Project inputs_embeds and combine with per_layer_inputs.
Steps:
...
...
@@ -826,29 +1078,94 @@ class Gemma4Model(nn.Module):
4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
"""
if
self
.
per_layer_model_projection
is
None
:
return
None
return
self
.
self_decoder
.
project_per_layer_inputs
(
inputs_embeds
,
per_layer_inputs
)
# Project from hidden_size to total_ple_dim
# Scaled projection: output = linear(input, weight) * scale
per_layer_projection
=
self
.
per_layer_model_projection
(
inputs_embeds
)
per_layer_projection
=
per_layer_projection
*
self
.
per_layer_projection_scale
def
fast_prefill_forward
(
self
,
input_ids
:
torch
.
Tensor
|
None
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
logits_indices_padded
,
num_logits_indices
=
None
,
None
attn_metadata
=
get_forward_context
().
attn_metadata
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input)
per_layer_projection
=
per_layer_projection
.
reshape
(
*
inputs_embeds
.
shape
[:
-
1
],
self
.
config
.
num_hidden_layers
,
self
.
hidden_size_per_layer_input
,
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
layer_attn_metadata
=
attn_metadata
[
self
.
layers
[
-
1
].
self_attn
.
attn
.
layer_name
]
if
isinstance
(
layer_attn_metadata
,
KVSharingFastPrefillMetadata
):
logits_indices_padded
=
layer_attn_metadata
.
logits_indices_padded
num_logits_indices
=
layer_attn_metadata
.
num_logits_indices
batch_size
=
positions
.
size
(
0
)
self
.
positions
[:
batch_size
].
copy_
(
positions
)
self_decoder_hidden_states
,
per_layer_inputs
=
self
.
self_decoder
(
input_ids
=
input_ids
,
positions
=
self
.
positions
[:
batch_size
],
inputs_embeds
=
inputs_embeds
,
per_layer_inputs
=
per_layer_inputs
,
**
kwargs
,
)
# Normalize
per_layer_projection
=
self
.
per_layer_projection_norm
(
per_layer_projection
)
if
logits_indices_padded
is
None
:
logits_indices_padded
=
torch
.
arange
(
batch_size
,
dtype
=
positions
.
dtype
,
device
=
positions
.
device
,
)
if
per_layer_inputs
is
None
:
return
per_layer_projection
# NOTE: Keep .clone() until fix in
# https://github.com/vllm-project/vllm/pull/22282
hidden_states
=
self_decoder_hidden_states
.
clone
()
# Combine: (projection + per_layer_inputs) * scale
return
(
per_layer_projection
+
per_layer_inputs
)
*
self
.
per_layer_input_scale
num_padded
=
logits_indices_padded
.
size
(
0
)
self
.
positions
[:
num_padded
].
copy_
(
positions
[
logits_indices_padded
])
self
.
hidden_states
[:
num_padded
].
copy_
(
self_decoder_hidden_states
[
logits_indices_padded
]
)
if
self
.
per_layer_inputs
is
not
None
and
per_layer_inputs
is
not
None
:
self
.
per_layer_inputs
[:
num_padded
].
copy_
(
per_layer_inputs
[
logits_indices_padded
]
)
# Update batch_descriptor so the cross-decoder's piecewise
# CUDAGraphWrapper dispatches to the correct (reduced) batch size.
forward_context
=
get_forward_context
()
orig_batch_desc
=
forward_context
.
batch_descriptor
if
orig_batch_desc
is
not
None
:
forward_context
.
batch_descriptor
=
replace
(
orig_batch_desc
,
num_tokens
=
num_padded
)
cross_per_layer
=
(
self
.
per_layer_inputs
[:
num_padded
]
if
self
.
per_layer_inputs
is
not
None
else
None
)
cross_hidden_states
=
self
.
cross_decoder
(
self
.
positions
[:
num_padded
],
self
.
hidden_states
[:
num_padded
],
cross_per_layer
,
**
kwargs
,
)
# Restore the original batch_descriptor
forward_context
.
batch_descriptor
=
orig_batch_desc
if
num_logits_indices
is
not
None
:
assert
num_logits_indices
>
0
hidden_states
[
logits_indices_padded
[:
num_logits_indices
]]
=
(
cross_hidden_states
[:
num_logits_indices
]
)
else
:
hidden_states
=
cross_hidden_states
return
hidden_states
def
forward
(
self
,
...
...
@@ -858,7 +1175,19 @@ class Gemma4Model(nn.Module):
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
per_layer_inputs
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
|
IntermediateTensors
:
)
->
torch
.
Tensor
|
IntermediateTensors
|
tuple
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]:
if
self
.
fast_prefill_enabled
:
hidden_states
=
self
.
fast_prefill_forward
(
input_ids
,
positions
,
inputs_embeds
,
per_layer_inputs
,
**
kwargs
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
# Normal (non-fast-prefill) path with PP support
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
...
...
@@ -882,6 +1211,7 @@ class Gemma4Model(nn.Module):
residual
=
intermediate_tensors
[
"residual"
]
per_layer_inputs
=
intermediate_tensors
.
get
(
"per_layer_inputs"
)
aux_hidden_states
=
self
.
_maybe_add_hidden_state
([],
0
,
hidden_states
,
residual
)
for
layer_idx
,
layer
in
enumerate
(
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
)
):
...
...
@@ -900,6 +1230,9 @@ class Gemma4Model(nn.Module):
per_layer_input
=
layer_per_input
,
**
kwargs
,
)
self
.
_maybe_add_hidden_state
(
aux_hidden_states
,
layer_idx
+
1
,
hidden_states
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
(
{
...
...
@@ -914,6 +1247,9 @@ class Gemma4Model(nn.Module):
hidden_states
=
self
.
norm
(
hidden_states
)
else
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
len
(
aux_hidden_states
)
>
0
:
return
hidden_states
,
aux_hidden_states
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
...
...
@@ -926,21 +1262,27 @@ class Gemma4Model(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
# MoE expert weight mapping: checkpoint 3D packed tensors are
# exploded in _weight_iterator to per-expert 2D weights like:
# MoE expert weight mapping: checkpoint can have either:
# 1. 3D packed tensors (exploded in _weight_iterator to per-expert 2D)
# 2. Already per-expert 2D weights (if quantized)
# Map to FusedMoE parameters:
# moe.experts.{id}.gate_proj → FusedMoE w1 (shard of w13)
# moe.experts.{id}.up_proj → FusedMoE w3 (shard of w13)
# moe.experts.{id}.down_proj → FusedMoE w2
# We build the mapping directly since Gemma4 uses bare param
# names (no .weight suffix) unlike standard MoE checkpoints.
#
# Use prefix matching to handle both weights and
# quantization scale parameters. The param_name is a prefix ending
# in underscore, and weight_name ends with a dot, so that:
# "experts.0.gate_proj.weight_scale" -> "experts.w13_weight_scale"
# "experts.0.gate_proj.weight" -> "experts.w13_weight"
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
None
)
or
0
expert_params_mapping
=
[
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_
weight
"
"experts.w13_"
if
proj_name
in
[
"gate_proj"
,
"up_proj"
]
else
"experts.w2_
weight
"
,
f
"experts.
{
expert_id
}
.
{
proj_name
}
"
,
else
"experts.w2_"
,
f
"experts.
{
expert_id
}
.
{
proj_name
}
.
"
,
expert_id
,
shard_id
,
)
...
...
@@ -1000,9 +1342,21 @@ class Gemma4Model(nn.Module):
expert_id
,
shard_id
,
)
in
expert_params_mapping
:
if
weight_name
not
in
name
:
# Match both:
# - Bare weights: "experts.0.down_proj" (from 3D explosion)
# - With suffix: "experts.0.down_proj.weight_scale" (2D quantized)
# weight_name has trailing dot, so check with and without it
weight_name_base
=
weight_name
.
rstrip
(
"."
)
if
weight_name
in
name
:
# Has suffix (e.g., .weight_scale)
moe_name
=
name
.
replace
(
weight_name
,
param_name
)
elif
name
.
endswith
(
weight_name_base
):
# Bare weight (no suffix)
moe_name
=
name
.
replace
(
weight_name_base
,
param_name
.
rstrip
(
"_"
)
+
"_weight"
)
else
:
continue
moe_name
=
name
.
replace
(
weight_name
,
param_name
)
if
moe_name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
moe_name
,
self
):
...
...
@@ -1012,15 +1366,12 @@ class Gemma4Model(nn.Module):
# orientation for FusedMoE after _weight_iterator:
# gate/up: [I, H] → w1/w3 expects [I, H]
# down: [H, I] → w2 expects [H, I]
assert
loaded_weight
.
dim
()
==
2
,
(
f
"Expected 2D expert weight for
{
weight_name
}
, "
f
"got shape
{
loaded_weight
.
shape
}
"
)
# Scales and other quantization params may be 1D or scalar.
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_name
+
".weight"
,
moe_name
,
# Pass mapped name (handles both weights and scales)
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
...
...
@@ -1044,7 +1395,25 @@ class Gemma4Model(nn.Module):
return
loaded_params
class
Gemma4ForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
,
MixtureOfExperts
):
class
Gemma4ForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
,
MixtureOfExperts
,
SupportsEagle3
):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
# Gemma4ForConditionalGeneration already loads the text stack
# from `model.language_model.*`. We reuse that same checkpoint
# and adapter naming for the text-only Gemma4ForCausalLM path,
# so LoRA keys from the conditional wrapper map onto `model.*`.
"model.language_model."
:
"model."
,
},
orig_to_new_substr
=
{
# Gemma4ForConditionalGeneration names MoE adapter targets under
# `...moe.experts.*`, while the text-only model exposes them
# under `...moe.*`.
".moe.experts.gate_up_proj"
:
".moe.gate_up_proj"
,
".moe.experts.down_proj"
:
".moe.down_proj"
,
},
)
# Note: qkv_proj packing applies to non-k_eq_v layers (sliding
# attention and full attention without k_eq_v). k_eq_v layers use
# separate q_proj + k_proj without packing.
...
...
@@ -1126,7 +1495,7 @@ class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
|
IntermediateTensors
:
)
->
torch
.
Tensor
|
IntermediateTensors
|
tuple
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
,
**
kwargs
)
...
...
@@ -1177,6 +1546,11 @@ class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
".moe.down_proj"
,
)
# Remap individual 2D expert weights:
# .experts.{id}.{proj} → .moe.experts.{id}.{proj}
# (This handles per-expert 2D quantized weights)
name
=
re
.
sub
(
r
"\.experts\.(\d+)\."
,
r
".moe.experts.\1."
,
name
)
# MoE expert weights: checkpoint stores as 3D packed
# tensors. Explode into per-expert 2D weights for
# FusedMoE weight_loader.
...
...
vllm/model_executor/models/gemma4_mm.py
View file @
fc67613a
...
...
@@ -65,7 +65,12 @@ from vllm.multimodal.processing.processor import (
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsEagle3
,
SupportsMultiModal
,
SupportsPP
,
)
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
...
...
@@ -121,8 +126,12 @@ class Gemma4AudioInputs(TensorSchema):
"""
type
:
Literal
[
"audio"
]
=
"audio"
input_features_padded
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
"s"
,
"f"
)]
input_features_mask
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
"s"
)]
input_features_padded
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
"s"
,
"f"
,
dynamic_dims
=
{
"s"
})
]
input_features_mask
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"bn"
,
"s"
,
dynamic_dims
=
{
"s"
})
]
Gemma4ImageInputs
=
Gemma4ImagePixelInputs
...
...
@@ -163,10 +172,15 @@ class Gemma4ProcessingInfo(BaseProcessingInfo):
Setting ``add_special_tokens=False`` here prevents the duplicate and
ensures both ``llm.generate()`` and the chat/completions API behave
correctly.
correctly for IT models. For PT models (without chat template), we
keep the default (True) to ensure BOS is added for raw prompts.
"""
tokenizer
=
self
.
ctx
.
get_tokenizer
()
has_chat_template
=
getattr
(
tokenizer
,
"chat_template"
,
None
)
is
not
None
params
=
super
().
get_default_tok_params
()
params
=
params
.
with_kwargs
(
add_special_tokens
=
False
)
if
has_chat_template
:
params
=
params
.
with_kwargs
(
add_special_tokens
=
False
)
return
params
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
Gemma4Processor
:
...
...
@@ -503,6 +517,8 @@ class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]):
video_timestamps_per_video
:
list
[
list
[
float
]]
=
[]
video_frame_counts
:
list
[
int
]
=
[]
video_replacements
:
list
[
str
]
=
[]
for
item
in
videos
:
video_array
,
metadata
=
item
...
...
@@ -555,10 +571,7 @@ class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]):
video_timestamps_per_video
.
append
(
timestamps
)
video_frame_counts
.
append
(
len
(
frames
))
# Build expanded replacement text and replace the
# <|video|> placeholder in the prompt.
# Use split(token, 1) to avoid collision — the
# replacement text itself contains <|video|> tokens.
# Build expanded replacement text for this video.
ts_strs
=
[
f
"
{
int
(
s
//
60
):
02
d
}
:
{
int
(
s
%
60
):
02
d
}
"
for
s
in
timestamps
]
replacement
=
" "
.
join
(
f
"
{
t
}
{
processor
.
boi_token
}
"
...
...
@@ -566,9 +579,23 @@ class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]):
f
"
{
processor
.
eoi_token
}
"
for
t
,
n
in
zip
(
ts_strs
,
num_soft_per_frame
)
)
parts
=
prompt
.
split
(
processor
.
video_token
,
1
)
if
len
(
parts
)
==
2
:
prompt
=
parts
[
0
]
+
replacement
+
parts
[
1
]
video_replacements
.
append
(
replacement
)
# Replace all <|video|> placeholders at once. We split on
# video_token to get N+1 parts, then interleave with the
# N replacement strings. This avoids the iterative
# split-replace bug where replacement text (which itself
# contains <|video|> tokens) collides with later splits.
vt
=
processor
.
video_token
parts
=
prompt
.
split
(
vt
,
len
(
video_replacements
))
# NOTE: len(parts) <= len(video_replacements) + 1
parts_with_repl
:
list
[
str
]
=
[]
for
part
,
repl
in
zip
(
parts
,
video_replacements
):
parts_with_repl
.
extend
([
part
,
repl
])
parts_with_repl
.
extend
(
parts
[
len
(
video_replacements
)
:])
prompt
=
""
.
join
(
parts_with_repl
)
video_outputs
=
{
"pixel_values_videos"
:
torch
.
cat
(
all_video_pixel_values
,
dim
=
0
),
...
...
@@ -631,19 +658,23 @@ class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]):
)
if
"input_features"
in
processed_outputs
:
# Keep padded features for batched audio tower execution.
processed_outputs
[
"input_features_padded"
]
=
processed_outputs
[
"input_features"
]
# Unpad per-item so each item's cache entry is self-contained.
# Unpad per-item so each item's cache entry is
# self-contained. The batched() field config in
# _get_mm_fields_config will re-pad all fields to the
# batch's max length at batch time, ensuring consistent
# padding regardless of cache history.
masks
=
processed_outputs
[
"input_features_mask"
]
unpadded_features
=
[
f
[
mask
]
for
f
,
mask
in
zip
(
processed_outputs
[
"input_features"
],
processed_outputs
[
"input_features_
mask
"
]
,
mask
s
,
)
]
unpadded_masks
=
[
mask
[
mask
]
for
mask
in
masks
]
processed_outputs
[
"input_features"
]
=
unpadded_features
processed_outputs
[
"input_features_padded"
]
=
unpadded_features
processed_outputs
[
"input_features_mask"
]
=
unpadded_masks
# Merge video outputs into the final result
combined_outputs
=
dict
(
processed_outputs
,
**
video_outputs
)
...
...
@@ -848,7 +879,12 @@ class Gemma4MultimodalEmbedder(nn.Module):
info
=
Gemma4ProcessingInfo
,
dummy_inputs
=
Gemma4DummyInputsBuilder
,
)
class
Gemma4ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
class
Gemma4ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
SupportsEagle3
,
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
vllm/model_executor/models/kimi_k25.py
View file @
fc67613a
...
...
@@ -113,7 +113,29 @@ class KimiK25ProcessingInfo(BaseProcessingInfo):
trust_remote_code
=
self
.
ctx
.
model_config
.
trust_remote_code
,
)
self
.
media_token_id
=
media_token_id
=
hf_config
.
media_placeholder_token_id
# Resolve token ID from the tokenizer because transformers v5
# may remap token IDs vs config.json.
config_token_id
=
hf_config
.
media_placeholder_token_id
resolved_token_id
=
tokenizer
.
convert_tokens_to_ids
(
"<|media_pad|>"
)
is_valid_resolved
=
isinstance
(
resolved_token_id
,
int
)
and
(
tokenizer
.
unk_token_id
is
None
or
resolved_token_id
!=
tokenizer
.
unk_token_id
)
if
is_valid_resolved
and
resolved_token_id
!=
config_token_id
:
logger
.
warning_once
(
"Kimi-K2.5 config.media_placeholder_token_id (%d) disagrees "
"with tokenizer mapping for <|media_pad|> (%d). "
"Using tokenizer value."
,
config_token_id
,
resolved_token_id
,
)
media_token_id
=
resolved_token_id
# Patch config so downstream code also sees the correct ID.
hf_config
.
media_placeholder_token_id
=
resolved_token_id
else
:
media_token_id
=
config_token_id
self
.
media_token_id
=
media_token_id
self
.
media_token
=
tokenizer
.
decode
(
media_token_id
)
self
.
image_processor
=
image_processor
...
...
@@ -232,8 +254,7 @@ class KimiK25MultiModalProcessor(BaseMultiModalProcessor[KimiK25ProcessingInfo])
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
out_mm_kwargs
:
MultiModalKwargsItems
,
)
->
Sequence
[
PromptUpdate
]:
hf_config
=
self
.
info
.
get_hf_config
()
media_token_id
=
hf_config
.
media_placeholder_token_id
media_token_id
=
self
.
info
.
media_token_id
def
get_replacement
(
item_idx
:
int
):
media
=
mm_items
.
get_items
(
"vision_chunk"
,
(
VisionChunkProcessorItems
,))
...
...
vllm/model_executor/models/minimax_m2.py
View file @
fc67613a
...
...
@@ -232,9 +232,7 @@ class MiniMaxM2Attention(nn.Module):
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
MiniMaxText01RMSNormTP
.
forward_qk
(
self
.
q_norm
,
self
.
k_norm
,
q
.
contiguous
(),
k
.
contiguous
()
)
q
,
k
=
MiniMaxText01RMSNormTP
.
forward_qk
(
self
.
q_norm
,
self
.
k_norm
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
...
...
vllm/model_executor/models/musicflamingo.py
View file @
fc67613a
...
...
@@ -32,9 +32,9 @@ from transformers.models.musicflamingo import (
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.inputs
import
MultiModalDataDict
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargsItems
,
)
...
...
vllm/model_executor/models/transformers/__init__.py
View file @
fc67613a
...
...
@@ -16,13 +16,11 @@
# limitations under the License.
"""Wrapper around `transformers` models"""
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.model_executor.models.transformers.base
import
Base
from
vllm.model_executor.models.transformers.causal
import
CausalMixin
from
vllm.model_executor.models.transformers.legacy
import
LegacyMixin
from
vllm.model_executor.models.transformers.moe
import
MoEMixin
from
vllm.model_executor.models.transformers.multimodal
import
(
DYNAMIC_ARG_DIMS
,
MultiModalDummyInputsBuilder
,
MultiModalMixin
,
MultiModalProcessingInfo
,
...
...
@@ -32,16 +30,13 @@ from vllm.model_executor.models.transformers.pooling import (
EmbeddingMixin
,
SequenceClassificationMixin
,
)
from
vllm.model_executor.models.transformers.utils
import
can_enable_torch_compile
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
# Text only models
@
support_torch_compile
(
enable_if
=
can_enable_torch_compile
)
class
TransformersForCausalLM
(
CausalMixin
,
Base
):
...
@
support_torch_compile
(
enable_if
=
can_enable_torch_compile
)
class
TransformersMoEForCausalLM
(
MoEMixin
,
CausalMixin
,
Base
):
...
...
...
@@ -51,9 +46,6 @@ class TransformersMoEForCausalLM(MoEMixin, CausalMixin, Base): ...
info
=
MultiModalProcessingInfo
,
dummy_inputs
=
MultiModalDummyInputsBuilder
,
)
@
support_torch_compile
(
dynamic_arg_dims
=
DYNAMIC_ARG_DIMS
,
enable_if
=
can_enable_torch_compile
)
class
TransformersMultiModalForCausalLM
(
MultiModalMixin
,
CausalMixin
,
Base
):
...
...
...
@@ -62,20 +54,15 @@ class TransformersMultiModalForCausalLM(MultiModalMixin, CausalMixin, Base): ...
info
=
MultiModalProcessingInfo
,
dummy_inputs
=
MultiModalDummyInputsBuilder
,
)
@
support_torch_compile
(
dynamic_arg_dims
=
DYNAMIC_ARG_DIMS
,
enable_if
=
can_enable_torch_compile
)
class
TransformersMultiModalMoEForCausalLM
(
MoEMixin
,
MultiModalMixin
,
CausalMixin
,
Base
):
...
# Embedding models
@
support_torch_compile
(
enable_if
=
can_enable_torch_compile
)
class
TransformersEmbeddingModel
(
EmbeddingMixin
,
LegacyMixin
,
Base
):
...
@
support_torch_compile
(
enable_if
=
can_enable_torch_compile
)
class
TransformersMoEEmbeddingModel
(
EmbeddingMixin
,
MoEMixin
,
Base
):
...
...
...
@@ -84,20 +71,15 @@ class TransformersMoEEmbeddingModel(EmbeddingMixin, MoEMixin, Base): ...
info
=
MultiModalProcessingInfo
,
dummy_inputs
=
MultiModalDummyInputsBuilder
,
)
@
support_torch_compile
(
dynamic_arg_dims
=
DYNAMIC_ARG_DIMS
,
enable_if
=
can_enable_torch_compile
)
class
TransformersMultiModalEmbeddingModel
(
EmbeddingMixin
,
MultiModalMixin
,
Base
):
...
# Sequence classification models
@
support_torch_compile
(
enable_if
=
can_enable_torch_compile
)
class
TransformersForSequenceClassification
(
SequenceClassificationMixin
,
LegacyMixin
,
Base
):
...
@
support_torch_compile
(
enable_if
=
can_enable_torch_compile
)
class
TransformersMoEForSequenceClassification
(
SequenceClassificationMixin
,
MoEMixin
,
Base
):
...
...
...
@@ -108,9 +90,6 @@ class TransformersMoEForSequenceClassification(
info
=
MultiModalProcessingInfo
,
dummy_inputs
=
MultiModalDummyInputsBuilder
,
)
@
support_torch_compile
(
dynamic_arg_dims
=
DYNAMIC_ARG_DIMS
,
enable_if
=
can_enable_torch_compile
)
class
TransformersMultiModalForSequenceClassification
(
SequenceClassificationMixin
,
MultiModalMixin
,
Base
):
...
...
...
vllm/model_executor/models/transformers/base.py
View file @
fc67613a
...
...
@@ -16,6 +16,7 @@
# limitations under the License.
"""Transformers modeling backend base class."""
import
sys
from
collections.abc
import
Callable
,
Iterable
from
itertools
import
chain
from
operator
import
attrgetter
...
...
@@ -29,6 +30,7 @@ from torch import nn
from
transformers
import
AutoModel
from
transformers.modeling_utils
import
ALL_ATTENTION_FUNCTIONS
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config.utils
import
getattr_iter
from
vllm.distributed
import
get_pp_group
,
get_tp_group
from
vllm.distributed.utils
import
get_pp_indices
...
...
@@ -47,6 +49,7 @@ from vllm.model_executor.models.interfaces import (
)
from
vllm.model_executor.models.interfaces_base
import
VllmModel
from
vllm.model_executor.models.transformers.utils
import
(
can_enable_torch_compile
,
get_feature_request_tip
,
init_on_device_without_buffers
,
log_replacement
,
...
...
@@ -117,6 +120,7 @@ class Base(
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
text_config
=
self
.
config
.
get_text_config
()
self
.
cache_config
=
vllm_config
.
cache_config
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
device_config
=
vllm_config
.
device_config
self
.
model_config
=
vllm_config
.
model_config
self
.
parallel_config
=
vllm_config
.
parallel_config
...
...
@@ -146,7 +150,7 @@ class Base(
if
self
.
quant_config
:
quant_method_name
=
self
.
quant_config
.
get_name
()
# Check for unsupported quantization methods.
if
quant_method_name
==
"
mxfp4"
:
if
quant_method_name
in
(
"mxfp4"
,
"gpt_oss_
mxfp4"
)
:
raise
NotImplementedError
(
"Transformers modeling backend does "
"not support MXFP4 quantization yet."
...
...
@@ -155,14 +159,16 @@ class Base(
if
"gptq"
in
quant_method_name
:
self
.
ignore_unexpected_suffixes
.
append
(
".bias"
)
# Patch config and init on "meta" to delay allocating GPU tensors
self
.
_patch_config
()
from_config_kwargs
=
dict
(
config
=
self
.
config
,
dtype
=
self
.
model_config
.
dtype
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
self
.
_decorate_for_torch_compile
(
**
from_config_kwargs
)
# Init on "meta" to delay allocating GPU tensors
with
init_on_device_without_buffers
(
"meta"
):
self
.
model
:
PreTrainedModel
=
AutoModel
.
from_config
(
self
.
config
,
dtype
=
self
.
model_config
.
dtype
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
self
.
model
:
PreTrainedModel
=
AutoModel
.
from_config
(
**
from_config_kwargs
)
# Create weight name to module qualname mapper
self
.
_create_hf_to_vllm_mapper
()
...
...
@@ -218,6 +224,87 @@ class Base(
if
sub_config
.
dtype
!=
(
dtype
:
=
self
.
config
.
dtype
):
sub_config
.
dtype
=
dtype
def
_get_decoder_cls
(
self
,
**
kwargs
:
dict
)
->
type
[
PreTrainedModel
]:
"""
Get the decoder class from the model.
Args:
kwargs: The kwargs to create the model.
Returns:
The decoder class.
"""
with
torch
.
device
(
"meta"
):
model
:
PreTrainedModel
=
AutoModel
.
from_config
(
**
kwargs
)
decoder_cls
=
type
(
model
.
get_decoder
())
logger
.
debug
(
"Identified decoder class as: %s"
,
decoder_cls
)
del
model
return
decoder_cls
def
_decorate_cls_for_torch_compile
(
self
,
cls
:
type
[
PreTrainedModel
],
dynamic_arg_dims
:
dict
[
str
,
int
]
|
None
,
enable_if
:
Callable
[[
"VllmConfig"
],
bool
],
is_encoder
:
bool
,
):
"""
Decorate `cls` to indicate to vLLM that it supports torch compile.
Args:
cls: The PreTrainedModel class to decorate.
dynamic_arg_dims: A mapping from argument name to the dynamic dimensions
of the argument. If None, default dynamic arg dims will be used. See
[`support_torch_compile`][vllm.compilation.decorators.support_torch_compile]
for more details.
enable_if: A function which takes in the vLLM config and returns whether
torch compile should be enabled for this class.
is_encoder: Whether the class being decorated is an encoder.
"""
logger
.
debug
(
"Decorating `%s` as %s for torch compile with dynamic_arg_dims of %s"
,
cls
.
__name__
,
"encoder"
if
is_encoder
else
"decoder"
,
dynamic_arg_dims
,
)
@
support_torch_compile
(
dynamic_arg_dims
=
dynamic_arg_dims
,
enable_if
=
enable_if
,
is_encoder
=
is_encoder
,
)
class
SupportTorchCompileWrapper
(
cls
):
...
# Preserve __module__ so transformers v5's source-file checks
# (e.g. _can_set_experts_implementation) read the original
# model's module instead of this file.
SupportTorchCompileWrapper
.
__module__
=
cls
.
__module__
# Patch the class in its module
module
=
sys
.
modules
[
cls
.
__module__
]
setattr
(
module
,
cls
.
__name__
,
SupportTorchCompileWrapper
)
def
_decorate_for_torch_compile
(
self
,
**
kwargs
:
dict
):
"""
Decorate the model's decoder class to indicate to vLLM that it supports torch
compile if `can_enable_torch_compile` is True.
Args:
kwargs: The kwargs to create the model, which are needed to get the decoder
class.
"""
self
.
_decorate_cls_for_torch_compile
(
cls
=
self
.
_get_decoder_cls
(
**
kwargs
),
# Applied to a PreTrainedModel so the batch dimension will exist
dynamic_arg_dims
=
dict
[
str
,
int
](
input_ids
=
1
,
# shape: [1, seq_len]
inputs_embeds
=
1
,
# shape: [1, seq_len, hidden_size]
position_ids
=-
1
,
# shape: [1, seq_len] or [3, 1, seq_len] for mrope
),
enable_if
=
can_enable_torch_compile
,
is_encoder
=
False
,
)
def
_create_hf_to_vllm_mapper
(
self
):
"""
Create a WeightsMapper to map checkpoint weight names to module qualnames.
...
...
@@ -553,11 +640,6 @@ class Base(
input_ids
=
None
inputs_embeds
=
intermediate_tensors
[
"hidden_states"
]
if
input_ids
is
not
None
:
input_ids
=
input_ids
[
None
,
...]
if
inputs_embeds
is
not
None
:
inputs_embeds
=
inputs_embeds
[
None
,
...]
# If the model scales embeddings inside the input embedding layer we must
# ensure they are scaled here since VocabParallelEmbedding will not do it
if
(
...
...
@@ -568,22 +650,29 @@ class Base(
inputs_embeds
=
self
.
embed_input_ids
(
input_ids
)
input_ids
=
None
if
self
.
model_config
.
uses_mrope
:
position_ids
=
positions
[:,
None
]
else
:
position_ids
=
positions
[
None
,
...]
# Add batch dimension before entering Transformers model
if
input_ids
is
not
None
and
input_ids
.
ndim
==
1
:
# [seq_len] -> [1, seq_len]
input_ids
=
input_ids
[
None
,
...]
if
inputs_embeds
is
not
None
and
inputs_embeds
.
ndim
==
2
:
# [seq_len, hidden_size] -> [1, seq_len, hidden_size]
inputs_embeds
=
inputs_embeds
[
None
,
...]
if
positions
.
ndim
==
1
:
# [seq_len] -> [1, seq_len]
positions
=
positions
[
None
,
...]
outputs
=
self
.
model
(
input_ids
=
input_ids
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
False
,
position_ids
=
position
_id
s
,
position_ids
=
positions
,
attention_instances
=
self
.
attention_instances
,
return_dict
=
False
,
**
self
.
_output_aux_hidden_states_kwargs
,
**
kwargs
,
)
# We must remove the batch dimension from these outputs
# Remove batch dimension after exiting Transformers model
hidden_states
=
outputs
[
0
][
0
,
...]
if
self
.
_output_aux_hidden_states_kwargs
:
aux_hidden_states
=
[
x
[
0
][
0
,
...]
for
x
in
outputs
[
1
:]]
...
...
vllm/model_executor/models/transformers/multimodal.py
View file @
fc67613a
...
...
@@ -20,7 +20,9 @@ from collections.abc import Mapping
from
typing
import
TYPE_CHECKING
import
torch
from
transformers
import
AutoModel
from
vllm.compilation.decorators
import
should_torch_compile_mm_encoder
from
vllm.config.utils
import
getattr_iter
from
vllm.inputs
import
MultiModalDataDict
,
MultiModalInput
,
mm_input
from
vllm.logger
import
init_logger
...
...
@@ -46,19 +48,11 @@ from vllm.platforms import current_platform
from
vllm.sequence
import
IntermediateTensors
if
TYPE_CHECKING
:
from
transformers
import
BatchFeature
from
transformers
import
BatchFeature
,
PreTrainedModel
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
DYNAMIC_ARG_DIMS
=
{
"input_ids"
:
0
,
# set `positions` to last dim to support Qwen-mrope
"positions"
:
-
1
,
"intermediate_tensors"
:
0
,
"inputs_embeds"
:
0
,
}
logger
=
init_logger
(
__name__
)
...
...
@@ -274,6 +268,66 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
# Skip SupportsMRoPE.__init__ and call the next class in MRO
super
(
SupportsMRoPE
,
self
).
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
def
_get_encoder_cls
(
self
,
modality
:
str
=
"image"
,
**
kwargs
:
dict
)
->
type
[
"PreTrainedModel"
]:
"""
Get the encoder class from the model.
Args:
kwargs: The kwargs to create the model.
Returns:
The encoder class.
"""
with
torch
.
device
(
"meta"
):
model
:
PreTrainedModel
=
AutoModel
.
from_config
(
**
kwargs
)
encoder_cls
=
type
(
model
.
get_encoder
(
modality
=
modality
))
logger
.
debug
(
"Identified encoder class as: %s"
,
encoder_cls
)
if
type
(
model
)
is
encoder_cls
:
raise
ValueError
(
"Unable to infer vision encoder class from the model. "
"You must either: update the model so that "
"https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.get_encoder"
" can detect the vision encoder correctly, or remove "
"'compile_mm_encoder'."
)
del
model
return
encoder_cls
def
_decorate_for_torch_compile
(
self
,
**
kwargs
:
dict
):
"""
Decorate the model's decoder and encoder classes to indicate to vLLM that they
support torch compile if `can_enable_torch_compile` and
`should_torch_compile_mm_encoder` are True respectively.
Args:
kwargs: The kwargs to create the model, which are needed to get the decoder
and encoder classes.
"""
super
().
_decorate_for_torch_compile
(
**
kwargs
)
# Decorate the vision encoder model class to support torch compile if needed
if
self
.
compilation_config
.
compile_mm_encoder
:
self
.
check_version
(
"5.0.0"
,
"multimodal encoder compilation support"
)
logger
.
warning_once
(
"Multimodal encoder compilation with the Transformers modeling backend "
"is an experimental feature. It relies on:
\n
"
"- The vision encoder being torch compilable.
\n
"
"- All vision encoder tensor inputs must be type hinted as either "
"`torch.Tensor` or `torch.FloatTensor`.
\n
"
"- The 0-th dimension of all tensor inputs to the vision encoder being "
"the dynamic dimension (i.e., sequence length or number of patches).
\n
"
"Please report any issues you encounter to help us improve it."
)
self
.
_decorate_cls_for_torch_compile
(
cls
=
self
.
_get_encoder_cls
(
**
kwargs
),
# TODO: properly infer dynamic_arg_dims based on the encoder's forward
# method signature. Currently we assume dim 0 for all tensor inputs.
dynamic_arg_dims
=
None
,
enable_if
=
should_torch_compile_mm_encoder
,
is_encoder
=
True
,
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
|
None
,
...
...
@@ -285,6 +339,10 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
# Gemma3 and PaliGemma needs `token_type_ids` to work correctly
# Other models will not have `token_type_ids` in kwargs
kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
==
"token_type_ids"
}
# Positions shape handling for MRoPE models
if
self
.
model_config
.
uses_mrope
:
# [3, seq_len] -> [3, 1, seq_len]
positions
=
positions
[:,
None
]
model_output
=
super
().
forward
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
,
**
kwargs
)
...
...
vllm/parser/abstract_parser.py
View file @
fc67613a
...
...
@@ -470,6 +470,15 @@ class DelegatingParser(Parser):
# No tool calls
return
[],
content
def
adjust_request
(
self
,
request
:
ChatCompletionRequest
|
ResponsesRequest
)
->
ChatCompletionRequest
|
ResponsesRequest
:
if
self
.
_reasoning_parser
is
not
None
:
request
=
self
.
_reasoning_parser
.
adjust_request
(
request
)
if
self
.
_tool_parser
is
not
None
:
request
=
self
.
_tool_parser
.
adjust_request
(
request
)
return
request
def
extract_reasoning_streaming
(
self
,
previous_text
:
str
,
...
...
vllm/reasoning/abs_reasoning_parsers.py
View file @
fc67613a
...
...
@@ -6,7 +6,7 @@ import os
from
abc
import
abstractmethod
from
collections.abc
import
Callable
,
Iterable
,
Sequence
from
functools
import
cached_property
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
cast
from
vllm.entrypoints.mcp.tool_server
import
ToolServer
from
vllm.logger
import
init_logger
...
...
@@ -150,6 +150,12 @@ class ReasoningParser:
previously been parsed and extracted (see constructor)
"""
def
adjust_request
(
self
,
request
:
"ChatCompletionRequest | ResponsesRequest"
)
->
"ChatCompletionRequest | ResponsesRequest"
:
"""Adjust request parameters; override in subclasses as needed."""
return
request
def
prepare_structured_tag
(
self
,
original_tag
:
str
|
None
,
...
...
@@ -298,7 +304,7 @@ class ReasoningParserManager:
if
isinstance
(
name
,
str
):
names
=
[
name
]
elif
is_list_of
(
name
,
str
):
names
=
name
names
=
cast
(
list
[
str
],
name
)
else
:
names
=
[
class_name
]
...
...
vllm/reasoning/gemma4_reasoning_parser.py
View file @
fc67613a
...
...
@@ -52,6 +52,16 @@ class Gemma4ReasoningParser(BaseThinkingReasoningParser):
# skip_special_tokens=True).
self
.
_reasoning_text
:
str
=
""
self
.
_prefix_stripped
:
bool
=
False
self
.
new_turn_token_id
=
self
.
vocab
[
"<|turn>"
]
self
.
tool_call_token_id
=
self
.
vocab
[
"<|tool_call>"
]
self
.
tool_response_token_id
=
self
.
vocab
[
"<|tool_response>"
]
def
adjust_request
(
self
,
request
:
"ChatCompletionRequest | ResponsesRequest"
)
->
"ChatCompletionRequest | ResponsesRequest"
:
"""Disable special-token stripping to preserve boundary tokens."""
request
.
skip_special_tokens
=
False
return
request
@
property
def
start_token
(
self
)
->
str
:
...
...
@@ -63,6 +73,29 @@ class Gemma4ReasoningParser(BaseThinkingReasoningParser):
"""The token that ends reasoning content."""
return
"<channel|>"
def
is_reasoning_end
(
self
,
input_ids
:
Sequence
[
int
])
->
bool
:
start_token_id
=
self
.
start_token_id
end_token_id
=
self
.
end_token_id
new_turn_token_id
=
self
.
new_turn_token_id
tool_call_token_id
=
self
.
tool_call_token_id
tool_response_token_id
=
self
.
tool_response_token_id
# Search from the end of input_ids to find the last match.
for
i
in
range
(
len
(
input_ids
)
-
1
,
-
1
,
-
1
):
if
input_ids
[
i
]
==
start_token_id
:
return
False
if
input_ids
[
i
]
==
tool_call_token_id
:
# We're generating a tool call, so reasoning must be ended.
return
True
if
input_ids
[
i
]
in
(
new_turn_token_id
,
tool_response_token_id
):
# We found a new turn or tool response token so don't consider
# reasoning ended yet, since the model starts new reasoning
# after these tokens.
return
False
if
input_ids
[
i
]
==
end_token_id
:
return
True
return
False
# ------------------------------------------------------------------
# Non-streaming path
# ------------------------------------------------------------------
...
...
@@ -159,11 +192,10 @@ class Gemma4ReasoningParser(BaseThinkingReasoningParser):
result
.
reasoning
=
stripped
return
result
else
:
# This entire delta was prefix — suppress it.
# Don't set _prefix_stripped yet; there may be more
# prefix chars to consume in the next delta.
if
len
(
self
.
_reasoning_text
)
>=
prefix_len
:
self
.
_prefix_stripped
=
True
result
.
reasoning
=
""
return
result
return
None
# Case 2: Accumulated text is a strict prefix of
...
...
vllm/tokenizers/registry.py
View file @
fc67613a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextlib
from
dataclasses
import
dataclass
,
field
from
functools
import
lru_cache
from
pathlib
import
Path
...
...
@@ -10,6 +11,7 @@ from typing_extensions import TypeVar, assert_never
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.config
import
get_config
from
vllm.transformers_utils.gguf_utils
import
(
check_gguf_file
,
get_gguf_file_path_from_hf
,
...
...
@@ -31,6 +33,13 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
# Model types whose hub tokenizer_class is incorrect and should be overridden with
# TokenizersBackend (the generic fast tokenizer). Adding a model type here is always a
# temporary workaround and better long term solutions are:
# - Add model type to MODELS_WITH_INCORRECT_HUB_TOKENIZER_CLASS in transformers (better)
# - Fix tokenizer_class on the hub for the affected models (best)
_MODEL_TYPES_WITH_INCORRECT_TOKENIZER_CLASS
:
set
[
str
]
=
{
"step3_vl"
}
_VLLM_TOKENIZERS
=
{
"deepseek_v32"
:
(
"deepseek_v32"
,
"DeepseekV32Tokenizer"
),
"grok2"
:
(
"grok2"
,
"Grok2Tokenizer"
),
...
...
@@ -202,7 +211,31 @@ def get_tokenizer(
**
kwargs
,
)
if
tokenizer_cls
==
TokenizerLike
:
# Ensure that, if the config were to come from vllm.transformers_utils.config, it is
# registered with AutoConfig before the tokenizer is loaded. This is necessary since
# tokenizer_cls_.from_pretrained will call AutoConfig.from_pretrained internally.
# This may fail for paths that don't have a model config (e.g. LoRA adapters),
# which is fine — those don't need custom config registration.
config
=
None
with
contextlib
.
suppress
(
ValueError
,
OSError
):
config
=
get_config
(
tokenizer_name
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
)
# Some models have an incorrect tokenizer_class on the hub.
# For these model types, bypass AutoTokenizer and use TokenizersBackend directly.
model_type
=
getattr
(
config
,
"model_type"
,
None
)
if
config
else
None
if
model_type
in
_MODEL_TYPES_WITH_INCORRECT_TOKENIZER_CLASS
:
from
transformers.tokenization_utils_tokenizers
import
TokenizersBackend
logger
.
debug
(
"Overriding tokenizer_class to TokenizersBackend for model_type=%r"
,
model_type
,
)
tokenizer_cls_
=
TokenizersBackend
elif
tokenizer_cls
==
TokenizerLike
:
tokenizer_cls_
=
TokenizerRegistry
.
load_tokenizer_cls
(
tokenizer_mode
)
else
:
tokenizer_cls_
=
tokenizer_cls
...
...
vllm/tool_parsers/gemma4_tool_parser.py
View file @
fc67613a
...
...
@@ -66,6 +66,10 @@ def _parse_gemma4_value(value_str: str) -> object:
if
value_str
==
"false"
:
return
False
# Null
if
value_str
.
lower
()
in
(
"null"
,
"none"
,
"nil"
):
return
None
# Number (int or float)
try
:
if
"."
in
value_str
:
...
...
@@ -78,7 +82,7 @@ def _parse_gemma4_value(value_str: str) -> object:
return
value_str
def
_parse_gemma4_args
(
args_str
:
str
)
->
dict
:
def
_parse_gemma4_args
(
args_str
:
str
,
*
,
partial
:
bool
=
False
)
->
dict
:
"""Parse Gemma4's custom key:value format into a Python dict.
Format examples::
...
...
@@ -89,6 +93,12 @@ def _parse_gemma4_args(args_str: str) -> dict:
nested:{inner_key:<|"|>val<|"|>}
items:[<|"|>a<|"|>,<|"|>b<|"|>]
Args:
args_str: The raw Gemma4 argument string.
partial: When True (streaming), bare values at end of string are
omitted because they may be incomplete and type-unstable
(e.g. partial boolean parsed as bare string).
Returns a dict ready for ``json.dumps()``.
"""
if
not
args_str
or
not
args_str
.
strip
():
...
...
@@ -116,14 +126,16 @@ def _parse_gemma4_args(args_str: str) -> dict:
# Parse value
if
i
>=
n
:
result
[
key
]
=
""
if
not
partial
:
result
[
key
]
=
""
break
# Skip whitespace after ':'
while
i
<
n
and
args_str
[
i
]
in
(
" "
,
"
\n
"
,
"
\t
"
):
i
+=
1
if
i
>=
n
:
result
[
key
]
=
""
if
not
partial
:
result
[
key
]
=
""
break
# String value: <|"|>...<|"|>
...
...
@@ -155,7 +167,12 @@ def _parse_gemma4_args(args_str: str) -> dict:
elif
args_str
[
i
]
==
"}"
:
depth
-=
1
i
+=
1
result
[
key
]
=
_parse_gemma4_args
(
args_str
[
obj_start
:
i
-
1
])
if
depth
>
0
:
# Incomplete nested object — use i (not i-1) to avoid
# dropping the last char, and recurse as partial.
result
[
key
]
=
_parse_gemma4_args
(
args_str
[
obj_start
:
i
],
partial
=
True
)
else
:
result
[
key
]
=
_parse_gemma4_args
(
args_str
[
obj_start
:
i
-
1
])
# Array: [...]
elif
args_str
[
i
]
==
"["
:
...
...
@@ -173,20 +190,26 @@ def _parse_gemma4_args(args_str: str) -> dict:
elif
args_str
[
i
]
==
"]"
:
depth
-=
1
i
+=
1
arr_content
=
args_str
[
arr_start
:
i
-
1
]
result
[
key
]
=
_parse_gemma4_array
(
arr_content
)
if
depth
>
0
:
result
[
key
]
=
_parse_gemma4_array
(
args_str
[
arr_start
:
i
],
partial
=
True
)
else
:
result
[
key
]
=
_parse_gemma4_array
(
args_str
[
arr_start
:
i
-
1
])
# Bare value (number, boolean, etc.)
else
:
val_start
=
i
while
i
<
n
and
args_str
[
i
]
not
in
(
","
,
"}"
,
"]"
):
i
+=
1
if
partial
and
i
>=
n
:
# Value may be incomplete (e.g. partial boolean) —
# withhold to avoid type instability during streaming.
break
result
[
key
]
=
_parse_gemma4_value
(
args_str
[
val_start
:
i
])
return
result
def
_parse_gemma4_array
(
arr_str
:
str
)
->
list
:
def
_parse_gemma4_array
(
arr_str
:
str
,
*
,
partial
:
bool
=
False
)
->
list
:
"""Parse a Gemma4 array content string into a Python list."""
items
:
list
=
[]
i
=
0
...
...
@@ -224,7 +247,10 @@ def _parse_gemma4_array(arr_str: str) -> list:
elif
arr_str
[
i
]
==
"}"
:
depth
-=
1
i
+=
1
items
.
append
(
_parse_gemma4_args
(
arr_str
[
obj_start
:
i
-
1
]))
if
depth
>
0
:
items
.
append
(
_parse_gemma4_args
(
arr_str
[
obj_start
:
i
],
partial
=
True
))
else
:
items
.
append
(
_parse_gemma4_args
(
arr_str
[
obj_start
:
i
-
1
]))
# Nested array
elif
arr_str
[
i
]
==
"["
:
...
...
@@ -237,13 +263,18 @@ def _parse_gemma4_array(arr_str: str) -> list:
elif
arr_str
[
i
]
==
"]"
:
depth
-=
1
i
+=
1
items
.
append
(
_parse_gemma4_array
(
arr_str
[
sub_start
:
i
-
1
]))
if
depth
>
0
:
items
.
append
(
_parse_gemma4_array
(
arr_str
[
sub_start
:
i
],
partial
=
True
))
else
:
items
.
append
(
_parse_gemma4_array
(
arr_str
[
sub_start
:
i
-
1
]))
# Bare value
else
:
val_start
=
i
while
i
<
n
and
arr_str
[
i
]
not
in
(
","
,
"]"
):
i
+=
1
if
partial
and
i
>=
n
:
break
items
.
append
(
_parse_gemma4_value
(
arr_str
[
val_start
:
i
]))
return
items
...
...
@@ -436,8 +467,10 @@ class Gemma4ToolParser(ToolParser):
)
->
DeltaMessage
|
None
:
# Buffer delta text to handle multi-token special sequences
delta_text
=
self
.
_buffer_delta_text
(
delta_text
)
# Reconstruct current_text after buffering to stay in sync
current_text
=
previous_text
+
delta_text
# Keep current_text from the upstream stream state. The buffered delta
# is only for emission, and must not be stitched back into the
# accumulated model text or normal content like "<div>" can be
# duplicated into "<<div>" when a tool call just ended.
# If no tool call token seen yet, emit as content
if
self
.
tool_call_start_token
not
in
current_text
:
...
...
@@ -661,7 +694,7 @@ class Gemma4ToolParser(ToolParser):
DeltaMessage with the argument diff, or None if no new content.
"""
try
:
current_args
=
_parse_gemma4_args
(
raw_args_str
)
current_args
=
_parse_gemma4_args
(
raw_args_str
,
partial
=
True
)
except
Exception
:
logger
.
debug
(
"Could not parse partial Gemma4 args yet: %s"
,
...
...
@@ -675,10 +708,11 @@ class Gemma4ToolParser(ToolParser):
current_args_json
=
json
.
dumps
(
current_args
,
ensure_ascii
=
False
)
# Withhold trailing closing characters that may shift as more
# tokens arrive. Strip trailing '}', '"', and ']' sequences
# to get the "safe prefix".
# tokens arrive. Strip trailing '}', '"', ']' and partial
# STRING_DELIM fragments ('<', '|', '\\', '>') to get the
# "safe prefix".
safe_json
=
current_args_json
while
safe_json
and
safe_json
[
-
1
]
in
(
"}"
,
'"'
,
"]"
):
while
safe_json
and
safe_json
[
-
1
]
in
(
"}"
,
'"'
,
"]"
,
"<"
,
"|"
,
"
\\
"
,
">"
):
safe_json
=
safe_json
[:
-
1
]
prev_streamed
=
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment