Commit 0da93439 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.18.1rc0' into v0.18.1rc0-ori

parents 25f2f756 298e5108
......@@ -7,9 +7,12 @@ from typing import Any
import msgspec
from vllm.config import ModelConfig, PoolerConfig
from vllm.logger import init_logger
from vllm.sampling_params import RequestOutputKind
from vllm.tasks import PoolingTask
logger = init_logger(__name__)
class LateInteractionParams(
msgspec.Struct,
......@@ -54,10 +57,6 @@ class PoolingParams(
dimensions: int | None = None
# --8<-- [end:embed-pooling-params]
## for classification, scoring and rerank
# --8<-- [start:classify-pooling-params]
# --8<-- [end:classify-pooling-params]
## for step pooling models
step_tag_id: int | None = None
returned_token_ids: list[int] | None = None
......@@ -79,7 +78,6 @@ class PoolingParams(
return {
"embed": ["dimensions", "use_activation"],
"classify": ["use_activation"],
"score": ["use_activation"],
"token_embed": ["dimensions", "use_activation"],
"token_classify": ["use_activation"],
}
......@@ -89,6 +87,13 @@ class PoolingParams(
return deepcopy(self)
def verify(self, model_config: ModelConfig) -> None:
if self.task == "score":
logger.warning_once(
"`score` task is deprecated and will be removed in v0.20. "
"Please use `classify` instead."
)
self.task = "classify"
# plugin task uses io_processor.parse_request to verify inputs,
# skipping PoolingParams verify
if self.task == "plugin":
......@@ -96,6 +101,10 @@ class PoolingParams(
self.skip_reading_prefix_cache = True
return
# skipping verify, let plugins configure and validate pooling params
if self.task not in self.valid_parameters:
return
# NOTE: Task validation needs to done against the model instance,
# which is not available in model config. So, it's not included
# in this method
......@@ -180,7 +189,7 @@ class PoolingParams(
elif self.dimensions < 1:
raise ValueError("Dimensions must be greater than 0")
elif self.task in ["classify", "score", "token_classify"]:
elif self.task in ["classify", "token_classify"]:
if self.use_activation is None:
self.use_activation = True
else:
......
......@@ -17,9 +17,7 @@ class NemotronV3ReasoningParser(DeepSeekR1ReasoningParser):
def extract_reasoning(
self, model_output: str, request: ChatCompletionRequest | ResponsesRequest
) -> tuple[str | None, str | None]:
reasoning_content, final_content = super().extract_reasoning(
model_output, request
)
reasoning, final_content = super().extract_reasoning(model_output, request)
chat_template_kwargs = getattr(request, "chat_template_kwargs", None)
if (
......@@ -30,6 +28,6 @@ class NemotronV3ReasoningParser(DeepSeekR1ReasoningParser):
)
and final_content is None
):
reasoning_content, final_content = final_content, reasoning_content
reasoning, final_content = final_content, reasoning
return reasoning_content, final_content
return reasoning, final_content
......@@ -172,9 +172,6 @@ class BaseRenderer(ABC, Generic[_T]):
For chat requests:
- Jinja2 template compilation
For multi-modal requests:
- Importing libraries such as librosa triggers JIT compilation.
"""
from vllm.entrypoints.chat_utils import ChatTemplateResolutionError
......@@ -700,12 +697,20 @@ class BaseRenderer(ABC, Generic[_T]):
enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"]
skip_decoder_start_token = False
if self.mm_processor is not None:
from vllm.multimodal.processing import EncDecMultiModalProcessor
if isinstance(self.mm_processor, EncDecMultiModalProcessor):
skip_decoder_start_token = self.mm_processor.skip_decoder_start_token
return build_enc_dec_inputs(
encoder_inputs=self._process_singleton(enc_prompt),
decoder_inputs=(
None if dec_prompt is None else self._process_singleton(dec_prompt)
),
decoder_start_token_id=self.get_dec_start_token_id(),
skip_decoder_start_token=skip_decoder_start_token,
)
def process_for_engine(
......
......@@ -6,14 +6,15 @@ GenerationTask = Literal["generate", "transcription", "realtime"]
GENERATION_TASKS: tuple[GenerationTask, ...] = get_args(GenerationTask)
PoolingTask = Literal[
"embed", "classify", "score", "token_embed", "token_classify", "plugin"
"embed",
"classify",
"token_embed",
"token_classify",
"plugin",
"embed&token_classify",
]
POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask)
# Score API handles score/rerank for:
# - "score" task (score_type: cross-encoder models)
# - "embed" task (score_type: bi-encoder models)
# - "token_embed" task (score_type: late interaction models)
ScoreType = Literal["bi-encoder", "cross-encoder", "late-interaction"]
FrontendTask = Literal["render"]
......
......@@ -15,8 +15,15 @@ from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,
SpecialTokens,
Tokenizer,
)
from mistral_common.tokens.tokenizers.instruct import (
InstructTokenizerBase,
InstructTokenizerV13,
)
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as MistralCommonTokenizer,
)
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
......@@ -26,21 +33,20 @@ from pydantic import ValidationError
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.logger import init_logger
from vllm.tokenizers.protocol import TokenizerLike
from .protocol import TokenizerLike
try:
# Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend
except ImportError:
# Transformers v4
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as MistralCommonBackend,
)
if TYPE_CHECKING:
from transformers import BatchEncoding
try:
# Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend
except ImportError:
# Transformers v4
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as MistralCommonBackend,
)
logger = init_logger(__name__)
......@@ -235,15 +241,6 @@ class MistralTokenizer(TokenizerLike):
download_dir: str | None = None,
**kwargs,
) -> "MistralTokenizer":
try:
# Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend
except ImportError:
# Transformers v4
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as MistralCommonBackend,
)
tokenizer = MistralCommonBackend.from_pretrained(
path_or_repo_id,
*args,
......@@ -255,13 +252,13 @@ class MistralTokenizer(TokenizerLike):
return cls(tokenizer)
def __init__(self, tokenizer: "MistralCommonBackend") -> None:
def __init__(self, tokenizer: MistralCommonBackend) -> None:
super().__init__()
self.transformers_tokenizer = tokenizer
self.mistral = tokenizer.tokenizer
self.instruct = self.mistral.instruct_tokenizer
self.tokenizer = self.instruct.tokenizer
self.transformers_tokenizer: MistralCommonBackend = tokenizer
self.mistral: MistralCommonTokenizer = tokenizer.tokenizer
self.instruct: InstructTokenizerBase = self.mistral.instruct_tokenizer
self.tokenizer: Tokenizer = self.instruct.tokenizer
mode = self.mistral._chat_completion_request_validator._mode
if mode != ValidationMode.test:
......@@ -483,7 +480,11 @@ class MistralTokenizer(TokenizerLike):
return self.transformers_tokenizer.convert_tokens_to_ids(tokens)
def convert_tokens_to_string(self, tokens: list[str]) -> str:
to_decode_special_tokens = {SpecialTokens.tool_calls}
to_decode_special_tokens = {
SpecialTokens.tool_calls,
SpecialTokens.begin_think,
SpecialTokens.end_think,
}
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
tokens = [
......
......@@ -61,6 +61,10 @@ def get_qwen_vl_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
class QwenVLTokenizer(TokenizerLike):
image_start_tag: str
image_end_tag: str
image_pad_tag: str
@classmethod
def from_pretrained(cls, *args, **kwargs) -> HfTokenizer:
tokenizer = AutoTokenizer.from_pretrained(*args, **kwargs)
......
......@@ -6,8 +6,9 @@ import os
from collections.abc import Callable, Sequence
from functools import cached_property
from openai.types.responses.response_format_text_json_schema_config import (
from openai.types.responses import (
ResponseFormatTextJSONSchemaConfig,
ResponseTextConfig,
)
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
......@@ -17,7 +18,6 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
ResponseTextConfig,
)
from vllm.logger import init_logger
from vllm.sampling_params import (
......
......@@ -48,41 +48,12 @@ class DeepSeekV32ToolParser(ToolParser):
self.prev_tool_call_arr: list[dict] = []
# Sentinel tokens
self.dsml_token: str = "|DSML|"
self.dsml_start_check: str = "<" + self.dsml_token
# Sentinel token
self.tool_call_start_token: str = "<|DSML|function_calls>"
self.tool_call_end_token: str = "</|DSML|function_calls>"
self.invoke_start_prefix: str = "<|DSML|invoke name="
self.invoke_end_token: str = "</|DSML|invoke>"
self.parameter_prefix: str = "<|DSML|parameter name="
self.parameter_end_token: str = "</|DSML|parameter>"
# Streaming state variables
self.current_tool_name_sent: bool = False
# Override base class type - we use string IDs for tool calls
self.current_tool_id: str | None = None # type: ignore
self.streamed_args_for_tool: list[str] = []
self.is_tool_call_started: bool = False
self.failed_count: int = 0
# Initialize streaming state variables
# Streaming state
self.is_tool_call_started: bool = False
self.current_tool_index: int = 0
self.invoke_index: int = 0
self.header_sent: bool = False
self.current_function_name: str | None = None
self.current_param_name: str | None = None
self.current_param_value: str = ""
self.param_count: int = 0
self.in_param: bool = False
self.in_function: bool = False
self.json_started: bool = False
self.json_closed: bool = False
self.accumulated_params: dict = {}
self.streaming_request: ChatCompletionRequest | None = None
# Enhanced streaming state - reset for each new message
self._reset_streaming_state()
# Regex patterns for complete parsing
self.tool_call_complete_regex = re.compile(
......@@ -106,10 +77,6 @@ class DeepSeekV32ToolParser(ToolParser):
"vLLM Successfully import tool parser %s !", self.__class__.__name__
)
def _generate_tool_call_id(self) -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:24]}"
def adjust_request(self, request):
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
......@@ -122,33 +89,77 @@ class DeepSeekV32ToolParser(ToolParser):
request.skip_special_tokens = False
return request
def _reset_streaming_state(self):
"""Reset all streaming state."""
self.current_tool_index = 0
self.invoke_index = 0
self.is_tool_call_started = False
self.header_sent = False
self.current_tool_id = None
self.current_function_name = None
self.current_param_name = None
self.current_param_value = ""
self.param_count = 0
self.in_param = False
self.in_function = False
self.json_started = False
self.json_closed = False
# Store accumulated parameters for type conversion
self.accumulated_params = {}
self.streaming_request = None
# Clear previous tool call history to avoid state pollution
self.prev_tool_call_arr.clear()
def _generate_tool_call_id(self) -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:24]}"
def _parse_invoke_params(self, invoke_str: str) -> dict | None:
def _parse_invoke_params(self, invoke_str: str) -> dict:
param_dict = dict()
for param_name, param_val in self.parameter_complete_regex.findall(invoke_str):
param_dict[param_name] = param_val
return param_dict
def _convert_param_value(self, value: str, param_type: str) -> Any:
"""Convert parameter value to the correct type."""
if value.lower() == "null":
return None
param_type = param_type.lower()
if param_type in ["string", "str", "text"]:
return value
elif param_type in ["integer", "int"]:
try:
return int(value)
except (ValueError, TypeError):
return value
elif param_type in ["number", "float"]:
try:
val = float(value)
return val if val != int(val) else int(val)
except (ValueError, TypeError):
return value
elif param_type in ["boolean", "bool"]:
return value.lower() in ["true", "1"]
elif param_type in ["object", "array"]:
try:
return json.loads(value)
except json.JSONDecodeError:
return value
else:
# Try JSON parse first, fallback to string
try:
return json.loads(value)
except json.JSONDecodeError:
return value
def _convert_params_with_schema(
self,
function_name: str,
param_dict: dict[str, str],
request: ChatCompletionRequest | None,
) -> dict[str, Any]:
"""Convert raw string param values using the tool schema types."""
param_config: dict = {}
if request and request.tools:
for tool in request.tools:
if (
hasattr(tool, "function")
and tool.function.name == function_name
and hasattr(tool.function, "parameters")
):
schema = tool.function.parameters
if isinstance(schema, dict) and "properties" in schema:
param_config = schema["properties"]
break
converted: dict[str, Any] = {}
for name, value in param_dict.items():
param_type = "string"
if name in param_config and isinstance(param_config[name], dict):
param_type = param_config[name].get("type", "string")
converted[name] = self._convert_param_value(value, param_type)
return converted
def extract_tool_calls(
self,
model_output: str,
......@@ -200,56 +211,55 @@ class DeepSeekV32ToolParser(ToolParser):
tools_called=False, tool_calls=[], content=model_output
)
def _extract_name(self, name_str: str) -> str:
"""Extract name from quoted string."""
name_str = name_str.strip()
if (
name_str.startswith('"')
and name_str.endswith('"')
or name_str.startswith("'")
and name_str.endswith("'")
):
return name_str[1:-1]
return name_str
def _extract_param_name(self, input_str: str) -> str:
"""Extract param name"""
start = input_str.find('"') + 1
end = input_str.find('"', start)
return input_str[start:end] if start > 0 and end > start else input_str
def _reset_streaming_state(self):
"""Reset all streaming state."""
self.current_tool_index = 0
self.is_tool_call_started = False
self.prev_tool_call_arr.clear()
self.streamed_args_for_tool.clear()
def _convert_param_value(self, value: str, param_type: str) -> Any:
"""Convert parameter value to the correct type."""
if value.lower() == "null":
return None
def _extract_delta_tool_calls(
self,
current_text: str,
request: ChatCompletionRequest | None,
) -> list[DeltaToolCall]:
"""Extract DeltaToolCalls from newly completed <invoke> blocks.
Tracks progress via ``current_tool_index`` so each block is
extracted exactly once across successive streaming calls.
"""
complete_invokes = self.invoke_complete_regex.findall(current_text)
delta_tool_calls: list[DeltaToolCall] = []
while len(complete_invokes) > self.current_tool_index:
invoke_name, invoke_body = complete_invokes[self.current_tool_index]
param_dict = self._parse_invoke_params(invoke_body)
converted = self._convert_params_with_schema(
invoke_name, param_dict, request
)
args_json = json.dumps(converted, ensure_ascii=False)
idx = self.current_tool_index
self.current_tool_index += 1
param_type = param_type.lower()
if param_type in ["string", "str", "text"]:
return value
elif param_type in ["integer", "int"]:
try:
return int(value)
except (ValueError, TypeError):
return value
elif param_type in ["number", "float"]:
try:
val = float(value)
return val if val != int(val) else int(val)
except (ValueError, TypeError):
return value
elif param_type in ["boolean", "bool"]:
return value.lower() in ["true", "1"]
elif param_type in ["object", "array"]:
try:
return json.loads(value)
except json.JSONDecodeError:
return value
else:
# Try JSON parse first, fallback to string
try:
return json.loads(value)
except json.JSONDecodeError:
return value
self.prev_tool_call_arr.append(
{"name": invoke_name, "arguments": converted}
)
self.streamed_args_for_tool.append(args_json)
delta_tool_calls.append(
DeltaToolCall(
index=idx,
id=self._generate_tool_call_id(),
function=DeltaFunctionCall(
name=invoke_name,
arguments=args_json,
),
type="function",
)
)
return delta_tool_calls
def extract_tool_calls_streaming(
self,
......@@ -261,345 +271,44 @@ class DeepSeekV32ToolParser(ToolParser):
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""Extract tool calls from streaming model output."""
"""Extract tool calls from streaming model output.
Uses a buffer-until-complete-invoke strategy: tokens are buffered
until a complete invoke block is available, then parsed and emitted
in one shot.
"""
# Store request for type conversion
# First chunk of a new stream — reset state from prior request.
if not previous_text:
self._reset_streaming_state()
self.streaming_request = request
# If no delta text, return None unless it's an EOS token after tools
if not delta_text:
# Check if this is an EOS token after all tool calls are complete
if delta_token_ids:
# Count complete tool calls
complete_calls = len(
self.tool_call_complete_regex.findall(current_text)
)
# If we have completed tool calls and populated prev_tool_call_arr
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
# Check if all tool calls are closed
open_calls = current_text.count(
self.tool_call_start_token
) - current_text.count(self.tool_call_end_token)
if open_calls == 0:
# Return empty delta for finish_reason processing
return DeltaMessage(content="")
elif not self.is_tool_call_started and current_text:
# This is a regular content response that's now complete
return DeltaMessage(content="")
return None
# Check if we need to advance to next tool
if self.json_closed and not self.in_function:
# Check if this tool call has ended
invoke_ends = current_text.count(self.invoke_end_token)
if invoke_ends > self.current_tool_index:
# This tool has ended, advance to next
self.current_tool_index += 1
self.header_sent = False
self.param_count = 0
self.json_started = False
self.json_closed = False
self.in_function = False # Now we can safely set this to False
self.accumulated_params = {}
# Continue processing next tool
return None
# Handle normal content before tool calls
if not self.is_tool_call_started:
# Check if tool call is starting
if self.dsml_token in current_text:
self.is_tool_call_started = True
# Return any content before the tool call
if self.dsml_start_check in delta_text:
content_before = delta_text[
: delta_text.index(self.dsml_start_check)
]
if content_before:
return DeltaMessage(content=content_before)
return None
else:
# Check if we're between tool calls - skip whitespace
if (
current_text.rstrip().endswith(self.tool_call_end_token)
and delta_text.strip() == ""
):
# We just ended a tool call, skip whitespace
return None
# Normal content, no tool call
if delta_text.endswith("<"):
return DeltaMessage(content=delta_text[:-1])
if previous_text and previous_text.endswith("<"):
return DeltaMessage(content="<" + delta_text)
return DeltaMessage(content=delta_text)
# Check if we're between tool calls (waiting for next one)
invoke_starts_count = current_text.count(self.invoke_start_prefix)
if self.current_tool_index >= invoke_starts_count:
# We're past all tool calls, shouldn't be here
return None
# Find the current tool call portion
invoke_start_positions: list[int] = []
idx = 0
while True:
idx = current_text.find(self.invoke_start_prefix, idx)
if idx == -1:
break
invoke_start_positions.append(idx)
idx += len(self.invoke_start_prefix)
if self.current_tool_index >= len(invoke_start_positions):
# No more tool calls to process yet
return None
invoke_start_idx = invoke_start_positions[self.current_tool_index]
# Find where this tool call ends (or current position if not ended yet)
invoke_end_idx = current_text.find(self.invoke_end_token, invoke_start_idx)
if invoke_end_idx == -1:
tool_text = current_text[invoke_start_idx:]
# Detect whether we've entered the tool-call region.
# Use current_text (not delta_text) since the start token may
# be split across chunks.
content_before = None
if self.is_tool_call_started:
pass
elif self.tool_call_start_token in current_text:
# Tool-call region found, capture any plain text before it.
self.is_tool_call_started = True
start_idx = current_text.index(self.tool_call_start_token)
content_before = current_text[len(previous_text) : start_idx] or None
else:
tool_text = current_text[
invoke_start_idx : invoke_end_idx + len(self.invoke_end_token)
]
# Looking for function header
if not self.header_sent:
if self.invoke_start_prefix in tool_text:
func_start = tool_text.find(self.invoke_start_prefix) + len(
self.invoke_start_prefix
)
# Find the end quote for the function name
func_end = tool_text.find(">", func_start)
if func_end != -1:
# Found complete function name
function_name_raw = tool_text[func_start:func_end]
self.current_function_name = self._extract_name(function_name_raw)
self.current_tool_id = self._generate_tool_call_id()
self.header_sent = True
self.in_function = True
# Add to prev_tool_call_arr immediately when we detect a tool call
# Each tool call should be recorded regardless of function name
# Ensure we don't add the same tool call index multiple times
if len(self.prev_tool_call_arr) <= self.current_tool_index:
self.prev_tool_call_arr.append(
{
"name": self.current_function_name,
"arguments": "{}", # Placeholder, will be updated later
}
)
# Still in plain-text region, forward as content.
return DeltaMessage(content=delta_text) if delta_text else None
# Send header with function info
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
id=self.current_tool_id,
function=DeltaFunctionCall(
name=self.current_function_name, arguments=""
),
type="function",
)
]
)
return None
# Inside tool-call region: emit any newly completed invokes.
delta_tool_calls = self._extract_delta_tool_calls(current_text, request)
# We've sent header, now handle function body
if self.in_function:
# Send opening brace if not sent yet
if self.in_function and not self.json_started:
self.json_started = True
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="{"),
)
]
)
# Make sure json_started is set if we're processing parameters
if not self.json_started:
self.json_started = True
# Check for function end in accumulated text
if not self.json_closed and self.invoke_end_token in tool_text:
# Count total parameters in the tool text
total_param_count = tool_text.count(self.parameter_prefix)
# Only close JSON if all parameters have been processed
if self.param_count >= total_param_count:
# Close JSON
self.json_closed = True
# Extract complete tool call
# Find the invoke content
invoke_start = tool_text.find(self.invoke_start_prefix) + len(
self.invoke_start_prefix
)
invoke_content_end = tool_text.find(
self.invoke_end_token, invoke_start
)
if invoke_content_end != -1:
invoke_content = tool_text[invoke_start:invoke_content_end]
# Parse to get the complete arguments
try:
invoke_params = self._parse_invoke_params(invoke_content)
if invoke_params and self.current_tool_index < len(
self.prev_tool_call_arr
):
# Update existing entry in prev_tool_call_arr
self.prev_tool_call_arr[self.current_tool_index][
"arguments"
] = json.dumps(invoke_params, ensure_ascii=False)
except Exception:
pass # Ignore parsing errors during streaming
result = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="}"),
)
]
)
# Reset state for next tool
self.json_closed = True
self.in_function = False
self.accumulated_params = {}
logger.debug("[M2_STREAMING] Tool call completed")
return result
else:
# Don't close JSON yet, continue processing parameters
return None
# Look for parameters
# Find all parameter starts
param_starts = []
idx = 0
while True:
idx = tool_text.find(self.parameter_prefix, idx)
if idx == -1:
break
param_starts.append(idx)
idx += len(self.parameter_prefix)
# Check if we should start a new parameter
if (
not self.in_param
and self.param_count < len(param_starts)
and len(param_starts) > self.param_count
):
# Process the next parameter
param_idx = param_starts[self.param_count]
param_start = param_idx + len(self.parameter_prefix)
remaining = tool_text[param_start:]
if ">" in remaining:
# We have the complete parameter name
name_end = remaining.find(">")
param_name_raw = remaining[:name_end]
self.current_param_name = self._extract_param_name(param_name_raw)
# Find the parameter value
value_start = param_start + name_end + 1
value_text = tool_text[value_start:]
if value_text.startswith("\n"):
value_text = value_text[1:]
# Find where this parameter ends
param_end_idx = value_text.find(self.parameter_end_token)
if param_end_idx == -1:
# No closing tag, look for next parameter or function end
next_param_idx = value_text.find(self.parameter_prefix)
func_end_idx = value_text.find(self.invoke_end_token)
if next_param_idx != -1 and (
func_end_idx == -1 or next_param_idx < func_end_idx
):
param_end_idx = next_param_idx
elif func_end_idx != -1:
param_end_idx = func_end_idx
else:
# Neither found, check if tool call is complete
if self.invoke_end_token in tool_text:
# Tool call and parameter is complete
param_end_idx = len(value_text)
else:
# Still streaming, wait for more content
return None
if param_end_idx != -1:
# Complete parameter found
param_value = value_text[:param_end_idx]
if param_value.endswith("\n"):
param_value = param_value[:-1]
# Store raw value for later processing
self.accumulated_params[self.current_param_name] = param_value
# Get parameter configuration for type conversion
param_config = {}
if self.streaming_request and self.streaming_request.tools:
for tool in self.streaming_request.tools:
if (
hasattr(tool, "function")
and tool.function.name == self.current_function_name
and hasattr(tool.function, "parameters")
):
params = tool.function.parameters
if (
isinstance(params, dict)
and "properties" in params
):
param_config = params["properties"]
break
# Get parameter type
param_type = "string"
if (
self.current_param_name in param_config
and isinstance(param_config[self.current_param_name], dict)
and "type" in param_config[self.current_param_name]
):
param_type = param_config[self.current_param_name]["type"]
# Convert param value to appropriate type
converted_value = self._convert_param_value(
param_value, param_type
)
# Build JSON fragment based on the converted type
# Use json.dumps to properly serialize the value
serialized_value = json.dumps(
converted_value, ensure_ascii=False
)
if delta_tool_calls or content_before:
return DeltaMessage(
content=content_before,
tool_calls=delta_tool_calls,
)
if self.param_count == 0:
json_fragment = (
f'"{self.current_param_name}": {serialized_value}'
)
else:
json_fragment = (
f', "{self.current_param_name}": {serialized_value}'
)
self.param_count += 1
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments=json_fragment),
)
]
)
# Empty delta with token ids means EOS or closing tag; return
# non-None so the serving framework can finalize finish_reason.
if not delta_text and delta_token_ids and self.prev_tool_call_arr:
return DeltaMessage(content="")
return None
......@@ -25,7 +25,12 @@ from vllm.tool_parsers.abstract_tool_parser import ToolParser
logger = init_logger(__name__)
REGEX_FUNCTION_CALL = re.compile(
r"function call(?:<\|role_sep\|>\n)?(\{.*)",
r"(?:function call<\|role_sep\|>\n|<\|function_call\|>)(.*)",
re.DOTALL,
)
REGEX_CONTENT_PATTERN = re.compile(
r"^(.*?)(?:<\|message_sep\|>|<\|function_call\|>)",
re.DOTALL,
)
......@@ -47,57 +52,67 @@ class GigaChat3ToolParser(ToolParser):
self.tool_name_sent: bool = False
self.tool_id: str | None = None
self.prev_tool_call_arr: list[dict] = []
self.content_buffer: str = ""
self.trigger_start = "function call{"
self.end_content: bool = False
self.streamed_args_for_tool: list[str] = []
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
return request
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
match = REGEX_FUNCTION_CALL.search(model_output)
if not match:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
json_candidate = match.group(1).strip()
try:
data = json.loads(json_candidate)
except json.JSONDecodeError:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
if not (isinstance(data, dict) and "name" in data and "arguments" in data):
function_call = None
content = None
if model_output.rstrip().endswith("</s>"):
model_output = model_output[: model_output.rfind("</s>")]
m_func = REGEX_FUNCTION_CALL.search(model_output)
if m_func:
try:
function_call = json.loads(m_func.group(1), strict=False)
if (
isinstance(function_call, dict)
and "name" in function_call
and "arguments" in function_call
):
if not isinstance(function_call["arguments"], dict):
function_call = None
else:
function_call = None
except json.JSONDecodeError:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
m_content = REGEX_CONTENT_PATTERN.search(model_output)
content = m_content.group(1) if m_content else model_output
if not function_call:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
content=content if content else None,
)
name = data["name"]
args = data["arguments"]
name = function_call["name"]
args = function_call["arguments"]
if not isinstance(args, str):
args = json.dumps(args, ensure_ascii=False)
tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=name,
arguments=args,
),
)
]
prefix = model_output[: match.start()]
content = prefix.rstrip() if prefix and prefix.strip() else None
args = json.dumps(function_call["arguments"], ensure_ascii=False)
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content,
tool_calls=[
ToolCall(
type="function",
function=FunctionCall(
name=name,
arguments=args,
),
)
],
content=content if content else None,
)
def extract_tool_calls_streaming(
......@@ -110,39 +125,37 @@ class GigaChat3ToolParser(ToolParser):
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
content = None
func_name = None
cur_args = None
m_func = REGEX_FUNCTION_CALL.search(current_text)
if not self.tool_started:
match = REGEX_FUNCTION_CALL.search(current_text)
if match:
self.tool_started = True
self.content_buffer = ""
m_content = REGEX_CONTENT_PATTERN.search(delta_text)
if m_content:
content = m_content.group(1)
self.end_content = True
else:
self.content_buffer += delta_text
clean_buffer = self.content_buffer.lstrip()
is_prefix = self.trigger_start.startswith(clean_buffer)
starts_with_trigger = clean_buffer.startswith(self.trigger_start)
if is_prefix or starts_with_trigger:
return None
else:
flush_text = self.content_buffer
self.content_buffer = ""
return DeltaMessage(content=flush_text)
match = REGEX_FUNCTION_CALL.search(current_text)
if not match:
if not self.end_content:
content = delta_text
if m_func:
self.tool_started = True
if content:
return DeltaMessage(content=content)
if not m_func:
return None
json_tail = match.group(1).strip()
json_tail = m_func.group(1).strip()
name_match = NAME_REGEX.search(json_tail)
if name_match:
func_name = name_match.group(1)
args_match = ARGS_REGEX.search(json_tail)
if args_match:
cur_args = args_match.group(1).strip()
if cur_args.endswith("</s>"):
cur_args = cur_args[: -len("</s>")]
if cur_args.endswith("}"): # last '}' end of json
try:
candidate = cur_args[:-1].strip()
json.loads(candidate)
json.loads(candidate, strict=False)
cur_args = candidate
except json.JSONDecodeError:
pass
......@@ -165,11 +178,10 @@ class GigaChat3ToolParser(ToolParser):
).model_dump(exclude_none=True),
)
],
content=None,
)
if cur_args is None:
return None
prev_args = self.prev_tool_call_arr[0].get("arguments", "")
prev_args = self.prev_tool_call_arr[0].get("arguments_str", "")
if not prev_args:
delta_args = cur_args
elif cur_args.startswith(prev_args):
......@@ -178,7 +190,15 @@ class GigaChat3ToolParser(ToolParser):
return None
if not delta_args:
return None
self.prev_tool_call_arr[0]["arguments"] = cur_args
self.prev_tool_call_arr[0]["arguments_str"] = cur_args
try:
args_dict = json.loads(cur_args, strict=False)
self.prev_tool_call_arr[0]["arguments"] = args_dict
except json.JSONDecodeError:
self.prev_tool_call_arr[0]["arguments"] = {}
if len(self.streamed_args_for_tool) <= 0:
self.streamed_args_for_tool.append("")
self.streamed_args_for_tool[0] = cur_args
return DeltaMessage(
tool_calls=[
DeltaToolCall(
......@@ -188,5 +208,4 @@ class GigaChat3ToolParser(ToolParser):
).model_dump(exclude_none=True),
)
],
content=None,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
GLM-4.7 Tool Call Parser.
GLM-4.7 uses a slightly different tool call format compared to GLM-4.5:
- The function name may appear on the same line as ``<tool_call>`` without
a newline separator before the first ``<arg_key>``.
- Tool calls may have zero arguments
(e.g. ``<tool_call>func</tool_call>``).
This parser overrides the parent regex patterns to handle both formats.
"""
import regex as re
......@@ -14,10 +24,14 @@ logger = init_logger(__name__)
class Glm47MoeModelToolParser(Glm4MoeModelToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# GLM-4.7 format: <tool_call>func_name[<arg_key>...]*</tool_call>
# The function name can be followed by a newline, whitespace, or
# directly by <arg_key> tags (no separator). The arg section is
# optional so that zero-argument calls are supported.
self.func_detail_regex = re.compile(
r"<tool_call>(.*?)(<arg_key>.*?)?</tool_call>", re.DOTALL
r"<tool_call>\s*(\S+?)\s*(<arg_key>.*)?</tool_call>", re.DOTALL
)
self.func_arg_regex = re.compile(
r"<arg_key>(.*?)</arg_key>(?:\\n|\s)*<arg_value>(.*?)</arg_value>",
r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>",
re.DOTALL,
)
......@@ -206,7 +206,12 @@ class Glm4MoeModelToolParser(ToolParser):
)
else:
if len(tool_calls) > 0:
content = model_output[: model_output.find(self.tool_calls_start_token)]
content: str | None = model_output[
: model_output.find(self.tool_calls_start_token)
]
# Normalize empty/whitespace-only content to None
if not content or not content.strip():
content = None
return ExtractedToolCallInformation(
tools_called=True, tool_calls=tool_calls, content=content
)
......
......@@ -241,7 +241,10 @@ class MistralToolParser(ToolParser):
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
if self.bot_token_id not in current_token_ids:
has_bot_token = (
self.bot_token_id in current_token_ids or self.bot_token in current_text
)
if not has_bot_token:
# if the tool call token is not in the tokens generated so far,
# append output to contents since it's not a tool
return DeltaMessage(content=delta_text)
......@@ -275,7 +278,8 @@ class MistralToolParser(ToolParser):
additional_content: str = ""
if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START:
# this is the first tool call
assert self.bot_token_id in delta_token_ids
if self.bot_token not in delta_text:
return DeltaMessage(content=delta_text)
if not delta_text.startswith(self.bot_token):
additional_content += delta_text.split(self.bot_token)[0]
delta_text = self.bot_token + "".join(
......@@ -411,7 +415,7 @@ class MistralToolParser(ToolParser):
index=self.current_tool_id, type="function"
)
current_tool_call_modified = False
if self.bot_token_id in delta_token_ids:
if self.bot_token_id in delta_token_ids or self.bot_token in delta_text:
# this is the first tool call
if not delta_text.startswith(self.bot_token):
content = delta_text.split(self.bot_token)[0]
......
......@@ -295,7 +295,7 @@ class StreamingXMLToolCallParser:
final_delta = DeltaMessage(
role=None,
content=None,
reasoning_content=None,
reasoning=None,
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
......
......@@ -55,7 +55,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
"OvisConfig": "vllm.transformers_utils.configs.ovis",
"PixelShuffleSiglip2VisionConfig": "vllm.transformers_utils.configs.isaac",
"RadioConfig": "vllm.transformers_utils.configs.radio",
"SpeculatorsConfig": "vllm.transformers_utils.configs.speculators.base",
"SpeculatorsConfig": "vllm.transformers_utils.configs.speculators",
"UltravoxConfig": "vllm.transformers_utils.configs.ultravox",
"Step3VLConfig": "vllm.transformers_utils.configs.step3_vl",
"Step3VisionEncoderConfig": "vllm.transformers_utils.configs.step3_vl",
......
......@@ -27,7 +27,6 @@ class ColPaliConfig(PaliGemmaConfig):
embedding_dim: int | None = None,
embed_dim: int | None = None,
dim: int | None = None,
projection_dim: int | None = None,
colbert_dim: int | None = None,
pooling: str | None = None,
vlm_config: dict | None = None,
......@@ -37,7 +36,6 @@ class ColPaliConfig(PaliGemmaConfig):
self.embedding_dim = embedding_dim
self.embed_dim = embed_dim
self.dim = dim
self.projection_dim = projection_dim
self.colbert_dim = colbert_dim
self.pooling = pooling
......
......@@ -90,8 +90,6 @@ class MlpProjectorConfig(PretrainedConfig):
class DeepseekVLV2Config(PretrainedConfig):
model_type = "deepseek_vl_v2"
architectures: list[str] | None = None
vision_config: VisionEncoderConfig
projector_config: MlpProjectorConfig
tile_tag: str = "2D"
global_view_pos: str = "head"
......
......@@ -257,7 +257,6 @@ def _remap_mistral_audio_args(config: dict) -> dict:
encoder_attention_heads=encoder_args["n_heads"],
encoder_head_dim=encoder_args["head_dim"],
vocab_size=encoder_args["vocab_size"],
max_source_positions=encoder_args["max_source_positions"],
is_encoder_decoder=False, # Override WhisperConfig default
is_causal=encoder_args.get("causal", False),
sliding_window=encoder_args.get("sliding_window", None),
......@@ -270,6 +269,10 @@ def _remap_mistral_audio_args(config: dict) -> dict:
max_position_embeddings=block_pool_size * config["max_position_embeddings"],
),
}
# Sometimes max_source_positions is explicitly set to None in params.json but this
# is not a valid value for WhisperConfig (or downstream code that uses it).
if (max_source_positions := encoder_args.get("max_source_positions")) is not None:
config["audio_config"].max_source_positions = max_source_positions
if quant_config:
config["quantization_config"] = quant_config
return config
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
from transformers.configuration_utils import PretrainedConfig
class OlmoHybridConfig(PretrainedConfig):
......@@ -228,7 +228,15 @@ class OlmoHybridConfig(PretrainedConfig):
if "full_attention" not in layer_types:
layer_types[-1] = "full_attention"
layer_type_validation(layer_types, num_hidden_layers)
if hasattr(self, "validate_layer_type"):
# Transformers v5
self.layer_types = layer_types
self.validate_layer_type()
else:
# Transformers v4
from transformers.configuration_utils import layer_type_validation
layer_type_validation(layer_types, num_hidden_layers)
if "linear_attention" not in layer_types:
raise ValueError(
"OLMoHybrid expects at least one 'linear_attention' layer."
......
......@@ -6,11 +6,21 @@ from transformers import ParakeetEncoderConfig, PretrainedConfig
class ParakeetConfig(ParakeetEncoderConfig):
llm_hidden_size: int
projection_hidden_size: int
projection_bias: bool
projection_eps: float = 1e-5
sampling_rate: int
def __init__(
self,
llm_hidden_size: int,
projection_hidden_size: int,
projection_bias: bool,
sampling_rate: int,
projection_eps: float = 1e-5,
**kwargs,
):
super().__init__(**kwargs)
self.llm_hidden_size = llm_hidden_size
self.projection_hidden_size = projection_hidden_size
self.projection_bias = projection_bias
self.sampling_rate = sampling_rate
self.projection_eps = projection_eps
@staticmethod
def from_hf_config(
......
......@@ -16,7 +16,7 @@
# limitations under the License.
"""Qwen3.5 model configuration"""
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
from transformers.configuration_utils import PretrainedConfig
class Qwen3_5TextConfig(PretrainedConfig):
......@@ -68,10 +68,6 @@ class Qwen3_5TextConfig(PretrainedConfig):
eos_token_id=None,
**kwargs,
):
kwargs["ignore_keys_at_rope_validation"] = [
"mrope_section",
"mrope_interleaved",
]
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
......@@ -98,7 +94,18 @@ class Qwen3_5TextConfig(PretrainedConfig):
else "full_attention"
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types, self.num_hidden_layers)
if hasattr(self, "validate_layer_type"):
# Transformers v5
kwargs["ignore_keys_at_rope_validation"] = {
"mrope_section",
"mrope_interleaved",
}
self.validate_layer_type()
else:
# Transformers v4
from transformers.configuration_utils import layer_type_validation
layer_type_validation(self.layer_types, self.num_hidden_layers)
# linear attention part
self.linear_conv_kernel_dim = linear_conv_kernel_dim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment