Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fc7980db
Commit
fc7980db
authored
Feb 05, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.15.1' into v0.15.1-ori
parents
3eab7fef
1892993b
Changes
62
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3033 additions
and
130 deletions
+3033
-130
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+3
-12
vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py
...el_executor/layers/quantization/kernels/scaled_mm/rocm.py
+1
-0
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
..._executor/layers/quantization/utils/flashinfer_fp4_moe.py
+0
-27
vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py
...l_executor/layers/quantization/utils/nvfp4_moe_support.py
+0
-67
vllm/model_executor/layers/utils.py
vllm/model_executor/layers/utils.py
+2
-0
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+10
-4
vllm/model_executor/models/minimax_m2.py
vllm/model_executor/models/minimax_m2.py
+1
-0
vllm/model_executor/models/nemotron_parse.py
vllm/model_executor/models/nemotron_parse.py
+4
-1
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+2
-0
vllm/model_executor/models/step3p5.py
vllm/model_executor/models/step3p5.py
+894
-0
vllm/model_executor/models/step3p5_mtp.py
vllm/model_executor/models/step3p5_mtp.py
+315
-0
vllm/reasoning/__init__.py
vllm/reasoning/__init__.py
+4
-0
vllm/reasoning/step3p5_reasoning_parser.py
vllm/reasoning/step3p5_reasoning_parser.py
+153
-0
vllm/tool_parsers/__init__.py
vllm/tool_parsers/__init__.py
+4
-0
vllm/tool_parsers/step3p5_tool_parser.py
vllm/tool_parsers/step3p5_tool_parser.py
+1511
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+2
-0
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+4
-0
vllm/transformers_utils/configs/step3p5.py
vllm/transformers_utils/configs/step3p5.py
+100
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+0
-12
vllm/v1/core/kv_cache_coordinator.py
vllm/v1/core/kv_cache_coordinator.py
+23
-7
No files found.
vllm/model_executor/layers/quantization/fp8.py
View file @
fc7980db
...
...
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe import (
)
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
,
RoutingMethodType
,
)
from
vllm.model_executor.layers.fused_moe.layer
import
UnquantizedFusedMoEMethod
from
vllm.model_executor.layers.fused_moe.oracle.fp8
import
(
...
...
@@ -964,17 +963,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if
self
.
block_quant
:
import
vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe
# noqa: E501, F401
e_score_correction_bias
=
(
layer
.
e_score_correction_bias
.
to
(
x
.
dtype
)
if
layer
.
e_score_correction_bias
is
not
None
else
None
)
routing_method_type
=
layer
.
routing_method_type
return
torch
.
ops
.
vllm
.
flashinfer_fused_moe_blockscale_fp8
(
routing_logits
=
router_logits
.
to
(
torch
.
float32
)
if
routing_method_type
==
RoutingMethodType
.
DeepSeekV3
else
router_logits
,
routing_bias
=
e_score_correction_bias
,
routing_logits
=
router_logits
,
routing_bias
=
layer
.
e_score_correction_bias
,
x
=
x
,
w13_weight
=
layer
.
w13_weight
,
w13_weight_scale_inv
=
layer
.
w13_weight_scale_inv
,
...
...
@@ -988,7 +979,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_offset
=
layer
.
ep_rank
*
layer
.
local_num_experts
,
local_num_experts
=
layer
.
local_num_experts
,
block_shape
=
self
.
weight_block_size
,
routing_method_type
=
routing_method_type
,
routing_method_type
=
layer
.
routing_method_type
,
routed_scaling
=
layer
.
routed_scaling_factor
,
)
else
:
...
...
vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py
View file @
fc7980db
...
...
@@ -28,6 +28,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl(
A
.
shape
[
0
]
==
1
and
B
.
shape
[
1
]
%
16
==
0
and
((
bias
is
None
)
or
(
bias
.
dtype
==
out_dtype
))
and
A
.
is_contiguous
()
):
output
=
ops
.
wvSplitKQ
(
B
.
t
(),
...
...
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
View file @
fc7980db
...
...
@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING
import
torch
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
...
...
@@ -22,10 +21,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
swizzle_blockscale
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
(
has_flashinfer_cutedsl_grouped_gemm_nt_masked
,
has_flashinfer_cutlass_fused_moe
,
)
if
TYPE_CHECKING
:
from
vllm.model_executor.layers.fused_moe.oracle.nvfp4
import
(
...
...
@@ -36,8 +31,6 @@ logger = init_logger(__name__)
__all__
=
[
"is_flashinfer_fp4_cutlass_moe_available"
,
"is_flashinfer_fp4_cutedsl_moe_available"
,
"reorder_w1w3_to_w3w1"
,
]
...
...
@@ -122,26 +115,6 @@ def is_supported_config_trtllm(
return
True
,
None
def
is_flashinfer_fp4_cutlass_moe_available
()
->
bool
:
"""Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
return
(
envs
.
VLLM_USE_FLASHINFER_MOE_FP4
and
has_flashinfer_cutlass_fused_moe
()
and
current_platform
.
is_cuda
()
and
current_platform
.
has_device_capability
(
100
)
)
def
is_flashinfer_fp4_cutedsl_moe_available
()
->
bool
:
"""Return ``True`` when FlashInfer CUTEDSL NV-FP4 kernels can be used."""
return
(
envs
.
VLLM_USE_FLASHINFER_MOE_FP4
and
has_flashinfer_cutedsl_grouped_gemm_nt_masked
()
and
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability_family
(
100
)
)
def
reorder_w1w3_to_w3w1
(
weight
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
dim
:
int
=
-
2
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py
deleted
100644 → 0
View file @
3eab7fef
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe
import
(
is_flashinfer_fp4_cutedsl_moe_available
,
is_flashinfer_fp4_cutlass_moe_available
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
is_fp4_marlin_supported
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
cutlass_fp4_supported
,
)
__all__
=
[
"detect_nvfp4_moe_support"
,
"NvFp4Support"
]
_logger
=
init_logger
(
__name__
)
@
dataclass
(
frozen
=
True
)
class
NvFp4Support
:
"""Result container for NV-FP4 capability probing."""
cutlass_supported
:
bool
allow_flashinfer
:
bool
use_marlin
:
bool
def
detect_nvfp4_moe_support
(
class_name
:
str
=
""
)
->
NvFp4Support
:
"""Detect platform support for NV-FP4 fused-MoE path"""
cutlass_supported
=
cutlass_fp4_supported
()
allow_flashinfer
=
cutlass_supported
and
(
is_flashinfer_fp4_cutlass_moe_available
()
or
is_flashinfer_fp4_cutedsl_moe_available
()
)
if
allow_flashinfer
:
_logger
.
info_once
(
"Using FlashInfer kernels for %s."
,
class_name
or
"NVFP4 path"
)
else
:
if
envs
.
VLLM_USE_FLASHINFER_MOE_FP4
:
_logger
.
warning_once
(
"FlashInfer kernels unavailable for %s on current platform."
,
class_name
or
"NVFP4 path"
,
)
use_marlin
=
False
if
not
cutlass_supported
:
if
is_fp4_marlin_supported
():
use_marlin
=
True
_logger
.
info_once
(
"Falling back to Marlin FP4 MoE kernel."
)
else
:
raise
ValueError
(
"Current platform does not support NVFP4 quantization. "
"Please use Blackwell GPUs or enable FlashInfer."
)
return
NvFp4Support
(
cutlass_supported
=
cutlass_supported
,
allow_flashinfer
=
allow_flashinfer
,
use_marlin
=
use_marlin
,
)
vllm/model_executor/layers/utils.py
View file @
fc7980db
...
...
@@ -146,6 +146,7 @@ def rocm_unquantized_gemm_impl(
and
n
<=
128
and
k
>
512
and
math
.
ceil
(
k
/
512
)
*
math
.
ceil
(
m
/
16
)
<
get_cu_count
()
and
x
.
is_contiguous
()
)
# k == 2880 and (m == 640 or m == 128))
)
...
...
@@ -165,6 +166,7 @@ def rocm_unquantized_gemm_impl(
and
on_gfx9
()
and
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
and
k
%
8
==
0
and
x
.
is_contiguous
()
)
if
use_skinny
is
not
True
:
...
...
vllm/model_executor/models/adapters.py
View file @
fc7980db
...
...
@@ -466,6 +466,7 @@ def load_weights_using_from_2_way_softmax(
language_model
=
_get_language_model_for_seq_cls
(
model
)
is_vlm
=
language_model
is
not
model
using_vlm_head
=
is_vlm
and
hasattr
(
language_model
,
"score"
)
language_model
.
lm_head
=
ParallelLMHead
(
text_config
.
vocab_size
,
text_config
.
hidden_size
,
quant_config
=
quant_config
...
...
@@ -506,14 +507,16 @@ def load_weights_using_from_2_way_softmax(
torch
.
float32
)
-
lm_head_weight
.
data
[[
false_id
]].
to
(
torch
.
float32
)
score_layer
=
language_model
.
score
if
is_vlm
else
model
.
score
score_layer
=
language_model
.
score
if
using_vlm_head
else
model
.
score
param
=
score_layer
.
weight
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
score_weight
)
del
language_model
.
lm_head
score_weight_name
=
"language_model.score.weight"
if
is_vlm
else
"score.weight"
score_weight_name
=
(
"language_model.score.weight"
if
using_vlm_head
else
"score.weight"
)
loaded_weights
.
add
(
score_weight_name
)
lm_head_name
=
"lm_head.weight"
...
...
@@ -537,6 +540,7 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te
language_model
=
_get_language_model_for_seq_cls
(
model
)
is_vlm
=
language_model
is
not
model
using_vlm_head
=
is_vlm
and
hasattr
(
language_model
,
"score"
)
language_model
.
lm_head
=
ParallelLMHead
(
text_config
.
vocab_size
,
text_config
.
hidden_size
,
quant_config
=
quant_config
...
...
@@ -572,14 +576,16 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te
token_ids
=
[
tokenizer
.
convert_tokens_to_ids
(
t
)
for
t
in
tokens
]
score_weight
=
language_model
.
lm_head
.
weight
.
data
[
token_ids
]
score_layer
=
language_model
.
score
if
is_vlm
else
model
.
score
score_layer
=
language_model
.
score
if
using_vlm_head
else
model
.
score
param
=
score_layer
.
weight
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
score_weight
)
del
language_model
.
lm_head
score_weight_name
=
"language_model.score.weight"
if
is_vlm
else
"score.weight"
score_weight_name
=
(
"language_model.score.weight"
if
using_vlm_head
else
"score.weight"
)
loaded_weights
.
add
(
score_weight_name
)
lm_head_name
=
"lm_head.weight"
...
...
vllm/model_executor/models/minimax_m2.py
View file @
fc7980db
...
...
@@ -107,6 +107,7 @@ class MiniMaxM2MoE(nn.Module):
renormalize
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
router_logits_dtype
=
torch
.
float32
,
)
self
.
gate
=
ReplicatedLinear
(
...
...
vllm/model_executor/models/nemotron_parse.py
View file @
fc7980db
...
...
@@ -11,7 +11,6 @@ import math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Annotated
,
Literal
import
cv2
import
numpy
as
np
import
torch
import
torch.nn
as
nn
...
...
@@ -416,6 +415,8 @@ class NemotronParseImageProcessor:
else
:
self
.
target_height
=
self
.
target_width
=
int
(
self
.
final_size
)
import
cv2
self
.
transform
=
A
.
Compose
(
[
A
.
PadIfNeeded
(
...
...
@@ -457,6 +458,8 @@ class NemotronParseImageProcessor:
new_height
=
int
(
new_width
/
aspect_ratio
)
# Use cv2.INTER_LINEAR like the original
import
cv2
return
cv2
.
resize
(
image
,
(
new_width
,
new_height
),
interpolation
=
cv2
.
INTER_LINEAR
)
...
...
vllm/model_executor/models/registry.py
View file @
fc7980db
...
...
@@ -188,6 +188,7 @@ _TEXT_GENERATION_MODELS = {
"SeedOssForCausalLM"
:
(
"seed_oss"
,
"SeedOssForCausalLM"
),
"Step1ForCausalLM"
:
(
"step1"
,
"Step1ForCausalLM"
),
"Step3TextForCausalLM"
:
(
"step3_text"
,
"Step3TextForCausalLM"
),
"Step3p5ForCausalLM"
:
(
"step3p5"
,
"Step3p5ForCausalLM"
),
"StableLMEpochForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
"StableLmForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
"Starcoder2ForCausalLM"
:
(
"starcoder2"
,
"Starcoder2ForCausalLM"
),
...
...
@@ -478,6 +479,7 @@ _SPECULATIVE_DECODING_MODELS = {
"MedusaModel"
:
(
"medusa"
,
"Medusa"
),
"OpenPanguMTPModel"
:
(
"openpangu_mtp"
,
"OpenPanguMTP"
),
"Qwen3NextMTP"
:
(
"qwen3_next_mtp"
,
"Qwen3NextMTP"
),
"Step3p5MTP"
:
(
"step3p5_mtp"
,
"Step3p5MTP"
),
# Temporarily disabled.
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
...
...
vllm/model_executor/models/step3p5.py
0 → 100644
View file @
fc7980db
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Jurassic model."""
from
collections.abc
import
Iterable
from
typing
import
Any
import
torch
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
from
vllm.attention.layer
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed
import
(
get_dp_group
,
get_ep_group
,
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tp_group
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
,
SwigluStepAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe.shared_fused_moe
import
SharedFusedMoE
from
vllm.model_executor.layers.layernorm
import
GemmaRMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.sequence
import
IntermediateTensors
from
vllm.v1.attention.backend
import
AttentionType
from
.interfaces
import
MixtureOfExperts
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
WeightsMapper
,
extract_layer_index
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
,
)
logger
=
init_logger
(
__name__
)
class
FP32ReplicatedLinear
(
ReplicatedLinear
):
"""
Use FP32 for higher precision.
"""
def
forward
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
Parameter
|
None
]:
assert
self
.
params_dtype
==
torch
.
float32
return
super
().
forward
(
x
.
to
(
torch
.
float32
))
class
Step3p5MLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
ModelConfig
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
QuantizationConfig
|
None
=
None
,
reduce_results
:
bool
=
True
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
self
.
prefix
=
prefix
self
.
hidden_size
=
hidden_size
self
.
limit
=
None
layer_idx
=
extract_layer_index
(
prefix
)
if
(
config
.
swiglu_limits_shared
and
config
.
swiglu_limits_shared
[
layer_idx
]
is
not
None
and
config
.
swiglu_limits_shared
[
layer_idx
]
!=
0
):
self
.
limit
=
config
.
swiglu_limits_shared
[
layer_idx
]
self
.
act_fn
=
SwigluStepAndMul
(
limit
=
self
.
limit
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
gate_up
,
_
=
self
.
gate_up_proj
(
hidden_states
)
intermediate_act
=
self
.
act_fn
(
gate_up
)
output
,
_
=
self
.
down_proj
(
intermediate_act
)
return
output
class
Step3p5Attention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
max_position
:
int
=
4096
*
32
,
head_dim
:
int
|
None
=
None
,
rms_norm_eps
:
float
=
1e-06
,
qkv_bias
:
bool
=
False
,
rope_theta
:
float
|
list
[
float
]
|
None
=
10000
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
rope_scaling
:
dict
[
str
,
Any
]
|
None
=
None
,
prefix
:
str
=
""
,
attn_type
:
str
=
AttentionType
.
DECODER
,
# Step3p5 specific args
sliding_window
:
int
|
None
=
None
,
use_head_wise_attn_gate
:
bool
=
False
,
layer_types
:
list
=
None
,
use_rope_layers
:
list
=
None
,
yarn_only_types
:
list
=
None
,
swa_num_attention_heads
:
int
|
None
=
None
,
partial_rotary_factor
:
float
=
1.0
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
total_num_heads
=
num_heads
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
layer_idx
=
extract_layer_index
(
prefix
)
if
layer_types
:
enable_sliding_window
=
layer_types
[
self
.
layer_idx
]
==
"sliding_attention"
else
:
enable_sliding_window
=
self
.
layer_idx
%
2
==
0
if
yarn_only_types
and
layer_types
[
self
.
layer_idx
]
not
in
yarn_only_types
:
rope_scaling
=
None
if
sliding_window
is
not
None
and
enable_sliding_window
:
sliding_window
=
sliding_window
if
swa_num_attention_heads
is
not
None
:
num_heads
=
swa_num_attention_heads
self
.
total_num_heads
=
swa_num_attention_heads
else
:
sliding_window
=
None
if
isinstance
(
rope_theta
,
list
):
rope_theta
=
rope_theta
[
self
.
layer_idx
]
self
.
rank
=
get_tensor_model_parallel_rank
()
self
.
partial_rotary_factor
=
partial_rotary_factor
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
head_dim
or
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
qkv_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
if
rope_scaling
is
not
None
and
not
isinstance
(
rope_scaling
,
dict
):
raise
ValueError
(
"rope_scaling must be a dict for Step3p5Attention."
)
rope_parameters
:
dict
[
str
,
Any
]
=
(
dict
(
rope_scaling
)
if
rope_scaling
is
not
None
else
{}
)
rope_parameters
.
setdefault
(
"rope_type"
,
"default"
)
rope_parameters
[
"rope_theta"
]
=
self
.
rope_theta
rope_parameters
[
"partial_rotary_factor"
]
=
partial_rotary_factor
self
.
rotary_emb
=
get_rope
(
head_size
=
self
.
head_dim
,
max_position
=
max_position
,
rope_parameters
=
rope_parameters
,
)
self
.
q_norm
=
GemmaRMSNorm
(
self
.
head_dim
,
rms_norm_eps
)
self
.
k_norm
=
GemmaRMSNorm
(
self
.
head_dim
,
rms_norm_eps
)
self
.
use_head_wise_attn_gate
=
use_head_wise_attn_gate
if
use_head_wise_attn_gate
:
self
.
g_proj
=
ColumnParallelLinear
(
hidden_size
,
self
.
total_num_heads
,
bias
=
False
,
prefix
=
f
"
{
prefix
}
.g_proj"
,
)
self
.
use_rope
=
True
if
use_rope_layers
:
self
.
use_rope
=
use_rope_layers
[
self
.
layer_idx
]
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
per_layer_sliding_window
=
sliding_window
,
attn_type
=
attn_type
,
)
self
.
max_position_embeddings
=
max_position
assert
self
.
partial_rotary_factor
==
1
or
self
.
partial_rotary_factor
==
0.5
self
.
rotary_dim
=
(
self
.
head_dim
if
self
.
partial_rotary_factor
==
1
else
self
.
head_dim
//
2
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
# Add qk-norm inline similar to Qwen3 MOE attention
q_by_head
=
q
.
view
(
*
q
.
shape
[:
-
1
],
q
.
shape
[
-
1
]
//
self
.
head_dim
,
self
.
head_dim
)
q_by_head
=
self
.
q_norm
(
q_by_head
.
contiguous
())
q
=
q_by_head
.
view
(
q
.
shape
)
k_by_head
=
k
.
view
(
*
k
.
shape
[:
-
1
],
k
.
shape
[
-
1
]
//
self
.
head_dim
,
self
.
head_dim
)
k_by_head
=
self
.
k_norm
(
k_by_head
.
contiguous
())
k
=
k_by_head
.
view
(
k
.
shape
)
if
self
.
use_rope
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
if
self
.
use_head_wise_attn_gate
:
extra_dims
,
_
=
self
.
g_proj
(
hidden_states
)
output
=
(
attn_output
.
view
(
*
attn_output
.
shape
[:
-
1
],
self
.
num_heads
,
self
.
head_dim
)
*
extra_dims
.
unsqueeze
(
-
1
).
sigmoid
()
)
attn_output
=
output
.
view
(
*
attn_output
.
shape
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
FusedMoEBlock
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
layer_idx
=
extract_layer_index
(
prefix
)
self
.
ep_size
=
get_ep_group
().
device_group
.
size
()
self
.
ep_rank
=
get_ep_group
().
device_group
.
rank
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
parallel_config
=
vllm_config
.
parallel_config
self
.
hidden_size
=
config
.
hidden_size
self
.
enable_eplb
=
parallel_config
.
enable_eplb
self
.
n_routed_experts
=
config
.
moe_num_experts
self
.
n_logical_experts
=
self
.
n_routed_experts
self
.
n_redundant_experts
=
parallel_config
.
eplb_config
.
num_redundant_experts
self
.
n_physical_experts
=
self
.
n_logical_experts
+
self
.
n_redundant_experts
self
.
n_local_physical_experts
=
self
.
n_physical_experts
//
self
.
ep_size
self
.
physical_expert_start
=
self
.
ep_rank
*
self
.
n_local_physical_experts
self
.
physical_expert_end
=
(
self
.
physical_expert_start
+
self
.
n_local_physical_experts
)
if
self
.
tp_size
>
config
.
moe_num_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
config
.
moe_num_experts
}
."
)
self
.
gate
=
FP32ReplicatedLinear
(
config
.
hidden_size
,
config
.
moe_num_experts
,
bias
=
False
,
quant_config
=
None
,
params_dtype
=
torch
.
float32
,
# Use FP32 for higher precision.
prefix
=
f
"
{
prefix
}
.gate"
,
)
self
.
use_moe_router_bias
=
config
.
use_moe_router_bias
assert
self
.
use_moe_router_bias
,
"Only support use_moe_router_bias is true."
self
.
routed_scaling_factor
=
config
.
moe_router_scaling_factor
self
.
router_bias
=
nn
.
Parameter
(
torch
.
zeros
(
config
.
moe_num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
self
.
need_fp32_gate
=
config
.
need_fp32_gate
assert
self
.
need_fp32_gate
,
(
"Router logits must use FP32 precision for numerical stability."
)
activation
=
"silu"
swiglu_limits
=
config
.
swiglu_limits
or
[]
swiglu_limit
=
(
swiglu_limits
[
self
.
layer_idx
]
if
self
.
layer_idx
<
len
(
swiglu_limits
)
else
None
)
if
swiglu_limit
not
in
(
None
,
0
):
swiglu_limit
=
float
(
swiglu_limit
)
assert
swiglu_limit
==
7.0
,
(
"Swiglu limit in fused moe block only suport 7.0 now."
)
activation
=
"swiglustep"
logger
.
debug
(
"step3p5 layer_idx: %s, activation: %s, limit: %s"
,
self
.
layer_idx
,
activation
,
swiglu_limit
,
)
self
.
share_expert
=
Step3p5MLP
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
share_expert_dim
,
hidden_act
=
"silu"
,
reduce_results
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.share_expert"
,
)
self
.
experts
=
SharedFusedMoE
(
shared_experts
=
self
.
share_expert
,
gate
=
self
.
gate
,
num_experts
=
config
.
moe_num_experts
,
top_k
=
config
.
moe_top_k
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_expert_weight
,
quant_config
=
quant_config
,
activation
=
activation
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
getattr
(
config
,
"moe_router_activation"
,
"sigmoid"
),
e_score_correction_bias
=
self
.
router_bias
,
routed_scaling_factor
=
config
.
moe_router_scaling_factor
,
enable_eplb
=
self
.
enable_eplb
,
num_redundant_experts
=
self
.
n_redundant_experts
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
experts
.
is_internal_router
:
# In this case, the gate/router runs inside the FusedMoE class
fused_moe_out
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
hidden_states
)
else
:
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
fused_moe_out
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
shared_output
,
final_hidden_states
=
fused_moe_out
if
self
.
share_expert
is
None
:
assert
shared_output
is
None
if
self
.
share_expert
is
not
None
:
assert
shared_output
is
not
None
final_hidden_states
+=
shared_output
if
self
.
tp_size
>
1
:
final_hidden_states
=
self
.
experts
.
maybe_all_reduce_tensor_model_parallel
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
class
Step3p5DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
self
.
hidden_size
=
config
.
hidden_size
layer_idx
=
extract_layer_index
(
prefix
)
self
.
layer_idx
=
layer_idx
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
if
cache_config
is
not
None
:
cache_config
.
sliding_window
=
None
if
config
.
att_impl_type
==
"GQA"
:
num_attention_heads
=
None
num_attention_groups
=
None
head_dim
=
None
if
(
getattr
(
config
,
"attention_other_setting"
,
None
)
and
getattr
(
config
,
"layer_types"
,
[])
and
config
.
layer_types
[
layer_idx
]
==
config
.
attention_other_setting
[
"attention_type"
]
):
num_attention_heads
=
config
.
attention_other_setting
[
"num_attention_heads"
]
num_attention_groups
=
config
.
attention_other_setting
[
"num_attention_groups"
]
head_dim
=
config
.
attention_other_setting
[
"head_dim"
]
partial_rotary_factors
=
getattr
(
config
,
"partial_rotary_factors"
,
[])
self
.
self_attn
=
Step3p5Attention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
num_attention_heads
if
num_attention_heads
else
config
.
num_attention_heads
,
max_position
=
config
.
max_position_embeddings
,
num_kv_heads
=
num_attention_groups
if
num_attention_groups
else
config
.
num_attention_groups
,
rope_theta
=
config
.
rope_theta
,
rms_norm_eps
=
config
.
rms_norm_eps
,
qkv_bias
=
getattr
(
config
,
"attention_bias"
,
False
),
head_dim
=
head_dim
if
head_dim
else
getattr
(
config
,
"head_dim"
,
None
),
cache_config
=
cache_config
,
quant_config
=
quant_config
,
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
),
sliding_window
=
getattr
(
config
,
"sliding_window"
,
None
),
use_head_wise_attn_gate
=
getattr
(
config
,
"use_head_wise_attn_gate"
,
False
),
layer_types
=
getattr
(
config
,
"layer_types"
,
[]),
use_rope_layers
=
getattr
(
config
,
"use_rope_layers"
,
[]),
yarn_only_types
=
getattr
(
config
,
"yarn_only_types"
,
[]),
partial_rotary_factor
=
partial_rotary_factors
[
layer_idx
]
if
partial_rotary_factors
else
1.0
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
else
:
raise
ValueError
(
f
"Unsupported attention implementation:
{
config
.
att_impl_type
}
"
)
self
.
use_moe
=
False
self
.
tp_group
=
get_tp_group
()
self
.
use_fused_all_reduce
=
(
get_tensor_model_parallel_world_size
()
>
1
and
get_dp_group
().
world_size
==
1
)
if
self
.
use_fused_all_reduce
:
logger
.
warning_once
(
"Enable custom fused all reduce..."
)
else
:
logger
.
warning_once
(
"Disable custom fused all reduce..."
)
moe_layers_enum
=
getattr
(
config
,
"moe_layers_enum"
,
None
)
if
moe_layers_enum
is
not
None
:
moe_layers_idx
=
[
int
(
i
)
for
i
in
moe_layers_enum
.
strip
().
split
(
","
)]
else
:
moe_layers_idx
=
[
i
for
i
in
range
(
1
,
config
.
num_hidden_layers
)]
if
layer_idx
in
moe_layers_idx
:
self
.
moe
=
FusedMoEBlock
(
vllm_config
,
prefix
=
f
"
{
prefix
}
.moe"
,
)
self
.
use_moe
=
True
else
:
self
.
mlp
=
Step3p5MLP
(
config
=
config
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
"silu"
,
quant_config
=
quant_config
,
reduce_results
=
True
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
config
.
rms_norm_eps
)
self
.
prefix
=
prefix
def
add_and_maybe_inplace_all_reduce
(
self
,
in1
:
torch
.
Tensor
,
in2
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
not
self
.
use_fused_all_reduce
:
return
in1
+
in2
return
self
.
tp_group
.
all_reduce
(
in1
+
in2
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
)
hidden_states
+=
residual
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
if
self
.
use_moe
:
ffn_output
=
self
.
moe
(
hidden_states
)
else
:
ffn_output
=
self
.
mlp
(
hidden_states
)
hidden_states
=
ffn_output
+
residual
return
hidden_states
@
support_torch_compile
class
Step3p5Model
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
self
.
vllm_config
=
vllm_config
config
=
vllm_config
.
model_config
.
hf_config
self
.
vocab_size
=
config
.
vocab_size
self
.
config
=
config
self
.
moe_num_experts
=
config
.
moe_num_experts
if
get_pp_group
().
is_first_rank
or
(
config
.
tie_word_embeddings
and
get_pp_group
().
is_last_rank
):
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
Step3p5DecoderLayer
(
vllm_config
,
prefix
=
prefix
,
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
GemmaRMSNorm
(
config
.
hidden_size
,
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
],
config
.
hidden_size
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
embed_input_ids
(
input_ids
)
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
(
{
"hidden_states"
:
hidden_states
,
}
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
config
=
self
.
config
assert
config
.
num_attention_groups
>
1
,
"Only support GQA"
qkv_params_mapping
=
[]
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
expert_params_mapping
=
[
(
".moe.experts.w13_weight"
,
".moe.gate_proj.weight"
,
"w1"
),
(
".moe.experts.w13_weight"
,
".moe.up_proj.weight"
,
"w3"
),
(
".moe.experts.w2_weight"
,
".moe.down_proj.weight"
,
"w2"
),
]
disable_moe_stacked_params
=
[
data
[
1
]
for
data
in
expert_params_mapping
]
for
name
,
loaded_weight
in
weights
:
if
name
.
startswith
(
"model."
):
local_name
=
name
[
len
(
"model."
)
:]
full_name
=
name
else
:
local_name
=
name
full_name
=
f
"model.
{
name
}
"
if
name
else
"model"
spec_layer
=
get_spec_layer_idx_from_weight_name
(
config
,
full_name
)
if
spec_layer
is
not
None
:
continue
# skip spec decode layers for main model
# Skip any layers beyond the main model's depth (e.g., MTP layers)
if
full_name
.
startswith
(
"model.layers."
):
parts
=
full_name
.
split
(
"."
)
if
len
(
parts
)
>
2
and
parts
[
2
].
isdigit
():
layer_idx
=
int
(
parts
[
2
])
if
layer_idx
>=
config
.
num_hidden_layers
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
local_name
:
continue
if
any
(
disable_moe_stacked_param
in
local_name
for
disable_moe_stacked_param
in
disable_moe_stacked_params
):
continue
replaced_name
=
local_name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
replaced_name
,
self
):
continue
if
replaced_name
not
in
params_dict
:
continue
param
=
params_dict
[
replaced_name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
replaced_name
)
break
else
:
for
param_name
,
weight_name
,
shard_id
in
expert_params_mapping
:
if
weight_name
not
in
local_name
:
continue
replaced_name
=
local_name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
replaced_name
,
self
):
continue
if
(
replaced_name
.
endswith
(
".bias"
)
or
replaced_name
.
endswith
(
"_bias"
)
)
and
replaced_name
not
in
params_dict
:
continue
if
replaced_name
not
in
params_dict
:
continue
param
=
params_dict
[
replaced_name
]
weight_loader
=
param
.
weight_loader
moe_expert_num
=
self
.
moe_num_experts
assert
loaded_weight
.
shape
[
0
]
==
moe_expert_num
for
expert_id
in
range
(
moe_expert_num
):
loaded_weight_expert
=
loaded_weight
[
expert_id
]
weight_loader
(
param
,
loaded_weight_expert
,
replaced_name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
loaded_params
.
add
(
replaced_name
)
break
else
:
for
(
param_name
,
weight_name
,
start_idx
,
end_idx
,
)
in
qkv_params_mapping
:
if
weight_name
not
in
local_name
:
continue
replaced_name
=
local_name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
replaced_name
,
self
):
continue
if
replaced_name
not
in
params_dict
:
continue
param
=
params_dict
[
replaced_name
]
dim
=
param
.
shape
[
param
.
output_dim
]
begin_idx
=
int
(
start_idx
*
dim
)
end_idx
=
int
(
end_idx
*
dim
)
param_slice
=
param
.
narrow
(
param
.
output_dim
,
begin_idx
,
end_idx
-
begin_idx
)
param_slice
.
copy_
(
loaded_weight
)
loaded_params
.
add
(
replaced_name
)
break
else
:
if
is_pp_missing_parameter
(
local_name
,
self
):
continue
if
"expert_bias"
in
local_name
:
logger
.
warning_once
(
"ignore expert_bias"
)
continue
if
local_name
not
in
params_dict
:
continue
param
=
params_dict
[
local_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
local_name
)
return
loaded_params
class
Step3p5ForCausalLM
(
nn
.
Module
,
SupportsPP
,
MixtureOfExperts
):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
".share_expert."
:
".moe.share_expert."
}
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
vllm_config
=
vllm_config
self
.
model
=
Step3p5Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
)
self
.
moe_layers
:
list
[
FusedMoEBlock
]
=
[]
for
layer
in
self
.
model
.
layers
:
if
isinstance
(
layer
,
PPMissingLayer
):
continue
assert
isinstance
(
layer
,
Step3p5DecoderLayer
)
if
hasattr
(
layer
,
"moe"
)
and
isinstance
(
layer
.
moe
,
FusedMoEBlock
):
self
.
moe_layers
.
append
(
layer
.
moe
)
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
# Set MoE hyperparameters
self
.
expert_weights
=
[]
assert
len
(
self
.
moe_layers
)
>
0
,
"No MoE layers found in the model."
example_layer
=
self
.
moe_layers
[
0
]
self
.
num_moe_layers
=
len
(
self
.
moe_layers
)
self
.
num_expert_groups
=
1
self
.
num_shared_experts
=
0
self
.
num_logical_experts
=
example_layer
.
n_logical_experts
self
.
num_physical_experts
=
example_layer
.
n_physical_experts
self
.
num_local_physical_experts
=
example_layer
.
n_local_physical_experts
self
.
num_routed_experts
=
example_layer
.
n_routed_experts
self
.
num_redundant_experts
=
example_layer
.
n_redundant_experts
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
):
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
.
norm
(
hidden_states
)
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
return
logits
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
embed_tokens
(
input_ids
)
def
set_eplb_state
(
self
,
expert_load_view
:
torch
.
Tensor
,
logical_to_physical_map
:
torch
.
Tensor
,
logical_replica_count
:
torch
.
Tensor
,
)
->
None
:
for
layer_idx
,
layer
in
enumerate
(
self
.
moe_layers
):
experts
=
layer
.
experts
assert
isinstance
(
experts
,
FusedMoE
)
# Register the expert weights.
self
.
expert_weights
.
append
(
experts
.
get_expert_weights
())
experts
.
set_eplb_state
(
moe_layer_idx
=
layer_idx
,
expert_load_view
=
expert_load_view
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
,
)
def
update_physical_experts_metadata
(
self
,
num_physical_experts
:
int
,
num_local_physical_experts
:
int
,
)
->
None
:
assert
self
.
num_local_physical_experts
==
num_local_physical_experts
self
.
num_physical_experts
=
num_physical_experts
self
.
num_local_physical_experts
=
num_local_physical_experts
self
.
num_redundant_experts
=
num_physical_experts
-
self
.
num_logical_experts
for
layer
in
self
.
moe_layers
:
assert
isinstance
(
layer
,
FusedMoEBlock
)
layer
.
n_local_physical_experts
=
num_local_physical_experts
layer
.
n_physical_experts
=
num_physical_experts
layer
.
n_redundant_experts
=
self
.
num_redundant_experts
layer
.
experts
.
update_expert_map
()
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
def
get_spec_layer_idx_from_weight_name
(
config
:
ModelConfig
,
weight_name
:
str
)
->
int
|
None
:
if
hasattr
(
config
,
"num_nextn_predict_layers"
)
and
(
config
.
num_nextn_predict_layers
>
0
):
layer_idx
=
config
.
num_hidden_layers
for
i
in
range
(
config
.
num_nextn_predict_layers
):
if
weight_name
.
startswith
(
f
"layers.
{
layer_idx
+
i
}
."
# Step3p5Model
)
or
weight_name
.
startswith
(
f
"model.layers.
{
layer_idx
+
i
}
."
):
# Step3p5MTP
return
layer_idx
+
i
return
None
vllm/model_executor/models/step3p5_mtp.py
0 → 100644
View file @
fc7980db
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
import
torch
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
GemmaRMSNorm
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.sequence
import
IntermediateTensors
from
.step3p5
import
Step3p5DecoderLayer
,
get_spec_layer_idx_from_weight_name
from
.utils
import
maybe_prefix
logger
=
init_logger
(
__name__
)
class
SharedHead
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
QuantizationConfig
|
None
=
None
,
)
->
None
:
super
().
__init__
()
self
.
norm
=
GemmaRMSNorm
(
config
.
hidden_size
,
config
.
rms_norm_eps
)
self
.
head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
norm
(
hidden_states
)
class
Step3p5AMultiTokenPredictorLayer
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
,
)
->
None
:
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
enorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
config
.
rms_norm_eps
)
self
.
hnorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
config
.
rms_norm_eps
)
self
.
eh_proj
=
nn
.
Linear
(
config
.
hidden_size
*
2
,
config
.
hidden_size
,
bias
=
False
)
self
.
shared_head
=
SharedHead
(
config
=
config
,
quant_config
=
quant_config
)
self
.
mtp_block
=
Step3p5DecoderLayer
(
vllm_config
,
prefix
=
f
"
{
prefix
}
.mtp_block"
,
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
previous_hidden_states
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
spec_step_index
:
int
=
0
,
)
->
torch
.
Tensor
:
assert
inputs_embeds
is
not
None
inputs_embeds
=
self
.
enorm
(
inputs_embeds
)
previous_hidden_states
=
self
.
hnorm
(
previous_hidden_states
)
hidden_states
=
self
.
eh_proj
(
torch
.
cat
([
inputs_embeds
,
previous_hidden_states
],
dim
=-
1
)
)
hidden_states
=
self
.
mtp_block
(
positions
=
positions
,
hidden_states
=
hidden_states
)
return
hidden_states
class
Step3p5AMultiTokenPredictor
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
mtp_start_layer_idx
=
config
.
num_hidden_layers
self
.
num_mtp_layers
=
config
.
num_nextn_predict_layers
# to map the exact layer index from weights
self
.
layers
=
torch
.
nn
.
ModuleDict
(
{
str
(
idx
):
Step3p5AMultiTokenPredictorLayer
(
vllm_config
,
f
"
{
prefix
}
.layers.
{
idx
}
"
,
)
for
idx
in
range
(
self
.
mtp_start_layer_idx
,
self
.
mtp_start_layer_idx
+
self
.
num_mtp_layers
,
)
}
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
previous_hidden_states
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
spec_step_idx
:
int
=
0
,
)
->
torch
.
Tensor
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
current_step_idx
=
spec_step_idx
%
self
.
num_mtp_layers
return
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)](
input_ids
,
positions
,
previous_hidden_states
,
inputs_embeds
,
current_step_idx
,
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
spec_step_idx
:
int
=
0
,
)
->
torch
.
Tensor
:
current_step_idx
=
spec_step_idx
%
self
.
num_mtp_layers
mtp_layer
=
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)]
logits
=
self
.
logits_processor
(
mtp_layer
.
shared_head
.
head
,
mtp_layer
.
shared_head
(
hidden_states
)
)
return
logits
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
class
Step3p5MTP
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
vllm_config
=
vllm_config
self
.
model
=
Step3p5AMultiTokenPredictor
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
embed_input_ids
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
spec_step_idx
:
int
=
0
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
hidden_states
,
inputs_embeds
,
spec_step_idx
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
spec_step_idx
:
int
=
0
,
)
->
torch
.
Tensor
|
None
:
return
self
.
model
.
compute_logits
(
hidden_states
,
spec_step_idx
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
expert_params_mapping
=
[
(
".moe.experts.w13_weight"
,
".moe.gate_proj.weight"
,
"w1"
),
(
".moe.experts.w13_weight"
,
".moe.up_proj.weight"
,
"w3"
),
(
".moe.experts.w2_weight"
,
".moe.down_proj.weight"
,
"w2"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
spec_layer
=
get_spec_layer_idx_from_weight_name
(
self
.
config
,
name
)
if
"embed_tokens"
not
in
name
and
spec_layer
is
None
:
continue
name
=
self
.
_rewrite_spec_layer_name
(
spec_layer
,
name
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
(
"mlp.experts."
in
name
)
and
name
not
in
params_dict
:
continue
if
"experts"
in
name
or
"moe"
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
(
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
)
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
for
expert_id
in
range
(
loaded_weight
.
shape
[
0
]):
loaded_weight_expert
=
loaded_weight
[
expert_id
]
weight_loader
(
param
,
loaded_weight_expert
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
loaded_params
.
add
(
name
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
(
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
or
"tok_embeddings"
in
name
):
continue
if
spec_layer
is
not
None
and
".transformer."
in
name
:
name
=
name
.
replace
(
".transformer."
,
"."
)
if
"shared_head"
in
name
:
name
=
name
.
replace
(
"shared_head.output"
,
"shared_head.head"
)
if
"embed_tokens"
in
name
:
assert
(
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
)
and
self
.
config
.
num_nextn_predict_layers
>
0
)
name
=
"model.embed_tokens.weight"
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
params_need_to_load
=
set
(
params_dict
.
keys
())
# Some KV cache scales are optional: checkpoints may omit them and vLLM
# will fall back to default scales during initialization.
optional_params
=
{
name
for
name
,
param
in
params_dict
.
items
()
if
name
.
endswith
((
".k_scale"
,
".v_scale"
,
".q_scale"
,
".prob_scale"
))
and
getattr
(
param
,
"numel"
,
lambda
:
0
)()
==
1
and
getattr
(
param
,
"requires_grad"
,
False
)
is
False
}
params_need_to_load
-=
optional_params
if
params_need_to_load
!=
loaded_params
:
missing_params
=
list
(
params_need_to_load
-
loaded_params
)
param_name_example
=
missing_params
[
0
]
raise
RuntimeError
(
"Some parameters like "
f
"
{
param_name_example
}
are not in the checkpoint and will falsely "
"use random initialization"
)
return
loaded_params
def
_rewrite_spec_layer_name
(
self
,
spec_layer
:
int
,
name
:
str
)
->
str
:
"""
Rewrite the weight name to match the format of the original model.
Add .mtp_block for modules in transformer layer block for spec layer
"""
spec_layer_weight_names
=
[
"embed_tokens"
,
"enorm"
,
"hnorm"
,
"eh_proj"
,
"shared_head"
,
]
spec_layer_weight
=
False
for
weight_name
in
spec_layer_weight_names
:
if
weight_name
in
name
:
spec_layer_weight
=
True
break
if
not
spec_layer_weight
:
# treat rest weights as weights for transformer layer block
name
=
name
.
replace
(
f
"model.layers.
{
spec_layer
}
."
,
f
"model.layers.
{
spec_layer
}
.mtp_block."
)
return
name
vllm/reasoning/__init__.py
View file @
fc7980db
...
...
@@ -84,6 +84,10 @@ _REASONING_PARSERS_TO_REGISTER = {
"step3_reasoning_parser"
,
"Step3ReasoningParser"
,
),
"step3p5"
:
(
"step3p5_reasoning_parser"
,
"Step3p5ReasoningParser"
,
),
}
...
...
vllm/reasoning/step3p5_reasoning_parser.py
0 → 100644
View file @
fc7980db
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Sequence
from
vllm.entrypoints.openai.chat_completion.protocol
import
(
ChatCompletionRequest
,
)
from
vllm.entrypoints.openai.engine.protocol
import
DeltaMessage
from
vllm.entrypoints.openai.responses.protocol
import
(
ResponsesRequest
,
)
from
vllm.reasoning.basic_parsers
import
BaseThinkingReasoningParser
from
vllm.tokenizers
import
TokenizerLike
class
Step3p5ReasoningParser
(
BaseThinkingReasoningParser
):
"""
Reasoning parser for Step3p5 model.
Step3p5 uses the <think>...</think> format, but it tends to emit an extra
newline immediately before and/or after the </think> token. This parser trims:
- the newline right before </think>
- the newline right after </think>
"""
@
property
def
start_token
(
self
)
->
str
:
return
"<think>"
@
property
def
end_token
(
self
)
->
str
:
return
"</think>"
def
__init__
(
self
,
tokenizer
:
TokenizerLike
,
*
args
,
**
kwargs
):
super
().
__init__
(
tokenizer
,
*
args
,
**
kwargs
)
# Used to hold a trailing "\n" from reasoning content so we can decide
# whether it is immediately before </think>.
self
.
_pending_reasoning_newline
=
False
# Used to delay the reasoning end detection.
# This is necessary to remove the newline appears immediately after </think>,
# which may cause the end detection to be delayed by one round.
self
.
end_offset
=
1
def
is_reasoning_end
(
self
,
input_ids
:
Sequence
[
int
])
->
bool
:
if
self
.
end_token_id
in
input_ids
and
self
.
end_offset
>
0
:
self
.
end_offset
-=
1
return
False
return
self
.
end_offset
<
1
def
is_reasoning_end_streaming
(
self
,
input_ids
:
Sequence
[
int
],
delta_ids
:
Sequence
[
int
]
)
->
bool
:
if
self
.
end_token_id
in
input_ids
and
self
.
end_offset
>
0
:
self
.
end_offset
-=
1
return
False
return
self
.
end_offset
<
1
def
extract_reasoning
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
|
ResponsesRequest
,
)
->
tuple
[
str
|
None
,
str
|
None
]:
reasoning
,
content
=
super
().
extract_reasoning
(
model_output
,
request
)
if
reasoning
is
not
None
:
reasoning
=
reasoning
.
removesuffix
(
"
\n
"
)
if
content
is
not
None
:
content
=
content
.
removeprefix
(
"
\n
"
)
return
reasoning
or
None
,
content
or
None
def
extract_reasoning_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
)
->
DeltaMessage
|
None
:
# Drop the immediate newline that models often emit after </think>.
if
previous_text
.
endswith
(
self
.
end_token
)
and
delta_text
:
if
delta_text
==
"
\n
"
:
return
None
elif
delta_text
.
startswith
(
"
\n
"
):
remaining
=
delta_text
.
removeprefix
(
"
\n
"
)
return
DeltaMessage
(
content
=
remaining
)
if
remaining
else
None
ret
=
super
().
extract_reasoning_streaming
(
previous_text
,
current_text
,
delta_text
,
previous_token_ids
,
current_token_ids
,
delta_token_ids
,
)
if
ret
is
None
:
return
None
# Compatibility path for models that don't generate the start token:
# treat everything before </think> as reasoning and everything after
# as content.
if
(
self
.
start_token_id
not
in
previous_token_ids
and
self
.
start_token_id
not
in
delta_token_ids
):
if
self
.
end_token_id
in
delta_token_ids
:
end_index
=
delta_text
.
find
(
self
.
end_token
)
reasoning
=
delta_text
[:
end_index
]
content
=
delta_text
[
end_index
+
len
(
self
.
end_token
)
:]
ret
=
DeltaMessage
(
reasoning
=
reasoning
,
content
=
content
or
None
)
elif
self
.
end_token_id
in
previous_token_ids
:
ret
=
DeltaMessage
(
content
=
delta_text
)
else
:
ret
=
DeltaMessage
(
reasoning
=
delta_text
)
reasoning_to_output
=
ret
.
reasoning
content_to_output
=
ret
.
content
# Reasoning: handle the newline immediately before </think>.
if
reasoning_to_output
is
not
None
:
if
self
.
_pending_reasoning_newline
:
reasoning_to_output
=
"
\n
"
+
reasoning_to_output
self
.
_pending_reasoning_newline
=
False
if
reasoning_to_output
.
endswith
(
"
\n
"
):
reasoning_to_output
=
reasoning_to_output
.
removesuffix
(
"
\n
"
)
if
self
.
end_token
in
delta_text
:
# Trailing "\n" is right before </think>, drop it.
self
.
_pending_reasoning_newline
=
False
else
:
# Hold the trailing "\n" until we know whether </think> follows.
self
.
_pending_reasoning_newline
=
True
# Content: handle the newline immediately after </think>.
if
content_to_output
is
not
None
:
# No need to get into parser again to remove newline after </think>.
self
.
end_offset
-=
1
# If we have content, reasoning must have ended.
self
.
_pending_reasoning_newline
=
False
if
self
.
end_token
in
delta_text
and
content_to_output
.
startswith
(
"
\n
"
):
content_to_output
=
content_to_output
.
removeprefix
(
"
\n
"
)
reasoning_to_output
=
reasoning_to_output
or
None
content_to_output
=
content_to_output
or
None
if
reasoning_to_output
is
None
and
content_to_output
is
None
:
return
None
return
DeltaMessage
(
reasoning
=
reasoning_to_output
,
content
=
content_to_output
)
vllm/tool_parsers/__init__.py
View file @
fc7980db
...
...
@@ -134,6 +134,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"step3_tool_parser"
,
"Step3ToolParser"
,
),
"step3p5"
:
(
"step3p5_tool_parser"
,
"Step3p5ToolParser"
,
),
"xlam"
:
(
"xlam_tool_parser"
,
"xLAMToolParser"
,
...
...
vllm/tool_parsers/step3p5_tool_parser.py
0 → 100644
View file @
fc7980db
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
ast
import
json
from
collections.abc
import
Sequence
from
typing
import
Any
from
xml.parsers.expat
import
ParserCreate
import
regex
as
re
from
vllm.entrypoints.chat_utils
import
make_tool_call_id
from
vllm.entrypoints.openai.chat_completion.protocol
import
(
ChatCompletionRequest
,
ChatCompletionToolsParam
,
)
from
vllm.entrypoints.openai.engine.protocol
import
(
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ExtractedToolCallInformation
,
FunctionCall
,
ToolCall
,
)
from
vllm.logger
import
init_logger
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tool_parsers.abstract_tool_parser
import
(
ToolParser
,
ToolParserManager
,
)
logger
=
init_logger
(
__name__
)
class
StreamingXMLToolCallParser
:
"""
Simplified streaming XML tool call parser
Supports streaming input, parsing, and output
"""
def
__init__
(
self
):
self
.
reset_streaming_state
()
# Tool configuration information
self
.
tools
:
list
[
ChatCompletionToolsParam
]
|
None
=
None
self
.
tool_call_start_token
:
str
=
"<tool_call>"
self
.
tool_call_end_token
:
str
=
"</tool_call>"
self
.
function_start_token
:
str
=
"<function="
self
.
function_end_token
:
str
=
"</function>"
self
.
parameter_start_token
:
str
=
"<parameter="
self
.
parameter_end_token
:
str
=
"</parameter>"
def
reset_streaming_state
(
self
):
"""Reset streaming parsing state"""
self
.
deltas
=
[]
# state for streaming
self
.
tool_call_index
=
0
self
.
current_call_id
=
None
self
.
last_completed_call_id
=
None
self
.
current_function_name
=
None
self
.
current_function_open
=
False
self
.
parameters
=
{}
self
.
current_param_name
=
None
self
.
current_param_value
=
""
self
.
current_param_value_converted
=
""
self
.
current_param_is_first
=
False
self
.
should_emit_end_newline
=
False
self
.
start_quote_emitted
=
False
self
.
streaming_buffer
=
""
self
.
last_processed_pos
=
0
self
.
text_content_buffer
=
""
# state for preprocessing and deferred parsing
self
.
_pre_inside_parameter
=
False
self
.
_pre_param_buffer
=
""
self
.
_pre_current_param_name
=
None
self
.
defer_current_parameter
=
False
self
.
deferred_param_raw_value
=
""
# recreate parser
self
.
parser
=
ParserCreate
()
self
.
setup_parser
()
def
parse_single_streaming_chunks
(
self
,
xml_chunk
:
str
)
->
DeltaMessage
:
"""
Parse single streaming XML chunk and return Delta response
This is the actual streaming interface that receives chunks
one by one and maintains internal state
Args:
xml_chunk: Single XML chunk string
Returns:
DeltaMessage: Contains delta information generated by this chunk,
returns empty response if no complete elements
"""
# Record delta count before processing
initial_delta_count
=
len
(
self
.
deltas
)
self
.
streaming_buffer
+=
xml_chunk
found_elements
=
self
.
_process_complete_xml_elements
()
if
found_elements
:
# If complete elements found, check if end events were missed
# some tags may not have been triggered
try
:
new_deltas
=
self
.
deltas
[
initial_delta_count
:]
# If this chunk contains </function>
# but didn't generate '}', then complete it
if
(
self
.
current_call_id
is
not
None
and
self
.
function_end_token
in
xml_chunk
):
# - Added '}' (non-empty parameter ending)
# - Added '{}' (empty parameter function)
has_function_close
=
any
(
(
td
.
tool_calls
and
any
(
(
tc
.
function
and
tc
.
id
==
self
.
current_call_id
and
isinstance
(
tc
.
function
.
arguments
,
str
)
and
(
tc
.
function
.
arguments
in
(
"}"
,
"{}"
))
)
for
tc
in
td
.
tool_calls
)
)
for
td
in
new_deltas
)
if
not
has_function_close
:
# Close potentially unclosed element
if
self
.
current_param_name
:
self
.
_end_element
(
"parameter"
)
if
self
.
current_function_name
:
self
.
_end_element
(
"function"
)
# If this chunk contains </tool_call>
# but didn't generate final empty delta, then complete it
if
(
self
.
current_call_id
is
not
None
and
self
.
tool_call_end_token
in
xml_chunk
):
has_toolcall_close
=
any
(
(
td
.
tool_calls
and
any
(
(
tc
.
type
==
"function"
and
tc
.
function
and
tc
.
function
.
arguments
==
""
and
tc
.
id
==
self
.
current_call_id
)
for
tc
in
td
.
tool_calls
)
)
for
td
in
new_deltas
)
if
not
has_toolcall_close
:
# Close potentially unclosed element
if
self
.
current_param_name
:
self
.
_end_element
(
"parameter"
)
if
self
.
current_function_name
:
self
.
_end_element
(
"function"
)
self
.
_end_element
(
"tool_call"
)
except
Exception
as
e
:
logger
.
warning
(
"Error with fallback parsing: %s"
,
e
)
# Merge newly generated deltas into single response
result_delta
=
self
.
_merge_new_deltas_to_single_response
(
initial_delta_count
)
return
result_delta
else
:
# No complete elements, check if there's unoutput text content
if
self
.
text_content_buffer
and
self
.
tool_call_index
==
0
:
# Has text content but no tool_call yet, output text content
text_delta
=
DeltaMessage
(
content
=
self
.
text_content_buffer
)
self
.
_emit_delta
(
text_delta
)
# Clear buffer to avoid duplicate output
self
.
text_content_buffer
=
""
return
text_delta
# If this chunk contains end tags but wasn't triggered by parser,
# manually complete end events
# Only execute when still on the same call as when entered,
# to prevent accidentally closing new calls
# in multi <tool_call> scenarios
if
self
.
current_call_id
is
not
None
and
(
self
.
function_end_token
in
xml_chunk
or
self
.
tool_call_end_token
in
xml_chunk
):
# Close potentially unclosed element
if
self
.
current_param_name
:
self
.
_end_element
(
"parameter"
)
if
self
.
function_end_token
in
xml_chunk
and
self
.
current_function_name
:
self
.
_end_element
(
"function"
)
if
self
.
tool_call_end_token
in
xml_chunk
:
self
.
_end_element
(
"tool_call"
)
# Return the merged delta result generated by this fallback
result_delta
=
self
.
_merge_new_deltas_to_single_response
(
initial_delta_count
)
return
result_delta
# No complete elements, return empty response
return
DeltaMessage
(
content
=
None
)
def
_escape_xml_special_chars
(
self
,
text
:
str
)
->
str
:
"""
Escape XML special characters
Args:
text: Original text
Returns:
Escaped text
"""
xml_escapes
=
{
"&"
:
"&"
,
"<"
:
"<"
,
">"
:
">"
,
'"'
:
"""
,
"'"
:
"'"
,
}
for
char
,
escape
in
xml_escapes
.
items
():
text
=
text
.
replace
(
char
,
escape
)
return
text
def
_process_complete_xml_elements
(
self
)
->
bool
:
"""
Process complete XML elements in buffer
Returns:
bool: Whether complete elements were found and processed
"""
found_any
=
False
while
self
.
last_processed_pos
<
len
(
self
.
streaming_buffer
):
# Find next complete xml element
element
,
end_pos
=
self
.
_find_next_complete_element
(
self
.
last_processed_pos
)
if
element
is
None
:
# No complete element found, wait for more data
break
# Check if this element should be skipped
if
self
.
_should_skip_element
(
element
):
self
.
last_processed_pos
=
end_pos
continue
# Found complete XML element, process it
try
:
preprocessed_element
=
self
.
_preprocess_xml_chunk
(
element
)
# Check if this is the first tool_call start
if
(
(
preprocessed_element
.
strip
().
startswith
(
"<tool_call>"
)
or
preprocessed_element
.
strip
().
startswith
(
"<function name="
)
)
and
self
.
tool_call_index
==
0
)
and
self
.
text_content_buffer
:
# First tool_call starts,
# output previously collected text content first
text_delta
=
DeltaMessage
(
content
=
self
.
text_content_buffer
)
self
.
_emit_delta
(
text_delta
)
# Clear buffer for potential subsequent text content
self
.
text_content_buffer
=
""
# If a new tool_call starts and
# there are already completed tool_calls with function name
if
(
preprocessed_element
.
strip
().
startswith
(
"<tool_call>"
)
and
self
.
tool_call_index
>
0
and
self
.
current_call_id
and
self
.
current_function_name
):
# Reset parser state but preserve generated deltas
if
self
.
current_param_name
:
self
.
_end_element
(
"parameter"
)
if
self
.
current_function_open
:
self
.
_end_element
(
"function"
)
# Output final tool_call tail delta
final_delta
=
DeltaMessage
(
role
=
None
,
content
=
None
,
reasoning_content
=
None
,
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
""
),
)
],
)
self
.
_emit_delta
(
final_delta
)
# Reset XML parser and current call state
self
.
_reset_xml_parser_after_tool_call
()
# Parse preprocessed element
self
.
parser
.
Parse
(
preprocessed_element
,
False
)
found_any
=
True
except
Exception
as
e
:
logger
.
warning
(
"Error when parsing XML elements: %s"
,
e
)
# Update processed position
self
.
last_processed_pos
=
end_pos
return
found_any
def
_fix_incomplete_tag_in_chunk
(
self
,
chunk
:
str
)
->
str
:
"""
Fallback: fix incomplete <parameter=xxx or <function=xxx tags
(missing >)
Examples: <parameter=-C: -> <parameter=-C>, <parameter=parameter=-n:
-> <parameter=-n>
Also handles missing = cases: <function xxx> -> <function=xxx>,
<functionxxx> -> <function=xxx>
Only fixes tags that pass validation (parameter exists in tool definition)
"""
# First, handle missing = cases for function tags
chunk
=
self
.
_fix_missing_equals_in_function_tag
(
chunk
)
for
tag_type
in
[
"parameter"
,
"function"
]:
pattern
=
f
"<
{
tag_type
}
="
if
pattern
not
in
chunk
:
continue
start_idx
=
chunk
.
find
(
pattern
)
after_tag
=
chunk
[
start_idx
:]
gt_pos
=
after_tag
.
find
(
">"
)
lt_pos
=
after_tag
.
find
(
"<"
,
len
(
pattern
))
# Skip if already well-formed
if
(
gt_pos
!=
-
1
and
(
lt_pos
==
-
1
or
gt_pos
<
lt_pos
)
and
pattern
in
after_tag
[:
gt_pos
]
):
continue
# Extract tag name (stop at space, newline, or <)
content
=
chunk
[
start_idx
+
len
(
pattern
)
:]
end_pos
=
next
(
(
i
for
i
,
ch
in
enumerate
(
content
)
if
ch
in
(
" "
,
"
\n
"
,
"<"
)),
len
(
content
),
)
tag_name
=
content
[:
end_pos
]
if
not
tag_name
:
continue
# Remove duplicate prefix: <parameter=parameter=xxx -> <parameter=xxx
if
tag_name
.
startswith
(
f
"
{
tag_type
}
="
):
tag_name
=
tag_name
[
len
(
tag_type
)
+
1
:]
# Remove trailing non-alphanumeric chars (keep - and _)
while
tag_name
and
not
(
tag_name
[
-
1
].
isalnum
()
or
tag_name
[
-
1
]
in
(
"-"
,
"_"
)
):
tag_name
=
tag_name
[:
-
1
]
if
not
tag_name
:
continue
# Validate parameter exists in tool definition
if
tag_type
==
"parameter"
and
not
self
.
_validate_parameter_name
(
tag_name
):
continue
# Apply fix
chunk
=
chunk
.
replace
(
f
"<
{
tag_type
}
=
{
content
[:
end_pos
]
}
"
,
f
"<
{
tag_type
}
=
{
tag_name
}
>"
,
1
)
return
chunk
def
_fix_missing_equals_in_function_tag
(
self
,
chunk
:
str
)
->
str
:
"""
Fix missing = in function tags: <function xxx> or <functionxxx>
Examples:
<function execute_bash> -> <function=execute_bash>
<functionexecute_bash> -> <function=execute_bash>
Only fixes if function name exists in tool definition
"""
# already correct
if
"<function="
in
chunk
:
return
chunk
# Pattern 1: <function xxx> (with space/newline but no =)
pattern1
=
r
"<function\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*>"
match1
=
re
.
search
(
pattern1
,
chunk
)
if
match1
:
func_name
=
match1
.
group
(
1
).
strip
()
# must validate function name exists before fixing
if
func_name
and
self
.
_validate_function_name
(
func_name
):
original
=
match1
.
group
(
0
)
fixed
=
f
"<function=
{
func_name
}
>"
chunk
=
chunk
.
replace
(
original
,
fixed
,
1
)
return
chunk
# Pattern 2: <functionxxx> (no space, no =)
# only match <function followed by letters
pattern2
=
r
"<function([a-zA-Z_][a-zA-Z0-9_]*)\s*>"
match2
=
re
.
search
(
pattern2
,
chunk
)
if
match2
:
func_name
=
match2
.
group
(
1
).
strip
()
# must validate function name exists before fixing
if
func_name
and
self
.
_validate_function_name
(
func_name
):
original
=
match2
.
group
(
0
)
fixed
=
f
"<function=
{
func_name
}
>"
chunk
=
chunk
.
replace
(
original
,
fixed
,
1
)
return
chunk
return
chunk
def
_validate_function_name
(
self
,
func_name
:
str
)
->
bool
:
"""Check if function name exists in tool definitions"""
if
not
self
.
tools
:
return
False
for
tool
in
self
.
tools
:
if
(
hasattr
(
tool
,
"type"
)
and
tool
.
type
==
"function"
and
hasattr
(
tool
,
"function"
)
and
hasattr
(
tool
.
function
,
"name"
)
and
tool
.
function
.
name
==
func_name
):
return
True
return
False
def
_validate_parameter_name
(
self
,
param_name
:
str
)
->
bool
:
"""Check if parameter exists in current function's tool definition"""
if
not
self
.
tools
or
not
self
.
current_function_name
:
return
True
for
tool
in
self
.
tools
:
if
(
hasattr
(
tool
,
"type"
)
and
tool
.
type
==
"function"
and
hasattr
(
tool
,
"function"
)
and
hasattr
(
tool
.
function
,
"name"
)
and
tool
.
function
.
name
==
self
.
current_function_name
):
if
not
hasattr
(
tool
.
function
,
"parameters"
):
return
True
params
=
tool
.
function
.
parameters
if
isinstance
(
params
,
dict
):
properties
=
params
.
get
(
"properties"
,
params
)
return
param_name
in
properties
break
return
True
def
_should_skip_element
(
self
,
element
:
str
)
->
bool
:
"""
Determine whether an element should be skipped
Args:
element: Element to evaluate
Returns:
bool: True means should skip, False means should process
"""
# If it's a tool_call XML tag, don't skip
if
(
element
.
startswith
(
self
.
tool_call_start_token
)
or
element
.
startswith
(
self
.
function_start_token
)
or
element
.
startswith
(
self
.
parameter_start_token
)
):
return
False
# If currently not parsing tool calls and not blank,
# collect this text instead of skipping
# Only process other XML elements after tool_call appears,
# otherwise treat as plain text
if
self
.
current_call_id
is
None
and
element
:
# Collect text content to buffer
self
.
text_content_buffer
+=
element
return
True
# Still skip, but content has been collected
# If currently parsing tool calls,
# this might be parameter value, don't skip
if
self
.
current_call_id
is
not
None
:
return
False
# Skip blank content
return
not
element
def
_find_next_complete_element
(
self
,
start_pos
:
int
)
->
tuple
[
str
|
None
,
int
]:
"""
Find next complete XML element from specified position
Args:
start_pos: Position to start searching
Returns:
(Complete element string, element end position),
returns (None, start_pos) if no complete element found
"""
buffer
=
self
.
streaming_buffer
[
start_pos
:]
if
not
buffer
:
return
None
,
start_pos
if
buffer
.
startswith
(
"<"
):
# Check if this is an incomplete parameter/function tag
# e.g., <parameter=-C: or <function=xxx
is_incomplete_param
=
(
buffer
.
startswith
(
"<parameter="
)
and
">"
not
in
buffer
.
split
(
"
\n
"
)[
0
]
)
is_incomplete_func
=
(
buffer
.
startswith
(
"<function="
)
and
">"
not
in
buffer
.
split
(
"
\n
"
)[
0
]
)
if
is_incomplete_param
or
is_incomplete_func
:
# Find the corresponding closing tag
tag_type
=
"parameter"
if
is_incomplete_param
else
"function"
closing_tag
=
f
"</
{
tag_type
}
>"
closing_pos
=
buffer
.
find
(
closing_tag
)
if
closing_pos
!=
-
1
:
# Found closing tag, return complete element including closing tag
complete_element
=
buffer
[:
closing_pos
+
len
(
closing_tag
)]
return
complete_element
,
start_pos
+
closing_pos
+
len
(
closing_tag
)
# Need to ensure no new < appears,
# find the nearest one between < and >
tag_end
=
buffer
.
find
(
"<"
,
1
)
tag_end2
=
buffer
.
find
(
">"
,
1
)
if
tag_end
!=
-
1
and
tag_end2
!=
-
1
:
# Next nearest is <
if
tag_end
<
tag_end2
:
return
buffer
[:
tag_end
],
start_pos
+
tag_end
# Next nearest is >, means found XML element
else
:
return
buffer
[:
tag_end2
+
1
],
start_pos
+
tag_end2
+
1
elif
tag_end
!=
-
1
:
return
buffer
[:
tag_end
],
start_pos
+
tag_end
elif
tag_end2
!=
-
1
:
return
buffer
[:
tag_end2
+
1
],
start_pos
+
tag_end2
+
1
else
:
# If currently not parsing tool calls (entering a tool_call),
# check if starts with <tool_call> or <function=
if
self
.
current_call_id
is
None
:
# Check if might be start of <tool_call>
if
buffer
==
"<tool_call>"
[:
len
(
buffer
)]:
# Might be start of <tool_call>, wait for more data
return
None
,
start_pos
elif
(
buffer
.
startswith
(
"<function="
)
or
buffer
==
"<function="
[:
len
(
buffer
)]
):
# Might be start of <function=, wait for more data
# to get the complete function tag
return
None
,
start_pos
else
:
# Not start of <tool_call> or <function=, treat as text
return
buffer
,
start_pos
+
len
(
buffer
)
else
:
# When parsing tool calls,
# wait for more data to get complete tag
return
None
,
start_pos
else
:
# Find text content (until next < or buffer end)
next_tag_pos
=
buffer
.
find
(
"<"
)
if
next_tag_pos
!=
-
1
:
# Found text content
text_content
=
buffer
[:
next_tag_pos
]
return
text_content
,
start_pos
+
next_tag_pos
else
:
# Buffer end is all text, process
# (no longer wait for more data)
remaining
=
buffer
return
remaining
,
start_pos
+
len
(
remaining
)
def
_merge_new_deltas_to_single_response
(
self
,
initial_count
:
int
)
->
DeltaMessage
:
"""
Merge newly generated deltas from this processing
into a single DeltaMessage
Args:
initial_count: Delta count before processing
Returns:
Merged DeltaMessage containing all newly generated delta information
"""
if
len
(
self
.
deltas
)
<=
initial_count
:
return
DeltaMessage
(
content
=
None
)
# Get newly generated deltas
new_deltas
=
self
.
deltas
[
initial_count
:]
if
len
(
new_deltas
)
==
1
:
# Only one new delta, return directly
return
new_deltas
[
0
]
# Merge multiple new deltas
merged_tool_calls
:
list
[
DeltaToolCall
]
=
[]
merged_content
:
str
=
""
for
delta
in
new_deltas
:
if
delta
.
content
:
merged_content
+=
delta
.
content
if
delta
.
tool_calls
:
# For tool_calls, we need to intelligently merge arguments
for
tool_call
in
delta
.
tool_calls
:
# Find if there's already a tool_call with the same call_id
existing_call
=
None
for
existing
in
merged_tool_calls
:
if
existing
.
id
==
tool_call
.
id
:
existing_call
=
existing
break
if
existing_call
and
existing_call
.
function
:
# Merge to existing tool_call
if
tool_call
.
function
and
tool_call
.
function
.
name
:
existing_call
.
function
.
name
=
tool_call
.
function
.
name
if
(
tool_call
.
function
and
tool_call
.
function
.
arguments
is
not
None
):
if
existing_call
.
function
.
arguments
is
None
:
existing_call
.
function
.
arguments
=
""
# For streaming JSON parameters,
# simply concatenate in order
new_args
=
tool_call
.
function
.
arguments
existing_call
.
function
.
arguments
+=
new_args
if
tool_call
.
type
:
existing_call
.
type
=
tool_call
.
type
else
:
# Add new tool_call
merged_tool_calls
.
append
(
tool_call
)
return
DeltaMessage
(
content
=
merged_content
if
merged_content
else
None
,
tool_calls
=
merged_tool_calls
,
)
def
_preprocess_xml_chunk
(
self
,
chunk
:
str
)
->
str
:
"""
Preprocess XML chunk, handle non-standard formats,
and escape special characters
Args:
chunk: Original XML chunk
Returns:
Processed XML chunk
"""
# Check if this is a tool_call related element
is_tool_call
=
False
if
chunk
.
startswith
(
self
.
tool_call_start_token
)
or
chunk
.
startswith
(
self
.
tool_call_end_token
):
is_tool_call
=
True
# Check for function tags (including malformed ones without =)
# <function=xxx>, </function>, <function xxx>, <functionxxx>
if
(
chunk
.
startswith
(
self
.
function_start_token
)
or
chunk
.
startswith
(
self
.
function_end_token
)
or
chunk
.
startswith
(
"<function "
)
or
re
.
match
(
r
"^<function[a-zA-Z_]"
,
chunk
)
):
# <functionXXX without space or =
is_tool_call
=
True
if
chunk
.
startswith
(
self
.
parameter_start_token
)
or
chunk
.
startswith
(
self
.
parameter_end_token
):
is_tool_call
=
True
# Fallback: fix incomplete <parameter= or <function= tags without
# closing >
# This handles cases like: <parameter=-C:\n or <parameter=-B 5\n
# Apply when parsing tool calls OR when chunk looks like a function/
# parameter tag
if
(
self
.
current_call_id
is
not
None
or
chunk
.
startswith
(
"<function"
)
or
chunk
.
startswith
(
"<parameter"
)
):
chunk
=
self
.
_fix_incomplete_tag_in_chunk
(
chunk
)
# Handle <function=name> format -> <function name="name">
processed
=
re
.
sub
(
r
"<function=([^>]+)>"
,
r
'<function name="\1">'
,
chunk
)
# Handle <parameter=name> format -> <parameter name="name">
processed
=
re
.
sub
(
r
"<parameter=([^>]+)>"
,
r
'<parameter name="\1">'
,
processed
)
original_chunk
=
chunk
# If in parameter value accumulation mode
if
self
.
_pre_inside_parameter
:
# Parameter end: output accumulated raw text
# safely then return </parameter>
if
processed
.
startswith
(
"</parameter>"
):
body_text
=
self
.
_pre_param_buffer
# Trigger deferred parsing mode
# literal_eval+json output in end_element
self
.
defer_current_parameter
=
True
self
.
deferred_param_raw_value
=
body_text
# Clean up state
self
.
_pre_inside_parameter
=
False
self
.
_pre_param_buffer
=
""
self
.
_pre_current_param_name
=
None
safe_text
=
self
.
_escape_xml_special_chars
(
body_text
)
return
f
"
{
safe_text
}
</parameter>"
else
:
# If this is the first block of content after entering parameter
# evaluate if deferred parsing is needed;
# If not needed, exit accumulation mode
# and pass through directly
if
self
.
_pre_param_buffer
==
""
:
# Get current parameter type
param_type
=
(
self
.
_get_param_type
(
self
.
_pre_current_param_name
)
if
self
.
_pre_current_param_name
else
"string"
)
# Only these types need deferred parsing to
# handle Python literals containing single quotes
is_object_type
=
param_type
in
[
"object"
]
is_complex_type
=
(
param_type
in
[
"array"
,
"arr"
,
"sequence"
]
or
param_type
.
startswith
(
"dict"
)
or
param_type
.
startswith
(
"list"
)
)
# Only delay when contains container symbols
# and has single quotes and is complex type
has_container_hint
=
(
(
"["
in
original_chunk
)
or
(
"{"
in
original_chunk
)
or
(
"("
in
original_chunk
)
)
# Determine if deferred parsing is needed
need_defer
=
False
if
is_complex_type
:
# Complex type, always need deferred parsing
need_defer
=
True
elif
(
is_object_type
and
has_container_hint
and
(
"'"
in
original_chunk
)
):
# Object type with container symbols
# and single quotes, need deferred parsing
need_defer
=
True
if
not
need_defer
:
# No need for deferred parsing,
# exit parameter mode directly
self
.
_pre_inside_parameter
=
False
return
self
.
_escape_xml_special_chars
(
original_chunk
)
self
.
_pre_param_buffer
+=
original_chunk
return
""
# Parameter start: enable accumulation
if
processed
.
startswith
(
"<parameter name="
):
m
=
re
.
match
(
r
'<parameter name="([^"]+)">'
,
processed
)
if
m
:
self
.
_pre_current_param_name
=
m
.
group
(
1
)
self
.
_pre_inside_parameter
=
True
self
.
_pre_param_buffer
=
""
return
processed
# If processed doesn't contain special_token, escape processed
# This is because XML parsing encounters special characters
# and reports errors, so escaping is needed
if
not
is_tool_call
:
processed
=
self
.
_escape_xml_special_chars
(
processed
)
return
processed
def
_emit_delta
(
self
,
delta
:
DeltaMessage
):
"""Emit Delta response (streaming output)"""
self
.
deltas
.
append
(
delta
)
def
_auto_close_open_parameter_if_needed
(
self
,
incoming_tag
:
str
|
None
=
None
):
"""Before starting to process new elements,
if there are unclosed tags from before,
automatically complete their endings to the parser.
- If there are unclosed parameters,
it's equivalent to feeding `</parameter>`
- When about to start a new function or tool_call,
if there are unclosed functions, complete `</function>`.
- When about to start a new tool_call,
if there are unclosed tool_calls, complete `</tool_call>`.
"""
# First close unclosed parameters
if
self
.
current_param_name
:
self
.
_end_element
(
"parameter"
)
# If about to start new function or tool_call,
# and there are unclosed functions, close function first
if
incoming_tag
in
(
"function"
,
"tool_call"
)
and
self
.
current_function_name
:
self
.
_end_element
(
"function"
)
# If about to start new tool_call,
# and there are unclosed tool_calls, close tool_call first
if
incoming_tag
==
"tool_call"
and
self
.
current_call_id
:
self
.
_end_element
(
"tool_call"
)
def
_start_element
(
self
,
name
:
str
,
attrs
:
dict
[
str
,
str
]):
"""Handle XML start element events"""
if
name
==
"root"
:
return
if
name
==
"tool_call"
:
# Before opening new tool_call,
# automatically complete previous unclosed tags
self
.
_auto_close_open_parameter_if_needed
(
"tool_call"
)
self
.
parameters
=
{}
self
.
current_call_id
=
make_tool_call_id
()
self
.
current_param_is_first
=
True
self
.
tool_call_index
+=
1
elif
name
.
startswith
(
"function"
)
or
(
name
==
"function"
):
# If missing tool_call, manually complete
if
not
self
.
current_call_id
:
self
.
_start_element
(
"tool_call"
,
{})
# Before opening new function,
# automatically complete previous unclosed tags (parameter/function)
self
.
_auto_close_open_parameter_if_needed
(
"function"
)
function_name
=
self
.
_extract_function_name
(
name
,
attrs
)
self
.
current_function_name
=
function_name
self
.
current_function_open
=
True
if
function_name
:
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
function_name
,
arguments
=
""
),
)
]
)
self
.
_emit_delta
(
delta
)
elif
name
.
startswith
(
"parameter"
)
or
(
name
==
"parameter"
):
# If previous parameter hasn't ended normally,
# complete its end first, then start new parameter
self
.
_auto_close_open_parameter_if_needed
(
"parameter"
)
param_name
=
self
.
_extract_parameter_name
(
name
,
attrs
)
self
.
current_param_name
=
param_name
self
.
current_param_value
=
""
self
.
current_param_value_converted
=
""
self
.
start_quote_emitted
=
False
# Reset start quote flag
# Only output parameter name and colon,
# don't output quotes
# decide after parameter value type is determined
if
param_name
:
if
not
self
.
parameters
:
# First parameter
# start JSON, only output parameter name and colon
json_start
=
f
'{{"
{
param_name
}
": '
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
json_start
),
)
]
)
self
.
_emit_delta
(
delta
)
self
.
current_param_is_first
=
True
else
:
# Subsequent parameters
# add comma and parameter name, no quotes
json_continue
=
f
', "
{
param_name
}
": '
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
json_continue
),
)
]
)
self
.
_emit_delta
(
delta
)
self
.
current_param_is_first
=
False
def
_char_data
(
self
,
data
:
str
):
"""Handle XML character data events"""
if
data
and
self
.
current_param_name
:
# If preprocessing stage determines deferred parsing is needed,
# only cache character data, no streaming output
if
self
.
defer_current_parameter
:
original_data
=
data
if
self
.
should_emit_end_newline
:
original_data
=
"
\n
"
+
original_data
self
.
should_emit_end_newline
=
False
if
original_data
.
endswith
(
"
\n
"
):
self
.
should_emit_end_newline
=
True
original_data
=
original_data
[:
-
1
]
self
.
current_param_value
+=
original_data
return
param_type
=
self
.
_get_param_type
(
self
.
current_param_name
)
# Check if this is the first time receiving data for this parameter
# If this is the first packet of data and starts with \n, remove \n
if
not
self
.
current_param_value
and
data
.
startswith
(
"
\n
"
):
data
=
data
[
1
:]
# Output start quote for string type (if not already output)
if
(
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]
and
not
self
.
start_quote_emitted
):
quote_delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
'"'
),
)
]
)
self
.
_emit_delta
(
quote_delta
)
self
.
start_quote_emitted
=
True
if
not
data
:
return
original_data
=
data
# Delay output of trailing newline
if
self
.
should_emit_end_newline
:
original_data
=
"
\n
"
+
original_data
self
.
should_emit_end_newline
=
False
if
original_data
.
endswith
(
"
\n
"
):
self
.
should_emit_end_newline
=
True
original_data
=
original_data
[:
-
1
]
self
.
current_param_value
+=
original_data
# convert parameter value by param_type
converted_value
=
self
.
_convert_param_value
(
self
.
current_param_value
,
param_type
)
output_data
=
self
.
_convert_for_json_streaming
(
converted_value
,
param_type
)
delta_data
=
output_data
[
len
(
self
.
current_param_value_converted
)
:]
self
.
current_param_value_converted
=
output_data
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
delta_data
),
)
]
)
self
.
_emit_delta
(
delta
)
def
_end_element
(
self
,
name
:
str
):
"""Handle XML end element events"""
if
name
==
"root"
:
return
# If function or tool_call ends and there are still unclosed parameters,
# complete parameter end first
if
(
name
.
startswith
(
"function"
)
or
name
==
"function"
or
name
==
"tool_call"
)
and
self
.
current_param_name
:
self
.
_auto_close_open_parameter_if_needed
()
if
(
name
.
startswith
(
"parameter"
)
or
name
==
"parameter"
)
and
self
.
current_param_name
:
# End current parameter
param_name
=
self
.
current_param_name
param_value
=
self
.
current_param_value
# If in deferred parsing mode,
# perform overall parsing on raw content
# accumulated in preprocessing stage and output once
if
self
.
defer_current_parameter
:
raw_text
=
(
self
.
deferred_param_raw_value
if
self
.
deferred_param_raw_value
else
param_value
)
parsed_value
=
None
output_arguments
=
None
try
:
# If previously delayed trailing newline,
# add it back before parsing
if
self
.
should_emit_end_newline
:
raw_for_parse
=
raw_text
+
"
\n
"
else
:
raw_for_parse
=
raw_text
parsed_value
=
ast
.
literal_eval
(
raw_for_parse
)
output_arguments
=
json
.
dumps
(
parsed_value
,
ensure_ascii
=
False
)
except
Exception
:
# Fallback: output as string as-is
output_arguments
=
json
.
dumps
(
raw_text
,
ensure_ascii
=
False
)
parsed_value
=
raw_text
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
output_arguments
),
)
]
)
self
.
_emit_delta
(
delta
)
# Clean up and store
self
.
should_emit_end_newline
=
False
self
.
parameters
[
param_name
]
=
parsed_value
self
.
current_param_name
=
None
self
.
current_param_value
=
""
self
.
current_param_value_converted
=
""
self
.
start_quote_emitted
=
False
self
.
defer_current_parameter
=
False
self
.
deferred_param_raw_value
=
""
return
param_type
=
self
.
_get_param_type
(
param_name
)
# convert complete parameter value by param_type
converted_value
=
self
.
_convert_param_value
(
param_value
,
param_type
)
# Decide whether to add end quote based on parameter type
if
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]:
# For empty string parameters, need special handling
if
not
param_value
and
not
self
.
start_quote_emitted
:
# No start quote output,
# directly output complete empty string
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
'""'
),
)
]
)
self
.
_emit_delta
(
delta
)
else
:
# Non-empty parameter value, output end quote
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
'"'
),
)
]
)
self
.
_emit_delta
(
delta
)
self
.
should_emit_end_newline
=
False
# Store converted value
self
.
parameters
[
param_name
]
=
converted_value
self
.
current_param_name
=
None
self
.
current_param_value
=
""
self
.
current_param_value_converted
=
""
self
.
start_quote_emitted
=
False
elif
name
.
startswith
(
"function"
)
or
name
==
"function"
:
# if there are parameters, close JSON object
if
self
.
parameters
:
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
"}"
),
)
]
)
self
.
_emit_delta
(
delta
)
# return empty object
else
:
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
"{}"
),
)
]
)
self
.
_emit_delta
(
delta
)
self
.
current_function_open
=
False
self
.
current_function_name
=
(
None
# Clear function name to prevent duplicate closing
)
elif
name
==
"tool_call"
:
# Before ending tool_call,
# ensure function is closed to complete missing right brace
if
self
.
current_function_open
:
# If there are still unclosed parameters, close them first
if
self
.
current_param_name
:
self
.
_end_element
(
"parameter"
)
# Close function, ensure output '}' or '{}'
self
.
_end_element
(
"function"
)
# Final Delta
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
""
),
)
]
)
self
.
_emit_delta
(
delta
)
# Check if there's text content to output (between tool_calls)
if
self
.
text_content_buffer
.
strip
():
text_delta
=
DeltaMessage
(
content
=
self
.
text_content_buffer
)
self
.
_emit_delta
(
text_delta
)
self
.
_reset_xml_parser_after_tool_call
()
def
setup_parser
(
self
):
"""Set up XML parser event handlers"""
self
.
parser
.
buffer_text
=
True
self
.
parser
.
StartElementHandler
=
self
.
_start_element
self
.
parser
.
EndElementHandler
=
self
.
_end_element
self
.
parser
.
CharacterDataHandler
=
self
.
_char_data
def
set_tools
(
self
,
tools
:
list
[
ChatCompletionToolsParam
]
|
None
):
"""Set tool configuration information"""
self
.
tools
=
tools
def
_extract_function_name
(
self
,
name
:
str
,
attrs
:
dict
[
str
,
str
])
->
str
|
None
:
"""Extract function name from various formats"""
if
attrs
and
"name"
in
attrs
:
return
attrs
[
"name"
]
if
"="
in
name
:
parts
=
name
.
split
(
"="
,
1
)
if
len
(
parts
)
==
2
and
parts
[
0
]
==
"function"
:
return
parts
[
1
]
return
None
def
_extract_parameter_name
(
self
,
name
:
str
,
attrs
:
dict
[
str
,
str
])
->
str
|
None
:
"""Extract parameter name from various formats"""
if
attrs
and
"name"
in
attrs
:
return
attrs
[
"name"
]
if
"="
in
name
:
parts
=
name
.
split
(
"="
,
1
)
if
len
(
parts
)
==
2
and
parts
[
0
]
==
"parameter"
:
return
parts
[
1
]
return
None
def
_get_param_type
(
self
,
param_name
:
str
)
->
str
:
"""Get parameter type based on tool configuration, defaults to string
Args:
param_name: Parameter name
Returns:
Parameter type
"""
if
not
self
.
tools
or
not
self
.
current_function_name
:
return
"string"
for
tool
in
self
.
tools
:
if
not
hasattr
(
tool
,
"type"
)
or
not
(
hasattr
(
tool
,
"function"
)
and
hasattr
(
tool
.
function
,
"name"
)
):
continue
if
(
tool
.
type
==
"function"
and
tool
.
function
.
name
==
self
.
current_function_name
):
if
not
hasattr
(
tool
.
function
,
"parameters"
):
return
"string"
params
=
tool
.
function
.
parameters
if
isinstance
(
params
,
dict
)
and
"properties"
in
params
:
properties
=
params
[
"properties"
]
if
param_name
in
properties
and
isinstance
(
properties
[
param_name
],
dict
):
return
self
.
repair_param_type
(
str
(
properties
[
param_name
].
get
(
"type"
,
"string"
))
)
elif
isinstance
(
params
,
dict
)
and
param_name
in
params
:
param_config
=
params
[
param_name
]
if
isinstance
(
param_config
,
dict
):
return
self
.
repair_param_type
(
str
(
param_config
.
get
(
"type"
,
"string"
))
)
break
return
"string"
def
repair_param_type
(
self
,
param_type
:
str
)
->
str
:
"""Repair unknown parameter types by treating them as string
Args:
param_type: Parameter type
Returns:
Repaired parameter type
"""
if
(
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]
or
param_type
.
startswith
(
"int"
)
or
param_type
.
startswith
(
"uint"
)
or
param_type
.
startswith
(
"long"
)
or
param_type
.
startswith
(
"short"
)
or
param_type
.
startswith
(
"unsigned"
)
or
param_type
.
startswith
(
"num"
)
or
param_type
.
startswith
(
"float"
)
or
param_type
in
[
"boolean"
,
"bool"
,
"binary"
]
or
(
param_type
in
[
"object"
,
"array"
,
"arr"
,
"sequence"
]
or
param_type
.
startswith
(
"dict"
)
or
param_type
.
startswith
(
"list"
)
)
):
return
param_type
else
:
return
"string"
def
_convert_param_value
(
self
,
param_value
:
str
,
param_type
:
str
)
->
Any
:
"""Convert value based on parameter type
Args:
param_value: Parameter value
param_type: Parameter type
Returns:
Converted value
"""
if
param_value
.
lower
()
==
"null"
:
return
None
param_type
=
param_type
.
strip
().
lower
()
if
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]:
return
param_value
elif
(
param_type
.
startswith
(
"int"
)
or
param_type
.
startswith
(
"uint"
)
or
param_type
.
startswith
(
"long"
)
or
param_type
.
startswith
(
"short"
)
or
param_type
.
startswith
(
"unsigned"
)
):
try
:
return
int
(
param_value
)
except
(
ValueError
,
TypeError
):
logger
.
warning
(
"Parsed value '%s' is not an integer, degenerating to string."
,
param_value
,
)
return
param_value
elif
param_type
.
startswith
(
"num"
)
or
param_type
.
startswith
(
"float"
):
try
:
float_param_value
:
float
=
float
(
param_value
)
return
(
float_param_value
if
float_param_value
-
int
(
float_param_value
)
!=
0
else
int
(
float_param_value
)
)
except
(
ValueError
,
TypeError
):
logger
.
warning
(
"Parsed value '%s' is not a float, degenerating to string."
,
param_value
,
)
return
param_value
elif
param_type
in
[
"boolean"
,
"bool"
,
"binary"
]:
param_value
=
param_value
.
lower
()
return
param_value
==
"true"
else
:
return
param_value
def
_convert_for_json_streaming
(
self
,
converted_value
:
Any
,
param_type
:
str
)
->
str
:
"""Convert converted_value based on
whether it's empty and if type is string
Args:
converted_value: Converted value
param_type: Parameter type
Returns:
Converted string for streaming output
"""
# Check if value is empty, but exclude numeric 0
if
converted_value
is
None
or
converted_value
==
""
:
return
""
if
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]:
# String type, remove double quotes
return
json
.
dumps
(
converted_value
,
ensure_ascii
=
False
)[
1
:
-
1
]
else
:
# Non-string type, return complete JSON string
if
not
isinstance
(
converted_value
,
str
):
return
json
.
dumps
(
converted_value
,
ensure_ascii
=
False
)
else
:
return
converted_value
def
_reset_xml_parser_after_tool_call
(
self
):
"""
Each tool_call is treated as a separate XML document,
so we need to reset the parser after each tool_call.
"""
# recreate XML parser
self
.
parser
=
ParserCreate
()
self
.
setup_parser
()
# Reset current tool_call state
if
self
.
current_call_id
:
self
.
last_completed_call_id
=
self
.
current_call_id
self
.
current_call_id
=
None
self
.
current_function_name
=
None
self
.
current_function_open
=
False
self
.
parameters
=
{}
self
.
current_param_name
=
None
self
.
current_param_value
=
""
self
.
current_param_value_converted
=
""
self
.
current_param_is_first
=
False
self
.
should_emit_end_newline
=
False
self
.
start_quote_emitted
=
False
self
.
text_content_buffer
=
""
# Reset preprocessing and deferred parsing state
self
.
_pre_inside_parameter
=
False
self
.
_pre_param_buffer
=
""
self
.
_pre_current_param_name
=
None
self
.
defer_current_parameter
=
False
self
.
deferred_param_raw_value
=
""
@
ToolParserManager
.
register_module
(
"step3p5"
)
class
Step3p5ToolParser
(
ToolParser
):
def
__init__
(
self
,
tokenizer
:
TokenizerLike
):
super
().
__init__
(
tokenizer
)
self
.
parser
=
StreamingXMLToolCallParser
()
# Add missing attributes for compatibility with serving_chat.py
self
.
prev_tool_call_arr
:
list
[
dict
]
=
[]
self
.
streamed_args_for_tool
:
list
[
str
]
=
[]
logger
.
info
(
"vLLM Successfully import tool parser %s !"
,
self
.
__class__
.
__name__
)
def
extract_tool_calls
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
,
)
->
ExtractedToolCallInformation
:
self
.
parser
.
reset_streaming_state
()
# Reset tool call tracking arrays for new extraction
self
.
prev_tool_call_arr
=
[]
self
.
streamed_args_for_tool
=
[]
if
request
:
self
.
parser
.
set_tools
(
request
.
tools
)
result
=
self
.
parser
.
parse_single_streaming_chunks
(
model_output
)
if
not
result
.
tool_calls
:
return
ExtractedToolCallInformation
(
tool_calls
=
[],
tools_called
=
False
,
content
=
result
.
content
,
)
else
:
tool_calls
=
[]
for
tool_call
in
result
.
tool_calls
:
if
tool_call
.
function
and
tool_call
.
function
.
name
:
tool_calls
.
append
(
ToolCall
(
id
=
tool_call
.
id
,
type
=
tool_call
.
type
,
function
=
FunctionCall
(
name
=
tool_call
.
function
.
name
,
arguments
=
tool_call
.
function
.
arguments
,
),
)
)
# Update tool call tracking arrays for compatibility
tool_index
=
(
tool_call
.
index
if
tool_call
.
index
is
not
None
else
len
(
self
.
prev_tool_call_arr
)
-
1
)
# Ensure we have enough entries in our tracking arrays
while
len
(
self
.
prev_tool_call_arr
)
<=
tool_index
:
self
.
prev_tool_call_arr
.
append
({
"name"
:
""
,
"arguments"
:
""
})
while
len
(
self
.
streamed_args_for_tool
)
<=
tool_index
:
self
.
streamed_args_for_tool
.
append
(
""
)
# Update tool call information
self
.
prev_tool_call_arr
[
tool_index
][
"name"
]
=
(
tool_call
.
function
.
name
)
self
.
prev_tool_call_arr
[
tool_index
][
"arguments"
]
=
(
tool_call
.
function
.
arguments
)
# Update streamed arguments
if
tool_call
.
function
.
arguments
:
self
.
streamed_args_for_tool
[
tool_index
]
=
(
tool_call
.
function
.
arguments
)
return
ExtractedToolCallInformation
(
tool_calls
=
tool_calls
,
tools_called
=
len
(
tool_calls
)
>
0
,
content
=
result
.
content
,
)
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
DeltaMessage
|
None
:
if
not
previous_text
:
self
.
parser
.
reset_streaming_state
()
# Reset tool call tracking arrays for new streaming session
self
.
prev_tool_call_arr
=
[]
self
.
streamed_args_for_tool
=
[]
if
request
:
self
.
parser
.
set_tools
(
request
.
tools
)
# Model sometimes outputs separately causing delta_text to be empty.
# If there were tool_calls before and all current tool_calls have ended,
# return an empty tool_call for outer streaming output
# to correctly output tool_call field
if
not
delta_text
and
delta_token_ids
:
open_calls
=
current_text
.
count
(
self
.
parser
.
tool_call_start_token
)
-
current_text
.
count
(
self
.
parser
.
tool_call_end_token
)
if
(
open_calls
==
0
and
self
.
parser
.
tool_call_index
>
0
or
not
self
.
parser
.
tool_call_index
and
current_text
):
return
DeltaMessage
(
content
=
""
)
return
None
# Parse the delta text and get the result
result
=
self
.
parser
.
parse_single_streaming_chunks
(
delta_text
)
# Update tool call tracking arrays based on incremental parsing results
if
result
and
result
.
tool_calls
:
for
tool_call
in
result
.
tool_calls
:
if
tool_call
.
function
:
tool_index
=
(
tool_call
.
index
if
tool_call
.
index
is
not
None
else
len
(
self
.
prev_tool_call_arr
)
-
1
)
# Ensure we have enough entries in our tracking arrays
while
len
(
self
.
prev_tool_call_arr
)
<=
tool_index
:
self
.
prev_tool_call_arr
.
append
({
"name"
:
""
,
"arguments"
:
""
})
while
len
(
self
.
streamed_args_for_tool
)
<=
tool_index
:
self
.
streamed_args_for_tool
.
append
(
""
)
# Update tool name if provided
if
tool_call
.
function
.
name
:
self
.
prev_tool_call_arr
[
tool_index
][
"name"
]
=
(
tool_call
.
function
.
name
)
# Update arguments incrementally
if
tool_call
.
function
.
arguments
is
not
None
:
# Concatenate the incremental arguments
# to the existing streamed arguments
self
.
prev_tool_call_arr
[
tool_index
][
"arguments"
]
+=
(
tool_call
.
function
.
arguments
)
self
.
streamed_args_for_tool
[
tool_index
]
+=
(
tool_call
.
function
.
arguments
)
return
result
def
parser_should_check_for_unstreamed_tool_arg_tokens
(
self
)
->
bool
:
"""
Skip the remaining_call calculation in serving_chat
"""
return
False
vllm/transformers_utils/config.py
View file @
fc7980db
...
...
@@ -96,6 +96,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
ultravox
=
"UltravoxConfig"
,
step3_vl
=
"Step3VLConfig"
,
step3_text
=
"Step3TextConfig"
,
step3p5
=
"Step3p5Config"
,
qwen3_asr
=
"Qwen3ASRConfig"
,
qwen3_next
=
"Qwen3NextConfig"
,
lfm2_moe
=
"Lfm2MoeConfig"
,
tarsier2
=
"Tarsier2Config"
,
...
...
vllm/transformers_utils/configs/__init__.py
View file @
fc7980db
...
...
@@ -50,6 +50,8 @@ _CLASS_TO_MODULE: dict[str, str] = {
"Step3VLConfig"
:
"vllm.transformers_utils.configs.step3_vl"
,
"Step3VisionEncoderConfig"
:
"vllm.transformers_utils.configs.step3_vl"
,
"Step3TextConfig"
:
"vllm.transformers_utils.configs.step3_vl"
,
"Step3p5Config"
:
"vllm.transformers_utils.configs.step3p5"
,
"Qwen3ASRConfig"
:
"vllm.transformers_utils.configs.qwen3_asr"
,
"Qwen3NextConfig"
:
"vllm.transformers_utils.configs.qwen3_next"
,
"Tarsier2Config"
:
"vllm.transformers_utils.configs.tarsier2"
,
# Special case: DeepseekV3Config is from HuggingFace Transformers
...
...
@@ -90,6 +92,8 @@ __all__ = [
"Step3VLConfig"
,
"Step3VisionEncoderConfig"
,
"Step3TextConfig"
,
"Step3p5Config"
,
"Qwen3ASRConfig"
,
"Qwen3NextConfig"
,
"Tarsier2Config"
,
]
...
...
vllm/transformers_utils/configs/step3p5.py
0 → 100644
View file @
fc7980db
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
from
transformers.configuration_utils
import
PretrainedConfig
class
Step3p5Config
(
PretrainedConfig
):
model_type
=
"step3p5"
def
__init__
(
self
,
hidden_size
:
int
=
5120
,
intermediate_size
:
int
=
13312
,
num_attention_heads
:
int
=
40
,
num_attention_groups
:
int
=
8
,
num_hidden_layers
:
int
=
48
,
max_seq_len
:
int
=
4096
,
vocab_size
:
int
=
65536
,
rms_norm_eps
:
float
=
1e-5
,
moe_every_n_layer
:
int
=
2
,
use_moe
:
bool
=
False
,
moe_intermediate_size
:
int
=
10240
,
moe_num_experts
:
int
=
16
,
moe_top_k
:
int
=
4
,
moe_layer_offset
:
int
=
0
,
rope_theta
:
float
|
list
[
float
]
|
None
=
500000
,
rope_scaling
:
dict
[
str
,
Any
]
|
None
=
None
,
head_dim
:
int
|
None
=
None
,
share_expert_dim
:
int
|
None
=
None
,
norm_expert_weight
:
bool
=
True
,
bos_token_id
:
list
[
int
]
|
int
|
None
=
None
,
eos_token_id
:
list
[
int
]
|
int
|
None
=
None
,
moe_router_activation
:
str
=
"softmax"
,
moe_router_scaling_factor
:
float
=
1.0
,
att_impl_type
:
str
=
"GQA"
,
use_head_wise_attn_gate
:
bool
=
False
,
use_moe_router_bias
:
bool
=
True
,
need_fp32_gate
:
bool
=
True
,
layer_types
:
list
[
str
]
|
None
=
None
,
use_rope_layers
:
list
[
bool
]
|
None
=
None
,
yarn_only_types
:
list
[
str
]
|
None
=
None
,
attention_other_setting
:
dict
[
str
,
Any
]
|
None
=
None
,
num_nextn_predict_layers
:
int
=
0
,
swiglu_limits
:
list
[
float
]
|
None
=
None
,
swiglu_limits_shared
:
list
[
float
]
|
None
=
None
,
max_position_embeddings
:
int
|
None
=
None
,
**
kwargs
,
):
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_groups
=
num_attention_groups
self
.
num_hidden_layers
=
num_hidden_layers
self
.
max_seq_len
=
max_seq_len
self
.
vocab_size
=
vocab_size
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_moe
=
use_moe
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
moe_every_n_layer
=
moe_every_n_layer
self
.
moe_num_experts
=
moe_num_experts
self
.
num_experts_per_tok
=
moe_top_k
self
.
moe_top_k
=
moe_top_k
self
.
moe_layer_offset
=
moe_layer_offset
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
head_dim
=
head_dim
if
share_expert_dim
is
None
:
self
.
share_expert_dim
=
self
.
moe_intermediate_size
*
self
.
moe_top_k
else
:
self
.
share_expert_dim
=
share_expert_dim
self
.
norm_expert_weight
=
norm_expert_weight
self
.
max_position_embeddings
=
max_position_embeddings
self
.
moe_router_activation
=
moe_router_activation
self
.
moe_router_scaling_factor
=
moe_router_scaling_factor
self
.
use_moe_router_bias
=
use_moe_router_bias
self
.
need_fp32_gate
=
need_fp32_gate
self
.
att_impl_type
=
att_impl_type
self
.
use_head_wise_attn_gate
=
use_head_wise_attn_gate
self
.
layer_types
=
layer_types
self
.
use_rope_layers
=
use_rope_layers
self
.
yarn_only_types
=
yarn_only_types
self
.
attention_other_setting
=
attention_other_setting
self
.
num_nextn_predict_layers
=
num_nextn_predict_layers
self
.
swiglu_limits
=
swiglu_limits
self
.
swiglu_limits_shared
=
swiglu_limits_shared
resolved_bos_token_id
=
1
if
bos_token_id
is
None
else
bos_token_id
resolved_eos_token_id
=
[
2
,
3
]
if
eos_token_id
is
None
else
eos_token_id
self
.
bos_token_id
=
resolved_bos_token_id
self
.
eos_token_id
=
resolved_eos_token_id
super
().
__init__
(
bos_token_id
=
resolved_bos_token_id
,
eos_token_id
=
resolved_eos_token_id
,
**
kwargs
,
)
vllm/v1/attention/backends/flash_attn.py
View file @
fc7980db
...
...
@@ -263,18 +263,6 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
vllm_config
:
"VllmConfig"
,
kv_cache_spec
:
"AttentionSpec"
,
)
->
AttentionCGSupport
:
# FA2 does not support CUDA graphs with encoder-decoder models due to
# accuracy issues reported in https://github.com/vllm-project/vllm/issues/33091
if
(
vllm_config
.
model_config
.
is_encoder_decoder
and
get_flash_attn_version
()
==
2
):
logger
.
warning_once
(
"FlashAttention2 does not support CUDA graphs with "
"encoder-decoder models due to accuracy issues reported in #33091. "
"Disabling CUDA graph."
)
return
AttentionCGSupport
.
NEVER
return
cls
.
_cudagraph_support
def
__init__
(
...
...
vllm/v1/core/kv_cache_coordinator.py
View file @
fc7980db
...
...
@@ -479,6 +479,16 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
hit_length
=
max_cache_hit_length
hit_blocks_by_group
:
list
[
list
[
KVCacheBlock
]
|
None
]
=
[
None
]
*
num_groups
# Simple hybrid (1 full attn + 1 other): one iteration suffices.
# Full attn is always first if it exists. This avoids EAGLE drops
# being applied multiple times to non-full-attn groups.
# FIXME (yifan): However, for complex hybrid models with multiple attn
# groups, we still have the EAGLE spiral block dropping problem. See
# discussion in issue https://github.com/vllm-project/vllm/issues/32802.
is_simple_hybrid
=
len
(
self
.
attention_groups
)
==
2
and
isinstance
(
self
.
attention_groups
[
0
][
0
],
FullAttentionSpec
)
while
True
:
curr_hit_length
=
hit_length
...
...
@@ -495,10 +505,6 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
# the last iteration.
num_blocks
=
curr_hit_length
//
spec
.
block_size
curr_hit_length
=
num_blocks
*
spec
.
block_size
for
group_id
in
group_ids
:
blocks
=
hit_blocks_by_group
[
group_id
]
assert
blocks
is
not
None
del
blocks
[
num_blocks
:]
else
:
hit_blocks
=
manager_cls
.
find_longest_cache_hit
(
block_hashes
=
_get_block_hashes
(
spec
),
...
...
@@ -513,10 +519,20 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
for
group_id
,
blocks
in
zip
(
group_ids
,
hit_blocks
):
hit_blocks_by_group
[
group_id
]
=
blocks
if
curr_hit_length
<
hit_length
:
hit_length
=
curr_hit_length
else
:
if
curr_hit_length
>=
hit_length
:
break
hit_length
=
curr_hit_length
# Simple hybrid: exit after one iteration
if
is_simple_hybrid
:
break
# Truncate full attention blocks to final hit_length (if present)
spec
,
group_ids
,
_
=
self
.
attention_groups
[
0
]
if
isinstance
(
spec
,
FullAttentionSpec
):
num_blocks
=
hit_length
//
spec
.
block_size
for
group_id
in
group_ids
:
if
(
blks
:
=
hit_blocks_by_group
[
group_id
])
is
not
None
:
del
blks
[
num_blocks
:]
return
tuple
(
blocks
if
blocks
is
not
None
else
[]
for
blocks
in
hit_blocks_by_group
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment