Unverified Commit 01079e17 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

feat(function call): complete utility method for KimiK2Detector and enhance documentation (#8043)

parent 0e7a5b26
......@@ -25,23 +25,49 @@ class BaseFormatDetector(ABC):
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
def __init__(self):
# initialize properties used for state when parsing tool calls in
# Streaming state management
# Buffer for accumulating incomplete patterns that arrive across multiple streaming chunks
self._buffer = ""
# streaming mode
# Stores complete tool call info (name and arguments) for each tool being parsed.
# Used by serving layer for completion handling when streaming ends.
# Format: [{"name": str, "arguments": dict}, ...]
self.prev_tool_call_arr: List[Dict] = []
# Index of currently streaming tool call. Starts at -1 (no active tool),
# increments as each tool completes. Tracks which tool's arguments are streaming.
self.current_tool_id: int = -1
# Flag for whether current tool's name has been sent to client.
# Tool names sent first with empty parameters, then arguments stream incrementally.
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: List[str] = (
[]
) # map what has been streamed for each tool so far to a list
# Tracks raw JSON string content streamed to client for each tool's arguments.
# Critical for serving layer to calculate remaining content when streaming ends.
# Each index corresponds to a tool_id. Example: ['{"location": "San Francisco"', '{"temp": 72']
self.streamed_args_for_tool: List[str] = []
# Token configuration (override in subclasses)
self.bot_token = ""
self.eot_token = ""
self.tool_call_separator = ", "
def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
tool_indices = {
def _get_tool_indices(self, tools: List[Tool]) -> Dict[str, int]:
"""
Get a mapping of tool names to their indices in the tools list.
This utility method creates a dictionary mapping function names to their
indices in the tools list, which is commonly needed for tool validation
and ToolCallItem creation.
Args:
tools: List of available tools
Returns:
Dictionary mapping tool names to their indices
"""
return {
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
}
def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
tool_indices = self._get_tool_indices(tools)
if not isinstance(action, list):
action = [action]
......@@ -130,11 +156,7 @@ class BaseFormatDetector(ABC):
# Build tool indices if not already built
if not hasattr(self, "_tool_indices"):
self._tool_indices = {
tool.function.name: i
for i, tool in enumerate(tools)
if tool.function and tool.function.name
}
self._tool_indices = self._get_tool_indices(tools)
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
......@@ -294,12 +316,48 @@ class BaseFormatDetector(ABC):
@abstractmethod
def has_tool_call(self, text: str) -> bool:
"""
Check if the given text contains function call markers specific to this format.
"""
raise NotImplementedError()
@abstractmethod
def structure_info(self) -> _GetInfoFunc:
"""
Return a function that creates StructureInfo for constrained generation.
The returned function takes a tool name and returns a StructureInfo object
containing the begin/end patterns and trigger tokens needed for constrained
generation of function calls in this format.
Returns:
A function that takes a tool name (str) and returns StructureInfo
"""
raise NotImplementedError()
@abstractmethod
def build_ebnf(self, tools: List[Tool]) -> str:
"""
Build an EBNF grammar for constrained generation of function calls.
This method generates an Extended Backus-Naur Form (EBNF) grammar that
constrains the model's output to valid function calls in this format.
The grammar should include all available tools and their parameter schemas.
Args:
tools: List of available tools/functions that can be called
Returns:
A string containing the EBNF grammar for this function call format
The EBNF grammar should:
- Define the overall structure of function calls in this format
- Include all tool names from the provided tools list
- Define valid JSON structures for function arguments
- Handle multiple function calls if the format supports them
Note:
Most implementations use EBNFComposer.build_ebnf() utility with
format-specific parameters rather than writing EBNF from scratch.
"""
raise NotImplementedError()
......@@ -19,9 +19,28 @@ logger = logging.getLogger(__name__)
class DeepSeekV3Detector(BaseFormatDetector):
"""
Detector for DeepSeek models.
Assumes function call format:
'<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
Detector for DeepSeek V3 model function call format.
The DeepSeek V3 format uses special Unicode tokens to delimit function calls
with JSON code blocks for arguments.
Format Structure:
```
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>{function_name}\n```json\n{json_arguments}\n```<|tool▁calls▁end|><|end▁of▁sentence|>
```
Examples:
```
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
```
Key Components:
- Tool Calls Section: Wrapped between `<|tool▁calls▁begin|>` and `<|tool▁calls▁end|>`
- Individual Tool Call: Wrapped between `<|tool▁call▁begin|>` and `<|tool▁call▁end|>`
- Function Declaration: `function<|tool▁sep|>{function_name}`
- Arguments: JSON code block between ````json` and ````
- Supports multiple tool calls
Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-0324?chat_template=default
"""
def __init__(self):
......@@ -89,11 +108,7 @@ class DeepSeekV3Detector(BaseFormatDetector):
return StreamingParseResult(normal_text=new_text)
if not hasattr(self, "_tool_indices"):
self._tool_indices = {
tool.function.name: i
for i, tool in enumerate(tools)
if tool.function and tool.function.name
}
self._tool_indices = self._get_tool_indices(tools)
calls: list[ToolCallItem] = []
try:
......@@ -127,7 +142,7 @@ class DeepSeekV3Detector(BaseFormatDetector):
)
)
self.current_tool_name_sent = True
# Store the tool call info for adapter.py
# Store the tool call info for serving layer completions endpoint
self.prev_tool_call_arr[self.current_tool_id] = {
"name": func_name,
"arguments": {},
......@@ -153,7 +168,7 @@ class DeepSeekV3Detector(BaseFormatDetector):
] += argument_diff
if _is_complete_json(func_args_raw):
# Update the stored arguments for adapter.py
# Update the stored arguments
try:
parsed_args = json.loads(func_args_raw)
self.prev_tool_call_arr[self.current_tool_id][
......
......@@ -18,16 +18,21 @@ logger = logging.getLogger(__name__)
class KimiK2Detector(BaseFormatDetector):
"""
Detector for Kimi K2 model function call format.
Format Structure:
```
<|tool_calls_section_begin|>
<|tool_call_begin|>functions.{func_name}:{index} <|tool_call_argument_begin|>{json_args}<|tool_call_end|>
<|tool_calls_section_end|>
```
Reference: https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md
"""
def __init__(self):
super().__init__()
self._buffer = ""
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[str] = (
[]
) # map what has been streamed for each tool so far to a list
self.bot_token: str = "<|tool_calls_section_begin|>"
self.eot_token: str = "<|tool_calls_section_end|>"
......@@ -114,11 +119,7 @@ class KimiK2Detector(BaseFormatDetector):
return StreamingParseResult(normal_text=new_text)
if not hasattr(self, "_tool_indices"):
self._tool_indices = {
tool.function.name: i
for i, tool in enumerate(tools)
if tool.function and tool.function.name
}
self._tool_indices = self._get_tool_indices(tools)
calls: list[ToolCallItem] = []
try:
......@@ -150,7 +151,7 @@ class KimiK2Detector(BaseFormatDetector):
)
)
self.current_tool_name_sent = True
# Store the tool call info for adapter.py
# Store the tool call info for serving layer completions endpoint
self.prev_tool_call_arr[self.current_tool_id] = {
"name": function_name,
"arguments": {},
......@@ -214,7 +215,31 @@ class KimiK2Detector(BaseFormatDetector):
return StreamingParseResult(normal_text=current_text)
def structure_info(self) -> _GetInfoFunc:
raise NotImplementedError()
"""Return function that creates StructureInfo for guided generation."""
def build_ebnf(self, tools: List[Tool]):
raise NotImplementedError()
def get_info(name: str) -> StructureInfo:
return StructureInfo(
begin=f"<|tool_calls_section_begin|><|tool_call_begin|>functions.{name}:0 <|tool_call_argument_begin|>",
end="<|tool_call_end|><|tool_calls_section_end|>",
trigger="<|tool_calls_section_begin|>",
)
return get_info
def build_ebnf(self, tools: List[Tool]) -> str:
"""
Build EBNF grammar for KimiK2 tool call format.
NOTE: The call_rule_fmt uses [0-9]+ for the function index to allow the grammar
to accept any numeric index (0, 1, 2, etc.) for proper sequential indexing in
multiple function call scenarios, while still maintaining the correct KimiK2
format structure for constrained generation.
"""
return EBNFComposer.build_ebnf(
tools,
sequence_start_token=self.bot_token,
sequence_end_token=self.eot_token,
tool_call_separator="",
call_rule_fmt='"<|tool_call_begin|>functions.{name}:" [0-9]+ " <|tool_call_argument_begin|>" {arguments_rule} "<|tool_call_end|>"',
function_format="json",
)
......@@ -16,9 +16,12 @@ logger = logging.getLogger(__name__)
class Llama32Detector(BaseFormatDetector):
"""
Detector for Llama 3.2 models.
Assumes function call format:
<|python_tag|>{"name":"xxx", "arguments":{...}}
Detector for Llama 3.2 models with json tool call format.
Format Structure:
```
<python_tag>{"name":"xxx", "arguments":{...}}
```
"""
def __init__(self):
......
......@@ -17,9 +17,17 @@ logger = logging.getLogger(__name__)
class MistralDetector(BaseFormatDetector):
"""
Detector for Mistral models.
Assumes function call format:
[TOOL_CALLS] [{"name":"func1", "arguments":{...}}, {"name":"func2", "arguments":{...}}]
Detector for Mistral model function call format.
The Mistral format uses a simple bracket-delimited structure with JSON arrays
containing function call objects.
Format Structure:
```
[TOOL_CALLS] [{"name": "function_name", "arguments": {json_args}}, ...]
```
Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3?chat_template=default
"""
def __init__(self):
......
......@@ -19,10 +19,17 @@ logger = logging.getLogger(__name__)
class PythonicDetector(BaseFormatDetector):
"""
Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
Assumes function call format:
Detector for Llama-4 models with Pythonic tool call format.
The Pythonic format uses Python function call syntax within square brackets,
with arguments as Python literals rather than JSON.
Format Structure:
```
[tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
Arguments are Python literals (not JSON).
```
Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct?chat_template=default
"""
def __init__(self):
......@@ -75,11 +82,7 @@ class PythonicDetector(BaseFormatDetector):
return StreamingParseResult(normal_text=normal_text, calls=[])
calls = []
tool_indices = {
tool.function.name: i
for i, tool in enumerate(tools)
if tool.function.name
}
tool_indices = self._get_tool_indices(tools)
for call_index, call in enumerate(parsed.elts):
if not isinstance(call.func, ast.Name):
continue
......
......@@ -17,9 +17,18 @@ logger = logging.getLogger(__name__)
class Qwen25Detector(BaseFormatDetector):
"""
Detector for Qwen 2.5 models.
Assumes function call format:
Detector for Qwen 2.5 and Qwen 3 model function call format.
Format Structure:
```
<tool_call>\n{"name":"func1", "arguments":{...}}\n</tool_call>\n<tool_call>\n{"name":"func2", "arguments":{...}}\n</tool_call>
```
Key Components:
- Tool Call Tags: `<tool_call>` and `</tool_call>` wrap each individual call
- Function Call Object: JSON object with "name" and "arguments" fields
Reference: https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?chat_template=default
"""
def __init__(self):
......
......@@ -507,6 +507,7 @@ class TestEBNFGeneration(unittest.TestCase):
self.llama32_detector = Llama32Detector()
self.mistral_detector = MistralDetector()
self.qwen25_detector = Qwen25Detector()
self.kimik2_detector = KimiK2Detector()
def test_pythonic_detector_ebnf(self):
"""Test that the PythonicDetector generates valid EBNF."""
......@@ -542,6 +543,33 @@ class TestEBNFGeneration(unittest.TestCase):
except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}")
def test_kimik2_detector_ebnf(self):
"""Test that the KimiK2Detector generates valid EBNF."""
ebnf = self.kimik2_detector.build_ebnf(self.tools)
self.assertIsNotNone(ebnf)
# Check that the EBNF contains expected patterns for KimiK2 format
self.assertIn("<|tool_calls_section_begin|>", ebnf)
self.assertIn("<|tool_calls_section_end|>", ebnf)
# Check for KimiK2-specific function call structure
self.assertIn("<|tool_call_begin|>functions.get_weather:", ebnf)
self.assertIn("<|tool_call_begin|>functions.search:", ebnf)
self.assertIn("<|tool_call_argument_begin|>", ebnf)
self.assertIn("<|tool_call_end|>", ebnf)
# Check that it uses the correct namespace.function format with numeric index pattern
self.assertIn("functions.get_weather:", ebnf)
self.assertIn("functions.search:", ebnf)
self.assertIn("[0-9]+", ebnf) # Numeric index pattern
# Validate that the EBNF can be compiled by GrammarCompiler
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}")
def test_llama32_detector_ebnf(self):
"""Test that the Llama32Detector generates valid EBNF."""
ebnf = self.llama32_detector.build_ebnf(self.tools)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment