Unverified Commit 582bbe6b authored by bigmoyan's avatar bigmoyan Committed by GitHub
Browse files

[Fix] correct tool_id for kimi-k2 when use tool_choice=required (#21259)


Co-authored-by: default avatarwangzhengtao <wangzhengtao@msh.team>
parent 0cdbf5e6
...@@ -13,48 +13,7 @@ from ...utils import RemoteOpenAIServer ...@@ -13,48 +13,7 @@ from ...utils import RemoteOpenAIServer
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "Qwen/Qwen3-0.6B" MODEL_NAME = "Qwen/Qwen3-0.6B"
tools = [
@pytest.fixture(scope="module")
def server(): # noqa: F811
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"half",
"--enable-auto-tool-choice",
"--guided-decoding-backend",
"xgrammar",
"--tool-call-parser",
"hermes",
"--reasoning-parser",
"qwen3",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.parametrize("tool_choice", [
"auto", "required", {
"type": "function",
"function": {
"name": "get_current_weather"
}
}
])
@pytest.mark.parametrize("enable_thinking", [True, False])
async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
stream: bool, tool_choice: Union[str, dict],
enable_thinking: bool):
tools = [
{ {
"type": "function", "type": "function",
"function": { "function": {
...@@ -77,14 +36,12 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, ...@@ -77,14 +36,12 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
}, },
"unit": { "unit": {
"type": "string", "type": "string",
"description": "description": "The unit to fetch the temperature in",
"The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"], "enum": ["celsius", "fahrenheit"],
}, },
"options": { "options": {
"$ref": "#/$defs/WeatherOptions", "$ref": "#/$defs/WeatherOptions",
"description": "description": "Optional parameters for weather query",
"Optional parameters for weather query",
}, },
}, },
"required": ["country", "unit"], "required": ["country", "unit"],
...@@ -149,8 +106,7 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, ...@@ -149,8 +106,7 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
}, },
"unit": { "unit": {
"type": "string", "type": "string",
"description": "description": "The unit to fetch the temperature in",
"The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"], "enum": ["celsius", "fahrenheit"],
}, },
}, },
...@@ -158,9 +114,9 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, ...@@ -158,9 +114,9 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
}, },
}, },
}, },
] ]
messages = [ messages = [
{ {
"role": "user", "role": "user",
"content": "Hi! How are you doing today?" "content": "Hi! How are you doing today?"
...@@ -176,7 +132,51 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, ...@@ -176,7 +132,51 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
"Can you tell me what the current weather is in Berlin and the "\ "Can you tell me what the current weather is in Berlin and the "\
"forecast for the next 5 days, in fahrenheit?", "forecast for the next 5 days, in fahrenheit?",
}, },
]
@pytest.fixture(scope="module")
def server(): # noqa: F811
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"half",
"--enable-auto-tool-choice",
"--guided-decoding-backend",
"xgrammar",
"--tool-call-parser",
"hermes",
"--reasoning-parser",
"qwen3",
"--gpu-memory-utilization",
"0.4"
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.parametrize("tool_choice", [
"auto", "required", {
"type": "function",
"function": {
"name": "get_current_weather"
}
}
])
@pytest.mark.parametrize("enable_thinking", [True, False])
async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
stream: bool, tool_choice: Union[str, dict],
enable_thinking: bool):
if not stream: if not stream:
# Non-streaming test # Non-streaming test
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
...@@ -216,3 +216,71 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, ...@@ -216,3 +216,71 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
output.extend(chunk.choices[0].delta.tool_calls) output.extend(chunk.choices[0].delta.tool_calls)
assert len(output) > 0 assert len(output) > 0
@pytest.fixture(scope="module")
def k2_server(): # noqa: F811
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"half",
"--enable-auto-tool-choice",
"--guided-decoding-backend",
"xgrammar",
"--tool-call-parser",
"hermes",
"--reasoning-parser",
"qwen3",
"--gpu-memory-utilization",
"0.4",
]
# hack to test kimi_k2 tool use tool_id format.
# avoid error in is_deepseek_mla check by setting kv_lora_rank=null
with RemoteOpenAIServer(MODEL_NAME,
args,
override_hf_configs={
"model_type": 'kimi_k2',
'kv_lora_rank': None
}) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def k2_client(k2_server):
async with k2_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.parametrize("tool_choice", ["required"])
async def test_tool_id_kimi_k2(k2_client: openai.AsyncOpenAI, model_name: str,
stream: bool, tool_choice: str):
if not stream:
# Non-streaming test
chat_completion = await k2_client.chat.completions.create(
messages=messages,
model=model_name,
tools=tools,
tool_choice=tool_choice)
assert chat_completion.choices[0].message.tool_calls is not None
assert len(chat_completion.choices[0].message.tool_calls) > 0
assert chat_completion.choices[0].message.tool_calls[
0].id == 'functions.get_current_weather:0'
else:
# Streaming test
output_stream = await k2_client.chat.completions.create(
messages=messages,
model=model_name,
tools=tools,
tool_choice=tool_choice,
stream=True)
output = []
async for chunk in output_stream:
if chunk.choices and chunk.choices[0].delta.tool_calls:
output.extend(chunk.choices[0].delta.tool_calls)
for o in output:
assert o.id is None or o.id == 'functions.get_current_weather:0'
...@@ -5,6 +5,7 @@ import asyncio ...@@ -5,6 +5,7 @@ import asyncio
import copy import copy
import functools import functools
import importlib import importlib
import json
import os import os
import signal import signal
import subprocess import subprocess
...@@ -101,7 +102,8 @@ class RemoteOpenAIServer: ...@@ -101,7 +102,8 @@ class RemoteOpenAIServer:
env_dict: Optional[dict[str, str]] = None, env_dict: Optional[dict[str, str]] = None,
seed: Optional[int] = 0, seed: Optional[int] = 0,
auto_port: bool = True, auto_port: bool = True,
max_wait_seconds: Optional[float] = None) -> None: max_wait_seconds: Optional[float] = None,
override_hf_configs: Optional[dict[str, Any]] = None) -> None:
if auto_port: if auto_port:
if "-p" in vllm_serve_args or "--port" in vllm_serve_args: if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
raise ValueError("You have manually specified the port " raise ValueError("You have manually specified the port "
...@@ -120,6 +122,12 @@ class RemoteOpenAIServer: ...@@ -120,6 +122,12 @@ class RemoteOpenAIServer:
vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
if override_hf_configs is not None:
vllm_serve_args = vllm_serve_args + [
"--hf-overrides",
json.dumps(override_hf_configs)
]
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.") description="vLLM's remote OpenAI server.")
subparsers = parser.add_subparsers(required=False, dest="subparser") subparsers = parser.add_subparsers(required=False, dest="subparser")
......
...@@ -1345,5 +1345,18 @@ def apply_mistral_chat_template( ...@@ -1345,5 +1345,18 @@ def apply_mistral_chat_template(
"template") "template")
raise ValueError(str(e)) from e raise ValueError(str(e)) from e
def random_tool_call_id() -> str: def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
idx = 0
for msg in conversation:
if msg['role'] == 'assistant':
tool_calls = msg.get('tool_calls')
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
return idx
def make_tool_call_id(id_type:str='random', func_name=None, idx=None):
if id_type=='kimi_k2':
return f'functions.{func_name}:{idx}'
else:
# by default return random
return f"chatcmpl-tool-{random_uuid()}" return f"chatcmpl-tool-{random_uuid()}"
...@@ -38,7 +38,7 @@ from typing_extensions import TypeAlias ...@@ -38,7 +38,7 @@ from typing_extensions import TypeAlias
from vllm import envs from vllm import envs
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
random_tool_call_id) make_tool_call_id)
from vllm.entrypoints.score_utils import (ScoreContentPartParam, from vllm.entrypoints.score_utils import (ScoreContentPartParam,
ScoreMultiModalParam) ScoreMultiModalParam)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -1634,7 +1634,7 @@ class FunctionCall(OpenAIBaseModel): ...@@ -1634,7 +1634,7 @@ class FunctionCall(OpenAIBaseModel):
class ToolCall(OpenAIBaseModel): class ToolCall(OpenAIBaseModel):
id: str = Field(default_factory=random_tool_call_id) id: str = Field(default_factory=make_tool_call_id)
type: Literal["function"] = "function" type: Literal["function"] = "function"
function: FunctionCall function: FunctionCall
......
...@@ -19,7 +19,8 @@ from vllm.config import ModelConfig ...@@ -19,7 +19,8 @@ from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
ConversationMessage, ConversationMessage,
random_tool_call_id) get_history_tool_calls_cnt,
make_tool_call_id)
from vllm.entrypoints.harmony_utils import ( from vllm.entrypoints.harmony_utils import (
get_developer_message, get_stop_tokens_for_assistant_actions, get_developer_message, get_stop_tokens_for_assistant_actions,
get_streamable_parser_for_assistant, get_system_message, parse_chat_input, get_streamable_parser_for_assistant, get_system_message, parse_chat_input,
...@@ -133,6 +134,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -133,6 +134,10 @@ class OpenAIServingChat(OpenAIServing):
source = "model" if source == "auto" else source source = "model" if source == "auto" else source
logger.info("Using default chat sampling params from %s: %s", logger.info("Using default chat sampling params from %s: %s",
source, self.default_sampling_params) source, self.default_sampling_params)
if self.model_config.hf_config.model_type == 'kimi_k2':
self.tool_call_id_type = 'kimi_k2'
else:
self.tool_call_id_type = 'random'
self.use_harmony = model_config.hf_config.model_type == "gpt_oss" self.use_harmony = model_config.hf_config.model_type == "gpt_oss"
if self.use_harmony: if self.use_harmony:
...@@ -379,6 +384,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -379,6 +384,7 @@ class OpenAIServingChat(OpenAIServing):
current_text: Optional[str], current_text: Optional[str],
delta_text: str, delta_text: str,
function_name_returned: bool, function_name_returned: bool,
tool_call_idx: Optional[int] = None
) -> tuple[Optional[DeltaMessage], bool]: ) -> tuple[Optional[DeltaMessage], bool]:
if current_text is None or current_text == "": if current_text is None or current_text == "":
# if the current text is empty, we cannot parse it # if the current text is empty, we cannot parse it
...@@ -424,8 +430,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -424,8 +430,12 @@ class OpenAIServingChat(OpenAIServing):
current_tool_call = obj[-2] current_tool_call = obj[-2]
function_name_returned = True function_name_returned = True
tool_call_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=current_tool_call["name"],
idx=tool_call_idx)
delta_message = DeltaMessage(tool_calls=[ delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(id=random_tool_call_id(), DeltaToolCall(id=tool_call_id,
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=current_tool_call["name"], name=current_tool_call["name"],
arguments=arguments), arguments=arguments),
...@@ -491,6 +501,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -491,6 +501,10 @@ class OpenAIServingChat(OpenAIServing):
all_previous_token_ids: Optional[list[list[int]]] all_previous_token_ids: Optional[list[list[int]]]
function_name_returned = [False] * num_choices function_name_returned = [False] * num_choices
if self.tool_call_id_type == 'kimi_k2':
history_tool_call_cnt = get_history_tool_calls_cnt(conversation)
else:
history_tool_call_cnt = 0
# Always track previous_texts for comprehensive output logging # Always track previous_texts for comprehensive output logging
previous_texts = [""] * num_choices previous_texts = [""] * num_choices
...@@ -673,7 +687,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -673,7 +687,6 @@ class OpenAIServingChat(OpenAIServing):
previous_text = previous_texts[i] previous_text = previous_texts[i]
previous_token_ids = all_previous_token_ids[i] previous_token_ids = all_previous_token_ids[i]
current_text = previous_text + delta_text current_text = previous_text + delta_text
# avoid the None + list error. # avoid the None + list error.
if previous_token_ids: if previous_token_ids:
current_token_ids = previous_token_ids + as_list( current_token_ids = previous_token_ids + as_list(
...@@ -733,7 +746,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -733,7 +746,7 @@ class OpenAIServingChat(OpenAIServing):
index=i) index=i)
else: else:
delta_tool_call = DeltaToolCall( delta_tool_call = DeltaToolCall(
id=random_tool_call_id(), id=make_tool_call_id(),
type="function", type="function",
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=tool_choice_function_name, name=tool_choice_function_name,
...@@ -764,7 +777,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -764,7 +777,11 @@ class OpenAIServingChat(OpenAIServing):
previous_text=previous_text, previous_text=previous_text,
current_text=content, current_text=content,
delta_text=delta_text, delta_text=delta_text,
function_name_returned=fn_name_returned)) function_name_returned=fn_name_returned,
tool_call_idx=history_tool_call_cnt))
if (delta_message and delta_message.tool_calls and
delta_message.tool_calls[0].id is not None):
history_tool_call_cnt += 1
# update the previous values for the next iteration # update the previous values for the next iteration
previous_texts[i] = current_text previous_texts[i] = current_text
...@@ -1089,6 +1106,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1089,6 +1106,10 @@ class OpenAIServingChat(OpenAIServing):
assert final_res is not None assert final_res is not None
choices: list[ChatCompletionResponseChoice] = [] choices: list[ChatCompletionResponseChoice] = []
if self.tool_call_id_type == 'kimi_k2':
history_tool_call_cnt = get_history_tool_calls_cnt(conversation)
else:
history_tool_call_cnt = 0
role = self.get_chat_request_role(request) role = self.get_chat_request_role(request)
for output in final_res.outputs: for output in final_res.outputs:
...@@ -1194,17 +1215,26 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1194,17 +1215,26 @@ class OpenAIServingChat(OpenAIServing):
assert content is not None assert content is not None
tool_calls = TypeAdapter( tool_calls = TypeAdapter(
list[FunctionDefinition]).validate_json(content) list[FunctionDefinition]).validate_json(content)
tool_call_ids = []
for tool_call in tool_calls:
tool_call_ids.append(
make_tool_call_id(id_type=self.tool_call_id_type,
func_name=tool_call.name,
idx=history_tool_call_cnt))
history_tool_call_cnt += 1
message = ChatMessage( message = ChatMessage(
role=role, role=role,
content="", content="",
reasoning_content=reasoning_content,
tool_calls=[ tool_calls=[
tool_call_class(function=FunctionCall( tool_call_class(id=tool_call_ids[i],
function=FunctionCall(
name=tool_call.name, name=tool_call.name,
arguments=json.dumps(tool_call.parameters, arguments=json.dumps(
tool_call.parameters,
ensure_ascii=False))) ensure_ascii=False)))
for tool_call in tool_calls for i, tool_call in enumerate(tool_calls)
]) ],
reasoning_content=reasoning_content)
# if the request doesn't use tool choice # if the request doesn't use tool choice
# OR specifies to not use a tool # OR specifies to not use a tool
...@@ -1248,7 +1278,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1248,7 +1278,6 @@ class OpenAIServingChat(OpenAIServing):
if (tool_call_info.content if (tool_call_info.content
and len(tool_call_info.content) > 0): and len(tool_call_info.content) > 0):
ret_content = tool_call_info.content ret_content = tool_call_info.content
message = ChatMessage(role=role, message = ChatMessage(role=role,
reasoning_content=reasoning_content, reasoning_content=reasoning_content,
content=ret_content) content=ret_content)
...@@ -1327,12 +1356,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1327,12 +1356,11 @@ class OpenAIServingChat(OpenAIServing):
elif choice.message.tool_calls: elif choice.message.tool_calls:
# For tool calls, log the function name and arguments # For tool calls, log the function name and arguments
tool_call_descriptions = [] tool_call_descriptions = []
for tool_call in choice.message.tool_calls: for tc in choice.message.tool_calls:
if hasattr(tool_call.function, "name") and hasattr( if hasattr(tc.function, "name") and hasattr(
tool_call.function, "arguments"): tc.function, "arguments"):
tool_call_descriptions.append( tool_call_descriptions.append(
f"{tool_call.function.name}({tool_call.function.arguments})" f"{tc.function.name}({tc.function.arguments})")
)
tool_calls_str = ", ".join(tool_call_descriptions) tool_calls_str = ", ".join(tool_call_descriptions)
output_text = f"[tool_calls: {tool_calls_str}]" output_text = f"[tool_calls: {tool_calls_str}]"
......
...@@ -6,7 +6,7 @@ from typing import Union ...@@ -6,7 +6,7 @@ from typing import Union
import regex as re import regex as re
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
...@@ -267,7 +267,7 @@ class DeepSeekV3ToolParser(ToolParser): ...@@ -267,7 +267,7 @@ class DeepSeekV3ToolParser(ToolParser):
DeltaToolCall( DeltaToolCall(
index=self.current_tool_id, index=self.current_tool_id,
type="function", type="function",
id=random_tool_call_id(), id=make_tool_call_id(),
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=function_name).model_dump( name=function_name).model_dump(
exclude_none=True), exclude_none=True),
......
...@@ -10,7 +10,7 @@ import partial_json_parser ...@@ -10,7 +10,7 @@ import partial_json_parser
import regex as re import regex as re
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
...@@ -203,7 +203,7 @@ class Granite20bFCToolParser(ToolParser): ...@@ -203,7 +203,7 @@ class Granite20bFCToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[ delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id, DeltaToolCall(index=self.current_tool_id,
type="function", type="function",
id=random_tool_call_id(), id=make_tool_call_id(),
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=function_name).model_dump( name=function_name).model_dump(
exclude_none=True)) exclude_none=True))
......
...@@ -8,7 +8,7 @@ from typing import Union ...@@ -8,7 +8,7 @@ from typing import Union
import partial_json_parser import partial_json_parser
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
...@@ -185,7 +185,7 @@ class GraniteToolParser(ToolParser): ...@@ -185,7 +185,7 @@ class GraniteToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[ delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id, DeltaToolCall(index=self.current_tool_id,
type="function", type="function",
id=random_tool_call_id(), id=make_tool_call_id(),
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=function_name).model_dump( name=function_name).model_dump(
exclude_none=True)) exclude_none=True))
......
...@@ -9,7 +9,7 @@ import partial_json_parser ...@@ -9,7 +9,7 @@ import partial_json_parser
import regex as re import regex as re
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
...@@ -307,7 +307,7 @@ class Hermes2ProToolParser(ToolParser): ...@@ -307,7 +307,7 @@ class Hermes2ProToolParser(ToolParser):
return DeltaMessage(tool_calls=[ return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id, DeltaToolCall(index=self.current_tool_id,
type="function", type="function",
id=random_tool_call_id(), id=make_tool_call_id(),
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=function_name).model_dump( name=function_name).model_dump(
exclude_none=True)) exclude_none=True))
......
...@@ -8,7 +8,7 @@ from typing import Union ...@@ -8,7 +8,7 @@ from typing import Union
import partial_json_parser import partial_json_parser
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
...@@ -107,7 +107,7 @@ class Internlm2ToolParser(ToolParser): ...@@ -107,7 +107,7 @@ class Internlm2ToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[ delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id, DeltaToolCall(index=self.current_tool_id,
type="function", type="function",
id=random_tool_call_id(), id=make_tool_call_id(),
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=function_name).model_dump( name=function_name).model_dump(
exclude_none=True)) exclude_none=True))
......
...@@ -9,7 +9,7 @@ import partial_json_parser ...@@ -9,7 +9,7 @@ import partial_json_parser
import regex as re import regex as re
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
...@@ -222,7 +222,7 @@ class JambaToolParser(ToolParser): ...@@ -222,7 +222,7 @@ class JambaToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[ delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id, DeltaToolCall(index=self.current_tool_id,
type="function", type="function",
id=random_tool_call_id(), id=make_tool_call_id(),
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=function_name).model_dump( name=function_name).model_dump(
exclude_none=True)) exclude_none=True))
......
...@@ -10,7 +10,7 @@ import regex as re ...@@ -10,7 +10,7 @@ import regex as re
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
...@@ -213,7 +213,7 @@ class Llama3JsonToolParser(ToolParser): ...@@ -213,7 +213,7 @@ class Llama3JsonToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[ delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id, DeltaToolCall(index=self.current_tool_id,
type="function", type="function",
id=random_tool_call_id(), id=make_tool_call_id(),
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=function_name).model_dump( name=function_name).model_dump(
exclude_none=True)) exclude_none=True))
......
...@@ -7,7 +7,7 @@ from typing import Any, Optional, Union ...@@ -7,7 +7,7 @@ from typing import Any, Optional, Union
import regex as re import regex as re
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
...@@ -394,7 +394,7 @@ class MinimaxToolParser(ToolParser): ...@@ -394,7 +394,7 @@ class MinimaxToolParser(ToolParser):
sent_tools.append({ sent_tools.append({
"sent_name": False, "sent_name": False,
"sent_arguments": "", "sent_arguments": "",
"id": random_tool_call_id(), "id": make_tool_call_id(),
}) })
while len(tool_ids) < tool_count: while len(tool_ids) < tool_count:
......
...@@ -8,7 +8,7 @@ from typing import Any, Optional ...@@ -8,7 +8,7 @@ from typing import Any, Optional
import regex as re import regex as re
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage, DeltaMessage,
ExtractedToolCallInformation, ExtractedToolCallInformation,
...@@ -74,7 +74,7 @@ class Phi4MiniJsonToolParser(ToolParser): ...@@ -74,7 +74,7 @@ class Phi4MiniJsonToolParser(ToolParser):
tool_calls: list[ToolCall] = [ tool_calls: list[ToolCall] = [
ToolCall( ToolCall(
id=random_tool_call_id(), id=make_tool_call_id(),
type="function", type="function",
function=FunctionCall( function=FunctionCall(
name=raw_function_call["name"], name=raw_function_call["name"],
......
...@@ -7,7 +7,7 @@ from typing import Any, Optional, Union ...@@ -7,7 +7,7 @@ from typing import Any, Optional, Union
import regex as re import regex as re
from vllm.entrypoints.chat_utils import random_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, DeltaToolCall,
...@@ -226,7 +226,7 @@ class xLAMToolParser(ToolParser): ...@@ -226,7 +226,7 @@ class xLAMToolParser(ToolParser):
function_name = name_match.group(1) function_name = name_match.group(1)
# The test expects us to send just the name first # The test expects us to send just the name first
tool_id = random_tool_call_id() tool_id = make_tool_call_id()
delta = DeltaMessage(tool_calls=[ delta = DeltaMessage(tool_calls=[
DeltaToolCall( DeltaToolCall(
index=0, index=0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment