Unverified Commit a68ed766 authored by mlmz's avatar mlmz Committed by GitHub
Browse files

feat: append more comprehensive fields in messages instead of merely role and content (#5996)

parent 82653f66
...@@ -38,7 +38,9 @@ ...@@ -38,7 +38,9 @@
" from patch import launch_server_cmd\n", " from patch import launch_server_cmd\n",
"else:\n", "else:\n",
" from sglang.utils import launch_server_cmd\n", " from sglang.utils import launch_server_cmd\n",
" import nest_asyncio\n",
"\n", "\n",
" nest_asyncio.apply()\n",
"\n", "\n",
"server_process, port = launch_server_cmd(\n", "server_process, port = launch_server_cmd(\n",
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0\" # qwen25\n", " \"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0\" # qwen25\n",
...@@ -164,7 +166,7 @@ ...@@ -164,7 +166,7 @@
"response_non_stream = client.chat.completions.create(\n", "response_non_stream = client.chat.completions.create(\n",
" model=model_name,\n", " model=model_name,\n",
" messages=messages,\n", " messages=messages,\n",
" temperature=0.1,\n", " temperature=0,\n",
" top_p=0.95,\n", " top_p=0.95,\n",
" max_tokens=1024,\n", " max_tokens=1024,\n",
" stream=False, # Non-streaming\n", " stream=False, # Non-streaming\n",
...@@ -219,7 +221,7 @@ ...@@ -219,7 +221,7 @@
"response_stream = client.chat.completions.create(\n", "response_stream = client.chat.completions.create(\n",
" model=model_name,\n", " model=model_name,\n",
" messages=messages,\n", " messages=messages,\n",
" temperature=0.1,\n", " temperature=0,\n",
" top_p=0.95,\n", " top_p=0.95,\n",
" max_tokens=1024,\n", " max_tokens=1024,\n",
" stream=True, # Enable streaming\n", " stream=True, # Enable streaming\n",
...@@ -309,23 +311,24 @@ ...@@ -309,23 +311,24 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"call_data = json.loads(full_arguments)\n", "messages.append(response_non_stream.choices[0].message)\n",
"\n", "\n",
"# Call the corresponding tool function\n",
"tool_call = messages[-1].tool_calls[0]\n",
"tool_name = tool_call.function.name\n",
"tool_to_call = available_tools[tool_name]\n",
"result = tool_to_call(**(json.loads(tool_call.function.arguments)))\n",
"print_highlight(f\"Function call result: {result}\")\n",
"# messages.append({\"role\": \"tool\", \"content\": result, \"name\": tool_name})\n",
"messages.append(\n", "messages.append(\n",
" {\n", " {\n",
" \"role\": \"user\",\n", " \"role\": \"tool\",\n",
" \"content\": \"\",\n", " \"tool_call_id\": tool_call.id,\n",
" \"tool_calls\": {\"name\": \"get_current_weather\", \"arguments\": full_arguments},\n", " \"content\": str(result),\n",
" \"name\": tool_name,\n",
" }\n", " }\n",
")\n", ")\n",
"\n", "\n",
"# Call the corresponding tool function\n",
"tool_name = messages[-1][\"tool_calls\"][\"name\"]\n",
"tool_to_call = available_tools[tool_name]\n",
"result = tool_to_call(**call_data)\n",
"print_highlight(f\"Function call result: {result}\")\n",
"messages.append({\"role\": \"tool\", \"content\": result, \"name\": tool_name})\n",
"\n",
"print_highlight(f\"Updated message history: {messages}\")" "print_highlight(f\"Updated message history: {messages}\")"
] ]
}, },
...@@ -345,7 +348,7 @@ ...@@ -345,7 +348,7 @@
"final_response = client.chat.completions.create(\n", "final_response = client.chat.completions.create(\n",
" model=model_name,\n", " model=model_name,\n",
" messages=messages,\n", " messages=messages,\n",
" temperature=0.1,\n", " temperature=0,\n",
" top_p=0.95,\n", " top_p=0.95,\n",
" stream=False,\n", " stream=False,\n",
" tools=tools,\n", " tools=tools,\n",
...@@ -391,7 +394,7 @@ ...@@ -391,7 +394,7 @@
" \"sampling_params\": {\n", " \"sampling_params\": {\n",
" \"skip_special_tokens\": False,\n", " \"skip_special_tokens\": False,\n",
" \"max_new_tokens\": 1024,\n", " \"max_new_tokens\": 1024,\n",
" \"temperature\": 0.1,\n", " \"temperature\": 0,\n",
" \"top_p\": 0.95,\n", " \"top_p\": 0.95,\n",
" },\n", " },\n",
"}\n", "}\n",
...@@ -452,7 +455,7 @@ ...@@ -452,7 +455,7 @@
"\n", "\n",
"sampling_params = {\n", "sampling_params = {\n",
" \"max_new_tokens\": 1024,\n", " \"max_new_tokens\": 1024,\n",
" \"temperature\": 0.1,\n", " \"temperature\": 0,\n",
" \"top_p\": 0.95,\n", " \"top_p\": 0.95,\n",
" \"skip_special_tokens\": False,\n", " \"skip_special_tokens\": False,\n",
"}\n", "}\n",
...@@ -540,14 +543,6 @@ ...@@ -540,14 +543,6 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"import openai\n", "import openai\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"from sglang.test.test_utils import is_in_ci\n",
"\n",
"\n",
"if is_in_ci():\n",
" from patch import launch_server_cmd\n",
"else:\n",
" from sglang.utils import launch_server_cmd\n",
"\n", "\n",
"server_process, port = launch_server_cmd(\n", "server_process, port = launch_server_cmd(\n",
" \" python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --tool-call-parser pythonic --tp 1\" # llama-3.2-1b-instruct\n", " \" python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --tool-call-parser pythonic --tp 1\" # llama-3.2-1b-instruct\n",
...@@ -624,8 +619,8 @@ ...@@ -624,8 +619,8 @@
"response_non_stream = client.chat.completions.create(\n", "response_non_stream = client.chat.completions.create(\n",
" model=model_name,\n", " model=model_name,\n",
" messages=messages,\n", " messages=messages,\n",
" temperature=0.8,\n", " temperature=0,\n",
" top_p=0.8,\n", " top_p=0.9,\n",
" stream=False, # Non-streaming\n", " stream=False, # Non-streaming\n",
" tools=tools,\n", " tools=tools,\n",
")\n", ")\n",
...@@ -635,8 +630,8 @@ ...@@ -635,8 +630,8 @@
"response_stream = client.chat.completions.create(\n", "response_stream = client.chat.completions.create(\n",
" model=model_name,\n", " model=model_name,\n",
" messages=messages,\n", " messages=messages,\n",
" temperature=0.8,\n", " temperature=0,\n",
" top_p=0.8,\n", " top_p=0.9,\n",
" stream=True,\n", " stream=True,\n",
" tools=tools,\n", " tools=tools,\n",
")\n", ")\n",
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Conversion between OpenAI APIs and native SRT APIs""" """Conversion between OpenAI APIs and native SRT APIs"""
import asyncio import asyncio
import base64
import json import json
import logging import logging
import os import os
...@@ -970,17 +971,19 @@ def v1_chat_generate_request( ...@@ -970,17 +971,19 @@ def v1_chat_generate_request(
for message in request.messages: for message in request.messages:
if message.content is None: if message.content is None:
message.content = "" message.content = ""
if isinstance(message.content, str): msg_dict = message.dict()
openai_compatible_messages.append( if isinstance(msg_dict.get("content"), list):
{"role": message.role, "content": message.content} for chunk in msg_dict["content"]:
) if isinstance(chunk, dict) and chunk.get("type") == "text":
new_msg = msg_dict.copy()
new_msg["content"] = chunk["text"]
new_msg = {
k: v for k, v in new_msg.items() if v is not None
}
openai_compatible_messages.append(new_msg)
else: else:
content_list = message.dict()["content"] msg_dict = {k: v for k, v in msg_dict.items() if v is not None}
for content in content_list: openai_compatible_messages.append(msg_dict)
if content["type"] == "text":
openai_compatible_messages.append(
{"role": message.role, "content": content["text"]}
)
if ( if (
openai_compatible_messages openai_compatible_messages
and openai_compatible_messages[-1]["role"] == "assistant" and openai_compatible_messages[-1]["role"] == "assistant"
...@@ -1290,7 +1293,8 @@ def v1_chat_generate_response( ...@@ -1290,7 +1293,8 @@ def v1_chat_generate_response(
text, call_info_list = parser.parse_non_stream(text) text, call_info_list = parser.parse_non_stream(text)
tool_calls = [ tool_calls = [
ToolCall( ToolCall(
id=str(call_info.tool_index), id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
index=call_info.tool_index,
function=FunctionResponse( function=FunctionResponse(
name=call_info.name, arguments=call_info.parameters name=call_info.name, arguments=call_info.parameters
), ),
...@@ -1406,6 +1410,7 @@ async def v1_chat_completions( ...@@ -1406,6 +1410,7 @@ async def v1_chat_completions(
reasoning_parser_dict = {} reasoning_parser_dict = {}
async def generate_stream_resp(): async def generate_stream_resp():
tool_call_first = True
is_firsts = {} is_firsts = {}
stream_buffers = {} stream_buffers = {}
n_prev_tokens = {} n_prev_tokens = {}
...@@ -1572,7 +1577,6 @@ async def v1_chat_completions( ...@@ -1572,7 +1577,6 @@ async def v1_chat_completions(
# 2) if we found calls, we output them as separate chunk(s) # 2) if we found calls, we output them as separate chunk(s)
for call_item in calls: for call_item in calls:
# transform call_item -> FunctionResponse + ToolCall # transform call_item -> FunctionResponse + ToolCall
if finish_reason_type == "stop": if finish_reason_type == "stop":
latest_delta_len = 0 latest_delta_len = 0
if isinstance(call_item.parameters, str): if isinstance(call_item.parameters, str):
...@@ -1595,15 +1599,19 @@ async def v1_chat_completions( ...@@ -1595,15 +1599,19 @@ async def v1_chat_completions(
call_item.parameters = remaining_call call_item.parameters = remaining_call
finish_reason_type = "tool_calls" finish_reason_type = "tool_calls"
tool_call = ToolCall( tool_call = ToolCall(
id=str(call_item.tool_index), id=(
f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}"
if tool_call_first
else None
),
index=call_item.tool_index, index=call_item.tool_index,
function=FunctionResponse( function=FunctionResponse(
name=call_item.name, name=call_item.name,
arguments=call_item.parameters, arguments=call_item.parameters,
), ),
) )
tool_call_first = False
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=index, index=index,
delta=DeltaMessage(tool_calls=[tool_call]), delta=DeltaMessage(tool_calls=[tool_call]),
......
...@@ -250,9 +250,29 @@ ChatCompletionMessageContentPart = Union[ ...@@ -250,9 +250,29 @@ ChatCompletionMessageContentPart = Union[
] ]
class FunctionResponse(BaseModel):
"""Function response."""
name: Optional[str] = None
arguments: Optional[str] = None
class ToolCall(BaseModel):
"""Tool call response."""
id: Optional[str] = None
index: Optional[int] = None
type: Literal["function"] = "function"
function: FunctionResponse
class ChatCompletionMessageGenericParam(BaseModel): class ChatCompletionMessageGenericParam(BaseModel):
role: Literal["system", "assistant", "tool"] role: Literal["system", "assistant", "tool"]
content: Union[str, List[ChatCompletionMessageContentTextPart], None] content: Union[str, List[ChatCompletionMessageContentTextPart], None]
tool_call_id: Optional[str] = None
name: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
class ChatCompletionMessageUserParam(BaseModel): class ChatCompletionMessageUserParam(BaseModel):
...@@ -378,22 +398,6 @@ class ChatCompletionRequest(BaseModel): ...@@ -378,22 +398,6 @@ class ChatCompletionRequest(BaseModel):
bootstrap_room: Optional[int] = None bootstrap_room: Optional[int] = None
class FunctionResponse(BaseModel):
"""Function response."""
name: Optional[str] = None
arguments: Optional[str] = None
class ToolCall(BaseModel):
"""Tool call response."""
id: str
index: Optional[int] = None
type: Literal["function"] = "function"
function: FunctionResponse
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Optional[str] = None role: Optional[str] = None
content: Optional[str] = None content: Optional[str] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment