Unverified Commit 04ea38c3 authored by Nikita's avatar Nikita Committed by GitHub
Browse files

fix(parsers): back-port sglang kimi-k2 tool-call and reasoning detection (#8532)


Signed-off-by: default avatarNikita Sukharev <kaonael@gmail.com>
parent 48911230
...@@ -61,6 +61,10 @@ pub struct BasicReasoningParser { ...@@ -61,6 +61,10 @@ pub struct BasicReasoningParser {
stream_reasoning: bool, stream_reasoning: bool,
_buffer: String, _buffer: String,
stripped_think_start: bool, stripped_think_start: bool,
/// Optional marker that force-exits reasoning mode when encountered inside a
/// reasoning block (e.g. Kimi-K2/K2.5 models sometimes emit
/// `<|tool_calls_section_begin|>` without first closing `</think>`).
tool_start_token: Option<String>,
} }
impl BasicReasoningParser { impl BasicReasoningParser {
...@@ -77,8 +81,16 @@ impl BasicReasoningParser { ...@@ -77,8 +81,16 @@ impl BasicReasoningParser {
stream_reasoning, stream_reasoning,
_buffer: String::new(), _buffer: String::new(),
stripped_think_start: false, stripped_think_start: false,
tool_start_token: None,
} }
} }
/// Enables force-exit from reasoning when `token` appears inside an open reasoning
/// block.
pub fn with_tool_start_token(mut self, token: impl Into<String>) -> Self {
self.tool_start_token = Some(token.into());
self
}
} }
impl ReasoningParser for BasicReasoningParser { impl ReasoningParser for BasicReasoningParser {
...@@ -101,8 +113,17 @@ impl ReasoningParser for BasicReasoningParser { ...@@ -101,8 +113,17 @@ impl ReasoningParser for BasicReasoningParser {
}; };
} }
// If force_reasoning and no start tag, treat entire text as reasoning // If force_reasoning and no start tag, no end tag, and no tool-start marker,
if self._in_reasoning && !has_think_tag && !text.contains(&self.think_end_token) { // treat entire text as reasoning.
let has_tool_start = self
.tool_start_token
.as_deref()
.is_some_and(|tok| text.contains(tok));
if self._in_reasoning
&& !has_think_tag
&& !text.contains(&self.think_end_token)
&& !has_tool_start
{
return ParserResult { return ParserResult {
normal_text: String::new(), normal_text: String::new(),
reasoning_text: text.to_string(), reasoning_text: text.to_string(),
...@@ -121,15 +142,39 @@ impl ReasoningParser for BasicReasoningParser { ...@@ -121,15 +142,39 @@ impl ReasoningParser for BasicReasoningParser {
if text[cursor..].starts_with(&self.think_start_token) { if text[cursor..].starts_with(&self.think_start_token) {
cursor += self.think_start_token.len(); cursor += self.think_start_token.len();
} }
// We're inside a reasoning block — look for end token // Look for the earliest reasoning exit point: either </think> or the
if let Some(end_offset) = text[cursor..].find(&self.think_end_token) { // optional tool_start_token (force-exit case).
reasoning_parts.push(&text[cursor..cursor + end_offset]); let end_offset = text[cursor..].find(&self.think_end_token);
cursor += end_offset + self.think_end_token.len(); let tool_offset = self
currently_reasoning = false; .tool_start_token
} else { .as_deref()
// No end token — rest is reasoning (truncated) .and_then(|tok| text[cursor..].find(tok));
reasoning_parts.push(&text[cursor..]);
cursor = text.len(); match (end_offset, tool_offset) {
(Some(e), Some(t)) if t < e => {
// tool_start arrives before </think> — force-exit.
reasoning_parts.push(&text[cursor..cursor + t]);
normal_parts.push(&text[cursor + t..]);
cursor = text.len();
currently_reasoning = false;
}
(Some(e), _) => {
reasoning_parts.push(&text[cursor..cursor + e]);
cursor += e + self.think_end_token.len();
currently_reasoning = false;
}
(None, Some(t)) => {
// No </think> but tool_start is present — force-exit.
reasoning_parts.push(&text[cursor..cursor + t]);
normal_parts.push(&text[cursor + t..]);
cursor = text.len();
currently_reasoning = false;
}
(None, None) => {
// No end token — rest is reasoning (truncated)
reasoning_parts.push(&text[cursor..]);
cursor = text.len();
}
} }
} else { } else {
// We're in normal text — look for start token // We're in normal text — look for start token
...@@ -200,7 +245,29 @@ impl ReasoningParser for BasicReasoningParser { ...@@ -200,7 +245,29 @@ impl ReasoningParser for BasicReasoningParser {
} }
if self._in_reasoning { if self._in_reasoning {
if let Some(end_idx) = current_text.find(self.think_end_token.as_str()) { let end_idx = current_text.find(self.think_end_token.as_str());
let tool_idx = self
.tool_start_token
.as_deref()
.and_then(|tok| current_text.find(tok));
// Prefer whichever marker appears first. If only one is present, use it.
let force_exit_idx = match (end_idx, tool_idx) {
(Some(e), Some(t)) if t < e => Some(t),
(None, Some(t)) => Some(t),
_ => None,
};
if let Some(tool_at) = force_exit_idx {
accumulated_reasoning.push_str(&current_text[..tool_at]);
accumulated_normal.push_str(&current_text[tool_at..]);
self._buffer.clear();
self._in_reasoning = false;
self.stripped_think_start = false;
break;
}
if let Some(end_idx) = end_idx {
// End of reasoning block: accumulate content and transition out. // End of reasoning block: accumulate content and transition out.
accumulated_reasoning.push_str(&current_text[..end_idx]); accumulated_reasoning.push_str(&current_text[..end_idx]);
let after_end = end_idx + self.think_end_token.len(); let after_end = end_idx + self.think_end_token.len();
...@@ -211,8 +278,16 @@ impl ReasoningParser for BasicReasoningParser { ...@@ -211,8 +278,16 @@ impl ReasoningParser for BasicReasoningParser {
} else { } else {
// No complete end token — check for partial at end of buffer // No complete end token — check for partial at end of buffer
// (e.g., "reasoning content</th" where "</th" is a prefix of "</think>"). // (e.g., "reasoning content</th" where "</th" is a prefix of "</think>").
// Partial prefixes of tool_start_token must also be buffered so the
// force-exit marker isn't split into reasoning text.
if self.stream_reasoning { if self.stream_reasoning {
let ol = overlap(&current_text, &self.think_end_token); let ol_end = overlap(&current_text, &self.think_end_token);
let ol_tool = self
.tool_start_token
.as_deref()
.map(|tok| overlap(&current_text, tok))
.unwrap_or(0);
let ol = ol_end.max(ol_tool);
if ol >= 2 { if ol >= 2 {
let safe_end = current_text.len() - ol; let safe_end = current_text.len() - ol;
if safe_end > 0 { if safe_end > 0 {
...@@ -268,6 +343,7 @@ impl ReasoningParser for BasicReasoningParser { ...@@ -268,6 +343,7 @@ impl ReasoningParser for BasicReasoningParser {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use rstest::rstest;
#[test] #[test]
fn test_detect_and_parse_reasoning_reasoning() { fn test_detect_and_parse_reasoning_reasoning() {
...@@ -1053,4 +1129,96 @@ mod tests { ...@@ -1053,4 +1129,96 @@ mod tests {
assert_eq!(overlap("text◁/thi", "◁/think▷"), 7); assert_eq!(overlap("text◁/thi", "◁/think▷"), 7);
assert_eq!(overlap("no match", "◁think▷"), 0); assert_eq!(overlap("no match", "◁think▷"), 0);
} }
fn kimi_k2_parser() -> BasicReasoningParser {
// Mirrors the `kimi_k25` registration in reasoning/mod.rs.
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true)
.with_tool_start_token(crate::reasoning::KIMI_K2_TOOL_SECTION_BEGIN)
}
#[rstest]
#[case(
"thinking text <|tool_calls_section_begin|><|tool_call_begin|>functions.foo:0<|tool_call_argument_begin|>{}<|tool_call_end|><|tool_calls_section_end|>",
"thinking text",
"<|tool_calls_section_begin|><|tool_call_begin|>functions.foo:0<|tool_call_argument_begin|>{}<|tool_call_end|><|tool_calls_section_end|>"
)]
#[case("r</think>a", "r", "a")]
#[case(
"reasoning</think>answer <|tool_calls_section_begin|>tc",
"reasoning",
"answer <|tool_calls_section_begin|>tc"
)]
fn test_kimi_k2_one_shot_split(
#[case] input: &str,
#[case] expected_reasoning: &str,
#[case] expected_normal: &str,
) {
let mut parser = kimi_k2_parser();
let r = parser.detect_and_parse_reasoning(input, &[]);
assert_eq!(r.reasoning_text, expected_reasoning);
assert_eq!(r.normal_text, expected_normal);
}
#[test]
fn test_force_exit_streaming_single_chunk() {
let mut parser = kimi_k2_parser();
let r = parser.parse_reasoning_streaming_incremental(
"thinking text <|tool_calls_section_begin|><|tool_call_begin|>functions.foo:0<|tool_call_argument_begin|>{}<|tool_call_end|><|tool_calls_section_end|>",
&[],
);
assert_eq!(r.reasoning_text, "thinking text ");
assert_eq!(
r.normal_text,
"<|tool_calls_section_begin|><|tool_call_begin|>functions.foo:0<|tool_call_argument_begin|>{}<|tool_call_end|><|tool_calls_section_end|>"
);
}
#[test]
fn test_force_exit_streaming_split_across_chunks() {
let mut parser = kimi_k2_parser();
let r1 = parser.parse_reasoning_streaming_incremental("thinking ", &[]);
assert_eq!(r1.reasoning_text, "thinking ");
assert_eq!(r1.normal_text, "");
// Second chunk ends with a prefix of the tool marker — the suffix must be buffered.
let r2 = parser.parse_reasoning_streaming_incremental("text <|tool_cal", &[]);
assert_eq!(r2.reasoning_text, "text ");
assert_eq!(r2.normal_text, "");
let r3 = parser.parse_reasoning_streaming_incremental("ls_section_begin|>rest", &[]);
assert_eq!(r3.reasoning_text, "");
assert_eq!(r3.normal_text, "<|tool_calls_section_begin|>rest");
}
#[test]
fn test_force_exit_partial_marker_resolves_as_non_marker() {
// First chunk ends with "<|tool_ca" (prefix of marker) — must be buffered.
// Second chunk "xxx" makes the combined "<|tool_caxxx" which is NOT a marker.
// With force_reasoning=true, the content then flushes as reasoning.
let mut parser = kimi_k2_parser();
let r1 = parser.parse_reasoning_streaming_incremental("abc <|tool_ca", &[]);
assert_eq!(r1.reasoning_text, "abc ");
assert_eq!(r1.normal_text, "");
let r2 = parser.parse_reasoning_streaming_incremental("xxx", &[]);
assert_eq!(r2.reasoning_text, "<|tool_caxxx");
assert_eq!(r2.normal_text, "");
}
#[test]
fn test_no_tool_start_token_behaves_as_before() {
// Without the tool_start_token setter, BasicReasoningParser is byte-identical
// to the pre-patch behavior — the marker is just reasoning content.
let mut parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true);
let r =
parser.detect_and_parse_reasoning("thinking <|tool_calls_section_begin|>stuff", &[]);
assert_eq!(
r.reasoning_text,
"thinking <|tool_calls_section_begin|>stuff"
);
assert_eq!(r.normal_text, "");
}
} }
...@@ -14,6 +14,11 @@ pub use gpt_oss_parser::GptOssReasoningParser; ...@@ -14,6 +14,11 @@ pub use gpt_oss_parser::GptOssReasoningParser;
pub use granite_parser::GraniteReasoningParser; pub use granite_parser::GraniteReasoningParser;
pub use minimax_append_think_parser::MiniMaxAppendThinkParser; pub use minimax_append_think_parser::MiniMaxAppendThinkParser;
/// Kimi-K2/K2.5 tool-call section marker. Shared between the `kimi_k25` reasoning-parser
/// registration and its test fixtures so both stay in sync. Mirrors
/// `KimiK2ParserConfig::default().section_start` in `crate::tool_calling::config`.
pub(crate) const KIMI_K2_TOOL_SECTION_BEGIN: &str = "<|tool_calls_section_begin|>";
static REASONING_PARSER_MAP: OnceLock<HashMap<&'static str, ReasoningParserType>> = OnceLock::new(); static REASONING_PARSER_MAP: OnceLock<HashMap<&'static str, ReasoningParserType>> = OnceLock::new();
/// Initialize the global reasoning parser map /// Initialize the global reasoning parser map
...@@ -168,12 +173,10 @@ impl ReasoningParserType { ...@@ -168,12 +173,10 @@ impl ReasoningParserType {
)), )),
}, },
ReasoningParserType::KimiK25 => ReasoningParserWrapper { ReasoningParserType::KimiK25 => ReasoningParserWrapper {
parser: Box::new(BasicReasoningParser::new( parser: Box::new(
"<think>".into(), BasicReasoningParser::new("<think>".into(), "</think>".into(), true, true)
"</think>".into(), .with_tool_start_token(KIMI_K2_TOOL_SECTION_BEGIN),
true, ),
true,
)),
}, },
ReasoningParserType::Mistral => ReasoningParserWrapper { ReasoningParserType::Mistral => ReasoningParserWrapper {
parser: Box::new(BasicReasoningParser::new( parser: Box::new(BasicReasoningParser::new(
......
...@@ -21,12 +21,13 @@ static TOOL_CALL_REGEX: OnceLock<Regex> = OnceLock::new(); ...@@ -21,12 +21,13 @@ static TOOL_CALL_REGEX: OnceLock<Regex> = OnceLock::new();
/// `arguments` (JSON object) between the configured `call_start`, `argument_begin`, and /// `arguments` (JSON object) between the configured `call_start`, `argument_begin`, and
/// `call_end` tokens. /// `call_end` tokens.
/// ///
/// The `function_id` pattern `[\w.]+:\d+` matches the `functions.name:index` format used by /// The `function_id` pattern `[\w.\-]+:\d+` matches the `functions.name:index` format used by
/// Kimi K2, consistent with sglang/vllm reference implementations. /// Kimi K2, consistent with sglang's reference implementation. The hyphen is included to
/// support function names with dashes (common in MCP tools, e.g. `mcp__portal__search-documents`).
fn get_tool_call_regex(config: &KimiK2ParserConfig) -> &'static Regex { fn get_tool_call_regex(config: &KimiK2ParserConfig) -> &'static Regex {
TOOL_CALL_REGEX.get_or_init(|| { TOOL_CALL_REGEX.get_or_init(|| {
let pattern = format!( let pattern = format!(
r"(?s){}\s*(?P<function_id>[\w.]+:\d+)\s*{}\s*(?P<arguments>\{{.*?\}})\s*{}", r"(?s){}\s*(?P<function_id>[\w.\-]+:\d+)\s*{}\s*(?P<arguments>\{{.*?\}})\s*{}",
regex::escape(&config.call_start), regex::escape(&config.call_start),
regex::escape(&config.argument_begin), regex::escape(&config.argument_begin),
regex::escape(&config.call_end), regex::escape(&config.call_end),
...@@ -37,7 +38,7 @@ fn get_tool_call_regex(config: &KimiK2ParserConfig) -> &'static Regex { ...@@ -37,7 +38,7 @@ fn get_tool_call_regex(config: &KimiK2ParserConfig) -> &'static Regex {
fn get_id_regex() -> &'static Regex { fn get_id_regex() -> &'static Regex {
ID_REGEX.get_or_init(|| { ID_REGEX.get_or_init(|| {
Regex::new(r"^(?:functions\.)?(?P<name>[\w\.]+):(?P<index>\d+)$") Regex::new(r"^(?:functions\.)?(?P<name>[\w.\-]+):(?P<index>\d+)$")
.expect("Failed to compile kimi k2 id regex") .expect("Failed to compile kimi k2 id regex")
}) })
} }
...@@ -300,6 +301,7 @@ fn parse_section_block( ...@@ -300,6 +301,7 @@ fn parse_section_block(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use rstest::rstest;
fn default_config() -> KimiK2ParserConfig { fn default_config() -> KimiK2ParserConfig {
KimiK2ParserConfig::default() KimiK2ParserConfig::default()
...@@ -630,7 +632,7 @@ mod tests { ...@@ -630,7 +632,7 @@ mod tests {
fn test_parse_invalid_function_id_rejected_by_regex() { fn test_parse_invalid_function_id_rejected_by_regex() {
// vllm: test_extract_tool_calls_invalid_funcall // vllm: test_extract_tool_calls_invalid_funcall
// sglang: test_invalid_tool_call // sglang: test_invalid_tool_call
// After C2 fix, function_id regex requires [\w.]+:\d+ — IDs without :digit are rejected // function_id regex requires [\w.\-]+:\d+ — IDs without :digit are rejected
let config = default_config(); let config = default_config();
// No colon+digit suffix at all // No colon+digit suffix at all
...@@ -734,4 +736,28 @@ mod tests { ...@@ -734,4 +736,28 @@ mod tests {
"Text around empty section should be preserved" "Text around empty section should be preserved"
); );
} }
#[rstest]
#[case(
r#"<|tool_calls_section_begin|><|tool_call_begin|>functions.list-tasklists:0<|tool_call_argument_begin|>{}<|tool_call_end|><|tool_calls_section_end|>"#,
"list-tasklists",
"functions.list-tasklists:0"
)]
#[case(
r#"<|tool_calls_section_begin|><|tool_call_begin|>functions.mcp__portal__search-documents:3<|tool_call_argument_begin|>{}<|tool_call_end|><|tool_calls_section_end|>"#,
"mcp__portal__search-documents",
"functions.mcp__portal__search-documents:3"
)]
#[case(
r#"<|tool_calls_section_begin|><|tool_call_begin|>functions.gtasks_list-tasklists:0<|tool_call_argument_begin|>{}<|tool_call_end|><|tool_calls_section_end|>"#,
"gtasks_list-tasklists",
"functions.gtasks_list-tasklists:0"
)]
fn test_parse_names_with_hyphens(#[case] input: &str, #[case] name: &str, #[case] id: &str) {
let config = default_config();
let (calls, _normal) = try_tool_call_parse_kimi_k2(input, &config, None).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, name);
assert_eq!(calls[0].id, id);
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment