Unverified Commit 795e98f8 authored by Surya-Gunukula's avatar Surya-Gunukula Committed by GitHub
Browse files

Forward unknown tool calls instead of dropping (#12226)

parent 358ae356
......@@ -6,15 +6,16 @@ SGLang supports various environment variables that can be used to configure its
## General Configuration
| Environment Variable | Description | Default Value |
|-------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------|---------------|
| `SGLANG_USE_MODELSCOPE` | Enable using models from ModelScope | `false` |
| `SGLANG_HOST_IP` | Host IP address for the server | `0.0.0.0` |
| `SGLANG_PORT` | Port for the server | auto-detected |
| `SGLANG_LOGGING_CONFIG_PATH` | Custom logging configuration path | Not set |
| `SGLANG_DISABLE_REQUEST_LOGGING` | Disable request logging | `false` |
| `SGLANG_HEALTH_CHECK_TIMEOUT` | Timeout for health check in seconds | `20` |
| `SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL` | The interval of passes to collect the metric of selected count of physical experts on each layer and GPU rank. 0 means disabled. | `0` |
| Environment Variable | Description | Default Value |
|-------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------|------------------------------|
| `SGLANG_USE_MODELSCOPE` | Enable using models from ModelScope | `false` |
| `SGLANG_HOST_IP` | Host IP address for the server | `0.0.0.0` |
| `SGLANG_PORT` | Port for the server | auto-detected |
| `SGLANG_LOGGING_CONFIG_PATH` | Custom logging configuration path | Not set |
| `SGLANG_DISABLE_REQUEST_LOGGING` | Disable request logging | `false` |
| `SGLANG_HEALTH_CHECK_TIMEOUT` | Timeout for health check in seconds | `20` |
| `SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL` | The interval of passes to collect the metric of selected count of physical experts on each layer and GPU rank. 0 means disabled. | `0` |
| `SGLANG_FORWARD_UNKNOWN_TOOLS` | Forward unknown tool calls to clients instead of dropping them | `false` (drop unknown tools) |
## Performance Tuning
......
......@@ -158,6 +158,9 @@ class Envs:
SGLANG_DISABLE_OUTLINES_DISK_CACHE = EnvBool(True)
SGLANG_GRAMMAR_TIMEOUT = EnvFloat(300)
# Tool Calling
SGLANG_FORWARD_UNKNOWN_TOOLS = EnvBool(False)
# Hi-Cache
SGLANG_HICACHE_HF3FS_CONFIG_PATH = EnvStr(None)
......
......@@ -8,6 +8,7 @@ from partial_json_parser.core.exceptions import MalformedJSON
from partial_json_parser.core.options import Allow
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.environ import envs
from sglang.srt.function_call.core_types import (
StreamingParseResult,
ToolCallItem,
......@@ -75,19 +76,21 @@ class BaseFormatDetector(ABC):
results = []
for act in action:
name = act.get("name")
if name and name in tool_indices:
results.append(
ToolCallItem(
tool_index=-1, # Caller should update this based on the actual tools array called
name=name,
parameters=json.dumps(
act.get("parameters") or act.get("arguments", {}),
ensure_ascii=False,
),
)
)
else:
if not (name and name in tool_indices):
logger.warning(f"Model attempted to call undefined function: {name}")
if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get():
continue # Skip unknown tools (default legacy behavior)
results.append(
ToolCallItem(
tool_index=-1, # Caller should update this based on the actual tools array called
name=name,
parameters=json.dumps(
act.get("parameters") or act.get("arguments", {}),
ensure_ascii=False,
),
)
)
return results
......
......@@ -4,6 +4,7 @@ import re
from typing import List, Optional
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.environ import envs
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
......@@ -220,7 +221,8 @@ class GptOssDetector(BaseFormatDetector):
# Check if tool exists
if function_name not in tool_indices:
logger.debug(f"Function {function_name} not in available tools")
return None
if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get():
return None # Skip unknown tools (default legacy behavior)
# Parse JSON arguments
try:
......
......@@ -5,6 +5,7 @@ import re
from typing import List, Optional
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.environ import envs
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
......@@ -91,7 +92,9 @@ class PythonicDetector(BaseFormatDetector):
logger.warning(
f"Model attempted to call undefined function: {function_name}"
)
continue
if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get():
continue # Skip unknown tools (default legacy behavior)
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = self._get_parameter_value(keyword.value)
......
......@@ -6,6 +6,7 @@ import re
from typing import Any, Dict, List, Tuple
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.environ import envs
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
......@@ -120,45 +121,48 @@ class Qwen3CoderDetector(BaseFormatDetector):
function_name = function_match.group(1).strip()
# Validate function name
if function_name in self._tool_indices:
self._current_function_name = function_name
self._function_name_sent = True
# Initialize tool call tracking
if self.current_tool_id == -1:
self.current_tool_id = 0
# Ensure tracking arrays are large enough
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
while len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
# Store tool call info
self.prev_tool_call_arr[self.current_tool_id] = {
"name": function_name,
"arguments": {},
}
# Send tool name with empty parameters
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=function_name,
parameters="",
)
is_valid = function_name in self._tool_indices
if not is_valid:
logger.warning(f"Invalid function name: {function_name}")
if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get():
# Reset state and skip (default legacy behavior)
self._reset_streaming_state()
normal += self._buf
self._buf = ""
break
# Process tool call (valid or unknown with env=TRUE)
self._current_function_name = function_name
self._function_name_sent = True
# Initialize tool call tracking
if self.current_tool_id == -1:
self.current_tool_id = 0
# Ensure tracking arrays are large enough
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
while len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
# Store tool call info
self.prev_tool_call_arr[self.current_tool_id] = {
"name": function_name,
"arguments": {},
}
# Send tool name with empty parameters
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=function_name,
parameters="",
)
)
# Remove the processed function declaration
self._buf = self._buf[function_match.end() :]
continue
else:
# Invalid function name, reset state
logger.warning(f"Invalid function name: {function_name}")
self._reset_streaming_state()
normal += self._buf
self._buf = ""
break
# Remove the processed function declaration
self._buf = self._buf[function_match.end() :]
continue
else:
# Function name not complete yet, wait for more text
break
......
import json
import logging
from sglang.srt.entrypoints.openai.protocol import Function, Tool
from sglang.srt.environ import envs
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import StreamingParseResult
class DummyDetector(BaseFormatDetector):
def has_tool_call(self, text: str) -> bool:
return True
def detect_and_parse(self, text: str, tools):
action = json.loads(text)
return StreamingParseResult(
normal_text="", calls=self.parse_base_json(action, tools)
)
def test_unknown_tool_name_dropped_default(caplog):
"""Test that unknown tools are dropped by default (legacy behavior)."""
with envs.SGLANG_FORWARD_UNKNOWN_TOOLS.override(False):
tools = [
Tool(
function=Function(
name="get_weather", parameters={"type": "object", "properties": {}}
)
)
]
detector = DummyDetector()
with caplog.at_level(
logging.WARNING, logger="sglang.srt.function_call.base_format_detector"
):
result = detector.detect_and_parse(
'{"name":"unknown_tool","parameters":{"city":"Paris"}}', tools
)
assert any(
"Model attempted to call undefined function: unknown_tool" in m
for m in caplog.messages
)
assert len(result.calls) == 0 # dropped in default mode
def test_unknown_tool_name_forwarded(caplog):
"""Test that unknown tools are forwarded when env var is True."""
with envs.SGLANG_FORWARD_UNKNOWN_TOOLS.override(True):
tools = [
Tool(
function=Function(
name="get_weather", parameters={"type": "object", "properties": {}}
)
)
]
detector = DummyDetector()
with caplog.at_level(
logging.WARNING, logger="sglang.srt.function_call.base_format_detector"
):
result = detector.detect_and_parse(
'{"name":"unknown_tool","parameters":{"city":"Paris"}}', tools
)
assert any(
"Model attempted to call undefined function: unknown_tool" in m
for m in caplog.messages
)
assert len(result.calls) == 1
assert result.calls[0].name == "unknown_tool"
assert result.calls[0].tool_index == -1
assert json.loads(result.calls[0].parameters)["city"] == "Paris"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment