Unverified Commit 41ba767f authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

feat: Add warnings for invalid tool_choice and UTs (#6582)

parent f127355a
import logging
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.base_format_detector import BaseFormatDetector
...@@ -14,6 +15,8 @@ from sglang.srt.openai_api.protocol import ( ...@@ -14,6 +15,8 @@ from sglang.srt.openai_api.protocol import (
ToolChoice, ToolChoice,
) )
logger = logging.getLogger(__name__)
class FunctionCallParser: class FunctionCallParser:
""" """
...@@ -165,11 +168,35 @@ class FunctionCallParser: ...@@ -165,11 +168,35 @@ class FunctionCallParser:
) -> Optional[str]: ) -> Optional[str]:
""" """
Get the EBNF grammar for the specified tool choice. Get the EBNF grammar for the specified tool choice.
Args:
tool_choice: The tool choice specification
Returns:
EBNF grammar string, or None if no valid tools found
Note:
If a specific function is requested but not found in available tools,
logs a warning and falls back to using all available tools for backward compatibility.
""" """
filtered_tools = [] filtered_tools = []
if isinstance(tool_choice, ToolChoice): if isinstance(tool_choice, ToolChoice):
fn_name = tool_choice.function.name fn_name = tool_choice.function.name
filtered_tools = [t for t in self.tools if t.function.name == fn_name] filtered_tools = [t for t in self.tools if t.function.name == fn_name]
# Check if the requested function exists in available tools
if not filtered_tools:
available_functions = [t.function.name for t in self.tools]
logger.warning(
f"Function '{fn_name}' not found in available tools. "
f"Available functions: {available_functions}. "
f"Skipping tool choice."
)
# TODO: Return a 400 error instead of warning when adapter supports proper error handling
# For now, fall back to return None
return None
else: else:
filtered_tools = self.tools filtered_tools = self.tools
return self.detector.build_ebnf(filtered_tools) return self.detector.build_ebnf(filtered_tools)
...@@ -70,6 +70,7 @@ suites = { ...@@ -70,6 +70,7 @@ suites = {
TestFile("test_skip_tokenizer_init.py", 117), TestFile("test_skip_tokenizer_init.py", 117),
TestFile("test_srt_engine.py", 261), TestFile("test_srt_engine.py", 261),
TestFile("test_srt_endpoint.py", 130), TestFile("test_srt_endpoint.py", 130),
TestFile("test_tool_choice.py", 120),
TestFile("test_torch_compile.py", 76), TestFile("test_torch_compile.py", 76),
TestFile("test_torch_compile_moe.py", 172), TestFile("test_torch_compile_moe.py", 172),
TestFile("test_torch_native_attention_backend.py", 123), TestFile("test_torch_native_attention_backend.py", 123),
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment