Unverified Commit 466992b2 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][tool call] Clean up redundant `detect_format` and `has_tool_markers` (#11270)

parent 155cbb51
...@@ -1859,7 +1859,7 @@ impl GrpcPDRouter { ...@@ -1859,7 +1859,7 @@ impl GrpcPDRouter {
// Check format detection first // Check format detection first
let can_parse = { let can_parse = {
let parser = pooled_parser.lock().await; let parser = pooled_parser.lock().await;
parser.detect_format(processed_text) parser.has_tool_markers(processed_text)
// Lock is dropped here // Lock is dropped here
}; };
......
...@@ -306,7 +306,7 @@ impl GrpcRouter { ...@@ -306,7 +306,7 @@ impl GrpcRouter {
// Check format detection first // Check format detection first
let can_parse = { let can_parse = {
let parser = pooled_parser.lock().await; let parser = pooled_parser.lock().await;
parser.detect_format(processed_text) parser.has_tool_markers(processed_text)
// Lock is dropped here // Lock is dropped here
}; };
......
...@@ -77,11 +77,6 @@ impl DeepSeekParser { ...@@ -77,11 +77,6 @@ impl DeepSeekParser {
} }
} }
/// Check if text contains DeepSeek tool markers
fn has_tool_markers(&self, text: &str) -> bool {
text.contains("<|tool▁calls▁begin|>")
}
/// Parse a single tool call block - throws error if parsing fails /// Parse a single tool call block - throws error if parsing fails
fn parse_tool_call(&self, block: &str) -> ToolParserResult<ToolCall> { fn parse_tool_call(&self, block: &str) -> ToolParserResult<ToolCall> {
let captures = self.func_detail_extractor.captures(block).ok_or_else(|| { let captures = self.func_detail_extractor.captures(block).ok_or_else(|| {
...@@ -312,8 +307,8 @@ impl ToolParser for DeepSeekParser { ...@@ -312,8 +307,8 @@ impl ToolParser for DeepSeekParser {
}) })
} }
fn detect_format(&self, text: &str) -> bool { fn has_tool_markers(&self, text: &str) -> bool {
self.has_tool_markers(text) text.contains("<|tool▁calls▁begin|>")
} }
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> { fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
......
...@@ -71,11 +71,6 @@ impl Glm4MoeParser { ...@@ -71,11 +71,6 @@ impl Glm4MoeParser {
} }
} }
/// Check if text contains GLM-4 MoE tool markers
fn has_tool_markers(&self, text: &str) -> bool {
text.contains(self.bot_token)
}
/// Parse arguments from key-value pairs /// Parse arguments from key-value pairs
fn parse_arguments(&self, args_text: &str) -> ToolParserResult<serde_json::Map<String, Value>> { fn parse_arguments(&self, args_text: &str) -> ToolParserResult<serde_json::Map<String, Value>> {
let mut arguments = serde_json::Map::new(); let mut arguments = serde_json::Map::new();
...@@ -313,8 +308,8 @@ impl ToolParser for Glm4MoeParser { ...@@ -313,8 +308,8 @@ impl ToolParser for Glm4MoeParser {
}) })
} }
fn detect_format(&self, text: &str) -> bool { fn has_tool_markers(&self, text: &str) -> bool {
self.has_tool_markers(text) text.contains(self.bot_token)
} }
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> { fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
......
...@@ -38,7 +38,7 @@ impl ToolParser for GptOssHarmonyParser { ...@@ -38,7 +38,7 @@ impl ToolParser for GptOssHarmonyParser {
Ok(StreamingParseResult::default()) Ok(StreamingParseResult::default())
} }
fn detect_format(&self, text: &str) -> bool { fn has_tool_markers(&self, text: &str) -> bool {
// Reuse the legacy heuristics for now; this will be replaced with Harmony-specific // Reuse the legacy heuristics for now; this will be replaced with Harmony-specific
// start-token detection when the parser is fully implemented. // start-token detection when the parser is fully implemented.
text.contains("<|channel|>commentary") text.contains("<|channel|>commentary")
......
...@@ -58,11 +58,6 @@ impl GptOssParser { ...@@ -58,11 +58,6 @@ impl GptOssParser {
} }
} }
/// Check if text contains GPT-OSS tool markers
fn has_tool_markers(&self, text: &str) -> bool {
text.contains("<|channel|>commentary to=")
}
/// Extract function name from full namespace (e.g., "functions.get_weather" -> "get_weather") /// Extract function name from full namespace (e.g., "functions.get_weather" -> "get_weather")
fn extract_function_name(&self, full_name: &str) -> String { fn extract_function_name(&self, full_name: &str) -> String {
if let Some(dot_pos) = full_name.rfind('.') { if let Some(dot_pos) = full_name.rfind('.') {
...@@ -242,7 +237,7 @@ impl ToolParser for GptOssParser { ...@@ -242,7 +237,7 @@ impl ToolParser for GptOssParser {
Ok(StreamingParseResult::default()) Ok(StreamingParseResult::default())
} }
fn detect_format(&self, text: &str) -> bool { fn has_tool_markers(&self, text: &str) -> bool {
self.has_tool_markers(text) || text.contains("<|channel|>commentary") text.contains("<|channel|>commentary")
} }
} }
...@@ -261,7 +261,7 @@ impl ToolParser for JsonParser { ...@@ -261,7 +261,7 @@ impl ToolParser for JsonParser {
) )
} }
fn detect_format(&self, text: &str) -> bool { fn has_tool_markers(&self, text: &str) -> bool {
let trimmed = text.trim(); let trimmed = text.trim();
(trimmed.starts_with('[') || trimmed.starts_with('{')) && trimmed.contains(r#""name""#) (trimmed.starts_with('[') || trimmed.starts_with('{')) && trimmed.contains(r#""name""#)
} }
......
...@@ -82,11 +82,6 @@ impl KimiK2Parser { ...@@ -82,11 +82,6 @@ impl KimiK2Parser {
} }
} }
/// Check if text contains Kimi K2 tool markers
fn has_tool_markers(&self, text: &str) -> bool {
text.contains("<|tool_calls_section_begin|>")
}
/// Parse function ID to extract name and index /// Parse function ID to extract name and index
fn parse_function_id(&self, id: &str) -> Option<(String, usize)> { fn parse_function_id(&self, id: &str) -> Option<(String, usize)> {
if let Some(captures) = self.tool_call_id_regex.captures(id) { if let Some(captures) = self.tool_call_id_regex.captures(id) {
...@@ -331,8 +326,8 @@ impl ToolParser for KimiK2Parser { ...@@ -331,8 +326,8 @@ impl ToolParser for KimiK2Parser {
}) })
} }
fn detect_format(&self, text: &str) -> bool { fn has_tool_markers(&self, text: &str) -> bool {
self.has_tool_markers(text) || text.contains("<|tool_call_begin|>") text.contains("<|tool_calls_section_begin|>")
} }
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> { fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
......
...@@ -228,7 +228,7 @@ impl ToolParser for LlamaParser { ...@@ -228,7 +228,7 @@ impl ToolParser for LlamaParser {
) )
} }
fn detect_format(&self, text: &str) -> bool { fn has_tool_markers(&self, text: &str) -> bool {
// Llama format if contains python_tag or starts with JSON object // Llama format if contains python_tag or starts with JSON object
text.contains("<|python_tag|>") text.contains("<|python_tag|>")
|| (text.trim_start().starts_with('{') && text.contains(r#""name""#)) || (text.trim_start().starts_with('{') && text.contains(r#""name""#))
......
...@@ -156,11 +156,6 @@ impl MistralParser { ...@@ -156,11 +156,6 @@ impl MistralParser {
Ok(None) Ok(None)
} }
} }
/// Check if text contains Mistral tool markers
fn has_tool_markers(&self, text: &str) -> bool {
text.contains("[TOOL_CALLS]")
}
} }
impl Default for MistralParser { impl Default for MistralParser {
...@@ -254,8 +249,8 @@ impl ToolParser for MistralParser { ...@@ -254,8 +249,8 @@ impl ToolParser for MistralParser {
) )
} }
fn detect_format(&self, text: &str) -> bool { fn has_tool_markers(&self, text: &str) -> bool {
self.has_tool_markers(text) text.contains("[TOOL_CALLS]")
} }
fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> { fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
......
...@@ -203,7 +203,7 @@ impl ToolParser for PythonicParser { ...@@ -203,7 +203,7 @@ impl ToolParser for PythonicParser {
}) })
} }
fn detect_format(&self, text: &str) -> bool { fn has_tool_markers(&self, text: &str) -> bool {
let cleaned = Self::strip_special_tokens(text); let cleaned = Self::strip_special_tokens(text);
if pythonic_block_regex().is_match(&cleaned) { if pythonic_block_regex().is_match(&cleaned) {
return true; return true;
......
...@@ -98,16 +98,6 @@ impl QwenParser { ...@@ -98,16 +98,6 @@ impl QwenParser {
Ok(None) Ok(None)
} }
} }
/// Check if text contains Qwen tool markers
fn has_tool_markers(&self, text: &str) -> bool {
text.contains("<tool_call>")
}
/// Check if text has tool call
fn has_tool_call(&self, text: &str) -> bool {
text.contains("<tool_call>")
}
} }
impl Default for QwenParser { impl Default for QwenParser {
...@@ -165,7 +155,7 @@ impl ToolParser for QwenParser { ...@@ -165,7 +155,7 @@ impl ToolParser for QwenParser {
let current_text = &self.buffer.clone(); let current_text = &self.buffer.clone();
// Check if current_text has tool_call // Check if current_text has tool_call
let has_tool_start = self.has_tool_call(current_text) let has_tool_start = self.has_tool_markers(current_text)
|| (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator)); || (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));
if !has_tool_start { if !has_tool_start {
...@@ -243,8 +233,8 @@ impl ToolParser for QwenParser { ...@@ -243,8 +233,8 @@ impl ToolParser for QwenParser {
Ok(result) Ok(result)
} }
fn detect_format(&self, text: &str) -> bool { fn has_tool_markers(&self, text: &str) -> bool {
self.has_tool_markers(text) text.contains("<tool_call>")
} }
fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> { fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
......
...@@ -96,11 +96,6 @@ impl Step3Parser { ...@@ -96,11 +96,6 @@ impl Step3Parser {
} }
} }
/// Check if text contains Step3 tool markers
fn has_tool_markers(&self, text: &str) -> bool {
text.contains(self.bot_token)
}
/// Reset streaming state for the next tool call /// Reset streaming state for the next tool call
fn reset_streaming_state(&mut self) { fn reset_streaming_state(&mut self) {
self.in_tool_call = false; self.in_tool_call = false;
...@@ -553,8 +548,8 @@ impl ToolParser for Step3Parser { ...@@ -553,8 +548,8 @@ impl ToolParser for Step3Parser {
Ok(StreamingParseResult::default()) Ok(StreamingParseResult::default())
} }
fn detect_format(&self, text: &str) -> bool { fn has_tool_markers(&self, text: &str) -> bool {
self.has_tool_markers(text) text.contains(self.bot_token)
} }
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> { fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
......
...@@ -12,7 +12,7 @@ async fn test_tool_parser_factory() { ...@@ -12,7 +12,7 @@ async fn test_tool_parser_factory() {
// Test that we can get a pooled parser // Test that we can get a pooled parser
let pooled_parser = factory.get_pooled("gpt-4"); let pooled_parser = factory.get_pooled("gpt-4");
let parser = pooled_parser.lock().await; let parser = pooled_parser.lock().await;
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#)); assert!(parser.has_tool_markers(r#"{"name": "test", "arguments": {}}"#));
} }
#[tokio::test] #[tokio::test]
...@@ -25,7 +25,7 @@ async fn test_tool_parser_factory_model_mapping() { ...@@ -25,7 +25,7 @@ async fn test_tool_parser_factory_model_mapping() {
// Get parser for the test model // Get parser for the test model
let pooled_parser = factory.get_pooled("test-model"); let pooled_parser = factory.get_pooled("test-model");
let parser = pooled_parser.lock().await; let parser = pooled_parser.lock().await;
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#)); assert!(parser.has_tool_markers(r#"{"name": "test", "arguments": {}}"#));
} }
#[test] #[test]
...@@ -234,12 +234,12 @@ fn test_json_parser_format_detection() { ...@@ -234,12 +234,12 @@ fn test_json_parser_format_detection() {
let parser = JsonParser::new(); let parser = JsonParser::new();
// Should detect valid tool call formats // Should detect valid tool call formats
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#)); assert!(parser.has_tool_markers(r#"{"name": "test", "arguments": {}}"#));
assert!(parser.detect_format(r#"{"name": "test", "parameters": {"x": 1}}"#)); assert!(parser.has_tool_markers(r#"{"name": "test", "parameters": {"x": 1}}"#));
assert!(parser.detect_format(r#"[{"name": "test"}]"#)); assert!(parser.has_tool_markers(r#"[{"name": "test"}]"#));
// Should not detect non-tool formats // Should not detect non-tool formats
assert!(!parser.detect_format("plain text")); assert!(!parser.has_tool_markers("plain text"));
} }
#[tokio::test] #[tokio::test]
......
...@@ -25,7 +25,7 @@ pub trait ToolParser: Send + Sync { ...@@ -25,7 +25,7 @@ pub trait ToolParser: Send + Sync {
) -> ToolParserResult<StreamingParseResult>; ) -> ToolParserResult<StreamingParseResult>;
/// Check if text contains tool calls in this parser's format /// Check if text contains tool calls in this parser's format
fn detect_format(&self, text: &str) -> bool; fn has_tool_markers(&self, text: &str) -> bool;
/// Optionally expose a token-aware parser implementation. /// Optionally expose a token-aware parser implementation.
/// Default returns `None`, meaning the parser only supports text input. /// Default returns `None`, meaning the parser only supports text input.
......
...@@ -108,13 +108,13 @@ fn test_deepseek_format_detection() { ...@@ -108,13 +108,13 @@ fn test_deepseek_format_detection() {
let parser = DeepSeekParser::new(); let parser = DeepSeekParser::new();
// Should detect DeepSeek format // Should detect DeepSeek format
assert!(parser.detect_format("<|tool▁calls▁begin|>")); assert!(parser.has_tool_markers("<|tool▁calls▁begin|>"));
assert!(parser.detect_format("text with <|tool▁calls▁begin|> marker")); assert!(parser.has_tool_markers("text with <|tool▁calls▁begin|> marker"));
// Should not detect other formats // Should not detect other formats
assert!(!parser.detect_format("[TOOL_CALLS]")); assert!(!parser.has_tool_markers("[TOOL_CALLS]"));
assert!(!parser.detect_format("<tool_call>")); assert!(!parser.has_tool_markers("<tool_call>"));
assert!(!parser.detect_format("plain text")); assert!(!parser.has_tool_markers("plain text"));
} }
#[tokio::test] #[tokio::test]
......
...@@ -117,13 +117,13 @@ fn test_glm4_format_detection() { ...@@ -117,13 +117,13 @@ fn test_glm4_format_detection() {
let parser = Glm4MoeParser::new(); let parser = Glm4MoeParser::new();
// Should detect GLM-4 format // Should detect GLM-4 format
assert!(parser.detect_format("<tool_call>")); assert!(parser.has_tool_markers("<tool_call>"));
assert!(parser.detect_format("text with <tool_call> marker")); assert!(parser.has_tool_markers("text with <tool_call> marker"));
// Should not detect other formats // Should not detect other formats
assert!(!parser.detect_format("[TOOL_CALLS]")); assert!(!parser.has_tool_markers("[TOOL_CALLS]"));
assert!(!parser.detect_format("<|tool▁calls▁begin|>")); assert!(!parser.has_tool_markers("<|tool▁calls▁begin|>"));
assert!(!parser.detect_format("plain text")); assert!(!parser.has_tool_markers("plain text"));
} }
#[tokio::test] #[tokio::test]
......
...@@ -109,14 +109,14 @@ fn test_gpt_oss_format_detection() { ...@@ -109,14 +109,14 @@ fn test_gpt_oss_format_detection() {
let parser = GptOssParser::new(); let parser = GptOssParser::new();
// Should detect GPT-OSS format // Should detect GPT-OSS format
assert!(parser.detect_format("<|channel|>commentary to=")); assert!(parser.has_tool_markers("<|channel|>commentary to="));
assert!(parser.detect_format("<|channel|>commentary")); assert!(parser.has_tool_markers("<|channel|>commentary"));
assert!(parser.detect_format("text with <|channel|>commentary to= marker")); assert!(parser.has_tool_markers("text with <|channel|>commentary to= marker"));
// Should not detect other formats // Should not detect other formats
assert!(!parser.detect_format("[TOOL_CALLS]")); assert!(!parser.has_tool_markers("[TOOL_CALLS]"));
assert!(!parser.detect_format("<tool_call>")); assert!(!parser.has_tool_markers("<tool_call>"));
assert!(!parser.detect_format("plain text")); assert!(!parser.has_tool_markers("plain text"));
} }
#[tokio::test] #[tokio::test]
......
...@@ -155,7 +155,7 @@ async fn test_json_invalid_format() { ...@@ -155,7 +155,7 @@ async fn test_json_invalid_format() {
async fn test_json_format_detection() { async fn test_json_format_detection() {
let parser = JsonParser::new(); let parser = JsonParser::new();
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#)); assert!(parser.has_tool_markers(r#"{"name": "test", "arguments": {}}"#));
assert!(parser.detect_format(r#"[{"name": "test"}]"#)); assert!(parser.has_tool_markers(r#"[{"name": "test"}]"#));
assert!(!parser.detect_format("plain text")); assert!(!parser.has_tool_markers("plain text"));
} }
...@@ -98,14 +98,13 @@ fn test_kimik2_format_detection() { ...@@ -98,14 +98,13 @@ fn test_kimik2_format_detection() {
let parser = KimiK2Parser::new(); let parser = KimiK2Parser::new();
// Should detect Kimi K2 format // Should detect Kimi K2 format
assert!(parser.detect_format("<|tool_calls_section_begin|>")); assert!(parser.has_tool_markers("<|tool_calls_section_begin|>"));
assert!(parser.detect_format("<|tool_call_begin|>")); assert!(parser.has_tool_markers("text with <|tool_calls_section_begin|> marker"));
assert!(parser.detect_format("text with <|tool_calls_section_begin|> marker"));
// Should not detect other formats // Should not detect other formats
assert!(!parser.detect_format("[TOOL_CALLS]")); assert!(!parser.has_tool_markers("[TOOL_CALLS]"));
assert!(!parser.detect_format("<tool_call>")); assert!(!parser.has_tool_markers("<tool_call>"));
assert!(!parser.detect_format("plain text")); assert!(!parser.has_tool_markers("plain text"));
} }
#[tokio::test] #[tokio::test]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment