Unverified Commit 795e98f8 authored by Surya-Gunukula's avatar Surya-Gunukula Committed by GitHub
Browse files

Forward unknown tool calls instead of dropping (#12226)

parent 358ae356
......@@ -7,7 +7,7 @@ SGLang supports various environment variables that can be used to configure its
## General Configuration
| Environment Variable | Description | Default Value |
|-------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------|---------------|
|-------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------|------------------------------|
| `SGLANG_USE_MODELSCOPE` | Enable using models from ModelScope | `false` |
| `SGLANG_HOST_IP` | Host IP address for the server | `0.0.0.0` |
| `SGLANG_PORT` | Port for the server | auto-detected |
......@@ -15,6 +15,7 @@ SGLang supports various environment variables that can be used to configure its
| `SGLANG_DISABLE_REQUEST_LOGGING` | Disable request logging | `false` |
| `SGLANG_HEALTH_CHECK_TIMEOUT` | Timeout for health check in seconds | `20` |
| `SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL` | The interval of passes to collect the metric of selected count of physical experts on each layer and GPU rank. 0 means disabled. | `0` |
| `SGLANG_FORWARD_UNKNOWN_TOOLS` | Forward unknown tool calls to clients instead of dropping them | `false` (drop unknown tools) |
## Performance Tuning
......
......@@ -158,6 +158,9 @@ class Envs:
SGLANG_DISABLE_OUTLINES_DISK_CACHE = EnvBool(True)
SGLANG_GRAMMAR_TIMEOUT = EnvFloat(300)
# Tool Calling
SGLANG_FORWARD_UNKNOWN_TOOLS = EnvBool(False)
# Hi-Cache
SGLANG_HICACHE_HF3FS_CONFIG_PATH = EnvStr(None)
......
......@@ -8,6 +8,7 @@ from partial_json_parser.core.exceptions import MalformedJSON
from partial_json_parser.core.options import Allow
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.environ import envs
from sglang.srt.function_call.core_types import (
StreamingParseResult,
ToolCallItem,
......@@ -75,7 +76,11 @@ class BaseFormatDetector(ABC):
results = []
for act in action:
name = act.get("name")
if name and name in tool_indices:
if not (name and name in tool_indices):
logger.warning(f"Model attempted to call undefined function: {name}")
if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get():
continue # Skip unknown tools (default legacy behavior)
results.append(
ToolCallItem(
tool_index=-1, # Caller should update this based on the actual tools array called
......@@ -86,8 +91,6 @@ class BaseFormatDetector(ABC):
),
)
)
else:
logger.warning(f"Model attempted to call undefined function: {name}")
return results
......
......@@ -4,6 +4,7 @@ import re
from typing import List, Optional
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.environ import envs
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
......@@ -220,7 +221,8 @@ class GptOssDetector(BaseFormatDetector):
# Check if tool exists
if function_name not in tool_indices:
logger.debug(f"Function {function_name} not in available tools")
return None
if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get():
return None # Skip unknown tools (default legacy behavior)
# Parse JSON arguments
try:
......
......@@ -5,6 +5,7 @@ import re
from typing import List, Optional
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.environ import envs
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
......@@ -91,7 +92,9 @@ class PythonicDetector(BaseFormatDetector):
logger.warning(
f"Model attempted to call undefined function: {function_name}"
)
continue
if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get():
continue # Skip unknown tools (default legacy behavior)
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = self._get_parameter_value(keyword.value)
......
......@@ -6,6 +6,7 @@ import re
from typing import Any, Dict, List, Tuple
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.environ import envs
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
......@@ -120,7 +121,17 @@ class Qwen3CoderDetector(BaseFormatDetector):
function_name = function_match.group(1).strip()
# Validate function name
if function_name in self._tool_indices:
is_valid = function_name in self._tool_indices
if not is_valid:
logger.warning(f"Invalid function name: {function_name}")
if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get():
# Reset state and skip (default legacy behavior)
self._reset_streaming_state()
normal += self._buf
self._buf = ""
break
# Process tool call (valid or unknown with env=TRUE)
self._current_function_name = function_name
self._function_name_sent = True
......@@ -152,13 +163,6 @@ class Qwen3CoderDetector(BaseFormatDetector):
# Remove the processed function declaration
self._buf = self._buf[function_match.end() :]
continue
else:
# Invalid function name, reset state
logger.warning(f"Invalid function name: {function_name}")
self._reset_streaming_state()
normal += self._buf
self._buf = ""
break
else:
# Function name not complete yet, wait for more text
break
......
import json
import logging
from sglang.srt.entrypoints.openai.protocol import Function, Tool
from sglang.srt.environ import envs
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import StreamingParseResult
class DummyDetector(BaseFormatDetector):
def has_tool_call(self, text: str) -> bool:
return True
def detect_and_parse(self, text: str, tools):
action = json.loads(text)
return StreamingParseResult(
normal_text="", calls=self.parse_base_json(action, tools)
)
def test_unknown_tool_name_dropped_default(caplog):
"""Test that unknown tools are dropped by default (legacy behavior)."""
with envs.SGLANG_FORWARD_UNKNOWN_TOOLS.override(False):
tools = [
Tool(
function=Function(
name="get_weather", parameters={"type": "object", "properties": {}}
)
)
]
detector = DummyDetector()
with caplog.at_level(
logging.WARNING, logger="sglang.srt.function_call.base_format_detector"
):
result = detector.detect_and_parse(
'{"name":"unknown_tool","parameters":{"city":"Paris"}}', tools
)
assert any(
"Model attempted to call undefined function: unknown_tool" in m
for m in caplog.messages
)
assert len(result.calls) == 0 # dropped in default mode
def test_unknown_tool_name_forwarded(caplog):
"""Test that unknown tools are forwarded when env var is True."""
with envs.SGLANG_FORWARD_UNKNOWN_TOOLS.override(True):
tools = [
Tool(
function=Function(
name="get_weather", parameters={"type": "object", "properties": {}}
)
)
]
detector = DummyDetector()
with caplog.at_level(
logging.WARNING, logger="sglang.srt.function_call.base_format_detector"
):
result = detector.detect_and_parse(
'{"name":"unknown_tool","parameters":{"city":"Paris"}}', tools
)
assert any(
"Model attempted to call undefined function: unknown_tool" in m
for m in caplog.messages
)
assert len(result.calls) == 1
assert result.calls[0].name == "unknown_tool"
assert result.calls[0].tool_index == -1
assert json.loads(result.calls[0].parameters)["city"] == "Paris"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment