Unverified Commit 61555307 authored by Atream's avatar Atream Committed by GitHub
Browse files

Support Kimi K2 (#7940)

parent 49a5915f
...@@ -135,7 +135,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -135,7 +135,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--file-storage-path` | The path of the file storage in backend. | sglang_storage | | `--file-storage-path` | The path of the file storage in backend. | sglang_storage |
| `--enable-cache-report` | Return number of cached tokens in usage.prompt_tokens_details for each openai request. | False | | `--enable-cache-report` | Return number of cached tokens in usage.prompt_tokens_details for each openai request. | False |
| `--reasoning-parser` | Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}. | None | | `--reasoning-parser` | Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}. | None |
| `--tool-call-parser` | Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', and 'pythonic'. | None | | `--tool-call-parser` | Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'. | None |
## Data parallelism ## Data parallelism
......
...@@ -25,6 +25,7 @@ from transformers import PretrainedConfig ...@@ -25,6 +25,7 @@ from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import ( from sglang.srt.hf_transformers_utils import (
get_config, get_config,
get_context_length, get_context_length,
get_generation_config,
get_hf_text_config, get_hf_text_config,
) )
from sglang.srt.layers.quantization import QUANTIZATION_METHODS from sglang.srt.layers.quantization import QUANTIZATION_METHODS
...@@ -83,6 +84,13 @@ class ModelConfig: ...@@ -83,6 +84,13 @@ class ModelConfig:
**kwargs, **kwargs,
) )
self.hf_generation_config = get_generation_config(
self.model_path,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
self.attention_chunk_size = getattr( self.attention_chunk_size = getattr(
self.hf_text_config, "attention_chunk_size", None self.hf_text_config, "attention_chunk_size", None
...@@ -467,6 +475,19 @@ class ModelConfig: ...@@ -467,6 +475,19 @@ class ModelConfig:
if eos_ids: if eos_ids:
# it can be either int or list of int # it can be either int or list of int
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids) eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
if eos_ids is None:
eos_ids = set()
if self.hf_generation_config:
generation_eos_ids = getattr(
self.hf_generation_config, "eos_token_id", None
)
if generation_eos_ids:
generation_eos_ids = (
{generation_eos_ids}
if isinstance(generation_eos_ids, int)
else set(generation_eos_ids)
)
eos_ids = eos_ids | generation_eos_ids
return eos_ids return eos_ids
def maybe_pull_model_tokenizer_from_remote(self) -> None: def maybe_pull_model_tokenizer_from_remote(self) -> None:
......
...@@ -10,6 +10,7 @@ from sglang.srt.entrypoints.openai.protocol import ( ...@@ -10,6 +10,7 @@ from sglang.srt.entrypoints.openai.protocol import (
from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import ToolCallItem from sglang.srt.function_call.core_types import ToolCallItem
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
from sglang.srt.function_call.kimik2_detector import KimiK2Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector from sglang.srt.function_call.pythonic_detector import PythonicDetector
...@@ -33,6 +34,7 @@ class FunctionCallParser: ...@@ -33,6 +34,7 @@ class FunctionCallParser:
"mistral": MistralDetector, "mistral": MistralDetector,
"deepseekv3": DeepSeekV3Detector, "deepseekv3": DeepSeekV3Detector,
"pythonic": PythonicDetector, "pythonic": PythonicDetector,
"kimi_k2": KimiK2Detector,
} }
def __init__(self, tools: List[Tool], tool_call_parser: str): def __init__(self, tools: List[Tool], tool_call_parser: str):
......
import json
import logging
import re
from typing import List
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
StructureInfo,
ToolCallItem,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.function_call.utils import _is_complete_json
logger = logging.getLogger(__name__)
class KimiK2Detector(BaseFormatDetector):
def __init__(self):
super().__init__()
self._buffer = ""
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[str] = (
[]
) # map what has been streamed for each tool so far to a list
self.bot_token: str = "<|tool_calls_section_begin|>"
self.eot_token: str = "<|tool_calls_section_end|>"
self.tool_call_start_token: str = "<|tool_call_begin|>"
self.tool_call_end_token: str = "<|tool_call_end|>"
self.tool_call_regex = re.compile(
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*?\})\s*<\|tool_call_end\|>"
)
self.stream_tool_call_portion_regex = re.compile(
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>\{.*)"
)
self._last_arguments = ""
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a KimiK2 format tool call."""
return self.bot_token in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
if self.bot_token not in text:
return StreamingParseResult(normal_text=text, calls=[])
try:
# there are two possible captures - between tags, or between a
# tag and end-of-string so the result of
# findall is an array of tuples where one is a function call and
# the other is None
function_call_tuples = self.tool_call_regex.findall(text)
logger.debug("function_call_tuples: %s", function_call_tuples)
tool_calls = []
for match in function_call_tuples:
function_id, function_args = match
function_name = function_id.split(".")[1].split(":")[0]
function_idx = int(function_id.split(".")[1].split(":")[1])
logger.info(f"function_name {function_name}")
tool_calls.append(
ToolCallItem(
tool_index=function_idx, # Use the call index in the response, not tool position
name=function_name,
parameters=function_args,
)
)
content = text[: text.find(self.bot_token)]
return StreamingParseResult(normal_text=content, calls=tool_calls)
except Exception as e:
logger.error(f"Error in detect_and_parse: {e}")
# return the normal text if parsing fails
return StreamingParseResult(normal_text=text)
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing tool calls for KimiK2 format.
"""
self._buffer += new_text
current_text = self._buffer
# Check if we have a tool call (either the start token or individual tool call)
has_tool_call = (
self.bot_token in current_text or self.tool_call_start_token in current_text
)
if not has_tool_call:
self._buffer = ""
for e_token in [self.eot_token, self.tool_call_end_token]:
if e_token in new_text:
new_text = new_text.replace(e_token, "")
return StreamingParseResult(normal_text=new_text)
if not hasattr(self, "_tool_indices"):
self._tool_indices = {
tool.function.name: i
for i, tool in enumerate(tools)
if tool.function and tool.function.name
}
calls: list[ToolCallItem] = []
try:
match = self.stream_tool_call_portion_regex.search(current_text)
if match:
function_id = match.group("tool_call_id")
function_args = match.group("function_arguments")
function_name = function_id.split(".")[1].split(":")[0]
# Initialize state if this is the first tool call
if self.current_tool_id == -1:
self.current_tool_id = 0
self.prev_tool_call_arr = []
self.streamed_args_for_tool = [""]
# Ensure we have enough entries in our tracking arrays
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
while len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
if not self.current_tool_name_sent:
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=function_name,
parameters="",
)
)
self.current_tool_name_sent = True
# Store the tool call info for adapter.py
self.prev_tool_call_arr[self.current_tool_id] = {
"name": function_name,
"arguments": {},
}
else:
argument_diff = (
function_args[len(self._last_arguments) :]
if function_args.startswith(self._last_arguments)
else function_args
)
parsed_args_diff = argument_diff.split("<|tool_call_end|>", 1)[0]
if parsed_args_diff:
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=None,
parameters=parsed_args_diff,
)
)
self._last_arguments += argument_diff
self.streamed_args_for_tool[
self.current_tool_id
] += parsed_args_diff
parsed_args = function_args.split("<|tool_call_end|>", 1)[0]
if _is_complete_json(parsed_args):
try:
parsed_args = json.loads(parsed_args)
self.prev_tool_call_arr[self.current_tool_id][
"arguments"
] = parsed_args
except json.JSONDecodeError:
pass
# Find the end of the current tool call and remove only that part from buffer
tool_call_end_pattern = (
r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>"
)
match = re.search(
tool_call_end_pattern, current_text, re.DOTALL
)
if match:
# Remove the completed tool call from buffer, keep any remaining content
self._buffer = current_text[match.end() :]
else:
self._buffer = ""
result = StreamingParseResult(normal_text="", calls=calls)
self.current_tool_id += 1
self._last_arguments = ""
self.current_tool_name_sent = False
return result
return StreamingParseResult(normal_text="", calls=calls)
except Exception as e:
logger.error(f"Error in parse_streaming_increment: {e}")
return StreamingParseResult(normal_text=current_text)
def structure_info(self) -> _GetInfoFunc:
raise NotImplementedError()
def build_ebnf(self, tools: List[Tool]):
raise NotImplementedError()
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Utilities for Huggingface Transformers.""" """Utilities for Huggingface Transformers."""
import contextlib import contextlib
import logging
import os import os
import warnings import warnings
from pathlib import Path from pathlib import Path
...@@ -25,6 +26,7 @@ from transformers import ( ...@@ -25,6 +26,7 @@ from transformers import (
AutoConfig, AutoConfig,
AutoProcessor, AutoProcessor,
AutoTokenizer, AutoTokenizer,
GenerationConfig,
PretrainedConfig, PretrainedConfig,
PreTrainedTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
...@@ -153,6 +155,22 @@ def get_config( ...@@ -153,6 +155,22 @@ def get_config(
return config return config
@lru_cache_frozenset(maxsize=32)
def get_generation_config(
model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
**kwargs,
):
try:
return GenerationConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
except OSError as e:
logging.info("model doesn't have generation_config.json")
return None
# Models don't use the same configuration key for determining the maximum # Models don't use the same configuration key for determining the maximum
# context length. Store them here so we can sanely check them. # context length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we # NOTE: The ordering here is important. Some models have two of these and we
......
...@@ -1048,9 +1048,16 @@ class ServerArgs: ...@@ -1048,9 +1048,16 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--tool-call-parser", "--tool-call-parser",
type=str, type=str,
choices=["qwen25", "mistral", "llama3", "deepseekv3", "pythonic"], choices=[
"qwen25",
"mistral",
"llama3",
"deepseekv3",
"pythonic",
"kimi_k2",
],
default=ServerArgs.tool_call_parser, default=ServerArgs.tool_call_parser,
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', and 'pythonic'.", help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'.",
) )
# Data parallelism # Data parallelism
......
...@@ -6,6 +6,7 @@ from xgrammar import GrammarCompiler, TokenizerInfo ...@@ -6,6 +6,7 @@ from xgrammar import GrammarCompiler, TokenizerInfo
from sglang.srt.entrypoints.openai.protocol import Function, Tool from sglang.srt.entrypoints.openai.protocol import Function, Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
from sglang.srt.function_call.kimik2_detector import KimiK2Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector from sglang.srt.function_call.pythonic_detector import PythonicDetector
...@@ -1138,5 +1139,213 @@ class TestLlama32Detector(unittest.TestCase): ...@@ -1138,5 +1139,213 @@ class TestLlama32Detector(unittest.TestCase):
self.assertTrue(result.normal_text.strip().startswith("Some intro.")) self.assertTrue(result.normal_text.strip().startswith("Some intro."))
class TestKimiK2Detector(unittest.TestCase):
def setUp(self):
"""Set up test tools and detector."""
self.tools = [
Tool(
type="function",
function=Function(
name="get_weather",
description="Get weather information",
parameters={
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City name",
}
},
"required": ["city"],
},
),
),
Tool(
type="function",
function=Function(
name="get_tourist_attractions",
description="Get tourist attractions",
parameters={
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City name",
}
},
"required": ["city"],
},
),
),
]
self.detector = KimiK2Detector()
def test_single_tool_call(self):
"""Test parsing a single tool call in a complete text."""
text = '<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_calls_section_end|>'
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(result.calls[0].parameters, '{"city": "Paris"}')
self.assertEqual(result.normal_text, "")
def test_multiple_tool_calls(self):
"""Test parsing multiple tool calls in a complete text."""
text = '<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{"city": "London"}<|tool_call_end|><|tool_calls_section_end|>'
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(len(result.calls), 2)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(result.calls[0].parameters, '{"city": "Paris"}')
self.assertEqual(result.calls[1].name, "get_tourist_attractions")
self.assertEqual(result.calls[1].parameters, '{"city": "London"}')
self.assertEqual(result.normal_text, "")
def test_streaming_tool_call(self):
"""Test streaming incremental parsing of a tool call."""
chunks = [
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{",
'"city": "Paris"',
"}",
"<|tool_call_end|><|tool_calls_section_end|>",
]
tool_calls = []
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, self.tools)
for tool_call_chunk in result.calls:
if tool_call_chunk.tool_index is not None:
while len(tool_calls) <= tool_call_chunk.tool_index:
tool_calls.append({"name": "", "parameters": ""})
tc = tool_calls[tool_call_chunk.tool_index]
if tool_call_chunk.name:
tc["name"] += tool_call_chunk.name
if tool_call_chunk.parameters:
tc["parameters"] += tool_call_chunk.parameters
self.assertEqual(len(tool_calls), 1)
self.assertEqual(tool_calls[0]["name"], "get_weather")
self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"}')
def test_streaming_multiple_tool_calls(self):
"""Test streaming incremental parsing of multiple tool calls."""
chunks = [
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{",
'"city": "Paris"',
"}<|tool_call_end|>",
"<|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{",
'"city": "London"',
"}<|tool_call_end|>",
"<|tool_calls_section_end|>",
]
tool_calls = []
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, self.tools)
for tool_call_chunk in result.calls:
if tool_call_chunk.tool_index is not None:
while len(tool_calls) <= tool_call_chunk.tool_index:
tool_calls.append({"name": "", "parameters": ""})
tc = tool_calls[tool_call_chunk.tool_index]
if tool_call_chunk.name:
tc["name"] += tool_call_chunk.name
if tool_call_chunk.parameters:
tc["parameters"] += tool_call_chunk.parameters
self.assertEqual(len(tool_calls), 2)
self.assertEqual(tool_calls[0]["name"], "get_weather")
self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"}')
self.assertEqual(tool_calls[1]["name"], "get_tourist_attractions")
self.assertEqual(tool_calls[1]["parameters"], '{"city": "London"}')
def test_tool_call_completion(self):
"""Test that the buffer and state are reset after a tool call is completed."""
chunks = [
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{",
'"city": "Paris"',
"}",
"<|tool_call_end|>",
"<|tool_calls_section_end|>",
]
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, self.tools)
# After processing all chunks, the buffer should be empty and current_tool_id should be reset
self.assertEqual(self.detector._buffer, "")
self.assertEqual(self.detector.current_tool_id, 1)
def test_tool_name_streaming(self):
"""Test that tool names are streamed correctly with the right index."""
chunks = [
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{",
'"city": "Paris"',
"}",
"<|tool_call_end|>",
"<|tool_call_begin|>functions.get_tourist_attractions:1<|tool_call_argument_begin|>{",
]
tool_calls = []
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, self.tools)
for tool_call_chunk in result.calls:
if tool_call_chunk.tool_index is not None:
while len(tool_calls) <= tool_call_chunk.tool_index:
tool_calls.append({"name": "", "parameters": ""})
tc = tool_calls[tool_call_chunk.tool_index]
if tool_call_chunk.name:
tc["name"] += tool_call_chunk.name
if tool_call_chunk.parameters:
tc["parameters"] += tool_call_chunk.parameters
self.assertEqual(len(tool_calls), 2)
self.assertEqual(tool_calls[0]["name"], "get_weather")
self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"}')
self.assertEqual(tool_calls[1]["name"], "get_tourist_attractions")
def test_invalid_tool_call(self):
"""Test that invalid tool calls are handled correctly."""
text = 'invalid_tool:0<|tool_call_argument_begin|>{"city": "Paris"}<|tool_call_end|><|tool_calls_section_end|>'
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(len(result.calls), 0)
self.assertEqual(result.normal_text, text)
def test_partial_tool_call(self):
"""Test that partial tool calls are handled correctly in streaming mode."""
chunks = [
"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{",
'"city": "Paris"',
]
tool_calls = []
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, self.tools)
for tool_call_chunk in result.calls:
if tool_call_chunk.tool_index is not None:
while len(tool_calls) <= tool_call_chunk.tool_index:
tool_calls.append({"name": "", "parameters": ""})
tc = tool_calls[tool_call_chunk.tool_index]
if tool_call_chunk.name:
tc["name"] += tool_call_chunk.name
if tool_call_chunk.parameters:
tc["parameters"] += tool_call_chunk.parameters
self.assertEqual(len(tool_calls), 1)
self.assertEqual(tool_calls[0]["name"], "get_weather")
self.assertEqual(tool_calls[0]["parameters"], '{"city": "Paris"')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment