Unverified Commit 61555307 authored by Atream's avatar Atream Committed by GitHub
Browse files

Support Kimi K2 (#7940)

parent 49a5915f
......@@ -135,7 +135,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--file-storage-path` | The path of the file storage in backend. | sglang_storage |
| `--enable-cache-report` | Return number of cached tokens in usage.prompt_tokens_details for each openai request. | False |
| `--reasoning-parser` | Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}. | None |
| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', and 'pythonic'. | None |
| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'. | None |
## Data parallelism
......
......@@ -25,6 +25,7 @@ from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import (
get_config,
get_context_length,
get_generation_config,
get_hf_text_config,
)
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
......@@ -83,6 +84,13 @@ class ModelConfig:
**kwargs,
)
self.hf_generation_config = get_generation_config(
self.model_path,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.attention_chunk_size = getattr(
self.hf_text_config, "attention_chunk_size", None
......@@ -467,6 +475,19 @@ class ModelConfig:
if eos_ids:
# it can be either int or list of int
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
if eos_ids is None:
eos_ids = set()
if self.hf_generation_config:
generation_eos_ids = getattr(
self.hf_generation_config, "eos_token_id", None
)
if generation_eos_ids:
generation_eos_ids = (
{generation_eos_ids}
if isinstance(generation_eos_ids, int)
else set(generation_eos_ids)
)
eos_ids = eos_ids | generation_eos_ids
return eos_ids
def maybe_pull_model_tokenizer_from_remote(self) -> None:
......
......@@ -10,6 +10,7 @@ from sglang.srt.entrypoints.openai.protocol import (
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import ToolCallItem
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
from sglang.srt.function_call.kimik2_detector import KimiK2Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
......@@ -33,6 +34,7 @@ class FunctionCallParser:
"mistral": MistralDetector,
"deepseekv3": DeepSeekV3Detector,
"pythonic": PythonicDetector,
"kimi_k2": KimiK2Detector,
}
def __init__(self, tools: List[Tool], tool_call_parser: str):
......
import json
import logging
import re
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
StructureInfo,
ToolCallItem,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.function_call.utils import _is_complete_json
logger = logging.getLogger(__name__)
class KimiK2Detector(BaseFormatDetector):
def __init__(self):
super().__init__()
self._buffer = ""
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[str] = (
[]
) # map what has been streamed for each tool so far to a list
self.bot_token: str = "<|tool_calls_section_begin|>"
self.eot_token: str = "<|tool_calls_section_end|>"
self.tool_call_start_token: str = "<|tool_call_begin|>"
self.tool_call_end_token: str = "<|tool_call_end|>"
self.tool_call_regex = re.compile(
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*?\})\s*<\|tool_call_end\|>"
)
self.stream_tool_call_portion_regex = re.compile(
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)"
)
self._last_arguments = ""
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a KimiK2 format tool call."""
return self.bot_token in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
if self.bot_token not in text:
return StreamingParseResult(normal_text=text, calls=[])
try:
# there are two possible captures - between tags, or between a
# tag and end-of-string so the result of
# findall is an array of tuples where one is a function call and
# the other is None
function_call_tuples = self.tool_call_regex.findall(text)
logger.debug("function_call_tuples: %s", function_call_tuples)
tool_calls = []
for match in function_call_tuples:
function_id, function_args = match
function_name = function_id.split(".")[1].split(":")[0]
function_idx = int(function_id.split(".")[1].split(":")[1])
logger.info(f"function_name {function_name}")
tool_calls.append(
ToolCallItem(
tool_index=function_idx, # Use the call index in the response, not tool position
name=function_name,
parameters=function_args,
)
)
content = text[: text.find(self.bot_token)]
return StreamingParseResult(normal_text=content, calls=tool_calls)
except Exception as e:
logger.error(f"Error in detect_and_parse: {e}")
# return the normal text if parsing fails
return StreamingParseResult(normal_text=text)
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing tool calls for KimiK2 format.
"""
self._buffer += new_text
current_text = self._buffer
# Check if we have a tool call (either the start token or individual tool call)
has_tool_call = (
self.bot_token in current_text or self.tool_call_start_token in current_text
)
if not has_tool_call:
self._buffer = ""
for e_token in [self.eot_token, self.tool_call_end_token]:
if e_token in new_text:
new_text = new_text.replace(e_token, "")
return StreamingParseResult(normal_text=new_text)
if not hasattr(self, "_tool_indices"):
self._tool_indices = {
tool.function.name: i
for i, tool in enumerate(tools)
if tool.function and tool.function.name
}
calls: list[ToolCallItem] = []
try:
match = self.stream_tool_call_portion_regex.search(current_text)
if match:
function_id = match.group("tool_call_id")
function_args = match.group("function_arguments")
function_name = function_id.split(".")[1].split(":")[0]
# Initialize state if this is the first tool call
if self.current_tool_id == -1:
self.current_tool_id = 0
self.prev_tool_call_arr = []
self.streamed_args_for_tool = [""]
# Ensure we have enough entries in our tracking arrays
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
while len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
if not self.current_tool_name_sent:
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=function_name,
parameters="",
)
)
self.current_tool_name_sent = True
# Store the tool call info for adapter.py
self.prev_tool_call_arr[self.current_tool_id] = {
"name": function_name,
"arguments": {},
}
else:
argument_diff = (
function_args[len(self._last_arguments) :]
if function_args.startswith(self._last_arguments)
else function_args
)
parsed_args_diff = argument_diff.split("<|tool_call_end|>", 1)[0]
if parsed_args_diff:
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=None,
parameters=parsed_args_diff,
)
)
self._last_arguments += argument_diff
self.streamed_args_for_tool[
self.current_tool_id
] += parsed_args_diff
parsed_args = function_args.split("<|tool_call_end|>", 1)[0]
if _is_complete_json(parsed_args):
try:
parsed_args = json.loads(parsed_args)
self.prev_tool_call_arr[self.current_tool_id][
"arguments"
] = parsed_args
except json.JSONDecodeError:
pass
# Find the end of the current tool call and remove only that part from buffer
tool_call_end_pattern = (
r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>"
)
match = re.search(
tool_call_end_pattern, current_text, re.DOTALL
)
if match:
# Remove the completed tool call from buffer, keep any remaining content
self._buffer = current_text[match.end() :]
else:
self._buffer = ""
result = StreamingParseResult(normal_text="", calls=calls)
self.current_tool_id += 1
self._last_arguments = ""
self.current_tool_name_sent = False
return result
return StreamingParseResult(normal_text="", calls=calls)
except Exception as e:
logger.error(f"Error in parse_streaming_increment: {e}")
return StreamingParseResult(normal_text=current_text)
def structure_info(self) -> _GetInfoFunc:
raise NotImplementedError()
def build_ebnf(self, tools: List[Tool]):
raise NotImplementedError()
......@@ -14,6 +14,7 @@
"""Utilities for Huggingface Transformers."""
import contextlib
import logging
import os
import warnings
from pathlib import Path
......@@ -25,6 +26,7 @@ from transformers import (
AutoConfig,
AutoProcessor,
AutoTokenizer,
GenerationConfig,
PretrainedConfig,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
......@@ -153,6 +155,22 @@ def get_config(
return config
@lru_cache_frozenset(maxsize=32)
def get_generation_config(
model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
**kwargs,
):
try:
return GenerationConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
except OSError as e:
logging.info("model doesn't have generation_config.json")
return None
# Models don't use the same configuration key for determining the maximum
# context length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we
......
......@@ -1048,9 +1048,16 @@ class ServerArgs:
parser.add_argument(
"--tool-call-parser",
type=str,
choices=["qwen25", "mistral", "llama3", "deepseekv3", "pythonic"],
choices=[
"qwen25",
"mistral",
"llama3",
"deepseekv3",
"pythonic",
"kimi_k2",
],
default=ServerArgs.tool_call_parser,
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', and 'pythonic'.",
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'.",
)
# Data parallelism
......
......@@ -6,6 +6,7 @@ from xgrammar import GrammarCompiler, TokenizerInfo
from sglang.srt.entrypoints.openai.protocol import Function, Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
from sglang.srt.function_call.kimik2_detector import KimiK2Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
......@@ -1138,5 +1139,213 @@ class TestLlama32Detector(unittest.TestCase):
self.assertTrue(result.normal_text.strip().startswith("Some intro."))
class TestKimiK2Detector(unittest.TestCase):
def setUp(self):
"""Set up test tools and detector."""
self.tools = [
Tool(
type="function",
function=Function(
name="get_weather",
description="Get weather information",
parameters={
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City name",
}
},
"required": ["city"],
},
),
),
Tool(
type="function",
function=Function(
name="get_tourist_attractions",
description="Get tourist attractions",
parameters={
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City name",
}
},
"required": ["city"],
},
),
),
]
self.detector = KimiK2Detector()
def test_single_tool_call(self):
"""Test parsing a single tool call in a complete text."""
text = '<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_calls_section_end|>'
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(result.calls[0].parameters, '{"city": "Paris"}')
self.assertEqual(result.normal_text, "")
def test_multiple_tool_calls(self):
"""Test parsing multiple tool calls in a complete text."""
text = '<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{"city": "London"}<|tool_call_end|><|tool_calls_section_end|>'
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(len(result.calls), 2)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(result.calls[0].parameters, '{"city": "Paris"}')
self.assertEqual(result.calls[1].name, "get_tourist_attractions")
self.assertEqual(result.calls[1].parameters, '{"city": "London"}')
self.assertEqual(result.normal_text, "")
def test_streaming_tool_call(self):
"""Test streaming incremental parsing of a tool call."""
chunks = [
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{",
'"city": "Paris"',
"}",
"<|tool_call_end|><|tool_calls_section_end|>",
]
tool_calls = []
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, self.tools)
for tool_call_chunk in result.calls:
if tool_call_chunk.tool_index is not None:
while len(tool_calls) <= tool_call_chunk.tool_index:
tool_calls.append({"name": "", "parameters": ""})
tc = tool_calls[tool_call_chunk.tool_index]
if tool_call_chunk.name:
tc["name"] += tool_call_chunk.name
if tool_call_chunk.parameters:
tc["parameters"] += tool_call_chunk.parameters
self.assertEqual(len(tool_calls), 1)
self.assertEqual(tool_calls[0]["name"], "get_weather")
self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"}')
def test_streaming_multiple_tool_calls(self):
"""Test streaming incremental parsing of multiple tool calls."""
chunks = [
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{",
'"city": "Paris"',
"}<|tool_call_end|>",
"<|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{",
'"city": "London"',
"}<|tool_call_end|>",
"<|tool_calls_section_end|>",
]
tool_calls = []
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, self.tools)
for tool_call_chunk in result.calls:
if tool_call_chunk.tool_index is not None:
while len(tool_calls) <= tool_call_chunk.tool_index:
tool_calls.append({"name": "", "parameters": ""})
tc = tool_calls[tool_call_chunk.tool_index]
if tool_call_chunk.name:
tc["name"] += tool_call_chunk.name
if tool_call_chunk.parameters:
tc["parameters"] += tool_call_chunk.parameters
self.assertEqual(len(tool_calls), 2)
self.assertEqual(tool_calls[0]["name"], "get_weather")
self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"}')
self.assertEqual(tool_calls[1]["name"], "get_tourist_attractions")
self.assertEqual(tool_calls[1]["parameters"], '{"city": "London"}')
def test_tool_call_completion(self):
"""Test that the buffer and state are reset after a tool call is completed."""
chunks = [
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{",
'"city": "Paris"',
"}",
"<|tool_call_end|>",
"<|tool_calls_section_end|>",
]
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, self.tools)
# After processing all chunks, the buffer should be empty and current_tool_id should be reset
self.assertEqual(self.detector._buffer, "")
self.assertEqual(self.detector.current_tool_id, 1)
def test_tool_name_streaming(self):
"""Test that tool names are streamed correctly with the right index."""
chunks = [
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{",
'"city": "Paris"',
"}",
"<|tool_call_end|>",
"<|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{",
]
tool_calls = []
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, self.tools)
for tool_call_chunk in result.calls:
if tool_call_chunk.tool_index is not None:
while len(tool_calls) <= tool_call_chunk.tool_index:
tool_calls.append({"name": "", "parameters": ""})
tc = tool_calls[tool_call_chunk.tool_index]
if tool_call_chunk.name:
tc["name"] += tool_call_chunk.name
if tool_call_chunk.parameters:
tc["parameters"] += tool_call_chunk.parameters
self.assertEqual(len(tool_calls), 2)
self.assertEqual(tool_calls[0]["name"], "get_weather")
self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"}')
self.assertEqual(tool_calls[1]["name"], "get_tourist_attractions")
def test_invalid_tool_call(self):
"""Test that invalid tool calls are handled correctly."""
text = 'invalid_tool:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_calls_section_end|>'
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(len(result.calls), 0)
self.assertEqual(result.normal_text, text)
def test_partial_tool_call(self):
"""Test that partial tool calls are handled correctly in streaming mode."""
chunks = [
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{",
'"city": "Paris"',
]
tool_calls = []
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, self.tools)
for tool_call_chunk in result.calls:
if tool_call_chunk.tool_index is not None:
while len(tool_calls) <= tool_call_chunk.tool_index:
tool_calls.append({"name": "", "parameters": ""})
tc = tool_calls[tool_call_chunk.tool_index]
if tool_call_chunk.name:
tc["name"] += tool_call_chunk.name
if tool_call_chunk.parameters:
tc["parameters"] += tool_call_chunk.parameters
self.assertEqual(len(tool_calls), 1)
self.assertEqual(tool_calls[0]["name"], "get_weather")
self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment