Unverified Commit 795e98f8 authored by Surya-Gunukula's avatar Surya-Gunukula Committed by GitHub
Browse files

Forward unknown tool calls instead of dropping (#12226)

parent 358ae356
...@@ -6,15 +6,16 @@ SGLang supports various environment variables that can be used to configure its ...@@ -6,15 +6,16 @@ SGLang supports various environment variables that can be used to configure its
## General Configuration ## General Configuration
| Environment Variable | Description | Default Value | | Environment Variable | Description | Default Value |
|-------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------|---------------| |-------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------|------------------------------|
| `SGLANG_USE_MODELSCOPE` | Enable using models from ModelScope | `false` | | `SGLANG_USE_MODELSCOPE` | Enable using models from ModelScope | `false` |
| `SGLANG_HOST_IP` | Host IP address for the server | `0.0.0.0` | | `SGLANG_HOST_IP` | Host IP address for the server | `0.0.0.0` |
| `SGLANG_PORT` | Port for the server | auto-detected | | `SGLANG_PORT` | Port for the server | auto-detected |
| `SGLANG_LOGGING_CONFIG_PATH` | Custom logging configuration path | Not set | | `SGLANG_LOGGING_CONFIG_PATH` | Custom logging configuration path | Not set |
| `SGLANG_DISABLE_REQUEST_LOGGING` | Disable request logging | `false` | | `SGLANG_DISABLE_REQUEST_LOGGING` | Disable request logging | `false` |
| `SGLANG_HEALTH_CHECK_TIMEOUT` | Timeout for health check in seconds | `20` | | `SGLANG_HEALTH_CHECK_TIMEOUT` | Timeout for health check in seconds | `20` |
| `SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL` | The interval of passes to collect the metric of selected count of physical experts on each layer and GPU rank. 0 means disabled. | `0` | | `SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL` | The interval of passes to collect the metric of selected count of physical experts on each layer and GPU rank. 0 means disabled. | `0` |
| `SGLANG_FORWARD_UNKNOWN_TOOLS` | Forward unknown tool calls to clients instead of dropping them | `false` (drop unknown tools) |
## Performance Tuning ## Performance Tuning
......
...@@ -158,6 +158,9 @@ class Envs: ...@@ -158,6 +158,9 @@ class Envs:
SGLANG_DISABLE_OUTLINES_DISK_CACHE = EnvBool(True) SGLANG_DISABLE_OUTLINES_DISK_CACHE = EnvBool(True)
SGLANG_GRAMMAR_TIMEOUT = EnvFloat(300) SGLANG_GRAMMAR_TIMEOUT = EnvFloat(300)
# Tool Calling
SGLANG_FORWARD_UNKNOWN_TOOLS = EnvBool(False)
# Hi-Cache # Hi-Cache
SGLANG_HICACHE_HF3FS_CONFIG_PATH = EnvStr(None) SGLANG_HICACHE_HF3FS_CONFIG_PATH = EnvStr(None)
......
...@@ -8,6 +8,7 @@ from partial_json_parser.core.exceptions import MalformedJSON ...@@ -8,6 +8,7 @@ from partial_json_parser.core.exceptions import MalformedJSON
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.environ import envs
from sglang.srt.function_call.core_types import ( from sglang.srt.function_call.core_types import (
StreamingParseResult, StreamingParseResult,
ToolCallItem, ToolCallItem,
...@@ -75,19 +76,21 @@ class BaseFormatDetector(ABC): ...@@ -75,19 +76,21 @@ class BaseFormatDetector(ABC):
results = [] results = []
for act in action: for act in action:
name = act.get("name") name = act.get("name")
if name and name in tool_indices: if not (name and name in tool_indices):
results.append(
ToolCallItem(
tool_index=-1, # Caller should update this based on the actual tools array called
name=name,
parameters=json.dumps(
act.get("parameters") or act.get("arguments", {}),
ensure_ascii=False,
),
)
)
else:
logger.warning(f"Model attempted to call undefined function: {name}") logger.warning(f"Model attempted to call undefined function: {name}")
if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get():
continue # Skip unknown tools (default legacy behavior)
results.append(
ToolCallItem(
tool_index=-1, # Caller should update this based on the actual tools array called
name=name,
parameters=json.dumps(
act.get("parameters") or act.get("arguments", {}),
ensure_ascii=False,
),
)
)
return results return results
......
...@@ -4,6 +4,7 @@ import re ...@@ -4,6 +4,7 @@ import re
from typing import List, Optional from typing import List, Optional
from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.environ import envs
from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import ( from sglang.srt.function_call.core_types import (
StreamingParseResult, StreamingParseResult,
...@@ -220,7 +221,8 @@ class GptOssDetector(BaseFormatDetector): ...@@ -220,7 +221,8 @@ class GptOssDetector(BaseFormatDetector):
# Check if tool exists # Check if tool exists
if function_name not in tool_indices: if function_name not in tool_indices:
logger.debug(f"Function {function_name} not in available tools") logger.debug(f"Function {function_name} not in available tools")
return None if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get():
return None # Skip unknown tools (default legacy behavior)
# Parse JSON arguments # Parse JSON arguments
try: try:
......
...@@ -5,6 +5,7 @@ import re ...@@ -5,6 +5,7 @@ import re
from typing import List, Optional from typing import List, Optional
from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.environ import envs
from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import ( from sglang.srt.function_call.core_types import (
StreamingParseResult, StreamingParseResult,
...@@ -91,7 +92,9 @@ class PythonicDetector(BaseFormatDetector): ...@@ -91,7 +92,9 @@ class PythonicDetector(BaseFormatDetector):
logger.warning( logger.warning(
f"Model attempted to call undefined function: {function_name}" f"Model attempted to call undefined function: {function_name}"
) )
continue if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get():
continue # Skip unknown tools (default legacy behavior)
arguments = {} arguments = {}
for keyword in call.keywords: for keyword in call.keywords:
arguments[keyword.arg] = self._get_parameter_value(keyword.value) arguments[keyword.arg] = self._get_parameter_value(keyword.value)
......
...@@ -6,6 +6,7 @@ import re ...@@ -6,6 +6,7 @@ import re
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.environ import envs
from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import ( from sglang.srt.function_call.core_types import (
StreamingParseResult, StreamingParseResult,
...@@ -120,45 +121,48 @@ class Qwen3CoderDetector(BaseFormatDetector): ...@@ -120,45 +121,48 @@ class Qwen3CoderDetector(BaseFormatDetector):
function_name = function_match.group(1).strip() function_name = function_match.group(1).strip()
# Validate function name # Validate function name
if function_name in self._tool_indices: is_valid = function_name in self._tool_indices
self._current_function_name = function_name if not is_valid:
self._function_name_sent = True logger.warning(f"Invalid function name: {function_name}")
if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get():
# Initialize tool call tracking # Reset state and skip (default legacy behavior)
if self.current_tool_id == -1: self._reset_streaming_state()
self.current_tool_id = 0 normal += self._buf
self._buf = ""
# Ensure tracking arrays are large enough break
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({}) # Process tool call (valid or unknown with env=TRUE)
while len(self.streamed_args_for_tool) <= self.current_tool_id: self._current_function_name = function_name
self.streamed_args_for_tool.append("") self._function_name_sent = True
# Store tool call info # Initialize tool call tracking
self.prev_tool_call_arr[self.current_tool_id] = { if self.current_tool_id == -1:
"name": function_name, self.current_tool_id = 0
"arguments": {},
} # Ensure tracking arrays are large enough
while len(self.prev_tool_call_arr) <= self.current_tool_id:
# Send tool name with empty parameters self.prev_tool_call_arr.append({})
calls.append( while len(self.streamed_args_for_tool) <= self.current_tool_id:
ToolCallItem( self.streamed_args_for_tool.append("")
tool_index=self.current_tool_id,
name=function_name, # Store tool call info
parameters="", self.prev_tool_call_arr[self.current_tool_id] = {
) "name": function_name,
"arguments": {},
}
# Send tool name with empty parameters
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=function_name,
parameters="",
) )
)
# Remove the processed function declaration # Remove the processed function declaration
self._buf = self._buf[function_match.end() :] self._buf = self._buf[function_match.end() :]
continue continue
else:
# Invalid function name, reset state
logger.warning(f"Invalid function name: {function_name}")
self._reset_streaming_state()
normal += self._buf
self._buf = ""
break
else: else:
# Function name not complete yet, wait for more text # Function name not complete yet, wait for more text
break break
......
import json
import logging
from sglang.srt.entrypoints.openai.protocol import Function, Tool
from sglang.srt.environ import envs
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import StreamingParseResult
class DummyDetector(BaseFormatDetector):
def has_tool_call(self, text: str) -> bool:
return True
def detect_and_parse(self, text: str, tools):
action = json.loads(text)
return StreamingParseResult(
normal_text="", calls=self.parse_base_json(action, tools)
)
def test_unknown_tool_name_dropped_default(caplog):
"""Test that unknown tools are dropped by default (legacy behavior)."""
with envs.SGLANG_FORWARD_UNKNOWN_TOOLS.override(False):
tools = [
Tool(
function=Function(
name="get_weather", parameters={"type": "object", "properties": {}}
)
)
]
detector = DummyDetector()
with caplog.at_level(
logging.WARNING, logger="sglang.srt.function_call.base_format_detector"
):
result = detector.detect_and_parse(
'{"name":"unknown_tool","parameters":{"city":"Paris"}}', tools
)
assert any(
"Model attempted to call undefined function: unknown_tool" in m
for m in caplog.messages
)
assert len(result.calls) == 0 # dropped in default mode
def test_unknown_tool_name_forwarded(caplog):
"""Test that unknown tools are forwarded when env var is True."""
with envs.SGLANG_FORWARD_UNKNOWN_TOOLS.override(True):
tools = [
Tool(
function=Function(
name="get_weather", parameters={"type": "object", "properties": {}}
)
)
]
detector = DummyDetector()
with caplog.at_level(
logging.WARNING, logger="sglang.srt.function_call.base_format_detector"
):
result = detector.detect_and_parse(
'{"name":"unknown_tool","parameters":{"city":"Paris"}}', tools
)
assert any(
"Model attempted to call undefined function: unknown_tool" in m
for m in caplog.messages
)
assert len(result.calls) == 1
assert result.calls[0].name == "unknown_tool"
assert result.calls[0].tool_index == -1
assert json.loads(result.calls[0].parameters)["city"] == "Paris"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment