Commit 45a060d6 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.15.1' into v0.15.1-dev

parents 99fc9fc3 1892993b
...@@ -370,6 +370,11 @@ def apply_moe_activation( ...@@ -370,6 +370,11 @@ def apply_moe_activation(
torch.ops._C.gelu_and_mul(output, input) torch.ops._C.gelu_and_mul(output, input)
elif activation == "swigluoai": elif activation == "swigluoai":
torch.ops._C.swigluoai_and_mul(output, input) torch.ops._C.swigluoai_and_mul(output, input)
elif activation == "swiglustep":
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
swiglustep_and_mul_triton(output, input)
# Activations without gated multiplication # Activations without gated multiplication
elif activation == SILU_NO_MUL: elif activation == SILU_NO_MUL:
output.copy_(F.silu(input)) output.copy_(F.silu(input))
......
...@@ -27,7 +27,6 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -27,7 +27,6 @@ from vllm.model_executor.layers.fused_moe import (
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
RoutingMethodType,
int4_w4a16_moe_quant_config, int4_w4a16_moe_quant_config,
int4_w4afp8_moe_quant_config, int4_w4afp8_moe_quant_config,
int8_w8a8_moe_quant_config, int8_w8a8_moe_quant_config,
...@@ -1043,17 +1042,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1043,17 +1042,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
if self.block_quant: if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
e_score_correction_bias = (
layer.e_score_correction_bias.to(x.dtype)
if layer.e_score_correction_bias is not None
else None
)
routing_method_type = layer.routing_method_type
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32) routing_logits=router_logits,
if routing_method_type == RoutingMethodType.DeepSeekV3 routing_bias=layer.e_score_correction_bias,
else router_logits,
routing_bias=e_score_correction_bias,
x=x, x=x,
w13_weight=layer.w13_weight, w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale, w13_weight_scale_inv=layer.w13_weight_scale,
...@@ -1067,7 +1058,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1067,7 +1058,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
expert_offset=layer.ep_rank * layer.local_num_experts, expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts, local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size, block_shape=self.weight_block_size,
routing_method_type=routing_method_type, routing_method_type=layer.routing_method_type,
routed_scaling=layer.routed_scaling_factor, routed_scaling=layer.routed_scaling_factor,
) )
else: else:
......
...@@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -26,7 +26,6 @@ from vllm.model_executor.layers.fused_moe import (
) )
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
RoutingMethodType,
) )
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
...@@ -964,17 +963,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -964,17 +963,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self.block_quant: if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
e_score_correction_bias = (
layer.e_score_correction_bias.to(x.dtype)
if layer.e_score_correction_bias is not None
else None
)
routing_method_type = layer.routing_method_type
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32) routing_logits=router_logits,
if routing_method_type == RoutingMethodType.DeepSeekV3 routing_bias=layer.e_score_correction_bias,
else router_logits,
routing_bias=e_score_correction_bias,
x=x, x=x,
w13_weight=layer.w13_weight, w13_weight=layer.w13_weight,
w13_weight_scale_inv=layer.w13_weight_scale_inv, w13_weight_scale_inv=layer.w13_weight_scale_inv,
...@@ -988,7 +979,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -988,7 +979,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_offset=layer.ep_rank * layer.local_num_experts, expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts, local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size, block_shape=self.weight_block_size,
routing_method_type=routing_method_type, routing_method_type=layer.routing_method_type,
routed_scaling=layer.routed_scaling_factor, routed_scaling=layer.routed_scaling_factor,
) )
else: else:
......
...@@ -28,6 +28,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl( ...@@ -28,6 +28,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl(
A.shape[0] == 1 A.shape[0] == 1
and B.shape[1] % 16 == 0 and B.shape[1] % 16 == 0
and ((bias is None) or (bias.dtype == out_dtype)) and ((bias is None) or (bias.dtype == out_dtype))
and A.is_contiguous()
): ):
output = ops.wvSplitKQ( output = ops.wvSplitKQ(
B.t(), B.t(),
......
...@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING ...@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING
import torch import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -22,10 +21,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -22,10 +21,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
swizzle_blockscale, swizzle_blockscale,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
has_flashinfer_cutedsl_grouped_gemm_nt_masked,
has_flashinfer_cutlass_fused_moe,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
...@@ -36,8 +31,6 @@ logger = init_logger(__name__) ...@@ -36,8 +31,6 @@ logger = init_logger(__name__)
__all__ = [ __all__ = [
"is_flashinfer_fp4_cutlass_moe_available",
"is_flashinfer_fp4_cutedsl_moe_available",
"reorder_w1w3_to_w3w1", "reorder_w1w3_to_w3w1",
] ]
...@@ -122,26 +115,6 @@ def is_supported_config_trtllm( ...@@ -122,26 +115,6 @@ def is_supported_config_trtllm(
return True, None return True, None
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
"""Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
return (
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and current_platform.is_cuda()
and current_platform.has_device_capability(100)
)
def is_flashinfer_fp4_cutedsl_moe_available() -> bool:
"""Return ``True`` when FlashInfer CUTEDSL NV-FP4 kernels can be used."""
return (
envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutedsl_grouped_gemm_nt_masked()
and current_platform.is_cuda()
and current_platform.is_device_capability_family(100)
)
def reorder_w1w3_to_w3w1( def reorder_w1w3_to_w3w1(
weight: torch.Tensor, scale: torch.Tensor, dim: int = -2 weight: torch.Tensor, scale: torch.Tensor, dim: int = -2
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
is_flashinfer_fp4_cutedsl_moe_available,
is_flashinfer_fp4_cutlass_moe_available,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
is_fp4_marlin_supported,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported,
)
__all__ = ["detect_nvfp4_moe_support", "NvFp4Support"]
_logger = init_logger(__name__)
@dataclass(frozen=True)
class NvFp4Support:
"""Result container for NV-FP4 capability probing."""
cutlass_supported: bool
allow_flashinfer: bool
use_marlin: bool
def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support:
"""Detect platform support for NV-FP4 fused-MoE path"""
cutlass_supported = cutlass_fp4_supported()
allow_flashinfer = cutlass_supported and (
is_flashinfer_fp4_cutlass_moe_available()
or is_flashinfer_fp4_cutedsl_moe_available()
)
if allow_flashinfer:
_logger.info_once(
"Using FlashInfer kernels for %s.", class_name or "NVFP4 path"
)
else:
if envs.VLLM_USE_FLASHINFER_MOE_FP4:
_logger.warning_once(
"FlashInfer kernels unavailable for %s on current platform.",
class_name or "NVFP4 path",
)
use_marlin = False
if not cutlass_supported:
if is_fp4_marlin_supported():
use_marlin = True
_logger.info_once("Falling back to Marlin FP4 MoE kernel.")
else:
raise ValueError(
"Current platform does not support NVFP4 quantization. "
"Please use Blackwell GPUs or enable FlashInfer."
)
return NvFp4Support(
cutlass_supported=cutlass_supported,
allow_flashinfer=allow_flashinfer,
use_marlin=use_marlin,
)
...@@ -146,6 +146,7 @@ def rocm_unquantized_gemm_impl( ...@@ -146,6 +146,7 @@ def rocm_unquantized_gemm_impl(
and n <= 128 and n <= 128
and k > 512 and k > 512
and math.ceil(k / 512) * math.ceil(m / 16) < get_cu_count() and math.ceil(k / 512) * math.ceil(m / 16) < get_cu_count()
and x.is_contiguous()
) )
# k == 2880 and (m == 640 or m == 128)) # k == 2880 and (m == 640 or m == 128))
) )
...@@ -165,6 +166,7 @@ def rocm_unquantized_gemm_impl( ...@@ -165,6 +166,7 @@ def rocm_unquantized_gemm_impl(
and on_gfx9() and on_gfx9()
and x.dtype in [torch.float16, torch.bfloat16] and x.dtype in [torch.float16, torch.bfloat16]
and k % 8 == 0 and k % 8 == 0
and x.is_contiguous()
) )
if use_skinny is not True: if use_skinny is not True:
......
...@@ -466,6 +466,7 @@ def load_weights_using_from_2_way_softmax( ...@@ -466,6 +466,7 @@ def load_weights_using_from_2_way_softmax(
language_model = _get_language_model_for_seq_cls(model) language_model = _get_language_model_for_seq_cls(model)
is_vlm = language_model is not model is_vlm = language_model is not model
using_vlm_head = is_vlm and hasattr(language_model, "score")
language_model.lm_head = ParallelLMHead( language_model.lm_head = ParallelLMHead(
text_config.vocab_size, text_config.hidden_size, quant_config=quant_config text_config.vocab_size, text_config.hidden_size, quant_config=quant_config
...@@ -506,14 +507,16 @@ def load_weights_using_from_2_way_softmax( ...@@ -506,14 +507,16 @@ def load_weights_using_from_2_way_softmax(
torch.float32 torch.float32
) - lm_head_weight.data[[false_id]].to(torch.float32) ) - lm_head_weight.data[[false_id]].to(torch.float32)
score_layer = language_model.score if is_vlm else model.score score_layer = language_model.score if using_vlm_head else model.score
param = score_layer.weight param = score_layer.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, score_weight) weight_loader(param, score_weight)
del language_model.lm_head del language_model.lm_head
score_weight_name = "language_model.score.weight" if is_vlm else "score.weight" score_weight_name = (
"language_model.score.weight" if using_vlm_head else "score.weight"
)
loaded_weights.add(score_weight_name) loaded_weights.add(score_weight_name)
lm_head_name = "lm_head.weight" lm_head_name = "lm_head.weight"
...@@ -537,6 +540,7 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te ...@@ -537,6 +540,7 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te
language_model = _get_language_model_for_seq_cls(model) language_model = _get_language_model_for_seq_cls(model)
is_vlm = language_model is not model is_vlm = language_model is not model
using_vlm_head = is_vlm and hasattr(language_model, "score")
language_model.lm_head = ParallelLMHead( language_model.lm_head = ParallelLMHead(
text_config.vocab_size, text_config.hidden_size, quant_config=quant_config text_config.vocab_size, text_config.hidden_size, quant_config=quant_config
...@@ -572,14 +576,16 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te ...@@ -572,14 +576,16 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te
token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
score_weight = language_model.lm_head.weight.data[token_ids] score_weight = language_model.lm_head.weight.data[token_ids]
score_layer = language_model.score if is_vlm else model.score score_layer = language_model.score if using_vlm_head else model.score
param = score_layer.weight param = score_layer.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, score_weight) weight_loader(param, score_weight)
del language_model.lm_head del language_model.lm_head
score_weight_name = "language_model.score.weight" if is_vlm else "score.weight" score_weight_name = (
"language_model.score.weight" if using_vlm_head else "score.weight"
)
loaded_weights.add(score_weight_name) loaded_weights.add(score_weight_name)
lm_head_name = "lm_head.weight" lm_head_name = "lm_head.weight"
......
...@@ -107,6 +107,7 @@ class MiniMaxM2MoE(nn.Module): ...@@ -107,6 +107,7 @@ class MiniMaxM2MoE(nn.Module):
renormalize=True, renormalize=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
router_logits_dtype=torch.float32,
) )
self.gate = ReplicatedLinear( self.gate = ReplicatedLinear(
......
...@@ -11,7 +11,6 @@ import math ...@@ -11,7 +11,6 @@ import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal from typing import Annotated, Literal
import cv2
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -416,6 +415,8 @@ class NemotronParseImageProcessor: ...@@ -416,6 +415,8 @@ class NemotronParseImageProcessor:
else: else:
self.target_height = self.target_width = int(self.final_size) self.target_height = self.target_width = int(self.final_size)
import cv2
self.transform = A.Compose( self.transform = A.Compose(
[ [
A.PadIfNeeded( A.PadIfNeeded(
...@@ -457,6 +458,8 @@ class NemotronParseImageProcessor: ...@@ -457,6 +458,8 @@ class NemotronParseImageProcessor:
new_height = int(new_width / aspect_ratio) new_height = int(new_width / aspect_ratio)
# Use cv2.INTER_LINEAR like the original # Use cv2.INTER_LINEAR like the original
import cv2
return cv2.resize( return cv2.resize(
image, (new_width, new_height), interpolation=cv2.INTER_LINEAR image, (new_width, new_height), interpolation=cv2.INTER_LINEAR
) )
......
...@@ -189,6 +189,7 @@ _TEXT_GENERATION_MODELS = { ...@@ -189,6 +189,7 @@ _TEXT_GENERATION_MODELS = {
"SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"), "SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
"Step1ForCausalLM": ("step1", "Step1ForCausalLM"), "Step1ForCausalLM": ("step1", "Step1ForCausalLM"),
"Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"), "Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
"Step3p5ForCausalLM": ("step3p5", "Step3p5ForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
...@@ -479,6 +480,7 @@ _SPECULATIVE_DECODING_MODELS = { ...@@ -479,6 +480,7 @@ _SPECULATIVE_DECODING_MODELS = {
"MedusaModel": ("medusa", "Medusa"), "MedusaModel": ("medusa", "Medusa"),
"OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"), "OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"), "Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
"Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"),
# Temporarily disabled. # Temporarily disabled.
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1. # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
# "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Jurassic model."""
from collections.abc import Iterable
from typing import Any
import torch
from torch import nn
from torch.nn.parameter import Parameter
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (
get_dp_group,
get_ep_group,
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul, SwigluStepAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backend import AttentionType
from .interfaces import MixtureOfExperts, SupportsPP
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
logger = init_logger(__name__)
class FP32ReplicatedLinear(ReplicatedLinear):
"""
Use FP32 for higher precision.
"""
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
assert self.params_dtype == torch.float32
return super().forward(x.to(torch.float32))
class Step3p5MLP(nn.Module):
def __init__(
self,
config: ModelConfig,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
self.prefix = prefix
self.hidden_size = hidden_size
self.limit = None
layer_idx = extract_layer_index(prefix)
if (
config.swiglu_limits_shared
and config.swiglu_limits_shared[layer_idx] is not None
and config.swiglu_limits_shared[layer_idx] != 0
):
self.limit = config.swiglu_limits_shared[layer_idx]
self.act_fn = SwigluStepAndMul(limit=self.limit)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(hidden_states)
intermediate_act = self.act_fn(gate_up)
output, _ = self.down_proj(intermediate_act)
return output
class Step3p5Attention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
head_dim: int | None = None,
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
rope_theta: float | list[float] | None = 10000,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
rope_scaling: dict[str, Any] | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
# Step3p5 specific args
sliding_window: int | None = None,
use_head_wise_attn_gate: bool = False,
layer_types: list = None,
use_rope_layers: list = None,
yarn_only_types: list = None,
swa_num_attention_heads: int | None = None,
partial_rotary_factor: float = 1.0,
):
super().__init__()
self.hidden_size = hidden_size
self.total_num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
self.layer_idx = extract_layer_index(prefix)
if layer_types:
enable_sliding_window = layer_types[self.layer_idx] == "sliding_attention"
else:
enable_sliding_window = self.layer_idx % 2 == 0
if yarn_only_types and layer_types[self.layer_idx] not in yarn_only_types:
rope_scaling = None
if sliding_window is not None and enable_sliding_window:
sliding_window = sliding_window
if swa_num_attention_heads is not None:
num_heads = swa_num_attention_heads
self.total_num_heads = swa_num_attention_heads
else:
sliding_window = None
if isinstance(rope_theta, list):
rope_theta = rope_theta[self.layer_idx]
self.rank = get_tensor_model_parallel_rank()
self.partial_rotary_factor = partial_rotary_factor
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
if rope_scaling is not None and not isinstance(rope_scaling, dict):
raise ValueError("rope_scaling must be a dict for Step3p5Attention.")
rope_parameters: dict[str, Any] = (
dict(rope_scaling) if rope_scaling is not None else {}
)
rope_parameters.setdefault("rope_type", "default")
rope_parameters["rope_theta"] = self.rope_theta
rope_parameters["partial_rotary_factor"] = partial_rotary_factor
self.rotary_emb = get_rope(
head_size=self.head_dim,
max_position=max_position,
rope_parameters=rope_parameters,
)
self.q_norm = GemmaRMSNorm(self.head_dim, rms_norm_eps)
self.k_norm = GemmaRMSNorm(self.head_dim, rms_norm_eps)
self.use_head_wise_attn_gate = use_head_wise_attn_gate
if use_head_wise_attn_gate:
self.g_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads,
bias=False,
prefix=f"{prefix}.g_proj",
)
self.use_rope = True
if use_rope_layers:
self.use_rope = use_rope_layers[self.layer_idx]
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
per_layer_sliding_window=sliding_window,
attn_type=attn_type,
)
self.max_position_embeddings = max_position
assert self.partial_rotary_factor == 1 or self.partial_rotary_factor == 0.5
self.rotary_dim = (
self.head_dim if self.partial_rotary_factor == 1 else self.head_dim // 2
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Add qk-norm inline similar to Qwen3 MOE attention
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
q_by_head = self.q_norm(q_by_head.contiguous())
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
k_by_head = self.k_norm(k_by_head.contiguous())
k = k_by_head.view(k.shape)
if self.use_rope:
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
if self.use_head_wise_attn_gate:
extra_dims, _ = self.g_proj(hidden_states)
output = (
attn_output.view(*attn_output.shape[:-1], self.num_heads, self.head_dim)
* extra_dims.unsqueeze(-1).sigmoid()
)
attn_output = output.view(*attn_output.shape)
output, _ = self.o_proj(attn_output)
return output
class FusedMoEBlock(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.layer_idx = extract_layer_index(prefix)
self.ep_size = get_ep_group().device_group.size()
self.ep_rank = get_ep_group().device_group.rank()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.hidden_size = config.hidden_size
self.enable_eplb = parallel_config.enable_eplb
self.n_routed_experts = config.moe_num_experts
self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = parallel_config.eplb_config.num_redundant_experts
self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
self.physical_expert_end = (
self.physical_expert_start + self.n_local_physical_experts
)
if self.tp_size > config.moe_num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.moe_num_experts}."
)
self.gate = FP32ReplicatedLinear(
config.hidden_size,
config.moe_num_experts,
bias=False,
quant_config=None,
params_dtype=torch.float32, # Use FP32 for higher precision.
prefix=f"{prefix}.gate",
)
self.use_moe_router_bias = config.use_moe_router_bias
assert self.use_moe_router_bias, "Only support use_moe_router_bias is true."
self.routed_scaling_factor = config.moe_router_scaling_factor
self.router_bias = nn.Parameter(
torch.zeros(config.moe_num_experts, dtype=torch.float32),
requires_grad=False,
)
self.need_fp32_gate = config.need_fp32_gate
assert self.need_fp32_gate, (
"Router logits must use FP32 precision for numerical stability."
)
activation = "silu"
swiglu_limits = config.swiglu_limits or []
swiglu_limit = (
swiglu_limits[self.layer_idx]
if self.layer_idx < len(swiglu_limits)
else None
)
if swiglu_limit not in (None, 0):
swiglu_limit = float(swiglu_limit)
assert swiglu_limit == 7.0, (
"Swiglu limit in fused moe block only suport 7.0 now."
)
activation = "swiglustep"
logger.debug(
"step3p5 layer_idx: %s, activation: %s, limit: %s",
self.layer_idx,
activation,
swiglu_limit,
)
self.share_expert = Step3p5MLP(
config=config,
hidden_size=self.hidden_size,
intermediate_size=config.share_expert_dim,
hidden_act="silu",
reduce_results=False,
quant_config=quant_config,
prefix=f"{prefix}.share_expert",
)
self.experts = SharedFusedMoE(
shared_experts=self.share_expert,
gate=self.gate,
num_experts=config.moe_num_experts,
top_k=config.moe_top_k,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_expert_weight,
quant_config=quant_config,
activation=activation,
prefix=f"{prefix}.experts",
scoring_func=getattr(config, "moe_router_activation", "sigmoid"),
e_score_correction_bias=self.router_bias,
routed_scaling_factor=config.moe_router_scaling_factor,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=hidden_states
)
else:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
shared_output, final_hidden_states = fused_moe_out
if self.share_expert is None:
assert shared_output is None
if self.share_expert is not None:
assert shared_output is not None
final_hidden_states += shared_output
if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states
)
return final_hidden_states.view(num_tokens, hidden_dim)
class Step3p5DecoderLayer(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
self.hidden_size = config.hidden_size
layer_idx = extract_layer_index(prefix)
self.layer_idx = layer_idx
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
if cache_config is not None:
cache_config.sliding_window = None
if config.att_impl_type == "GQA":
num_attention_heads = None
num_attention_groups = None
head_dim = None
if (
getattr(config, "attention_other_setting", None)
and getattr(config, "layer_types", [])
and config.layer_types[layer_idx]
== config.attention_other_setting["attention_type"]
):
num_attention_heads = config.attention_other_setting[
"num_attention_heads"
]
num_attention_groups = config.attention_other_setting[
"num_attention_groups"
]
head_dim = config.attention_other_setting["head_dim"]
partial_rotary_factors = getattr(config, "partial_rotary_factors", [])
self.self_attn = Step3p5Attention(
hidden_size=self.hidden_size,
num_heads=num_attention_heads
if num_attention_heads
else config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=num_attention_groups
if num_attention_groups
else config.num_attention_groups,
rope_theta=config.rope_theta,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, "attention_bias", False),
head_dim=head_dim if head_dim else getattr(config, "head_dim", None),
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=getattr(config, "rope_scaling", None),
sliding_window=getattr(config, "sliding_window", None),
use_head_wise_attn_gate=getattr(
config, "use_head_wise_attn_gate", False
),
layer_types=getattr(config, "layer_types", []),
use_rope_layers=getattr(config, "use_rope_layers", []),
yarn_only_types=getattr(config, "yarn_only_types", []),
partial_rotary_factor=partial_rotary_factors[layer_idx]
if partial_rotary_factors
else 1.0,
prefix=f"{prefix}.self_attn",
)
else:
raise ValueError(
f"Unsupported attention implementation: {config.att_impl_type}"
)
self.use_moe = False
self.tp_group = get_tp_group()
self.use_fused_all_reduce = (
get_tensor_model_parallel_world_size() > 1
and get_dp_group().world_size == 1
)
if self.use_fused_all_reduce:
logger.warning_once("Enable custom fused all reduce...")
else:
logger.warning_once("Disable custom fused all reduce...")
moe_layers_enum = getattr(config, "moe_layers_enum", None)
if moe_layers_enum is not None:
moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")]
else:
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
if layer_idx in moe_layers_idx:
self.moe = FusedMoEBlock(
vllm_config,
prefix=f"{prefix}.moe",
)
self.use_moe = True
else:
self.mlp = Step3p5MLP(
config=config,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act="silu",
quant_config=quant_config,
reduce_results=True,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(
config.hidden_size, config.rms_norm_eps
)
self.prefix = prefix
def add_and_maybe_inplace_all_reduce(
self, in1: torch.Tensor, in2: torch.Tensor
) -> torch.Tensor:
if not self.use_fused_all_reduce:
return in1 + in2
return self.tp_group.all_reduce(in1 + in2)
def forward(
self, positions: torch.Tensor, hidden_states: torch.Tensor
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states += residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
if self.use_moe:
ffn_output = self.moe(hidden_states)
else:
ffn_output = self.mlp(hidden_states)
hidden_states = ffn_output + residual
return hidden_states
@support_torch_compile
class Step3p5Model(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
self.vllm_config = vllm_config
config = vllm_config.model_config.hf_config
self.vocab_size = config.vocab_size
self.config = config
self.moe_num_experts = config.moe_num_experts
if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank
):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Step3p5DecoderLayer(
vllm_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
self.norm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states"], config.hidden_size
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{
"hidden_states": hidden_states,
}
)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
config = self.config
assert config.num_attention_groups > 1, "Only support GQA"
qkv_params_mapping = []
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = [
(".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"),
(".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"),
(".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"),
]
disable_moe_stacked_params = [data[1] for data in expert_params_mapping]
for name, loaded_weight in weights:
if name.startswith("model."):
local_name = name[len("model.") :]
full_name = name
else:
local_name = name
full_name = f"model.{name}" if name else "model"
spec_layer = get_spec_layer_idx_from_weight_name(config, full_name)
if spec_layer is not None:
continue # skip spec decode layers for main model
# Skip any layers beyond the main model's depth (e.g., MTP layers)
if full_name.startswith("model.layers."):
parts = full_name.split(".")
if len(parts) > 2 and parts[2].isdigit():
layer_idx = int(parts[2])
if layer_idx >= config.num_hidden_layers:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in local_name:
continue
if any(
disable_moe_stacked_param in local_name
for disable_moe_stacked_param in disable_moe_stacked_params
):
continue
replaced_name = local_name.replace(weight_name, param_name)
if is_pp_missing_parameter(replaced_name, self):
continue
if replaced_name not in params_dict:
continue
param = params_dict[replaced_name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(replaced_name)
break
else:
for param_name, weight_name, shard_id in expert_params_mapping:
if weight_name not in local_name:
continue
replaced_name = local_name.replace(weight_name, param_name)
if is_pp_missing_parameter(replaced_name, self):
continue
if (
replaced_name.endswith(".bias")
or replaced_name.endswith("_bias")
) and replaced_name not in params_dict:
continue
if replaced_name not in params_dict:
continue
param = params_dict[replaced_name]
weight_loader = param.weight_loader
moe_expert_num = self.moe_num_experts
assert loaded_weight.shape[0] == moe_expert_num
for expert_id in range(moe_expert_num):
loaded_weight_expert = loaded_weight[expert_id]
weight_loader(
param,
loaded_weight_expert,
replaced_name,
shard_id=shard_id,
expert_id=expert_id,
)
loaded_params.add(replaced_name)
break
else:
for (
param_name,
weight_name,
start_idx,
end_idx,
) in qkv_params_mapping:
if weight_name not in local_name:
continue
replaced_name = local_name.replace(weight_name, param_name)
if is_pp_missing_parameter(replaced_name, self):
continue
if replaced_name not in params_dict:
continue
param = params_dict[replaced_name]
dim = param.shape[param.output_dim]
begin_idx = int(start_idx * dim)
end_idx = int(end_idx * dim)
param_slice = param.narrow(
param.output_dim, begin_idx, end_idx - begin_idx
)
param_slice.copy_(loaded_weight)
loaded_params.add(replaced_name)
break
else:
if is_pp_missing_parameter(local_name, self):
continue
if "expert_bias" in local_name:
logger.warning_once("ignore expert_bias")
continue
if local_name not in params_dict:
continue
param = params_dict[local_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(local_name)
return loaded_params
class Step3p5ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={".share_expert.": ".moe.share_expert."}
)
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_config
lora_config = vllm_config.lora_config
self.config = config
self.vllm_config = vllm_config
self.model = Step3p5Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.moe_layers: list[FusedMoEBlock] = []
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, Step3p5DecoderLayer)
if hasattr(layer, "moe") and isinstance(layer.moe, FusedMoEBlock):
self.moe_layers.append(layer.moe)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
if not lora_config
else lora_config.lora_vocab_padding_size,
)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size
)
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
# Set MoE hyperparameters
self.expert_weights = []
assert len(self.moe_layers) > 0, "No MoE layers found in the model."
example_layer = self.moe_layers[0]
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0
self.num_logical_experts = example_layer.n_logical_experts
self.num_physical_experts = example_layer.n_physical_experts
self.num_local_physical_experts = example_layer.n_local_physical_experts
self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
):
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.model.norm(hidden_states)
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_tokens(input_ids)
def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
experts = layer.experts
assert isinstance(experts, FusedMoE)
# Register the expert weights.
self.expert_weights.append(experts.get_expert_weights())
experts.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for layer in self.moe_layers:
assert isinstance(layer, FusedMoEBlock)
layer.n_local_physical_experts = num_local_physical_experts
layer.n_physical_experts = num_physical_experts
layer.n_redundant_experts = self.num_redundant_experts
layer.experts.update_expert_map()
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_spec_layer_idx_from_weight_name(
config: ModelConfig, weight_name: str
) -> int | None:
if hasattr(config, "num_nextn_predict_layers") and (
config.num_nextn_predict_layers > 0
):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(
f"layers.{layer_idx + i}." # Step3p5Model
) or weight_name.startswith(f"model.layers.{layer_idx + i}."): # Step3p5MTP
return layer_idx + i
return None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from .step3p5 import Step3p5DecoderLayer, get_spec_layer_idx_from_weight_name
from .utils import maybe_prefix
logger = init_logger(__name__)
class SharedHead(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
) -> None:
super().__init__()
self.norm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(hidden_states)
class Step3p5AMultiTokenPredictorLayer(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str,
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = Step3p5DecoderLayer(
vllm_config,
prefix=f"{prefix}.mtp_block",
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
)
hidden_states = self.mtp_block(positions=positions, hidden_states=hidden_states)
return hidden_states
class Step3p5AMultiTokenPredictor(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers
# to map the exact layer index from weights
self.layers = torch.nn.ModuleDict(
{
str(idx): Step3p5AMultiTokenPredictorLayer(
vllm_config,
f"{prefix}.layers.{idx}",
)
for idx in range(
self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers,
)
}
)
self.logits_processor = LogitsProcessor(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
positions,
previous_hidden_states,
inputs_embeds,
current_step_idx,
)
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor:
current_step_idx = spec_step_idx % self.num_mtp_layers
mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
logits = self.logits_processor(
mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
)
return logits
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
class Step3p5MTP(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
self.model = Step3p5AMultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
hidden_states = self.model(
input_ids, positions, hidden_states, inputs_embeds, spec_step_idx
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor | None:
return self.model.compute_logits(hidden_states, spec_step_idx)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = [
(".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"),
(".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"),
(".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if "embed_tokens" not in name and spec_layer is None:
continue
name = self._rewrite_spec_layer_name(spec_layer, name)
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
if "experts" in name or "moe" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
for expert_id in range(loaded_weight.shape[0]):
loaded_weight_expert = loaded_weight[expert_id]
weight_loader(
param,
loaded_weight_expert,
name,
shard_id=shard_id,
expert_id=expert_id,
)
loaded_params.add(name)
break
else:
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias")
and name not in params_dict
or "tok_embeddings" in name
):
continue
if spec_layer is not None and ".transformer." in name:
name = name.replace(".transformer.", ".")
if "shared_head" in name:
name = name.replace("shared_head.output", "shared_head.head")
if "embed_tokens" in name:
assert (
hasattr(self.config, "num_nextn_predict_layers")
and self.config.num_nextn_predict_layers > 0
)
name = "model.embed_tokens.weight"
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
params_need_to_load = set(params_dict.keys())
# Some KV cache scales are optional: checkpoints may omit them and vLLM
# will fall back to default scales during initialization.
optional_params = {
name
for name, param in params_dict.items()
if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale"))
and getattr(param, "numel", lambda: 0)() == 1
and getattr(param, "requires_grad", False) is False
}
params_need_to_load -= optional_params
if params_need_to_load != loaded_params:
missing_params = list(params_need_to_load - loaded_params)
param_name_example = missing_params[0]
raise RuntimeError(
"Some parameters like "
f"{param_name_example} are not in the checkpoint and will falsely "
"use random initialization"
)
return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
"""
Rewrite the weight name to match the format of the original model.
Add .mtp_block for modules in transformer layer block for spec layer
"""
spec_layer_weight_names = [
"embed_tokens",
"enorm",
"hnorm",
"eh_proj",
"shared_head",
]
spec_layer_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
break
if not spec_layer_weight:
# treat rest weights as weights for transformer layer block
name = name.replace(
f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block."
)
return name
...@@ -85,6 +85,10 @@ _REASONING_PARSERS_TO_REGISTER = { ...@@ -85,6 +85,10 @@ _REASONING_PARSERS_TO_REGISTER = {
"step3_reasoning_parser", "step3_reasoning_parser",
"Step3ReasoningParser", "Step3ReasoningParser",
), ),
"step3p5": (
"step3p5_reasoning_parser",
"Step3p5ReasoningParser",
),
} }
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.tokenizers import TokenizerLike
class Step3p5ReasoningParser(BaseThinkingReasoningParser):
"""
Reasoning parser for Step3p5 model.
Step3p5 uses the <think>...</think> format, but it tends to emit an extra
newline immediately before and/or after the </think> token. This parser trims:
- the newline right before </think>
- the newline right after </think>
"""
@property
def start_token(self) -> str:
return "<think>"
@property
def end_token(self) -> str:
return "</think>"
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
# Used to hold a trailing "\n" from reasoning content so we can decide
# whether it is immediately before </think>.
self._pending_reasoning_newline = False
# Used to delay the reasoning end detection.
# This is necessary to remove the newline appears immediately after </think>,
# which may cause the end detection to be delayed by one round.
self.end_offset = 1
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
if self.end_token_id in input_ids and self.end_offset > 0:
self.end_offset -= 1
return False
return self.end_offset < 1
def is_reasoning_end_streaming(
self, input_ids: Sequence[int], delta_ids: Sequence[int]
) -> bool:
if self.end_token_id in input_ids and self.end_offset > 0:
self.end_offset -= 1
return False
return self.end_offset < 1
def extract_reasoning(
self,
model_output: str,
request: ChatCompletionRequest | ResponsesRequest,
) -> tuple[str | None, str | None]:
reasoning, content = super().extract_reasoning(model_output, request)
if reasoning is not None:
reasoning = reasoning.removesuffix("\n")
if content is not None:
content = content.removeprefix("\n")
return reasoning or None, content or None
def extract_reasoning_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
# Drop the immediate newline that models often emit after </think>.
if previous_text.endswith(self.end_token) and delta_text:
if delta_text == "\n":
return None
elif delta_text.startswith("\n"):
remaining = delta_text.removeprefix("\n")
return DeltaMessage(content=remaining) if remaining else None
ret = super().extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
)
if ret is None:
return None
# Compatibility path for models that don't generate the start token:
# treat everything before </think> as reasoning and everything after
# as content.
if (
self.start_token_id not in previous_token_ids
and self.start_token_id not in delta_token_ids
):
if self.end_token_id in delta_token_ids:
end_index = delta_text.find(self.end_token)
reasoning = delta_text[:end_index]
content = delta_text[end_index + len(self.end_token) :]
ret = DeltaMessage(reasoning=reasoning, content=content or None)
elif self.end_token_id in previous_token_ids:
ret = DeltaMessage(content=delta_text)
else:
ret = DeltaMessage(reasoning=delta_text)
reasoning_to_output = ret.reasoning
content_to_output = ret.content
# Reasoning: handle the newline immediately before </think>.
if reasoning_to_output is not None:
if self._pending_reasoning_newline:
reasoning_to_output = "\n" + reasoning_to_output
self._pending_reasoning_newline = False
if reasoning_to_output.endswith("\n"):
reasoning_to_output = reasoning_to_output.removesuffix("\n")
if self.end_token in delta_text:
# Trailing "\n" is right before </think>, drop it.
self._pending_reasoning_newline = False
else:
# Hold the trailing "\n" until we know whether </think> follows.
self._pending_reasoning_newline = True
# Content: handle the newline immediately after </think>.
if content_to_output is not None:
# No need to get into parser again to remove newline after </think>.
self.end_offset -= 1
# If we have content, reasoning must have ended.
self._pending_reasoning_newline = False
if self.end_token in delta_text and content_to_output.startswith("\n"):
content_to_output = content_to_output.removeprefix("\n")
reasoning_to_output = reasoning_to_output or None
content_to_output = content_to_output or None
if reasoning_to_output is None and content_to_output is None:
return None
return DeltaMessage(reasoning=reasoning_to_output, content=content_to_output)
...@@ -134,6 +134,10 @@ _TOOL_PARSERS_TO_REGISTER = { ...@@ -134,6 +134,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"step3_tool_parser", "step3_tool_parser",
"Step3ToolParser", "Step3ToolParser",
), ),
"step3p5": (
"step3p5_tool_parser",
"Step3p5ToolParser",
),
"xlam": ( "xlam": (
"xlam_tool_parser", "xlam_tool_parser",
"xLAMToolParser", "xLAMToolParser",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
from collections.abc import Sequence
from typing import Any
from xml.parsers.expat import ParserCreate
import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
ToolParserManager,
)
logger = init_logger(__name__)
class StreamingXMLToolCallParser:
"""
Simplified streaming XML tool call parser
Supports streaming input, parsing, and output
"""
def __init__(self):
self.reset_streaming_state()
# Tool configuration information
self.tools: list[ChatCompletionToolsParam] | None = None
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.function_start_token: str = "<function="
self.function_end_token: str = "</function>"
self.parameter_start_token: str = "<parameter="
self.parameter_end_token: str = "</parameter>"
def reset_streaming_state(self):
"""Reset streaming parsing state"""
self.deltas = []
# state for streaming
self.tool_call_index = 0
self.current_call_id = None
self.last_completed_call_id = None
self.current_function_name = None
self.current_function_open = False
self.parameters = {}
self.current_param_name = None
self.current_param_value = ""
self.current_param_value_converted = ""
self.current_param_is_first = False
self.should_emit_end_newline = False
self.start_quote_emitted = False
self.streaming_buffer = ""
self.last_processed_pos = 0
self.text_content_buffer = ""
# state for preprocessing and deferred parsing
self._pre_inside_parameter = False
self._pre_param_buffer = ""
self._pre_current_param_name = None
self.defer_current_parameter = False
self.deferred_param_raw_value = ""
# recreate parser
self.parser = ParserCreate()
self.setup_parser()
def parse_single_streaming_chunks(self, xml_chunk: str) -> DeltaMessage:
"""
Parse single streaming XML chunk and return Delta response
This is the actual streaming interface that receives chunks
one by one and maintains internal state
Args:
xml_chunk: Single XML chunk string
Returns:
DeltaMessage: Contains delta information generated by this chunk,
returns empty response if no complete elements
"""
# Record delta count before processing
initial_delta_count = len(self.deltas)
self.streaming_buffer += xml_chunk
found_elements = self._process_complete_xml_elements()
if found_elements:
# If complete elements found, check if end events were missed
# some tags may not have been triggered
try:
new_deltas = self.deltas[initial_delta_count:]
# If this chunk contains </function>
# but didn't generate '}', then complete it
if (
self.current_call_id is not None
and self.function_end_token in xml_chunk
):
# - Added '}' (non-empty parameter ending)
# - Added '{}' (empty parameter function)
has_function_close = any(
(
td.tool_calls
and any(
(
tc.function
and tc.id == self.current_call_id
and isinstance(tc.function.arguments, str)
and (tc.function.arguments in ("}", "{}"))
)
for tc in td.tool_calls
)
)
for td in new_deltas
)
if not has_function_close:
# Close potentially unclosed element
if self.current_param_name:
self._end_element("parameter")
if self.current_function_name:
self._end_element("function")
# If this chunk contains </tool_call>
# but didn't generate final empty delta, then complete it
if (
self.current_call_id is not None
and self.tool_call_end_token in xml_chunk
):
has_toolcall_close = any(
(
td.tool_calls
and any(
(
tc.type == "function"
and tc.function
and tc.function.arguments == ""
and tc.id == self.current_call_id
)
for tc in td.tool_calls
)
)
for td in new_deltas
)
if not has_toolcall_close:
# Close potentially unclosed element
if self.current_param_name:
self._end_element("parameter")
if self.current_function_name:
self._end_element("function")
self._end_element("tool_call")
except Exception as e:
logger.warning("Error with fallback parsing: %s", e)
# Merge newly generated deltas into single response
result_delta = self._merge_new_deltas_to_single_response(
initial_delta_count
)
return result_delta
else:
# No complete elements, check if there's unoutput text content
if self.text_content_buffer and self.tool_call_index == 0:
# Has text content but no tool_call yet, output text content
text_delta = DeltaMessage(content=self.text_content_buffer)
self._emit_delta(text_delta)
# Clear buffer to avoid duplicate output
self.text_content_buffer = ""
return text_delta
# If this chunk contains end tags but wasn't triggered by parser,
# manually complete end events
# Only execute when still on the same call as when entered,
# to prevent accidentally closing new calls
# in multi <tool_call> scenarios
if self.current_call_id is not None and (
self.function_end_token in xml_chunk
or self.tool_call_end_token in xml_chunk
):
# Close potentially unclosed element
if self.current_param_name:
self._end_element("parameter")
if self.function_end_token in xml_chunk and self.current_function_name:
self._end_element("function")
if self.tool_call_end_token in xml_chunk:
self._end_element("tool_call")
# Return the merged delta result generated by this fallback
result_delta = self._merge_new_deltas_to_single_response(
initial_delta_count
)
return result_delta
# No complete elements, return empty response
return DeltaMessage(content=None)
def _escape_xml_special_chars(self, text: str) -> str:
"""
Escape XML special characters
Args:
text: Original text
Returns:
Escaped text
"""
xml_escapes = {
"&": "&amp;",
"<": "&lt;",
">": "&gt;",
'"': "&quot;",
"'": "&apos;",
}
for char, escape in xml_escapes.items():
text = text.replace(char, escape)
return text
def _process_complete_xml_elements(self) -> bool:
"""
Process complete XML elements in buffer
Returns:
bool: Whether complete elements were found and processed
"""
found_any = False
while self.last_processed_pos < len(self.streaming_buffer):
# Find next complete xml element
element, end_pos = self._find_next_complete_element(self.last_processed_pos)
if element is None:
# No complete element found, wait for more data
break
# Check if this element should be skipped
if self._should_skip_element(element):
self.last_processed_pos = end_pos
continue
# Found complete XML element, process it
try:
preprocessed_element = self._preprocess_xml_chunk(element)
# Check if this is the first tool_call start
if (
(
preprocessed_element.strip().startswith("<tool_call>")
or preprocessed_element.strip().startswith("<function name=")
)
and self.tool_call_index == 0
) and self.text_content_buffer:
# First tool_call starts,
# output previously collected text content first
text_delta = DeltaMessage(content=self.text_content_buffer)
self._emit_delta(text_delta)
# Clear buffer for potential subsequent text content
self.text_content_buffer = ""
# If a new tool_call starts and
# there are already completed tool_calls with function name
if (
preprocessed_element.strip().startswith("<tool_call>")
and self.tool_call_index > 0
and self.current_call_id
and self.current_function_name
):
# Reset parser state but preserve generated deltas
if self.current_param_name:
self._end_element("parameter")
if self.current_function_open:
self._end_element("function")
# Output final tool_call tail delta
final_delta = DeltaMessage(
role=None,
content=None,
reasoning_content=None,
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments=""),
)
],
)
self._emit_delta(final_delta)
# Reset XML parser and current call state
self._reset_xml_parser_after_tool_call()
# Parse preprocessed element
self.parser.Parse(preprocessed_element, False)
found_any = True
except Exception as e:
logger.warning("Error when parsing XML elements: %s", e)
# Update processed position
self.last_processed_pos = end_pos
return found_any
def _fix_incomplete_tag_in_chunk(self, chunk: str) -> str:
"""
Fallback: fix incomplete <parameter=xxx or <function=xxx tags
(missing >)
Examples: <parameter=-C: -> <parameter=-C>, <parameter=parameter=-n:
-> <parameter=-n>
Also handles missing = cases: <function xxx> -> <function=xxx>,
<functionxxx> -> <function=xxx>
Only fixes tags that pass validation (parameter exists in tool definition)
"""
# First, handle missing = cases for function tags
chunk = self._fix_missing_equals_in_function_tag(chunk)
for tag_type in ["parameter", "function"]:
pattern = f"<{tag_type}="
if pattern not in chunk:
continue
start_idx = chunk.find(pattern)
after_tag = chunk[start_idx:]
gt_pos = after_tag.find(">")
lt_pos = after_tag.find("<", len(pattern))
# Skip if already well-formed
if (
gt_pos != -1
and (lt_pos == -1 or gt_pos < lt_pos)
and pattern in after_tag[:gt_pos]
):
continue
# Extract tag name (stop at space, newline, or <)
content = chunk[start_idx + len(pattern) :]
end_pos = next(
(i for i, ch in enumerate(content) if ch in (" ", "\n", "<")),
len(content),
)
tag_name = content[:end_pos]
if not tag_name:
continue
# Remove duplicate prefix: <parameter=parameter=xxx -> <parameter=xxx
if tag_name.startswith(f"{tag_type}="):
tag_name = tag_name[len(tag_type) + 1 :]
# Remove trailing non-alphanumeric chars (keep - and _)
while tag_name and not (
tag_name[-1].isalnum() or tag_name[-1] in ("-", "_")
):
tag_name = tag_name[:-1]
if not tag_name:
continue
# Validate parameter exists in tool definition
if tag_type == "parameter" and not self._validate_parameter_name(tag_name):
continue
# Apply fix
chunk = chunk.replace(
f"<{tag_type}={content[:end_pos]}", f"<{tag_type}={tag_name}>", 1
)
return chunk
def _fix_missing_equals_in_function_tag(self, chunk: str) -> str:
"""
Fix missing = in function tags: <function xxx> or <functionxxx>
Examples:
<function execute_bash> -> <function=execute_bash>
<functionexecute_bash> -> <function=execute_bash>
Only fixes if function name exists in tool definition
"""
# already correct
if "<function=" in chunk:
return chunk
# Pattern 1: <function xxx> (with space/newline but no =)
pattern1 = r"<function\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*>"
match1 = re.search(pattern1, chunk)
if match1:
func_name = match1.group(1).strip()
# must validate function name exists before fixing
if func_name and self._validate_function_name(func_name):
original = match1.group(0)
fixed = f"<function={func_name}>"
chunk = chunk.replace(original, fixed, 1)
return chunk
# Pattern 2: <functionxxx> (no space, no =)
# only match <function followed by letters
pattern2 = r"<function([a-zA-Z_][a-zA-Z0-9_]*)\s*>"
match2 = re.search(pattern2, chunk)
if match2:
func_name = match2.group(1).strip()
# must validate function name exists before fixing
if func_name and self._validate_function_name(func_name):
original = match2.group(0)
fixed = f"<function={func_name}>"
chunk = chunk.replace(original, fixed, 1)
return chunk
return chunk
def _validate_function_name(self, func_name: str) -> bool:
"""Check if function name exists in tool definitions"""
if not self.tools:
return False
for tool in self.tools:
if (
hasattr(tool, "type")
and tool.type == "function"
and hasattr(tool, "function")
and hasattr(tool.function, "name")
and tool.function.name == func_name
):
return True
return False
def _validate_parameter_name(self, param_name: str) -> bool:
"""Check if parameter exists in current function's tool definition"""
if not self.tools or not self.current_function_name:
return True
for tool in self.tools:
if (
hasattr(tool, "type")
and tool.type == "function"
and hasattr(tool, "function")
and hasattr(tool.function, "name")
and tool.function.name == self.current_function_name
):
if not hasattr(tool.function, "parameters"):
return True
params = tool.function.parameters
if isinstance(params, dict):
properties = params.get("properties", params)
return param_name in properties
break
return True
def _should_skip_element(self, element: str) -> bool:
"""
Determine whether an element should be skipped
Args:
element: Element to evaluate
Returns:
bool: True means should skip, False means should process
"""
# If it's a tool_call XML tag, don't skip
if (
element.startswith(self.tool_call_start_token)
or element.startswith(self.function_start_token)
or element.startswith(self.parameter_start_token)
):
return False
# If currently not parsing tool calls and not blank,
# collect this text instead of skipping
# Only process other XML elements after tool_call appears,
# otherwise treat as plain text
if self.current_call_id is None and element:
# Collect text content to buffer
self.text_content_buffer += element
return True # Still skip, but content has been collected
# If currently parsing tool calls,
# this might be parameter value, don't skip
if self.current_call_id is not None:
return False
# Skip blank content
return not element
def _find_next_complete_element(self, start_pos: int) -> tuple[str | None, int]:
"""
Find next complete XML element from specified position
Args:
start_pos: Position to start searching
Returns:
(Complete element string, element end position),
returns (None, start_pos) if no complete element found
"""
buffer = self.streaming_buffer[start_pos:]
if not buffer:
return None, start_pos
if buffer.startswith("<"):
# Check if this is an incomplete parameter/function tag
# e.g., <parameter=-C: or <function=xxx
is_incomplete_param = (
buffer.startswith("<parameter=") and ">" not in buffer.split("\n")[0]
)
is_incomplete_func = (
buffer.startswith("<function=") and ">" not in buffer.split("\n")[0]
)
if is_incomplete_param or is_incomplete_func:
# Find the corresponding closing tag
tag_type = "parameter" if is_incomplete_param else "function"
closing_tag = f"</{tag_type}>"
closing_pos = buffer.find(closing_tag)
if closing_pos != -1:
# Found closing tag, return complete element including closing tag
complete_element = buffer[: closing_pos + len(closing_tag)]
return complete_element, start_pos + closing_pos + len(closing_tag)
# Need to ensure no new < appears,
# find the nearest one between < and >
tag_end = buffer.find("<", 1)
tag_end2 = buffer.find(">", 1)
if tag_end != -1 and tag_end2 != -1:
# Next nearest is <
if tag_end < tag_end2:
return buffer[:tag_end], start_pos + tag_end
# Next nearest is >, means found XML element
else:
return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1
elif tag_end != -1:
return buffer[:tag_end], start_pos + tag_end
elif tag_end2 != -1:
return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1
else:
# If currently not parsing tool calls (entering a tool_call),
# check if starts with <tool_call> or <function=
if self.current_call_id is None:
# Check if might be start of <tool_call>
if buffer == "<tool_call>"[: len(buffer)]:
# Might be start of <tool_call>, wait for more data
return None, start_pos
elif (
buffer.startswith("<function=")
or buffer == "<function="[: len(buffer)]
):
# Might be start of <function=, wait for more data
# to get the complete function tag
return None, start_pos
else:
# Not start of <tool_call> or <function=, treat as text
return buffer, start_pos + len(buffer)
else:
# When parsing tool calls,
# wait for more data to get complete tag
return None, start_pos
else:
# Find text content (until next < or buffer end)
next_tag_pos = buffer.find("<")
if next_tag_pos != -1:
# Found text content
text_content = buffer[:next_tag_pos]
return text_content, start_pos + next_tag_pos
else:
# Buffer end is all text, process
# (no longer wait for more data)
remaining = buffer
return remaining, start_pos + len(remaining)
def _merge_new_deltas_to_single_response(self, initial_count: int) -> DeltaMessage:
"""
Merge newly generated deltas from this processing
into a single DeltaMessage
Args:
initial_count: Delta count before processing
Returns:
Merged DeltaMessage containing all newly generated delta information
"""
if len(self.deltas) <= initial_count:
return DeltaMessage(content=None)
# Get newly generated deltas
new_deltas = self.deltas[initial_count:]
if len(new_deltas) == 1:
# Only one new delta, return directly
return new_deltas[0]
# Merge multiple new deltas
merged_tool_calls: list[DeltaToolCall] = []
merged_content: str = ""
for delta in new_deltas:
if delta.content:
merged_content += delta.content
if delta.tool_calls:
# For tool_calls, we need to intelligently merge arguments
for tool_call in delta.tool_calls:
# Find if there's already a tool_call with the same call_id
existing_call = None
for existing in merged_tool_calls:
if existing.id == tool_call.id:
existing_call = existing
break
if existing_call and existing_call.function:
# Merge to existing tool_call
if tool_call.function and tool_call.function.name:
existing_call.function.name = tool_call.function.name
if (
tool_call.function
and tool_call.function.arguments is not None
):
if existing_call.function.arguments is None:
existing_call.function.arguments = ""
# For streaming JSON parameters,
# simply concatenate in order
new_args = tool_call.function.arguments
existing_call.function.arguments += new_args
if tool_call.type:
existing_call.type = tool_call.type
else:
# Add new tool_call
merged_tool_calls.append(tool_call)
return DeltaMessage(
content=merged_content if merged_content else None,
tool_calls=merged_tool_calls,
)
def _preprocess_xml_chunk(self, chunk: str) -> str:
"""
Preprocess XML chunk, handle non-standard formats,
and escape special characters
Args:
chunk: Original XML chunk
Returns:
Processed XML chunk
"""
# Check if this is a tool_call related element
is_tool_call = False
if chunk.startswith(self.tool_call_start_token) or chunk.startswith(
self.tool_call_end_token
):
is_tool_call = True
# Check for function tags (including malformed ones without =)
# <function=xxx>, </function>, <function xxx>, <functionxxx>
if (
chunk.startswith(self.function_start_token)
or chunk.startswith(self.function_end_token)
or chunk.startswith("<function ")
or re.match(r"^<function[a-zA-Z_]", chunk)
): # <functionXXX without space or =
is_tool_call = True
if chunk.startswith(self.parameter_start_token) or chunk.startswith(
self.parameter_end_token
):
is_tool_call = True
# Fallback: fix incomplete <parameter= or <function= tags without
# closing >
# This handles cases like: <parameter=-C:\n or <parameter=-B 5\n
# Apply when parsing tool calls OR when chunk looks like a function/
# parameter tag
if (
self.current_call_id is not None
or chunk.startswith("<function")
or chunk.startswith("<parameter")
):
chunk = self._fix_incomplete_tag_in_chunk(chunk)
# Handle <function=name> format -> <function name="name">
processed = re.sub(r"<function=([^>]+)>", r'<function name="\1">', chunk)
# Handle <parameter=name> format -> <parameter name="name">
processed = re.sub(r"<parameter=([^>]+)>", r'<parameter name="\1">', processed)
original_chunk = chunk
# If in parameter value accumulation mode
if self._pre_inside_parameter:
# Parameter end: output accumulated raw text
# safely then return </parameter>
if processed.startswith("</parameter>"):
body_text = self._pre_param_buffer
# Trigger deferred parsing mode
# literal_eval+json output in end_element
self.defer_current_parameter = True
self.deferred_param_raw_value = body_text
# Clean up state
self._pre_inside_parameter = False
self._pre_param_buffer = ""
self._pre_current_param_name = None
safe_text = self._escape_xml_special_chars(body_text)
return f"{safe_text}</parameter>"
else:
# If this is the first block of content after entering parameter
# evaluate if deferred parsing is needed;
# If not needed, exit accumulation mode
# and pass through directly
if self._pre_param_buffer == "":
# Get current parameter type
param_type = (
self._get_param_type(self._pre_current_param_name)
if self._pre_current_param_name
else "string"
)
# Only these types need deferred parsing to
# handle Python literals containing single quotes
is_object_type = param_type in ["object"]
is_complex_type = (
param_type in ["array", "arr", "sequence"]
or param_type.startswith("dict")
or param_type.startswith("list")
)
# Only delay when contains container symbols
# and has single quotes and is complex type
has_container_hint = (
("[" in original_chunk)
or ("{" in original_chunk)
or ("(" in original_chunk)
)
# Determine if deferred parsing is needed
need_defer = False
if is_complex_type:
# Complex type, always need deferred parsing
need_defer = True
elif (
is_object_type
and has_container_hint
and ("'" in original_chunk)
):
# Object type with container symbols
# and single quotes, need deferred parsing
need_defer = True
if not need_defer:
# No need for deferred parsing,
# exit parameter mode directly
self._pre_inside_parameter = False
return self._escape_xml_special_chars(original_chunk)
self._pre_param_buffer += original_chunk
return ""
# Parameter start: enable accumulation
if processed.startswith("<parameter name="):
m = re.match(r'<parameter name="([^"]+)">', processed)
if m:
self._pre_current_param_name = m.group(1)
self._pre_inside_parameter = True
self._pre_param_buffer = ""
return processed
# If processed doesn't contain special_token, escape processed
# This is because XML parsing encounters special characters
# and reports errors, so escaping is needed
if not is_tool_call:
processed = self._escape_xml_special_chars(processed)
return processed
def _emit_delta(self, delta: DeltaMessage):
"""Emit Delta response (streaming output)"""
self.deltas.append(delta)
def _auto_close_open_parameter_if_needed(self, incoming_tag: str | None = None):
"""Before starting to process new elements,
if there are unclosed tags from before,
automatically complete their endings to the parser.
- If there are unclosed parameters,
it's equivalent to feeding `</parameter>`
- When about to start a new function or tool_call,
if there are unclosed functions, complete `</function>`.
- When about to start a new tool_call,
if there are unclosed tool_calls, complete `</tool_call>`.
"""
# First close unclosed parameters
if self.current_param_name:
self._end_element("parameter")
# If about to start new function or tool_call,
# and there are unclosed functions, close function first
if incoming_tag in ("function", "tool_call") and self.current_function_name:
self._end_element("function")
# If about to start new tool_call,
# and there are unclosed tool_calls, close tool_call first
if incoming_tag == "tool_call" and self.current_call_id:
self._end_element("tool_call")
def _start_element(self, name: str, attrs: dict[str, str]):
"""Handle XML start element events"""
if name == "root":
return
if name == "tool_call":
# Before opening new tool_call,
# automatically complete previous unclosed tags
self._auto_close_open_parameter_if_needed("tool_call")
self.parameters = {}
self.current_call_id = make_tool_call_id()
self.current_param_is_first = True
self.tool_call_index += 1
elif name.startswith("function") or (name == "function"):
# If missing tool_call, manually complete
if not self.current_call_id:
self._start_element("tool_call", {})
# Before opening new function,
# automatically complete previous unclosed tags (parameter/function)
self._auto_close_open_parameter_if_needed("function")
function_name = self._extract_function_name(name, attrs)
self.current_function_name = function_name
self.current_function_open = True
if function_name:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(
name=function_name, arguments=""
),
)
]
)
self._emit_delta(delta)
elif name.startswith("parameter") or (name == "parameter"):
# If previous parameter hasn't ended normally,
# complete its end first, then start new parameter
self._auto_close_open_parameter_if_needed("parameter")
param_name = self._extract_parameter_name(name, attrs)
self.current_param_name = param_name
self.current_param_value = ""
self.current_param_value_converted = ""
self.start_quote_emitted = False # Reset start quote flag
# Only output parameter name and colon,
# don't output quotes
# decide after parameter value type is determined
if param_name:
if not self.parameters:
# First parameter
# start JSON, only output parameter name and colon
json_start = f'{{"{param_name}": '
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(
name=None, arguments=json_start
),
)
]
)
self._emit_delta(delta)
self.current_param_is_first = True
else:
# Subsequent parameters
# add comma and parameter name, no quotes
json_continue = f', "{param_name}": '
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(
name=None, arguments=json_continue
),
)
]
)
self._emit_delta(delta)
self.current_param_is_first = False
def _char_data(self, data: str):
"""Handle XML character data events"""
if data and self.current_param_name:
# If preprocessing stage determines deferred parsing is needed,
# only cache character data, no streaming output
if self.defer_current_parameter:
original_data = data
if self.should_emit_end_newline:
original_data = "\n" + original_data
self.should_emit_end_newline = False
if original_data.endswith("\n"):
self.should_emit_end_newline = True
original_data = original_data[:-1]
self.current_param_value += original_data
return
param_type = self._get_param_type(self.current_param_name)
# Check if this is the first time receiving data for this parameter
# If this is the first packet of data and starts with \n, remove \n
if not self.current_param_value and data.startswith("\n"):
data = data[1:]
# Output start quote for string type (if not already output)
if (
param_type in ["string", "str", "text", "varchar", "char", "enum"]
and not self.start_quote_emitted
):
quote_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments='"'),
)
]
)
self._emit_delta(quote_delta)
self.start_quote_emitted = True
if not data:
return
original_data = data
# Delay output of trailing newline
if self.should_emit_end_newline:
original_data = "\n" + original_data
self.should_emit_end_newline = False
if original_data.endswith("\n"):
self.should_emit_end_newline = True
original_data = original_data[:-1]
self.current_param_value += original_data
# convert parameter value by param_type
converted_value = self._convert_param_value(
self.current_param_value, param_type
)
output_data = self._convert_for_json_streaming(converted_value, param_type)
delta_data = output_data[len(self.current_param_value_converted) :]
self.current_param_value_converted = output_data
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments=delta_data),
)
]
)
self._emit_delta(delta)
def _end_element(self, name: str):
"""Handle XML end element events"""
if name == "root":
return
# If function or tool_call ends and there are still unclosed parameters,
# complete parameter end first
if (
name.startswith("function") or name == "function" or name == "tool_call"
) and self.current_param_name:
self._auto_close_open_parameter_if_needed()
if (
name.startswith("parameter") or name == "parameter"
) and self.current_param_name:
# End current parameter
param_name = self.current_param_name
param_value = self.current_param_value
# If in deferred parsing mode,
# perform overall parsing on raw content
# accumulated in preprocessing stage and output once
if self.defer_current_parameter:
raw_text = (
self.deferred_param_raw_value
if self.deferred_param_raw_value
else param_value
)
parsed_value = None
output_arguments = None
try:
# If previously delayed trailing newline,
# add it back before parsing
if self.should_emit_end_newline:
raw_for_parse = raw_text + "\n"
else:
raw_for_parse = raw_text
parsed_value = ast.literal_eval(raw_for_parse)
output_arguments = json.dumps(parsed_value, ensure_ascii=False)
except Exception:
# Fallback: output as string as-is
output_arguments = json.dumps(raw_text, ensure_ascii=False)
parsed_value = raw_text
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(
name=None, arguments=output_arguments
),
)
]
)
self._emit_delta(delta)
# Clean up and store
self.should_emit_end_newline = False
self.parameters[param_name] = parsed_value
self.current_param_name = None
self.current_param_value = ""
self.current_param_value_converted = ""
self.start_quote_emitted = False
self.defer_current_parameter = False
self.deferred_param_raw_value = ""
return
param_type = self._get_param_type(param_name)
# convert complete parameter value by param_type
converted_value = self._convert_param_value(param_value, param_type)
# Decide whether to add end quote based on parameter type
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
# For empty string parameters, need special handling
if not param_value and not self.start_quote_emitted:
# No start quote output,
# directly output complete empty string
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments='""'),
)
]
)
self._emit_delta(delta)
else:
# Non-empty parameter value, output end quote
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments='"'),
)
]
)
self._emit_delta(delta)
self.should_emit_end_newline = False
# Store converted value
self.parameters[param_name] = converted_value
self.current_param_name = None
self.current_param_value = ""
self.current_param_value_converted = ""
self.start_quote_emitted = False
elif name.startswith("function") or name == "function":
# if there are parameters, close JSON object
if self.parameters:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments="}"),
)
]
)
self._emit_delta(delta)
# return empty object
else:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments="{}"),
)
]
)
self._emit_delta(delta)
self.current_function_open = False
self.current_function_name = (
None # Clear function name to prevent duplicate closing
)
elif name == "tool_call":
# Before ending tool_call,
# ensure function is closed to complete missing right brace
if self.current_function_open:
# If there are still unclosed parameters, close them first
if self.current_param_name:
self._end_element("parameter")
# Close function, ensure output '}' or '{}'
self._end_element("function")
# Final Delta
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments=""),
)
]
)
self._emit_delta(delta)
# Check if there's text content to output (between tool_calls)
if self.text_content_buffer.strip():
text_delta = DeltaMessage(content=self.text_content_buffer)
self._emit_delta(text_delta)
self._reset_xml_parser_after_tool_call()
def setup_parser(self):
"""Set up XML parser event handlers"""
self.parser.buffer_text = True
self.parser.StartElementHandler = self._start_element
self.parser.EndElementHandler = self._end_element
self.parser.CharacterDataHandler = self._char_data
def set_tools(self, tools: list[ChatCompletionToolsParam] | None):
"""Set tool configuration information"""
self.tools = tools
def _extract_function_name(self, name: str, attrs: dict[str, str]) -> str | None:
"""Extract function name from various formats"""
if attrs and "name" in attrs:
return attrs["name"]
if "=" in name:
parts = name.split("=", 1)
if len(parts) == 2 and parts[0] == "function":
return parts[1]
return None
def _extract_parameter_name(self, name: str, attrs: dict[str, str]) -> str | None:
"""Extract parameter name from various formats"""
if attrs and "name" in attrs:
return attrs["name"]
if "=" in name:
parts = name.split("=", 1)
if len(parts) == 2 and parts[0] == "parameter":
return parts[1]
return None
def _get_param_type(self, param_name: str) -> str:
"""Get parameter type based on tool configuration, defaults to string
Args:
param_name: Parameter name
Returns:
Parameter type
"""
if not self.tools or not self.current_function_name:
return "string"
for tool in self.tools:
if not hasattr(tool, "type") or not (
hasattr(tool, "function") and hasattr(tool.function, "name")
):
continue
if (
tool.type == "function"
and tool.function.name == self.current_function_name
):
if not hasattr(tool.function, "parameters"):
return "string"
params = tool.function.parameters
if isinstance(params, dict) and "properties" in params:
properties = params["properties"]
if param_name in properties and isinstance(
properties[param_name], dict
):
return self.repair_param_type(
str(properties[param_name].get("type", "string"))
)
elif isinstance(params, dict) and param_name in params:
param_config = params[param_name]
if isinstance(param_config, dict):
return self.repair_param_type(
str(param_config.get("type", "string"))
)
break
return "string"
def repair_param_type(self, param_type: str) -> str:
"""Repair unknown parameter types by treating them as string
Args:
param_type: Parameter type
Returns:
Repaired parameter type
"""
if (
param_type in ["string", "str", "text", "varchar", "char", "enum"]
or param_type.startswith("int")
or param_type.startswith("uint")
or param_type.startswith("long")
or param_type.startswith("short")
or param_type.startswith("unsigned")
or param_type.startswith("num")
or param_type.startswith("float")
or param_type in ["boolean", "bool", "binary"]
or (
param_type in ["object", "array", "arr", "sequence"]
or param_type.startswith("dict")
or param_type.startswith("list")
)
):
return param_type
else:
return "string"
def _convert_param_value(self, param_value: str, param_type: str) -> Any:
"""Convert value based on parameter type
Args:
param_value: Parameter value
param_type: Parameter type
Returns:
Converted value
"""
if param_value.lower() == "null":
return None
param_type = param_type.strip().lower()
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
return param_value
elif (
param_type.startswith("int")
or param_type.startswith("uint")
or param_type.startswith("long")
or param_type.startswith("short")
or param_type.startswith("unsigned")
):
try:
return int(param_value)
except (ValueError, TypeError):
logger.warning(
"Parsed value '%s' is not an integer, degenerating to string.",
param_value,
)
return param_value
elif param_type.startswith("num") or param_type.startswith("float"):
try:
float_param_value: float = float(param_value)
return (
float_param_value
if float_param_value - int(float_param_value) != 0
else int(float_param_value)
)
except (ValueError, TypeError):
logger.warning(
"Parsed value '%s' is not a float, degenerating to string.",
param_value,
)
return param_value
elif param_type in ["boolean", "bool", "binary"]:
param_value = param_value.lower()
return param_value == "true"
else:
return param_value
def _convert_for_json_streaming(self, converted_value: Any, param_type: str) -> str:
"""Convert converted_value based on
whether it's empty and if type is string
Args:
converted_value: Converted value
param_type: Parameter type
Returns:
Converted string for streaming output
"""
# Check if value is empty, but exclude numeric 0
if converted_value is None or converted_value == "":
return ""
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
# String type, remove double quotes
return json.dumps(converted_value, ensure_ascii=False)[1:-1]
else:
# Non-string type, return complete JSON string
if not isinstance(converted_value, str):
return json.dumps(converted_value, ensure_ascii=False)
else:
return converted_value
def _reset_xml_parser_after_tool_call(self):
"""
Each tool_call is treated as a separate XML document,
so we need to reset the parser after each tool_call.
"""
# recreate XML parser
self.parser = ParserCreate()
self.setup_parser()
# Reset current tool_call state
if self.current_call_id:
self.last_completed_call_id = self.current_call_id
self.current_call_id = None
self.current_function_name = None
self.current_function_open = False
self.parameters = {}
self.current_param_name = None
self.current_param_value = ""
self.current_param_value_converted = ""
self.current_param_is_first = False
self.should_emit_end_newline = False
self.start_quote_emitted = False
self.text_content_buffer = ""
# Reset preprocessing and deferred parsing state
self._pre_inside_parameter = False
self._pre_param_buffer = ""
self._pre_current_param_name = None
self.defer_current_parameter = False
self.deferred_param_raw_value = ""
@ToolParserManager.register_module("step3p5")
class Step3p5ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.parser = StreamingXMLToolCallParser()
# Add missing attributes for compatibility with serving_chat.py
self.prev_tool_call_arr: list[dict] = []
self.streamed_args_for_tool: list[str] = []
logger.info(
"vLLM Successfully import tool parser %s !", self.__class__.__name__
)
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
self.parser.reset_streaming_state()
# Reset tool call tracking arrays for new extraction
self.prev_tool_call_arr = []
self.streamed_args_for_tool = []
if request:
self.parser.set_tools(request.tools)
result = self.parser.parse_single_streaming_chunks(model_output)
if not result.tool_calls:
return ExtractedToolCallInformation(
tool_calls=[],
tools_called=False,
content=result.content,
)
else:
tool_calls = []
for tool_call in result.tool_calls:
if tool_call.function and tool_call.function.name:
tool_calls.append(
ToolCall(
id=tool_call.id,
type=tool_call.type,
function=FunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
)
)
# Update tool call tracking arrays for compatibility
tool_index = (
tool_call.index
if tool_call.index is not None
else len(self.prev_tool_call_arr) - 1
)
# Ensure we have enough entries in our tracking arrays
while len(self.prev_tool_call_arr) <= tool_index:
self.prev_tool_call_arr.append({"name": "", "arguments": ""})
while len(self.streamed_args_for_tool) <= tool_index:
self.streamed_args_for_tool.append("")
# Update tool call information
self.prev_tool_call_arr[tool_index]["name"] = (
tool_call.function.name
)
self.prev_tool_call_arr[tool_index]["arguments"] = (
tool_call.function.arguments
)
# Update streamed arguments
if tool_call.function.arguments:
self.streamed_args_for_tool[tool_index] = (
tool_call.function.arguments
)
return ExtractedToolCallInformation(
tool_calls=tool_calls,
tools_called=len(tool_calls) > 0,
content=result.content,
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
if not previous_text:
self.parser.reset_streaming_state()
# Reset tool call tracking arrays for new streaming session
self.prev_tool_call_arr = []
self.streamed_args_for_tool = []
if request:
self.parser.set_tools(request.tools)
# Model sometimes outputs separately causing delta_text to be empty.
# If there were tool_calls before and all current tool_calls have ended,
# return an empty tool_call for outer streaming output
# to correctly output tool_call field
if not delta_text and delta_token_ids:
open_calls = current_text.count(
self.parser.tool_call_start_token
) - current_text.count(self.parser.tool_call_end_token)
if (
open_calls == 0
and self.parser.tool_call_index > 0
or not self.parser.tool_call_index
and current_text
):
return DeltaMessage(content="")
return None
# Parse the delta text and get the result
result = self.parser.parse_single_streaming_chunks(delta_text)
# Update tool call tracking arrays based on incremental parsing results
if result and result.tool_calls:
for tool_call in result.tool_calls:
if tool_call.function:
tool_index = (
tool_call.index
if tool_call.index is not None
else len(self.prev_tool_call_arr) - 1
)
# Ensure we have enough entries in our tracking arrays
while len(self.prev_tool_call_arr) <= tool_index:
self.prev_tool_call_arr.append({"name": "", "arguments": ""})
while len(self.streamed_args_for_tool) <= tool_index:
self.streamed_args_for_tool.append("")
# Update tool name if provided
if tool_call.function.name:
self.prev_tool_call_arr[tool_index]["name"] = (
tool_call.function.name
)
# Update arguments incrementally
if tool_call.function.arguments is not None:
# Concatenate the incremental arguments
# to the existing streamed arguments
self.prev_tool_call_arr[tool_index]["arguments"] += (
tool_call.function.arguments
)
self.streamed_args_for_tool[tool_index] += (
tool_call.function.arguments
)
return result
def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool:
"""
Skip the remaining_call calculation in serving_chat
"""
return False
...@@ -96,6 +96,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( ...@@ -96,6 +96,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
ultravox="UltravoxConfig", ultravox="UltravoxConfig",
step3_vl="Step3VLConfig", step3_vl="Step3VLConfig",
step3_text="Step3TextConfig", step3_text="Step3TextConfig",
step3p5="Step3p5Config",
qwen3_asr="Qwen3ASRConfig",
qwen3_next="Qwen3NextConfig", qwen3_next="Qwen3NextConfig",
lfm2_moe="Lfm2MoeConfig", lfm2_moe="Lfm2MoeConfig",
tarsier2="Tarsier2Config", tarsier2="Tarsier2Config",
......
...@@ -51,6 +51,8 @@ _CLASS_TO_MODULE: dict[str, str] = { ...@@ -51,6 +51,8 @@ _CLASS_TO_MODULE: dict[str, str] = {
"Step3VLConfig": "vllm.transformers_utils.configs.step3_vl", "Step3VLConfig": "vllm.transformers_utils.configs.step3_vl",
"Step3VisionEncoderConfig": "vllm.transformers_utils.configs.step3_vl", "Step3VisionEncoderConfig": "vllm.transformers_utils.configs.step3_vl",
"Step3TextConfig": "vllm.transformers_utils.configs.step3_vl", "Step3TextConfig": "vllm.transformers_utils.configs.step3_vl",
"Step3p5Config": "vllm.transformers_utils.configs.step3p5",
"Qwen3ASRConfig": "vllm.transformers_utils.configs.qwen3_asr",
"Qwen3NextConfig": "vllm.transformers_utils.configs.qwen3_next", "Qwen3NextConfig": "vllm.transformers_utils.configs.qwen3_next",
"Tarsier2Config": "vllm.transformers_utils.configs.tarsier2", "Tarsier2Config": "vllm.transformers_utils.configs.tarsier2",
# Special case: DeepseekV3Config is from HuggingFace Transformers # Special case: DeepseekV3Config is from HuggingFace Transformers
...@@ -91,6 +93,8 @@ __all__ = [ ...@@ -91,6 +93,8 @@ __all__ = [
"Step3VLConfig", "Step3VLConfig",
"Step3VisionEncoderConfig", "Step3VisionEncoderConfig",
"Step3TextConfig", "Step3TextConfig",
"Step3p5Config",
"Qwen3ASRConfig",
"Qwen3NextConfig", "Qwen3NextConfig",
"Tarsier2Config", "Tarsier2Config",
] ]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from transformers.configuration_utils import PretrainedConfig
class Step3p5Config(PretrainedConfig):
model_type = "step3p5"
def __init__(
self,
hidden_size: int = 5120,
intermediate_size: int = 13312,
num_attention_heads: int = 40,
num_attention_groups: int = 8,
num_hidden_layers: int = 48,
max_seq_len: int = 4096,
vocab_size: int = 65536,
rms_norm_eps: float = 1e-5,
moe_every_n_layer: int = 2,
use_moe: bool = False,
moe_intermediate_size: int = 10240,
moe_num_experts: int = 16,
moe_top_k: int = 4,
moe_layer_offset: int = 0,
rope_theta: float | list[float] | None = 500000,
rope_scaling: dict[str, Any] | None = None,
head_dim: int | None = None,
share_expert_dim: int | None = None,
norm_expert_weight: bool = True,
bos_token_id: list[int] | int | None = None,
eos_token_id: list[int] | int | None = None,
moe_router_activation: str = "softmax",
moe_router_scaling_factor: float = 1.0,
att_impl_type: str = "GQA",
use_head_wise_attn_gate: bool = False,
use_moe_router_bias: bool = True,
need_fp32_gate: bool = True,
layer_types: list[str] | None = None,
use_rope_layers: list[bool] | None = None,
yarn_only_types: list[str] | None = None,
attention_other_setting: dict[str, Any] | None = None,
num_nextn_predict_layers: int = 0,
swiglu_limits: list[float] | None = None,
swiglu_limits_shared: list[float] | None = None,
max_position_embeddings: int | None = None,
**kwargs,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_hidden_layers = num_hidden_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.rms_norm_eps = rms_norm_eps
self.use_moe = use_moe
self.moe_intermediate_size = moe_intermediate_size
self.moe_every_n_layer = moe_every_n_layer
self.moe_num_experts = moe_num_experts
self.num_experts_per_tok = moe_top_k
self.moe_top_k = moe_top_k
self.moe_layer_offset = moe_layer_offset
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.head_dim = head_dim
if share_expert_dim is None:
self.share_expert_dim = self.moe_intermediate_size * self.moe_top_k
else:
self.share_expert_dim = share_expert_dim
self.norm_expert_weight = norm_expert_weight
self.max_position_embeddings = max_position_embeddings
self.moe_router_activation = moe_router_activation
self.moe_router_scaling_factor = moe_router_scaling_factor
self.use_moe_router_bias = use_moe_router_bias
self.need_fp32_gate = need_fp32_gate
self.att_impl_type = att_impl_type
self.use_head_wise_attn_gate = use_head_wise_attn_gate
self.layer_types = layer_types
self.use_rope_layers = use_rope_layers
self.yarn_only_types = yarn_only_types
self.attention_other_setting = attention_other_setting
self.num_nextn_predict_layers = num_nextn_predict_layers
self.swiglu_limits = swiglu_limits
self.swiglu_limits_shared = swiglu_limits_shared
resolved_bos_token_id = 1 if bos_token_id is None else bos_token_id
resolved_eos_token_id = [2, 3] if eos_token_id is None else eos_token_id
self.bos_token_id = resolved_bos_token_id
self.eos_token_id = resolved_eos_token_id
super().__init__(
bos_token_id=resolved_bos_token_id,
eos_token_id=resolved_eos_token_id,
**kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment