Unverified Commit d1bd210f authored by Nikita's avatar Nikita Committed by GitHub
Browse files

feat: Kimi K2/K2.5 tool and reasoning parsers (#6407)


Signed-off-by: default avatarNikita Sukharev <kaonael@gmail.com>
parent ff06b17e
......@@ -45,7 +45,14 @@ Parser to Model Mapping
| pythonic | meta-llama/Llama-4-* |
| jamba | ai21labs/AI21-Jamba-*-1.5, ai21labs/AI21-Jamba-*-1.6, ai21labs/AI21-Jamba-*-1.7, |
| glm47 | zai-org/GLM-4.7 |
| kimi_k2 | moonshotai/Kimi-K2-Thinking*, moonshotai/Kimi-K2-Instruct*, moonshotai/Kimi-K2.5* |
\* Currently requires converting `tiktoken.model` to `tokenizers.json`.
> [!TIP]
> For Kimi K2.5 thinking models, pair `--dyn-tool-call-parser kimi_k2` with
> `--dyn-reasoning-parser kimi_k25` so that both `<think>` blocks and tool calls
> are parsed correctly from the same response.
## Examples
......
......@@ -147,8 +147,8 @@ galil-seiferas = { version = "0.1" }
# preprocessor
bs62 = { version = "0.1" }
minijinja = { version = "2.14.0", features = ["loader"] }
minijinja-contrib = { version = "2.14.0", features = ["pycompat"] }
minijinja = { version = "2.15.1", features = ["loader", "loop_controls"] }
minijinja-contrib = { version = "2.15.1", features = ["pycompat"] }
json-five = { version = "0.3" }
# media loading in the preprocessor
......
......@@ -946,6 +946,25 @@ impl OpenAIPreprocessor {
jail.apply_with_finish_reason(stream)
}
/// Check if reasoning parsing should be disabled based on per-request parameters.
/// For kimi_k25: disabled when chat_template_args contains "thinking": false.
fn is_reasoning_disabled_by_request(
reasoning_parser: Option<&str>,
chat_template_args: Option<&std::collections::HashMap<String, serde_json::Value>>,
) -> bool {
match reasoning_parser {
Some("kimi_k25") => {
if let Some(args) = chat_template_args
&& let Some(thinking) = args.get("thinking")
{
return thinking == &serde_json::Value::Bool(false);
}
false
}
_ => false,
}
}
// Motivation: Each transformation on the stream should be a separate step to allow for more flexibility
// Earlier reasoning parser logic was nested under delta generation logic in choice_from_postprocessor
// Since we have tool calling parsing as separate step, it makes sense to have reasoning parser as separate step as well
......@@ -1094,7 +1113,11 @@ impl
);
// Try to parse reasoning content only if parser is configured
let should_parse_reasoning = self.runtime_config.reasoning_parser.is_some();
let should_parse_reasoning = self.runtime_config.reasoning_parser.is_some()
&& !Self::is_reasoning_disabled_by_request(
self.runtime_config.reasoning_parser.as_deref(),
request.chat_template_args.as_ref(),
);
// Reasoning Content Parsing Transformation Step
// Current Solution:
......@@ -1329,3 +1352,77 @@ impl
}
// Note: tests for jailing and parser detection live in `lib/llm/tests/test_jail.rs`
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_reasoning_disabled_by_request() {
let thinking_true = {
let mut m = std::collections::HashMap::new();
m.insert("thinking".to_string(), serde_json::Value::Bool(true));
m
};
let thinking_false = {
let mut m = std::collections::HashMap::new();
m.insert("thinking".to_string(), serde_json::Value::Bool(false));
m
};
let empty_args = std::collections::HashMap::new();
// (parser, args, expected_disabled, description)
let cases = [
(
Some("kimi_k25"),
Some(&thinking_false),
true,
"kimi_k25 + thinking=false → disabled",
),
(
Some("kimi_k25"),
Some(&thinking_true),
false,
"kimi_k25 + thinking=true → enabled",
),
(
Some("kimi_k25"),
None,
false,
"kimi_k25 + no args → enabled",
),
(
Some("kimi_k25"),
Some(&empty_args),
false,
"kimi_k25 + empty args → enabled",
),
(
Some("deepseek_r1"),
Some(&thinking_false),
false,
"deepseek_r1 → never disabled",
),
(
Some("basic"),
Some(&thinking_false),
false,
"basic → never disabled",
),
(
None,
Some(&thinking_false),
false,
"no parser → never disabled",
),
];
for (parser, args, expected, desc) in cases {
assert_eq!(
OpenAIPreprocessor::is_reasoning_disabled_by_request(parser, args),
expected,
"FAILED: {desc}",
);
}
}
}
......@@ -105,6 +105,38 @@ mod tests {
}
}
/// Shorthand for creating a mock chunk with content only
fn chunk(content: &str) -> Annotated<NvCreateChatCompletionStreamResponse> {
create_mock_response_chunk(content.to_string(), None)
}
/// Run chunks through a reasoning parser, return aggregated (reasoning, content)
async fn run_parser(
chunks: Vec<Annotated<NvCreateChatCompletionStreamResponse>>,
parser: &str,
) -> (String, String) {
let output_stream = OpenAIPreprocessor::parse_reasoning_content_from_stream(
stream::iter(chunks),
parser.to_string(),
);
let mut output_stream = std::pin::pin!(output_stream);
let mut all_reasoning = String::new();
let mut all_content = String::new();
while let Some(item) = output_stream.next().await {
if let Some(ref data) = item.data {
for choice in &data.choices {
if let Some(ref r) = choice.delta.reasoning_content {
all_reasoning.push_str(r);
}
if let Some(ref c) = choice.delta.content {
all_content.push_str(get_text(c));
}
}
}
}
(all_reasoning, all_content)
}
#[tokio::test]
async fn test_reasoning_parser_with_basic_parser() {
// Basic Parser test <think> </think> tags
......@@ -414,57 +446,69 @@ mod tests {
}
#[tokio::test]
async fn test_reasoning_parser_with_kimi_parser() {
// Create a mock runtime config with Kimi reasoning parser
let runtime_config = dynamo_llm::local_model::runtime_config::ModelRuntimeConfig {
reasoning_parser: Some("kimi".to_string()),
..Default::default()
};
// Create test input stream with Kimi-style reasoning tags
let input_chunks = vec![
create_mock_response_chunk("Let me analyze this. ◁think▷This is Kimi reasoning content◁/think▷ Here's my conclusion.".to_string(), None),
async fn test_reasoning_parser_with_kimi_k25() {
// (description, input_chunks, expected_reasoning, expected_content)
let cases = vec![
(
"thinking mode",
vec![
chunk("<think>Let me"),
chunk(" think about this carefully."),
chunk("</think>Bonjour!"),
],
"Let me think about this carefully.",
"Bonjour!",
),
(
"instant mode (empty think)",
vec![
chunk("<think>"),
chunk("</think>"),
chunk("Direct answer without thinking."),
],
"",
"Direct answer without thinking.",
),
(
"token-by-token",
vec![
chunk("<think>"),
chunk("The user"),
chunk(" asked me"),
chunk(" to say hello."),
chunk("</think>"),
chunk("Hello"),
chunk("!"),
],
"The user asked me to say hello.",
"Hello!",
),
];
let input_stream = stream::iter(input_chunks);
// Apply the reasoning parser transformation
let output_stream = OpenAIPreprocessor::parse_reasoning_content_from_stream(
input_stream,
runtime_config.reasoning_parser.unwrap(),
);
// Pin the stream and collect all output chunks
let mut output_stream = std::pin::pin!(output_stream);
let mut output_chunks = Vec::new();
while let Some(chunk) = output_stream.next().await {
output_chunks.push(chunk);
for (desc, chunks, expected_reasoning, expected_content) in cases {
let (reasoning, content) = run_parser(chunks, "kimi_k25").await;
assert_eq!(reasoning, expected_reasoning, "FAILED reasoning: {desc}");
assert_eq!(content, expected_content, "FAILED content: {desc}");
}
}
// Verify that Kimi-style reasoning is parsed correctly
assert_eq!(output_chunks.len(), 1);
let output_choice = &output_chunks[0].data.as_ref().unwrap().choices[0];
assert!(
output_choice.delta.reasoning_content.is_some(),
"Should extract Kimi reasoning content"
);
assert!(
output_choice.delta.content.is_some(),
"Should have normal content"
);
let reasoning_content = output_choice.delta.reasoning_content.as_ref().unwrap();
let normal_content = output_choice.delta.content.as_ref().unwrap();
#[tokio::test]
async fn test_reasoning_parser_with_kimi_parser() {
let (reasoning, content) = run_parser(
vec![chunk(
"Let me analyze this. ◁think▷This is Kimi reasoning content◁/think▷ Here's my conclusion.",
)],
"kimi",
)
.await;
// Verify the content was parsed with Kimi tags
assert!(
reasoning_content.contains("Kimi reasoning"),
"Should contain Kimi reasoning content"
reasoning.contains("Kimi reasoning"),
"Should contain Kimi reasoning, got: {reasoning}"
);
assert!(
get_text(normal_content).contains("Let me analyze")
|| get_text(normal_content).contains("Here's my conclusion"),
"Should contain normal content"
content.contains("Let me analyze") || content.contains("Here's my conclusion"),
"Should contain normal content, got: {content}"
);
}
......@@ -586,6 +630,103 @@ mod tests {
);
}
#[tokio::test]
async fn test_kimi_k25_with_reasoning_and_tool_calls() {
// Simulates a real Kimi K2.5 response: <think> block followed by tool calls.
// Verifies that reasoning and tool_calling parsers don't interfere with each other.
let input_chunks = vec![
chunk("<think>I should check the weather"),
chunk(" before answering.</think>"),
chunk("<|tool_calls_section_begin|>"),
chunk("<|tool_call_begin|>functions.get_weather:0"),
chunk("<|tool_call_argument_begin|>"),
chunk(r#"{"location":"NYC"}"#),
chunk("<|tool_call_end|>"),
chunk("<|tool_calls_section_end|>"),
];
let input_stream = stream::iter(input_chunks);
// Step 1: reasoning parser (kimi_k25) extracts <think> into reasoning_content
let reasoning_parsed_stream = OpenAIPreprocessor::parse_reasoning_content_from_stream(
input_stream,
"kimi_k25".to_string(),
);
// Step 2: tool calling jail (kimi_k2) extracts tool calls from remaining content
let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail(
Some("kimi_k2".to_string()),
None,
None,
reasoning_parsed_stream,
);
let mut tool_parsed_stream = std::pin::pin!(tool_parsed_stream);
let mut output_chunks = Vec::new();
while let Some(chunk) = tool_parsed_stream.next().await {
output_chunks.push(chunk);
}
assert!(!output_chunks.is_empty(), "Should have output chunks");
let mut all_reasoning = String::new();
let mut all_normal_content = String::new();
let mut found_tool_calls = false;
let mut tool_call_function_name: Option<String> = None;
let mut tool_call_arguments: Option<serde_json::Value> = None;
for chunk in output_chunks.iter() {
if let Some(ref data) = chunk.data {
for choice in &data.choices {
if let Some(ref r) = choice.delta.reasoning_content {
all_reasoning.push_str(r);
}
if let Some(ref c) = choice.delta.content {
all_normal_content.push_str(get_text(c));
}
if let Some(ref tool_calls) = choice.delta.tool_calls
&& !tool_calls.is_empty()
{
found_tool_calls = true;
for tc in tool_calls {
if let Some(ref f) = tc.function {
if let Some(ref name) = f.name {
tool_call_function_name = Some(name.clone());
}
if let Some(ref args) = f.arguments {
tool_call_arguments = Some(serde_json::from_str(args).unwrap());
}
}
}
}
}
}
}
assert_eq!(
all_reasoning, "I should check the weather before answering.",
"Reasoning mismatch"
);
assert!(
found_tool_calls,
"Should have found tool calls in the output"
);
assert_eq!(
tool_call_function_name.as_deref(),
Some("get_weather"),
"Tool call function name should be 'get_weather'"
);
assert_eq!(
tool_call_arguments.as_ref(),
Some(&serde_json::json!({"location": "NYC"})),
"Tool call arguments mismatch"
);
// No normal content expected — everything is either reasoning or tool calls
assert!(
all_normal_content.trim().is_empty(),
"Expected no normal content, got: {all_normal_content:?}"
);
}
#[tokio::test]
#[ignore]
// (TODO: Ayush) Fix this test
......
......@@ -108,6 +108,10 @@ impl ReasoningParser for BasicReasoningParser {
while cursor < text.len() {
if currently_reasoning {
// Skip leading start token if present (handles force_reasoning + explicit <think>)
if text[cursor..].starts_with(&self.think_start_token) {
cursor += self.think_start_token.len();
}
// We're inside a reasoning block — look for end token
if let Some(end_offset) = text[cursor..].find(&self.think_end_token) {
reasoning_parts.push(&text[cursor..cursor + end_offset]);
......@@ -175,6 +179,17 @@ impl ReasoningParser for BasicReasoningParser {
continue;
}
// Buffer is a prefix of the start token (e.g., "<thi" for "<think>") — wait
// for more data before deciding whether to strip it or emit as reasoning.
// Only applies when force_reasoning=true and we haven't stripped the tag yet.
if !self.stripped_think_start
&& self._in_reasoning
&& !current_text.is_empty()
&& self.think_start_token.starts_with(current_text.as_str())
{
break;
}
if self._in_reasoning {
if let Some(end_idx) = current_text.find(self.think_end_token.as_str()) {
// End of reasoning block: accumulate content and transition out.
......
......@@ -26,6 +26,7 @@ fn get_reasoning_parser_map() -> &'static HashMap<&'static str, ReasoningParserT
map.insert("qwen3", ReasoningParserType::Qwen);
map.insert("nemotron_deci", ReasoningParserType::NemotronDeci);
map.insert("kimi", ReasoningParserType::Kimi);
map.insert("kimi_k25", ReasoningParserType::KimiK25);
map.insert("step3", ReasoningParserType::Step3);
map.insert("mistral", ReasoningParserType::Mistral);
map.insert("granite", ReasoningParserType::Granite);
......@@ -97,6 +98,7 @@ pub enum ReasoningParserType {
Qwen,
NemotronDeci,
Kimi,
KimiK25,
Mistral,
Granite,
MiniMaxAppendThink,
......@@ -152,6 +154,14 @@ impl ReasoningParserType {
true,
)),
},
ReasoningParserType::KimiK25 => ReasoningParserWrapper {
parser: Box::new(BasicReasoningParser::new(
"<think>".into(),
"</think>".into(),
true,
true,
)),
},
ReasoningParserType::Mistral => ReasoningParserWrapper {
parser: Box::new(BasicReasoningParser::new(
"[THINK]".into(),
......@@ -222,6 +232,7 @@ mod tests {
"qwen3",
"nemotron_deci",
"kimi",
"kimi_k25",
"step3",
"mistral",
"granite",
......@@ -233,4 +244,135 @@ mod tests {
assert!(parsers.contains(&parser));
}
}
#[test]
fn test_kimi_k25_detect_and_parse() {
// (description, input, expected_reasoning, expected_normal)
let cases = [
(
"force reasoning: no think tags",
"no think tags here",
"no think tags here",
"",
),
(
"standard think tags",
"<think>Let me reason about this.</think>Hello!",
"Let me reason about this.",
"Hello!",
),
(
"empty think block (instant mode)",
"<think></think>Hello from instant mode!",
"",
"Hello from instant mode!",
),
(
"empty think block with newline",
"<think>\n</think>Hello from instant mode!",
"",
"Hello from instant mode!",
),
];
for (desc, input, expected_reasoning, expected_normal) in cases {
let mut parser = ReasoningParserType::KimiK25.get_reasoning_parser();
let result = parser.detect_and_parse_reasoning(input, &[]);
assert_eq!(
result.reasoning_text, expected_reasoning,
"FAILED reasoning: {desc}"
);
assert_eq!(result.normal_text, expected_normal, "FAILED normal: {desc}");
}
}
#[test]
fn test_kimi_k25_streaming_force_reasoning() {
// Streaming: force_reasoning means tokens before <think> are treated as reasoning
let mut parser = ReasoningParserType::KimiK25.get_reasoning_parser();
// First chunk: partial think tag — buffered because it's a prefix of "<think>"
let r1 = parser.parse_reasoning_streaming_incremental("<thi", &[]);
assert_eq!(r1.reasoning_text, "");
assert_eq!(r1.normal_text, "");
// Second chunk: completes the think tag + reasoning content
let r2 = parser.parse_reasoning_streaming_incremental("nk>reasoning here", &[]);
assert_eq!(r2.reasoning_text, "reasoning here");
assert_eq!(r2.normal_text, "");
// Third chunk: close tag + normal content
let r3 = parser.parse_reasoning_streaming_incremental("</think>Hello!", &[]);
assert_eq!(r3.reasoning_text, "");
assert_eq!(r3.normal_text, "Hello!");
}
#[test]
fn test_kimi_k25_streaming() {
// (description, tokens, expected_reasoning, expected_content)
let cases: Vec<(&str, &[&str], &str, &str)> = vec![
(
"complete response",
&[
"<think>",
"I need to",
" think about",
" this carefully.",
"</think>",
"Bonjour",
"!",
],
"I need to think about this carefully.",
"Bonjour!",
),
(
"empty think (instant mode)",
&["<think>", "</think>", "Direct answer."],
"",
"Direct answer.",
),
];
for (desc, tokens, expected_reasoning, expected_content) in cases {
let mut parser = ReasoningParserType::KimiK25.get_reasoning_parser();
let mut all_reasoning = String::new();
let mut all_content = String::new();
for token in tokens {
let r = parser.parse_reasoning_streaming_incremental(token, &[]);
all_reasoning.push_str(&r.reasoning_text);
all_content.push_str(&r.normal_text);
}
assert_eq!(
all_reasoning, expected_reasoning,
"FAILED reasoning: {desc}"
);
assert_eq!(all_content, expected_content, "FAILED content: {desc}");
}
}
#[test]
fn test_kimi_k25_parser_lookup_by_name() {
// Verify the parser can be looked up by name
let mut parser = ReasoningParserType::get_reasoning_parser_from_name("kimi_k25");
let result = parser.detect_and_parse_reasoning("<think>thinking</think>answer", &[]);
assert_eq!(result.reasoning_text, "thinking");
assert_eq!(result.normal_text, "answer");
}
#[test]
fn test_kimi_vs_kimi_k25_different_tags() {
// Kimi (original) uses ◁think▷/◁/think▷, KimiK25 uses <think>/</think>
let mut kimi = ReasoningParserType::Kimi.get_reasoning_parser();
let mut kimi_k25 = ReasoningParserType::KimiK25.get_reasoning_parser();
// Kimi original does NOT parse <think> tags
let r_kimi = kimi.detect_and_parse_reasoning("<think>reasoning</think>answer", &[]);
assert_eq!(r_kimi.normal_text, "<think>reasoning</think>answer");
assert_eq!(r_kimi.reasoning_text, "");
// KimiK25 does parse <think> tags
let r_k25 = kimi_k25.detect_and_parse_reasoning("<think>reasoning</think>answer", &[]);
assert_eq!(r_k25.reasoning_text, "reasoning");
assert_eq!(r_k25.normal_text, "answer");
}
}
......@@ -130,6 +130,57 @@ impl Default for Glm47ParserConfig {
}
}
/// Configuration for Kimi K2 tool call parser
///
/// Format:
/// ```text
/// <|tool_calls_section_begin|>
/// <|tool_call_begin|>functions.{name}:{index}<|tool_call_argument_begin|>{json_args}<|tool_call_end|>
/// <|tool_calls_section_end|>
/// ```
///
/// The model may emit either plural or singular forms of section tokens
/// (e.g., `<|tool_calls_section_begin|>` or `<|tool_call_section_begin|>`).
/// Both forms are supported via the `section_start_variants` and `section_end_variants` fields.
/// See vllm `kimi_k2_tool_parser.py` for reference.
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct KimiK2ParserConfig {
/// Primary start token for the tool calls section
pub section_start: String,
/// Primary end token for the tool calls section
pub section_end: String,
/// All recognized start tokens for the tool calls section (includes singular variants)
pub section_start_variants: Vec<String>,
/// All recognized end tokens for the tool calls section (includes singular variants)
pub section_end_variants: Vec<String>,
/// Start token for an individual tool call (e.g., "<|tool_call_begin|>")
pub call_start: String,
/// End token for an individual tool call (e.g., "<|tool_call_end|>")
pub call_end: String,
/// Token separating function ID from JSON arguments (e.g., "<|tool_call_argument_begin|>")
pub argument_begin: String,
}
impl Default for KimiK2ParserConfig {
fn default() -> Self {
Self {
section_start: "<|tool_calls_section_begin|>".to_string(),
section_end: "<|tool_calls_section_end|>".to_string(),
section_start_variants: vec![
"<|tool_calls_section_begin|>".to_string(),
"<|tool_call_section_begin|>".to_string(),
],
section_end_variants: vec![
"<|tool_calls_section_end|>".to_string(),
"<|tool_call_section_end|>".to_string(),
],
call_start: "<|tool_call_begin|>".to_string(),
call_end: "<|tool_call_end|>".to_string(),
argument_begin: "<|tool_call_argument_begin|>".to_string(),
}
}
}
/// Parser-specific configuration
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
......@@ -140,6 +191,7 @@ pub enum ParserConfig {
Harmony(JsonParserConfig),
Typescript,
Dsml(DsmlParserConfig),
KimiK2(KimiK2ParserConfig),
Glm47(Glm47ParserConfig),
}
......@@ -155,6 +207,7 @@ impl ParserConfig {
ParserConfig::Typescript => vec![],
ParserConfig::Dsml(config) => vec![config.function_calls_start.clone()],
ParserConfig::Glm47(config) => vec![config.tool_call_start.clone()],
ParserConfig::KimiK2(config) => config.section_start_variants.clone(),
}
}
......@@ -169,6 +222,7 @@ impl ParserConfig {
ParserConfig::Typescript => vec![],
ParserConfig::Dsml(config) => vec![config.function_calls_end.clone()],
ParserConfig::Glm47(config) => vec![config.tool_call_end.clone()],
ParserConfig::KimiK2(config) => config.section_end_variants.clone(),
}
}
}
......@@ -357,4 +411,15 @@ impl ToolCallConfig {
parser_config: ParserConfig::Glm47(Glm47ParserConfig::default()),
}
}
pub fn kimi_k2() -> Self {
// Kimi K2 format:
// <|tool_calls_section_begin|>
// <|tool_call_begin|>functions.{name}:{index}<|tool_call_argument_begin|>{json_args}<|tool_call_end|>
// <|tool_calls_section_end|>
// Reference: https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md
Self {
parser_config: ParserConfig::KimiK2(KimiK2ParserConfig::default()),
}
}
}
......@@ -23,7 +23,9 @@ pub struct ToolDefinition {
}
// Re-export main types and functions for convenience
pub use config::{JsonParserConfig, ParserConfig, ToolCallConfig, XmlParserConfig};
pub use config::{
JsonParserConfig, KimiK2ParserConfig, ParserConfig, ToolCallConfig, XmlParserConfig,
};
pub use dsml::try_tool_call_parse_dsml;
pub use harmony::parse_tool_calls_harmony_complete;
pub use json::try_tool_call_parse_json;
......@@ -34,4 +36,5 @@ pub use parsers::{
pub use pythonic::try_tool_call_parse_pythonic;
pub use response::{CalledFunction, ToolCallResponse, ToolCallType};
pub use tools::{try_tool_call_parse_aggregate, try_tool_call_parse_stream};
pub use xml::try_tool_call_parse_kimi_k2;
pub use xml::try_tool_call_parse_xml;
......@@ -19,8 +19,10 @@ use super::pythonic::{
};
use super::response::ToolCallResponse;
use super::xml::{
detect_tool_call_start_glm47, detect_tool_call_start_xml, find_tool_call_end_position_glm47,
find_tool_call_end_position_xml, try_tool_call_parse_glm47, try_tool_call_parse_xml,
detect_tool_call_start_glm47, detect_tool_call_start_kimi_k2, detect_tool_call_start_xml,
find_tool_call_end_position_glm47, find_tool_call_end_position_kimi_k2,
find_tool_call_end_position_xml, try_tool_call_parse_glm47, try_tool_call_parse_kimi_k2,
try_tool_call_parse_xml,
};
use std::collections::HashMap;
use std::sync::OnceLock;
......@@ -45,6 +47,7 @@ pub fn get_tool_parser_map() -> &'static HashMap<&'static str, ToolCallConfig> {
map.insert("jamba", ToolCallConfig::jamba());
map.insert("minimax_m2", ToolCallConfig::minimax_m2());
map.insert("glm47", ToolCallConfig::glm47());
map.insert("kimi_k2", ToolCallConfig::kimi_k2());
map.insert("default", ToolCallConfig::default());
map.insert("nemotron_nano", ToolCallConfig::qwen3_coder()); // nemotron nano follows qwen3_coder format
map
......@@ -91,6 +94,11 @@ pub async fn try_tool_call_parse(
try_tool_call_parse_glm47(message, glm47_config, tools)?;
Ok((results, normal_content))
}
ParserConfig::KimiK2(kimi_config) => {
let (results, normal_content) =
try_tool_call_parse_kimi_k2(message, kimi_config, tools)?;
Ok((results, normal_content))
}
}
}
......@@ -144,6 +152,9 @@ pub fn detect_tool_call_start(chunk: &str, parser_str: Option<&str>) -> anyhow::
ParserConfig::Glm47(glm47_config) => {
Ok(detect_tool_call_start_glm47(chunk, glm47_config))
}
ParserConfig::KimiK2(kimi_config) => {
Ok(detect_tool_call_start_kimi_k2(chunk, kimi_config))
}
},
None => anyhow::bail!(
"Parser '{}' is not implemented. Available parsers: {:?}",
......@@ -184,6 +195,9 @@ pub fn find_tool_call_end_position(chunk: &str, parser_str: Option<&str>) -> usi
ParserConfig::Glm47(glm47_config) => {
find_tool_call_end_position_glm47(chunk, glm47_config)
}
ParserConfig::KimiK2(kimi_config) => {
find_tool_call_end_position_kimi_k2(chunk, kimi_config)
}
},
None => {
// Unknown parser, return full content length
......@@ -225,6 +239,7 @@ mod tests {
"nemotron_nano",
"minimax_m2",
"glm47",
"kimi_k2",
];
for parser in available_parsers {
assert!(parsers.contains(&parser));
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// Reference implementation:
// https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/function_call/kimik2_detector.py
// https://github.com/vllm-project/vllm/blob/main/vllm/tool_parsers/kimi_k2_tool_parser.py
use std::sync::OnceLock;
use regex::Regex;
use super::super::ToolDefinition;
use super::super::config::KimiK2ParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
static ID_REGEX: OnceLock<Regex> = OnceLock::new();
static TOOL_CALL_REGEX: OnceLock<Regex> = OnceLock::new();
/// Returns the cached regex that captures `function_id` (e.g. `functions.get_weather:0`) and
/// `arguments` (JSON object) between the configured `call_start`, `argument_begin`, and
/// `call_end` tokens.
///
/// The `function_id` pattern `[\w.]+:\d+` matches the `functions.name:index` format used by
/// Kimi K2, consistent with sglang/vllm reference implementations.
fn get_tool_call_regex(config: &KimiK2ParserConfig) -> &'static Regex {
TOOL_CALL_REGEX.get_or_init(|| {
let pattern = format!(
r"(?s){}\s*(?P<function_id>[\w.]+:\d+)\s*{}\s*(?P<arguments>\{{.*?\}})\s*{}",
regex::escape(&config.call_start),
regex::escape(&config.argument_begin),
regex::escape(&config.call_end),
);
Regex::new(&pattern).expect("Failed to compile kimi k2 tool call regex")
})
}
fn get_id_regex() -> &'static Regex {
ID_REGEX.get_or_init(|| {
Regex::new(r"^(?:functions\.)?(?P<name>[\w\.]+):(?P<index>\d+)$")
.expect("Failed to compile kimi k2 id regex")
})
}
/// Check if a chunk contains the start of a Kimi K2 style tool call.
/// Detects `<|tool_calls_section_begin|>` (or singular variant) or partial match for streaming.
pub fn detect_tool_call_start_kimi_k2(chunk: &str, config: &KimiK2ParserConfig) -> bool {
for start_token in &config.section_start_variants {
debug_assert!(
start_token.is_ascii(),
"Kimi K2 section tokens must be ASCII for safe byte slicing, got: {start_token:?}"
);
// Check for complete start token.
if chunk.contains(start_token.as_str()) {
return true;
}
// Check for partial match at the end of the chunk (for streaming).
for i in 1..start_token.len() {
if chunk.ends_with(&start_token[..i]) {
return true;
}
}
}
false
}
/// Returns the position after `<|tool_calls_section_end|>` (or singular variant) or the length
/// of the chunk if not found.
pub fn find_tool_call_end_position_kimi_k2(chunk: &str, config: &KimiK2ParserConfig) -> usize {
// Find the earliest matching end token variant.
let mut earliest: Option<usize> = None;
for end_token in &config.section_end_variants {
if let Some(pos) = chunk.find(end_token.as_str()) {
let end_pos = pos + end_token.len();
earliest = Some(earliest.map_or(end_pos, |e: usize| e.min(end_pos)));
}
}
earliest.unwrap_or(chunk.len())
}
/// Format:
/// ```text
/// <|tool_calls_section_begin|>
/// <|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location":"NYC"}<|tool_call_end|>
/// <|tool_calls_section_end|>
/// ```
///
/// Returns (parsed_tool_calls, normal_text_content)
pub fn try_tool_call_parse_kimi_k2(
message: &str,
config: &KimiK2ParserConfig,
tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
let (normal_text, tool_calls) = extract_tool_calls(message, config, tools)?;
let normal_content = if normal_text.is_empty() {
Some("".to_string())
} else {
Some(normal_text)
};
Ok((tool_calls, normal_content))
}
/// Find the first occurrence of any section start variant in `text[cursor..]`.
/// Returns `(relative_position, matched_token_length)` or `None`.
fn find_section_start(
text: &str,
cursor: usize,
config: &KimiK2ParserConfig,
) -> Option<(usize, usize)> {
let mut best: Option<(usize, usize)> = None;
for variant in &config.section_start_variants {
if let Some(pos) = text[cursor..].find(variant.as_str())
&& best.is_none_or(|(bp, _)| pos < bp)
{
best = Some((pos, variant.len()));
}
}
best
}
/// Find the first occurrence of any section end variant in `text[from..]`.
/// Returns `(relative_position, matched_token_length)` or `None`.
fn find_section_end(
text: &str,
from: usize,
config: &KimiK2ParserConfig,
) -> Option<(usize, usize)> {
let mut best: Option<(usize, usize)> = None;
for variant in &config.section_end_variants {
if let Some(pos) = text[from..].find(variant.as_str())
&& best.is_none_or(|(bp, _)| pos < bp)
{
best = Some((pos, variant.len()));
}
}
best
}
/// Extract tool calls and normal text from message.
fn extract_tool_calls(
text: &str,
config: &KimiK2ParserConfig,
tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(String, Vec<ToolCallResponse>)> {
let mut normal_parts = Vec::new();
let mut calls = Vec::new();
let mut cursor = 0;
while cursor < text.len() {
if let Some((start_pos, _start_len)) = find_section_start(text, cursor, config) {
let abs_start = cursor + start_pos;
// Add text before tool call section to normal parts.
normal_parts.push(&text[cursor..abs_start]);
if let Some((end_pos, end_len)) = find_section_end(text, abs_start, config) {
let abs_end = abs_start + end_pos + end_len;
let block = &text[abs_start..abs_end];
// Parse individual tool calls within this section block.
if let Ok(mut parsed_calls) = parse_section_block(block, config, tools) {
calls.append(&mut parsed_calls);
}
cursor = abs_end;
} else {
// No end token found -> treat the rest as normal text.
normal_parts.push(&text[abs_start..]);
break;
}
} else {
// No more tool call sections.
normal_parts.push(&text[cursor..]);
break;
}
}
let normal_text = normal_parts.join("").trim().to_string();
Ok((normal_text, calls))
}
/// Parse a tool calls section block, extracting individual tool calls.
///
/// The block is between `<|tool_calls_section_begin|>` and `<|tool_calls_section_end|>`.
/// Each individual call is between `<|tool_call_begin|>` and `<|tool_call_end|>`.
fn parse_section_block(
block: &str,
config: &KimiK2ParserConfig,
tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<Vec<ToolCallResponse>> {
let tool_call_regex = get_tool_call_regex(config);
let id_regex = get_id_regex();
let mut results = Vec::new();
for cap in tool_call_regex.captures_iter(block) {
let function_id = cap
.name("function_id")
.map(|m| m.as_str().trim())
.unwrap_or("");
let arguments_raw = cap
.name("arguments")
.map(|m| m.as_str().trim())
.unwrap_or("{}");
// Parse function ID
let function_name = if let Some(id_cap) = id_regex.captures(function_id) {
id_cap
.name("name")
.map(|m| m.as_str().to_string())
.unwrap_or_default()
} else {
// Fallback: use the whole ID as the function name
tracing::warn!(
"Unexpected tool_call_id format: '{}', using as-is",
function_id
);
function_id.to_string()
};
if function_name.is_empty() {
continue;
}
// Validate function name against tools if provided
if let Some(tools) = tools
&& !tools.iter().any(|t| t.name == function_name)
{
tracing::warn!("Tool '{}' is not defined in the tools list.", function_name);
}
// Validate JSON arguments
let arguments_json = match serde_json::from_str::<serde_json::Value>(arguments_raw) {
Ok(val) => serde_json::to_string(&val)?,
Err(e) => {
tracing::warn!(
"Failed to parse JSON arguments for tool '{}': {}. Using raw string.",
function_name,
e,
);
arguments_raw.to_string()
}
};
// NOTE: Unlike other parsers (XML, DSML) which generate `call-{UUID}` IDs,
// we preserve the model's native function_id (e.g., "functions.bash:0") here.
// This matches the behavior of vllm/sglang and is required for Kimi K2 compatibility.
let tool_call = ToolCallResponse {
id: function_id.to_string(),
tp: ToolCallType::Function,
function: CalledFunction {
name: function_name,
arguments: arguments_json,
},
};
results.push(tool_call);
}
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
fn default_config() -> KimiK2ParserConfig {
KimiK2ParserConfig::default()
}
#[test]
fn test_detect_tool_call_start() {
let config = default_config();
assert!(detect_tool_call_start_kimi_k2(
"<|tool_calls_section_begin|>",
&config
));
assert!(detect_tool_call_start_kimi_k2(
"text <|tool_calls_section_begin|>",
&config
));
// Partial match at end
assert!(detect_tool_call_start_kimi_k2("<|tool_calls_sec", &config));
assert!(detect_tool_call_start_kimi_k2("<|", &config));
// No match
assert!(!detect_tool_call_start_kimi_k2(
"no tool call here",
&config
));
assert!(!detect_tool_call_start_kimi_k2("toolcall", &config));
}
#[test]
fn test_find_tool_call_end_position() {
let config = default_config();
let text = "<|tool_calls_section_begin|><|tool_call_begin|>functions.test:0<|tool_call_argument_begin|>{}<|tool_call_end|><|tool_calls_section_end|>more text";
let pos = find_tool_call_end_position_kimi_k2(text, &config);
assert_eq!(&text[pos..], "more text");
let text_no_end = "<|tool_calls_section_begin|><|tool_call_begin|>functions.test:0";
let pos = find_tool_call_end_position_kimi_k2(text_no_end, &config);
assert_eq!(pos, text_no_end.len());
}
#[test]
fn test_parse_simple_tool_call() {
let config = default_config();
let input = r#"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location":"NYC"}<|tool_call_end|><|tool_calls_section_end|>"#;
let (calls, normal) = try_tool_call_parse_kimi_k2(input, &config, None).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_weather");
assert_eq!(normal, Some("".to_string()));
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
assert_eq!(args["location"], "NYC");
}
#[test]
fn test_parse_multiple_args() {
let config = default_config();
let input = r#"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location":"San Francisco, CA","unit":"fahrenheit"}<|tool_call_end|><|tool_calls_section_end|>"#;
let (calls, _) = try_tool_call_parse_kimi_k2(input, &config, None).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_parse_multiple_tool_calls() {
let config = default_config();
let input = r#"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location":"NYC"}<|tool_call_end|><|tool_call_begin|>functions.get_time:1<|tool_call_argument_begin|>{"timezone":"EST"}<|tool_call_end|><|tool_calls_section_end|>"#;
let (calls, normal) = try_tool_call_parse_kimi_k2(input, &config, None).unwrap();
assert_eq!(calls.len(), 2);
assert_eq!(calls[0].function.name, "get_weather");
assert_eq!(calls[1].function.name, "get_time");
assert_eq!(normal, Some("".to_string()));
let args0: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
let args1: serde_json::Value = serde_json::from_str(&calls[1].function.arguments).unwrap();
assert_eq!(args0["location"], "NYC");
assert_eq!(args1["timezone"], "EST");
}
#[test]
fn test_parse_with_normal_text() {
let config = default_config();
let input = r#"I'll help you with that. <|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location":"Dallas"}<|tool_call_end|><|tool_calls_section_end|> Let me check."#;
let (calls, normal) = try_tool_call_parse_kimi_k2(input, &config, None).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_weather");
assert_eq!(
normal,
Some("I'll help you with that. Let me check.".to_string())
);
}
#[test]
fn test_parse_no_arg_call() {
let config = default_config();
let input = r#"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_current_time:0<|tool_call_argument_begin|>{}<|tool_call_end|><|tool_calls_section_end|>"#;
let (calls, _) = try_tool_call_parse_kimi_k2(input, &config, None).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_current_time");
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
assert!(args.as_object().unwrap().is_empty());
}
#[test]
fn test_parse_no_tool_calls() {
let config = default_config();
let input = "This is just normal text without any tool calls.";
let (calls, normal) = try_tool_call_parse_kimi_k2(input, &config, None).unwrap();
assert_eq!(calls.len(), 0);
assert_eq!(normal, Some(input.to_string()));
}
#[test]
fn test_parse_without_functions_prefix() {
let config = default_config();
// Some models may emit without the "functions." prefix
let input = r#"<|tool_calls_section_begin|><|tool_call_begin|>get_weather:0<|tool_call_argument_begin|>{"location":"NYC"}<|tool_call_end|><|tool_calls_section_end|>"#;
let (calls, _) = try_tool_call_parse_kimi_k2(input, &config, None).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_weather");
}
#[test]
fn test_parse_with_tool_validation() {
let config = default_config();
let tools = vec![ToolDefinition {
name: "get_weather".to_string(),
parameters: Some(serde_json::json!({
"type": "object",
"properties": {
"location": {"type": "string"}
}
})),
}];
let input = r#"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location":"NYC"}<|tool_call_end|><|tool_calls_section_end|>"#;
let (calls, _) = try_tool_call_parse_kimi_k2(input, &config, Some(&tools)).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_weather");
}
#[test]
fn test_parse_malformed_no_section_end() {
let config = default_config();
let input = r#"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location":"NYC"}<|tool_call_end|>"#;
// Should handle gracefully - section_end not found so whole text is treated as normal
let (calls, normal) = try_tool_call_parse_kimi_k2(input, &config, None).unwrap();
assert_eq!(
calls.len(),
0,
"No tool calls should be parsed without section end"
);
assert_eq!(
normal,
Some(input.to_string()),
"Input should be preserved as normal text"
);
}
#[test]
fn test_parse_with_whitespace() {
let config = default_config();
let input = "<|tool_calls_section_begin|>\n<|tool_call_begin|> functions.search:0 <|tool_call_argument_begin|> {\"query\":\"rust programming\"} <|tool_call_end|>\n<|tool_calls_section_end|>";
let (calls, _) = try_tool_call_parse_kimi_k2(input, &config, None).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "search");
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
assert_eq!(args["query"], "rust programming");
}
#[test]
fn test_parse_complex_json_arguments() {
let config = default_config();
let input = r#"<|tool_calls_section_begin|><|tool_call_begin|>functions.process_data:0<|tool_call_argument_begin|>{"items":[1,2,3],"config":{"nested":true}}<|tool_call_end|><|tool_calls_section_end|>"#;
let (calls, _) = try_tool_call_parse_kimi_k2(input, &config, None).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "process_data");
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
assert_eq!(args["items"], serde_json::json!([1, 2, 3]));
assert_eq!(args["config"]["nested"], true);
}
#[test]
fn test_parse_deeply_nested_json_multiple_calls() {
let config = default_config();
// Multiple tool calls with deeply nested JSON - stress test for regex backtracking
let input = r#"<|tool_calls_section_begin|><|tool_call_begin|>functions.create_config:0<|tool_call_argument_begin|>{"database":{"primary":{"host":"db1.example.com","port":5432,"options":{"ssl":true,"pool":{"min":5,"max":20}}},"replica":{"host":"db2.example.com","port":5432}},"features":["auth","logging"]}<|tool_call_end|><|tool_call_begin|>functions.deploy:1<|tool_call_argument_begin|>{"env":"production","services":[{"name":"api","replicas":3,"config":{"memory":"2Gi","cpu":"1000m"}},{"name":"worker","replicas":2,"config":{"memory":"4Gi","cpu":"2000m"}}]}<|tool_call_end|><|tool_call_begin|>functions.notify:2<|tool_call_argument_begin|>{"channels":["slack","email"],"message":"Deployment started"}<|tool_call_end|><|tool_calls_section_end|>"#;
let (calls, normal) = try_tool_call_parse_kimi_k2(input, &config, None).unwrap();
assert_eq!(calls.len(), 3);
assert_eq!(calls[0].function.name, "create_config");
assert_eq!(calls[0].id, "functions.create_config:0");
let args0: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
assert_eq!(args0["database"]["primary"]["options"]["pool"]["max"], 20);
assert_eq!(calls[1].function.name, "deploy");
assert_eq!(calls[1].id, "functions.deploy:1");
let args1: serde_json::Value = serde_json::from_str(&calls[1].function.arguments).unwrap();
assert_eq!(args1["services"][0]["config"]["memory"], "2Gi");
assert_eq!(calls[2].function.name, "notify");
assert_eq!(calls[2].id, "functions.notify:2");
let args2: serde_json::Value = serde_json::from_str(&calls[2].function.arguments).unwrap();
assert_eq!(args2["channels"], serde_json::json!(["slack", "email"]));
assert_eq!(normal, Some("".to_string()));
}
#[test]
fn test_detect_singular_section_start() {
let config = default_config();
// Singular variant: <|tool_call_section_begin|> (without 's')
assert!(detect_tool_call_start_kimi_k2(
"<|tool_call_section_begin|>",
&config
));
// Partial match of singular variant
assert!(detect_tool_call_start_kimi_k2(
"text <|tool_call_section_b",
&config
));
}
#[test]
fn test_parse_with_singular_section_tokens() {
let config = default_config();
// Use singular form: tool_call_section_begin/end (without 's')
let input = r#"<|tool_call_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location":"NYC"}<|tool_call_end|><|tool_call_section_end|>"#;
let (calls, normal) = try_tool_call_parse_kimi_k2(input, &config, None).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_weather");
assert_eq!(normal, Some("".to_string()));
}
#[test]
fn test_find_end_position_singular_variant() {
let config = default_config();
// Singular variant end token
let text = "<|tool_call_section_begin|><|tool_call_begin|>functions.test:0<|tool_call_argument_begin|>{}<|tool_call_end|><|tool_call_section_end|>more text";
let pos = find_tool_call_end_position_kimi_k2(text, &config);
assert_eq!(&text[pos..], "more text");
}
// --- Tests inspired by vllm/sglang coverage gaps ---
#[test]
fn test_parse_invalid_json_falls_back_to_raw_string() {
// vllm: test_extract_tool_calls_invalid_json
// Invalid JSON in arguments should fall back to raw string, not panic
let config = default_config();
let input = r#"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{invalid json here}<|tool_call_end|><|tool_calls_section_end|>"#;
let (calls, _) = try_tool_call_parse_kimi_k2(input, &config, None).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_weather");
// Arguments should be preserved as raw string since JSON parsing failed
assert_eq!(calls[0].function.arguments, "{invalid json here}");
}
#[test]
fn test_parse_invalid_function_id_rejected_by_regex() {
// vllm: test_extract_tool_calls_invalid_funcall
// sglang: test_invalid_tool_call
// After C2 fix, function_id regex requires [\w.]+:\d+ — IDs without :digit are rejected
let config = default_config();
// No colon+digit suffix at all
let input1 = r#"<|tool_calls_section_begin|><|tool_call_begin|>just_a_name<|tool_call_argument_begin|>{"key":"val"}<|tool_call_end|><|tool_calls_section_end|>"#;
let (calls, _) = try_tool_call_parse_kimi_k2(input1, &config, None).unwrap();
assert_eq!(calls.len(), 0, "ID without :digit should be rejected");
// Colon but non-digit suffix
let input2 = r#"<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:abc<|tool_call_argument_begin|>{"key":"val"}<|tool_call_end|><|tool_calls_section_end|>"#;
let (calls, _) = try_tool_call_parse_kimi_k2(input2, &config, None).unwrap();
assert_eq!(calls.len(), 0, "ID with :non-digit should be rejected");
// Multiple colons (garbage)
let input3 = r#"<|tool_calls_section_begin|><|tool_call_begin|>:::0<|tool_call_argument_begin|>{"key":"val"}<|tool_call_end|><|tool_calls_section_end|>"#;
let (calls, _) = try_tool_call_parse_kimi_k2(input3, &config, None).unwrap();
assert_eq!(calls.len(), 0, "Garbage ID should be rejected");
// Valid call mixed with invalid — only valid should be extracted
let input4 = r#"<|tool_calls_section_begin|><|tool_call_begin|>no_colon<|tool_call_argument_begin|>{"a":"b"}<|tool_call_end|><|tool_call_begin|>functions.valid:0<|tool_call_argument_begin|>{"x":"y"}<|tool_call_end|><|tool_calls_section_end|>"#;
let (calls, _) = try_tool_call_parse_kimi_k2(input4, &config, None).unwrap();
assert_eq!(calls.len(), 1, "Only valid call should be extracted");
assert_eq!(calls[0].function.name, "valid");
}
#[test]
fn test_parse_angle_brackets_in_json_arguments() {
// vllm: angle_brackets_in_json
// JSON values containing <tag> constructs should not confuse the parser,
// since Kimi markers use <| prefix which is distinct from bare <
let config = default_config();
let input = r#"<|tool_calls_section_begin|><|tool_call_begin|>functions.render_html:0<|tool_call_argument_begin|>{"template":"<div class=\"main\"><h1>Title</h1><p>Content</p></div>","format":"html"}<|tool_call_end|><|tool_calls_section_end|>"#;
let (calls, _) = try_tool_call_parse_kimi_k2(input, &config, None).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "render_html");
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
assert!(args["template"].as_str().unwrap().contains("<div"));
assert!(args["template"].as_str().unwrap().contains("</div>"));
assert_eq!(args["format"], "html");
}
#[test]
fn test_parse_three_concatenated_calls_no_spacing() {
// vllm: concatenated_tool_calls_bug_fix, three_concatenated_tool_calls
// Three tool calls concatenated with zero whitespace between them
let config = default_config();
let input = "<|tool_calls_section_begin|>\
<|tool_call_begin|>functions.search:0<|tool_call_argument_begin|>{\"q\":\"rust\"}<|tool_call_end|>\
<|tool_call_begin|>functions.search:1<|tool_call_argument_begin|>{\"q\":\"python\"}<|tool_call_end|>\
<|tool_call_begin|>functions.search:2<|tool_call_argument_begin|>{\"q\":\"go\"}<|tool_call_end|>\
<|tool_calls_section_end|>";
let (calls, normal) = try_tool_call_parse_kimi_k2(input, &config, None).unwrap();
assert_eq!(calls.len(), 3);
assert_eq!(calls[0].function.name, "search");
assert_eq!(calls[0].id, "functions.search:0");
assert_eq!(calls[1].function.name, "search");
assert_eq!(calls[1].id, "functions.search:1");
assert_eq!(calls[2].function.name, "search");
assert_eq!(calls[2].id, "functions.search:2");
let a0: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
let a1: serde_json::Value = serde_json::from_str(&calls[1].function.arguments).unwrap();
let a2: serde_json::Value = serde_json::from_str(&calls[2].function.arguments).unwrap();
assert_eq!(a0["q"], "rust");
assert_eq!(a1["q"], "python");
assert_eq!(a2["q"], "go");
assert_eq!(normal, Some("".to_string()));
}
#[test]
fn test_parse_newlines_in_json_arguments() {
// vllm: newlines_in_json
// Multi-line formatted JSON arguments (model may emit pretty-printed JSON)
let config = default_config();
let input = "<|tool_calls_section_begin|><|tool_call_begin|>functions.create_user:0<|tool_call_argument_begin|>{\n \"name\": \"John Doe\",\n \"address\": {\n \"street\": \"123 Main St\",\n \"city\": \"Springfield\"\n },\n \"tags\": [\n \"admin\",\n \"active\"\n ]\n}<|tool_call_end|><|tool_calls_section_end|>";
let (calls, _) = try_tool_call_parse_kimi_k2(input, &config, None).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "create_user");
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
assert_eq!(args["name"], "John Doe");
assert_eq!(args["address"]["city"], "Springfield");
assert_eq!(args["tags"], serde_json::json!(["admin", "active"]));
}
#[test]
fn test_parse_empty_tool_section() {
// vllm: test_empty_tool_section
// Section begin immediately followed by section end, no tool calls inside
let config = default_config();
let input = "Here is my answer. <|tool_calls_section_begin|><|tool_calls_section_end|> And more text.";
let (calls, normal) = try_tool_call_parse_kimi_k2(input, &config, None).unwrap();
assert_eq!(calls.len(), 0, "Empty section should produce no tool calls");
assert_eq!(
normal,
Some("Here is my answer. And more text.".to_string()),
"Text around empty section should be preserved"
);
}
}
......@@ -2,12 +2,17 @@
// SPDX-License-Identifier: Apache-2.0
mod glm47_parser;
mod kimi_k2_parser;
mod parser;
pub use super::response;
pub use glm47_parser::{
detect_tool_call_start_glm47, find_tool_call_end_position_glm47, try_tool_call_parse_glm47,
};
pub use kimi_k2_parser::{
detect_tool_call_start_kimi_k2, find_tool_call_end_position_kimi_k2,
try_tool_call_parse_kimi_k2,
};
pub use parser::{
detect_tool_call_start_xml, find_tool_call_end_position_xml, try_tool_call_parse_xml,
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment