Unverified Commit 927ca935 authored by Xihuai Wang's avatar Xihuai Wang Committed by GitHub
Browse files

Constraint Decoding: Tool call with text (#4067)

parent ef3c2dd0
...@@ -41,7 +41,7 @@ ...@@ -41,7 +41,7 @@
"\n", "\n",
"\n", "\n",
"server_process, port = launch_server_cmd(\n", "server_process, port = launch_server_cmd(\n",
" \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --tool-call-parser llama3 --host 0.0.0.0\" # llama3\n", " \"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0\" # qwen25\n",
")\n", ")\n",
"wait_for_server(f\"http://localhost:{port}\")" "wait_for_server(f\"http://localhost:{port}\")"
] ]
...@@ -55,7 +55,7 @@ ...@@ -55,7 +55,7 @@
"- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n", "- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n",
"- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n", "- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n",
"Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n", "Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n",
"- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct)." "- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in [reasoning parser](https://docs.sglang.ai/backend/separate_reasoning.html)."
] ]
}, },
{ {
...@@ -121,7 +121,7 @@ ...@@ -121,7 +121,7 @@
" return [\n", " return [\n",
" {\n", " {\n",
" \"role\": \"user\",\n", " \"role\": \"user\",\n",
" \"content\": \"What's the weather like in Boston today? Please respond with the format: Today's weather is :{function call result}\",\n", " \"content\": \"What's the weather like in Boston today? Output a reasoning before act, then use the tools to help you.\",\n",
" }\n", " }\n",
" ]\n", " ]\n",
"\n", "\n",
...@@ -164,20 +164,26 @@ ...@@ -164,20 +164,26 @@
"response_non_stream = client.chat.completions.create(\n", "response_non_stream = client.chat.completions.create(\n",
" model=model_name,\n", " model=model_name,\n",
" messages=messages,\n", " messages=messages,\n",
" temperature=0.8,\n", " temperature=0.1,\n",
" top_p=0.8,\n", " top_p=0.95,\n",
" max_tokens=1024,\n",
" stream=False, # Non-streaming\n", " stream=False, # Non-streaming\n",
" tools=tools,\n", " tools=tools,\n",
")\n", ")\n",
"print_highlight(\"Non-stream response:\")\n", "print_highlight(\"Non-stream response:\")\n",
"print(response_non_stream)" "print(response_non_stream)\n",
"print_highlight(\"==== content ====\")\n",
"print(response_non_stream.choices[0].message.content)\n",
"print_highlight(\"==== tool_calls ====\")\n",
"print(response_non_stream.choices[0].message.tool_calls)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### Streaming Request" "#### Handle Tools\n",
"When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly."
] ]
}, },
{ {
...@@ -186,39 +192,20 @@ ...@@ -186,39 +192,20 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Streaming mode test\n", "name_non_stream = response_non_stream.choices[0].message.tool_calls[0].function.name\n",
"print_highlight(\"Streaming response:\")\n", "arguments_non_stream = (\n",
"response_stream = client.chat.completions.create(\n", " response_non_stream.choices[0].message.tool_calls[0].function.arguments\n",
" model=model_name,\n",
" messages=messages,\n",
" temperature=0.8,\n",
" top_p=0.8,\n",
" stream=True, # Enable streaming\n",
" tools=tools,\n",
")\n", ")\n",
"\n", "\n",
"chunks = []\n", "print_highlight(f\"Final streamed function call name: {name_non_stream}\")\n",
"for chunk in response_stream:\n", "print_highlight(f\"Final streamed function call arguments: {arguments_non_stream}\")"
" chunks.append(chunk)\n",
" if chunk.choices[0].delta.tool_calls:\n",
" print(chunk.choices[0].delta.tool_calls[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"### Handle Tool Calls\n",
"\n",
"When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly."
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"**Non-Streaming Request**" "### Streaming Request"
] ]
}, },
{ {
...@@ -227,20 +214,41 @@ ...@@ -227,20 +214,41 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"name_non_stream = response_non_stream.choices[0].message.tool_calls[0].function.name\n", "# Streaming mode test\n",
"arguments_non_stream = (\n", "print_highlight(\"Streaming response:\")\n",
" response_non_stream.choices[0].message.tool_calls[0].function.arguments\n", "response_stream = client.chat.completions.create(\n",
" model=model_name,\n",
" messages=messages,\n",
" temperature=0.1,\n",
" top_p=0.95,\n",
" max_tokens=1024,\n",
" stream=True, # Enable streaming\n",
" tools=tools,\n",
")\n", ")\n",
"\n", "\n",
"print_highlight(f\"Final streamed function call name: {name_non_stream}\")\n", "texts = \"\"\n",
"print_highlight(f\"Final streamed function call arguments: {arguments_non_stream}\")" "tool_calls = []\n",
"name = \"\"\n",
"arguments = \"\"\n",
"for chunk in response_stream:\n",
" if chunk.choices[0].delta.content:\n",
" texts += chunk.choices[0].delta.content\n",
" if chunk.choices[0].delta.tool_calls:\n",
" tool_calls.append(chunk.choices[0].delta.tool_calls[0])\n",
"print_highlight(\"==== Text ====\")\n",
"print(texts)\n",
"\n",
"print_highlight(\"==== Tool Call ====\")\n",
"for tool_call in tool_calls:\n",
" print(tool_call)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"**Streaming Request**" "#### Handle Tools\n",
"When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly."
] ]
}, },
{ {
...@@ -251,21 +259,16 @@ ...@@ -251,21 +259,16 @@
"source": [ "source": [
"# Parse and combine function call arguments\n", "# Parse and combine function call arguments\n",
"arguments = []\n", "arguments = []\n",
"for chunk in chunks:\n", "for tool_call in tool_calls:\n",
" choice = chunk.choices[0]\n", " if tool_call.function.name:\n",
" delta = choice.delta\n", " print_highlight(f\"Streamed function call name: {tool_call.function.name}\")\n",
" if delta.tool_calls:\n", "\n",
" tool_call = delta.tool_calls[0]\n", " if tool_call.function.arguments:\n",
" if tool_call.function.name:\n", " arguments.append(tool_call.function.arguments)\n",
" print_highlight(f\"Streamed function call name: {tool_call.function.name}\")\n",
"\n",
" if tool_call.function.arguments:\n",
" arguments.append(tool_call.function.arguments)\n",
" print(f\"Streamed function call arguments: {tool_call.function.arguments}\")\n",
"\n", "\n",
"# Combine all fragments into a single JSON string\n", "# Combine all fragments into a single JSON string\n",
"full_arguments = \"\".join(arguments)\n", "full_arguments = \"\".join(arguments)\n",
"print_highlight(f\"Final streamed function call arguments: {full_arguments}\")" "print_highlight(f\"streamed function call arguments: {full_arguments}\")"
] ]
}, },
{ {
...@@ -342,13 +345,16 @@ ...@@ -342,13 +345,16 @@
"final_response = client.chat.completions.create(\n", "final_response = client.chat.completions.create(\n",
" model=model_name,\n", " model=model_name,\n",
" messages=messages,\n", " messages=messages,\n",
" temperature=0.8,\n", " temperature=0.1,\n",
" top_p=0.8,\n", " top_p=0.95,\n",
" stream=False,\n", " stream=False,\n",
" tools=tools,\n", " tools=tools,\n",
")\n", ")\n",
"print_highlight(\"Non-stream response:\")\n", "print_highlight(\"Non-stream response:\")\n",
"print(final_response)" "print(final_response)\n",
"\n",
"print_highlight(\"==== Text ====\")\n",
"print(final_response.choices[0].message.content)"
] ]
}, },
{ {
...@@ -368,7 +374,7 @@ ...@@ -368,7 +374,7 @@
"import requests\n", "import requests\n",
"\n", "\n",
"# generate an answer\n", "# generate an answer\n",
"tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-7B-Instruct\")\n",
"\n", "\n",
"messages = get_messages()\n", "messages = get_messages()\n",
"\n", "\n",
...@@ -380,8 +386,17 @@ ...@@ -380,8 +386,17 @@
")\n", ")\n",
"\n", "\n",
"gen_url = f\"http://localhost:{port}/generate\"\n", "gen_url = f\"http://localhost:{port}/generate\"\n",
"gen_data = {\"text\": input, \"sampling_params\": {\"skip_special_tokens\": False}}\n", "gen_data = {\n",
" \"text\": input,\n",
" \"sampling_params\": {\n",
" \"skip_special_tokens\": False,\n",
" \"max_new_tokens\": 1024,\n",
" \"temperature\": 0.1,\n",
" \"top_p\": 0.95,\n",
" },\n",
"}\n",
"gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n", "gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n",
"print_highlight(\"==== Reponse ====\")\n",
"print(gen_response)\n", "print(gen_response)\n",
"\n", "\n",
"# parse the response\n", "# parse the response\n",
...@@ -389,12 +404,16 @@ ...@@ -389,12 +404,16 @@
"\n", "\n",
"function_call_input = {\n", "function_call_input = {\n",
" \"text\": gen_response,\n", " \"text\": gen_response,\n",
" \"tool_call_parser\": \"llama3\",\n", " \"tool_call_parser\": \"qwen25\",\n",
" \"tools\": tools,\n", " \"tools\": tools,\n",
"}\n", "}\n",
"\n", "\n",
"function_call_response = requests.post(parse_url, json=function_call_input)\n", "function_call_response = requests.post(parse_url, json=function_call_input)\n",
"function_call_response_json = function_call_response.json()\n", "function_call_response_json = function_call_response.json()\n",
"\n",
"print_highlight(\"==== Text ====\")\n",
"print(function_call_response_json[\"normal_text\"])\n",
"print_highlight(\"==== Calls ====\")\n",
"print(\"function name: \", function_call_response_json[\"calls\"][0][\"name\"])\n", "print(\"function name: \", function_call_response_json[\"calls\"][0][\"name\"])\n",
"print(\"function arguments: \", function_call_response_json[\"calls\"][0][\"parameters\"])" "print(\"function arguments: \", function_call_response_json[\"calls\"][0][\"parameters\"])"
] ]
...@@ -425,15 +444,15 @@ ...@@ -425,15 +444,15 @@
"from sglang.srt.function_call_parser import FunctionCallParser\n", "from sglang.srt.function_call_parser import FunctionCallParser\n",
"from sglang.srt.managers.io_struct import Tool, Function\n", "from sglang.srt.managers.io_struct import Tool, Function\n",
"\n", "\n",
"llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n", "llm = sgl.Engine(model_path=\"Qwen/Qwen2.5-7B-Instruct\")\n",
"tokenizer = llm.tokenizer_manager.tokenizer\n", "tokenizer = llm.tokenizer_manager.tokenizer\n",
"input_ids = tokenizer.apply_chat_template(\n", "input_ids = tokenizer.apply_chat_template(\n",
" messages, tokenize=True, add_generation_prompt=True, tools=tools\n", " messages, tokenize=True, add_generation_prompt=True, tools=tools\n",
")\n", ")\n",
"\n", "\n",
"sampling_params = {\n", "sampling_params = {\n",
" \"max_new_tokens\": 128,\n", " \"max_new_tokens\": 1024,\n",
" \"temperature\": 0.3,\n", " \"temperature\": 0.1,\n",
" \"top_p\": 0.95,\n", " \"top_p\": 0.95,\n",
" \"skip_special_tokens\": False,\n", " \"skip_special_tokens\": False,\n",
"}\n", "}\n",
...@@ -461,10 +480,10 @@ ...@@ -461,10 +480,10 @@
"\n", "\n",
"tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]\n", "tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]\n",
"\n", "\n",
"parser = FunctionCallParser(tools=tools, tool_call_parser=\"llama3\")\n", "parser = FunctionCallParser(tools=tools, tool_call_parser=\"qwen25\")\n",
"normal_text, calls = parser.parse_non_stream(generated_text)\n", "normal_text, calls = parser.parse_non_stream(generated_text)\n",
"\n", "\n",
"print(\"\\n=== Parsing Result ===\")\n", "print(\"=== Parsing Result ===\")\n",
"print(\"Normal text portion:\", normal_text)\n", "print(\"Normal text portion:\", normal_text)\n",
"print(\"Function call portion:\")\n", "print(\"Function call portion:\")\n",
"for call in calls:\n", "for call in calls:\n",
...@@ -521,5 +540,5 @@ ...@@ -521,5 +540,5 @@
} }
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 2 "nbformat_minor": 4
} }
...@@ -128,13 +128,15 @@ class BaseFormatDetector: ...@@ -128,13 +128,15 @@ class BaseFormatDetector:
return results return results
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: def detect_and_parse(
self, text: str, tools: List[Function]
) -> StreamingParseResult:
""" """
Parses the text in one go. Returns success=True if the format matches, otherwise False. Parses the text in one go. Returns success=True if the format matches, otherwise False.
Note that leftover_text here represents "content that this parser will not consume further". Note that leftover_text here represents "content that this parser will not consume further".
""" """
action = json.loads(text) action = json.loads(text)
return self.parse_base_json(action, tools) return StreamingParseResult(calls=self.parse_base_json(action, tools))
def parse_streaming_increment( def parse_streaming_increment(
self, new_text: str, tools: List[Function] self, new_text: str, tools: List[Function]
...@@ -322,7 +324,9 @@ class Qwen25Detector(BaseFormatDetector): ...@@ -322,7 +324,9 @@ class Qwen25Detector(BaseFormatDetector):
"""Check if the text contains a Qwen 2.5 format tool call.""" """Check if the text contains a Qwen 2.5 format tool call."""
return self.bot_token in text return self.bot_token in text
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: def detect_and_parse(
self, text: str, tools: List[Function]
) -> StreamingParseResult:
""" """
One-time parsing: Detects and parses tool calls in the provided text. One-time parsing: Detects and parses tool calls in the provided text.
...@@ -330,15 +334,17 @@ class Qwen25Detector(BaseFormatDetector): ...@@ -330,15 +334,17 @@ class Qwen25Detector(BaseFormatDetector):
:param tools: List of available tools. :param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
""" """
if "<tool_call>" not in text: idx = text.find(self.bot_token)
return [] normal_text = text[:idx].strip() if idx != -1 else text
pattern = r"<tool_call>(.*?)</tool_call>" if self.bot_token not in text:
return StreamingParseResult(normal_text=normal_text, calls=[])
pattern = rf"{self.bot_token}(.*?){self.eot_token}"
match_result_list = re.findall(pattern, text, re.DOTALL) match_result_list = re.findall(pattern, text, re.DOTALL)
calls = [] calls = []
for match_result in match_result_list: for match_result in match_result_list:
match_result = json.loads(match_result) match_result = json.loads(match_result)
calls.extend(self.parse_base_json(match_result, tools)) calls.extend(self.parse_base_json(match_result, tools))
return calls return StreamingParseResult(normal_text=normal_text, calls=calls)
class MistralDetector(BaseFormatDetector): class MistralDetector(BaseFormatDetector):
...@@ -374,7 +380,9 @@ class MistralDetector(BaseFormatDetector): ...@@ -374,7 +380,9 @@ class MistralDetector(BaseFormatDetector):
else: else:
return "" return ""
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: def detect_and_parse(
self, text: str, tools: List[Function]
) -> StreamingParseResult:
""" """
One-time parsing: Detects and parses tool calls in the provided text. One-time parsing: Detects and parses tool calls in the provided text.
...@@ -382,6 +390,8 @@ class MistralDetector(BaseFormatDetector): ...@@ -382,6 +390,8 @@ class MistralDetector(BaseFormatDetector):
:param tools: List of available tools. :param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
""" """
idx = text.find(self.bot_token)
normal_text = text[:idx].strip() if idx != -1 else text
text = self._clean_text(text) text = self._clean_text(text)
tool_content = text.replace("[TOOL_CALLS]", "").strip() tool_content = text.replace("[TOOL_CALLS]", "").strip()
raw_tool_calls = self.tool_call_regex.findall(tool_content) raw_tool_calls = self.tool_call_regex.findall(tool_content)
...@@ -391,7 +401,7 @@ class MistralDetector(BaseFormatDetector): ...@@ -391,7 +401,7 @@ class MistralDetector(BaseFormatDetector):
function_call_arr = json.loads(raw_tool_call) function_call_arr = json.loads(raw_tool_call)
for match_result in function_call_arr: for match_result in function_call_arr:
calls.extend(self.parse_base_json(match_result, tools)) calls.extend(self.parse_base_json(match_result, tools))
return calls return StreamingParseResult(normal_text=normal_text, calls=calls)
class Llama32Detector(BaseFormatDetector): class Llama32Detector(BaseFormatDetector):
...@@ -414,7 +424,7 @@ class Llama32Detector(BaseFormatDetector): ...@@ -414,7 +424,7 @@ class Llama32Detector(BaseFormatDetector):
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
"""Parse function calls from text, handling multiple JSON objects.""" """Parse function calls from text, handling multiple JSON objects."""
if "<|python_tag|>" not in text and not text.startswith("{"): if "<|python_tag|>" not in text and not text.startswith("{"):
return [] return StreamingParseResult(normal_text=text, calls=[])
if "<|python_tag|>" in text: if "<|python_tag|>" in text:
_, action_text = text.split("<|python_tag|>") _, action_text = text.split("<|python_tag|>")
...@@ -423,7 +433,6 @@ class Llama32Detector(BaseFormatDetector): ...@@ -423,7 +433,6 @@ class Llama32Detector(BaseFormatDetector):
# Split by semicolon and process each part # Split by semicolon and process each part
json_parts = [part.strip() for part in action_text.split(";") if part.strip()] json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
all_actions = [] all_actions = []
for part in json_parts: for part in json_parts:
try: try:
...@@ -434,12 +443,11 @@ class Llama32Detector(BaseFormatDetector): ...@@ -434,12 +443,11 @@ class Llama32Detector(BaseFormatDetector):
logger.warning(f"Failed to parse JSON part: {part}") logger.warning(f"Failed to parse JSON part: {part}")
logger.warning(f"JSON parse error: {str(e)}") logger.warning(f"JSON parse error: {str(e)}")
continue continue
calls = []
# Only process if we found valid JSON objects # Only process if we found valid JSON objects
if all_actions: if all_actions:
return self.parse_base_json(all_actions, tools) calls = self.parse_base_json(all_actions, tools)
return StreamingParseResult(normal_text=normal_text, calls=calls)
return []
class MultiFormatParser: class MultiFormatParser:
...@@ -449,7 +457,9 @@ class MultiFormatParser: ...@@ -449,7 +457,9 @@ class MultiFormatParser:
""" """
self.detectors = detectors self.detectors = detectors
def parse_once(self, text: str, tools: List[Function]): def parse_once(
self, text: str, tools: List[Function]
) -> Tuple[str, list[ToolCallItem]]:
""" """
One-time parsing: Loop through detectors until there are no new matches or text is exhausted One-time parsing: Loop through detectors until there are no new matches or text is exhausted
Return: (final_text, all_calls) Return: (final_text, all_calls)
...@@ -459,15 +469,19 @@ class MultiFormatParser: ...@@ -459,15 +469,19 @@ class MultiFormatParser:
final_calls = [] final_calls = []
final_normal_text = text final_normal_text = text
for detector in self.detectors: for detector in self.detectors:
tool_call_list = detector.detect_and_parse(text, tools) parsed_result = detector.detect_and_parse(text, tools)
tool_call_list = parsed_result.calls
if len(tool_call_list) > 0: # parsed successfully if len(tool_call_list) > 0: # parsed successfully
final_calls = tool_call_list final_calls = tool_call_list
final_normal_text = parsed_result.normal_text
break break
# leftover_text is the normal text not consumed by any Detector # leftover_text is the normal text not consumed by any Detector
return final_normal_text, final_calls return final_normal_text, final_calls
def parse_streaming_increment(self, new_text: str, tools: List[Function]): def parse_streaming_increment(
self, new_text: str, tools: List[Function]
) -> Tuple[str, list[ToolCallItem]]:
""" """
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
and merge their produced normal_text/calls to return. and merge their produced normal_text/calls to return.
...@@ -532,7 +546,7 @@ class FunctionCallParser: ...@@ -532,7 +546,7 @@ class FunctionCallParser:
return True return True
return False return False
def parse_non_stream(self, full_text: str): def parse_non_stream(self, full_text: str) -> Tuple[str, list[ToolCallItem]]:
""" """
Non-streaming call: one-time parsing Non-streaming call: one-time parsing
""" """
...@@ -541,7 +555,7 @@ class FunctionCallParser: ...@@ -541,7 +555,7 @@ class FunctionCallParser:
) )
return full_normal_text, calls return full_normal_text, calls
def parse_stream_chunk(self, chunk_text: str): def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]:
""" """
Streaming call: incremental parsing Streaming call: incremental parsing
""" """
......
...@@ -1130,7 +1130,7 @@ def v1_chat_generate_response( ...@@ -1130,7 +1130,7 @@ def v1_chat_generate_response(
finish_reason["type"] = "tool_calls" finish_reason["type"] = "tool_calls"
finish_reason["matched"] = None finish_reason["matched"] = None
try: try:
full_normal_text, call_info_list = parser.parse_non_stream(text) text, call_info_list = parser.parse_non_stream(text)
tool_calls = [ tool_calls = [
ToolCall( ToolCall(
id=str(call_info.tool_index), id=str(call_info.tool_index),
...@@ -1153,9 +1153,9 @@ def v1_chat_generate_response( ...@@ -1153,9 +1153,9 @@ def v1_chat_generate_response(
"index": 0, "index": 0,
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": text if tool_calls is None else None, "content": text if text else None,
"tool_calls": tool_calls, "tool_calls": tool_calls,
"reasoning_content": reasoning_text, "reasoning_content": reasoning_text if reasoning_text else None,
}, },
"logprobs": choice_logprobs.model_dump() if choice_logprobs else None, "logprobs": choice_logprobs.model_dump() if choice_logprobs else None,
"finish_reason": (finish_reason["type"] if finish_reason else ""), "finish_reason": (finish_reason["type"] if finish_reason else ""),
...@@ -1170,9 +1170,9 @@ def v1_chat_generate_response( ...@@ -1170,9 +1170,9 @@ def v1_chat_generate_response(
index=idx, index=idx,
message=ChatMessage( message=ChatMessage(
role="assistant", role="assistant",
content=text if tool_calls is None else None, content=text if text else None,
tool_calls=tool_calls, tool_calls=tool_calls,
reasoning_content=reasoning_text, reasoning_content=reasoning_text if reasoning_text else None,
), ),
logprobs=choice_logprobs, logprobs=choice_logprobs,
finish_reason=(finish_reason["type"] if finish_reason else ""), finish_reason=(finish_reason["type"] if finish_reason else ""),
...@@ -1317,9 +1317,11 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -1317,9 +1317,11 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
tokenizer_manager.server_args.reasoning_parser tokenizer_manager.server_args.reasoning_parser
and request.separate_reasoning and request.separate_reasoning
): ):
delta = DeltaMessage(role="assistant", reasoning_content="") delta = DeltaMessage(
role="assistant", reasoning_content=None
)
else: else:
delta = DeltaMessage(role="assistant", content="") delta = DeltaMessage(role="assistant", content=None)
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=index, index=index,
delta=delta, delta=delta,
...@@ -1362,7 +1364,11 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -1362,7 +1364,11 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
if reasoning_text: if reasoning_text:
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=index, index=index,
delta=DeltaMessage(reasoning_content=reasoning_text), delta=DeltaMessage(
reasoning_content=(
reasoning_text if reasoning_text else None
)
),
finish_reason=( finish_reason=(
None None
if finish_reason_type if finish_reason_type
...@@ -1396,7 +1402,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -1396,7 +1402,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
if normal_text: if normal_text:
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=index, index=index,
delta=DeltaMessage(content=normal_text), delta=DeltaMessage(
content=normal_text if normal_text else None
),
finish_reason=( finish_reason=(
None None
if finish_reason_type if finish_reason_type
...@@ -1468,7 +1476,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -1468,7 +1476,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
# No tool calls => just treat this as normal text # No tool calls => just treat this as normal text
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=index, index=index,
delta=DeltaMessage(content=delta), delta=DeltaMessage(content=delta if delta else None),
finish_reason=( finish_reason=(
None None
if finish_reason_type and len(finish_reason_type) == 0 if finish_reason_type and len(finish_reason_type) == 0
......
...@@ -257,7 +257,7 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -257,7 +257,7 @@ class TestOpenAIServer(unittest.TestCase):
ret_num_top_logprobs == logprobs ret_num_top_logprobs == logprobs
), f"{ret_num_top_logprobs} vs {logprobs}" ), f"{ret_num_top_logprobs} vs {logprobs}"
assert isinstance(data.content, str) assert isinstance(data.content, str) or response.choices[0].finish_reason
assert response.id assert response.id
assert response.created assert response.created
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment