Unverified Commit 6e4e1c8c authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] add deepseek tool parser (#9694)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent 9768c50d
...@@ -24,4 +24,6 @@ pub use traits::{PartialJsonParser, ToolParser}; ...@@ -24,4 +24,6 @@ pub use traits::{PartialJsonParser, ToolParser};
pub use types::{FunctionCall, PartialToolCall, StreamResult, TokenConfig, ToolCall}; pub use types::{FunctionCall, PartialToolCall, StreamResult, TokenConfig, ToolCall};
// Re-export parsers for convenience // Re-export parsers for convenience
pub use parsers::{JsonParser, LlamaParser, MistralParser, PythonicParser, QwenParser}; pub use parsers::{
DeepSeekParser, JsonParser, LlamaParser, MistralParser, PythonicParser, QwenParser,
};
use async_trait::async_trait;
use regex::Regex;
use serde_json::Value;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
partial_json::PartialJson,
state::ParseState,
traits::ToolParser,
types::{FunctionCall, StreamResult, ToolCall},
};
/// DeepSeek V3 format parser for tool calls
///
/// Handles the DeepSeek V3 specific format that uses Unicode tokens:
/// `<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>{name}\n```json\n{args}\n```<|tool▁call▁end|><|tool▁calls▁end|>`
///
/// Features:
/// - Unicode token delimiters
/// - JSON arguments in code blocks
/// - Support for multiple sequential tool calls
pub struct DeepSeekParser {
/// Parser for handling incomplete JSON during streaming
partial_json: PartialJson,
/// Regex for extracting complete tool calls
tool_call_extractor: Regex,
/// Regex for extracting function details
func_detail_extractor: Regex,
}
impl DeepSeekParser {
/// Create a new DeepSeek parser
pub fn new() -> Self {
// Use (?s) flag for DOTALL mode to handle newlines
let tool_call_pattern = r"(?s)<|tool▁call▁begin|>.*?<|tool▁call▁end|>";
let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
let func_detail_pattern = r"(?s)<|tool▁call▁begin|>(.*?)<|tool▁sep|>(.*?)\n```json\n(.*?)\n```<|tool▁call▁end|>";
let func_detail_extractor = Regex::new(func_detail_pattern).expect("Valid regex pattern");
Self {
partial_json: PartialJson::default(),
tool_call_extractor,
func_detail_extractor,
}
}
/// Check if text contains DeepSeek tool markers
fn has_tool_markers(&self, text: &str) -> bool {
text.contains("<|tool▁calls▁begin|>")
}
/// Extract all tool call blocks from text
fn extract_tool_calls<'a>(&self, text: &'a str) -> Vec<&'a str> {
self.tool_call_extractor
.find_iter(text)
.map(|m| m.as_str())
.collect()
}
/// Parse a single tool call block
fn parse_tool_call(&self, block: &str) -> ToolParserResult<Option<ToolCall>> {
if let Some(captures) = self.func_detail_extractor.captures(block) {
// Get function type (should be "function")
let func_type = captures.get(1).map_or("", |m| m.as_str());
if func_type != "function" {
return Ok(None);
}
// Get function name
let func_name = captures.get(2).map_or("", |m| m.as_str()).trim();
// Get JSON arguments
let json_args = captures.get(3).map_or("{}", |m| m.as_str()).trim();
// Parse JSON arguments
match serde_json::from_str::<Value>(json_args) {
Ok(value) => {
// Create arguments object
let args = if value.is_object() {
value
} else {
// If not an object, wrap it
serde_json::json!({ "value": value })
};
let arguments = serde_json::to_string(&args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID
let id = format!("deepseek_call_{}", uuid::Uuid::new_v4());
Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: func_name.to_string(),
arguments,
},
}))
}
Err(_) => Ok(None),
}
} else {
Ok(None)
}
}
}
impl Default for DeepSeekParser {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ToolParser for DeepSeekParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
// Check if text contains DeepSeek format
if !self.has_tool_markers(text) {
return Ok(vec![]);
}
// Extract all tool call blocks
let tool_blocks = self.extract_tool_calls(text);
let mut tools = Vec::new();
for block in tool_blocks {
if let Some(tool) = self.parse_tool_call(block)? {
tools.push(tool);
}
}
Ok(tools)
}
async fn parse_incremental(
&self,
chunk: &str,
state: &mut ParseState,
) -> ToolParserResult<StreamResult> {
state.buffer.push_str(chunk);
// Check for tool markers
if !self.has_tool_markers(&state.buffer) {
// No markers found, return as incomplete
return Ok(StreamResult::Incomplete);
}
// Look for start of tool calls
if let Some(start_pos) = state.buffer.find("<|tool▁calls▁begin|>") {
// Look for individual tool call start
let search_from = start_pos + "<|tool▁calls▁begin|>".len();
if let Some(call_start) = state.buffer[search_from..].find("<|tool▁call▁begin|>")
{
let call_start_abs = search_from + call_start;
// Look for the end of this tool call
let search_end_from = call_start_abs + "<|tool▁call▁begin|>".len();
if let Some(call_end) = state.buffer[search_end_from..].find("<|tool▁call▁end|>")
{
let call_end_abs = search_end_from + call_end + "<|tool▁call▁end|>".len();
// Extract and parse the complete tool call
let tool_call_text = &state.buffer[call_start_abs..call_end_abs];
if let Some(tool) = self.parse_tool_call(tool_call_text)? {
// Remove the processed part from buffer
state.buffer.drain(..call_end_abs);
return Ok(StreamResult::ToolComplete(tool));
}
} else {
// Tool call not complete yet, try to extract partial info
let partial = &state.buffer[search_end_from..];
// Try to extract function name
if let Some(sep_pos) = partial.find("<|tool▁sep|>") {
if let Some(_func_start) = partial[..sep_pos].rfind("function") {
// We have the function type marker
let after_sep = &partial[sep_pos + "<|tool▁sep|>".len()..];
// Look for function name (ends at newline before ```json)
if let Some(name_end) = after_sep.find("\n```json\n") {
let func_name = after_sep[..name_end].trim();
if !state.in_string {
state.in_string = true; // Mark name as sent
return Ok(StreamResult::ToolName {
index: 0,
name: func_name.to_string(),
});
}
// Try to extract partial arguments
let args_start = name_end + "\n```json\n".len();
let partial_args = &after_sep[args_start..];
// Check if we can parse partial JSON
if !partial_args.is_empty() {
match self.partial_json.parse_value(partial_args) {
Ok((value, _consumed)) => {
let args_str = serde_json::to_string(&value)
.unwrap_or_else(|_| "{}".to_string());
return Ok(StreamResult::ToolArguments {
index: 0,
arguments: args_str,
});
}
Err(_) => {
// Can't parse yet, keep buffering
}
}
}
}
}
}
}
}
}
Ok(StreamResult::Incomplete)
}
fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_deepseek_single_tool() {
let parser = DeepSeekParser::new();
let input = r#"Some text
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
```json
{"location": "Tokyo", "units": "celsius"}
```<|tool▁call▁end|><|tool▁calls▁end|>More text"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
assert!(result[0].function.arguments.contains("Tokyo"));
}
#[tokio::test]
async fn test_parse_deepseek_multiple_tools() {
let parser = DeepSeekParser::new();
let input = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
```json
{"location": "Tokyo"}
```<|tool▁call▁end|>
<|tool▁call▁begin|>function<|tool▁sep|>get_weather
```json
{"location": "Paris"}
```<|tool▁call▁end|><|tool▁calls▁end|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "get_weather");
assert_eq!(result[1].function.name, "get_weather");
assert!(result[0].function.arguments.contains("Tokyo"));
assert!(result[1].function.arguments.contains("Paris"));
}
#[test]
fn test_detect_format() {
let parser = DeepSeekParser::new();
assert!(parser.detect_format("<|tool▁calls▁begin|>"));
assert!(!parser.detect_format("plain text"));
assert!(!parser.detect_format("[TOOL_CALLS]"));
}
}
...@@ -3,12 +3,16 @@ ...@@ -3,12 +3,16 @@
/// This module contains concrete parser implementations for various model-specific /// This module contains concrete parser implementations for various model-specific
/// tool/function call formats. /// tool/function call formats.
// Individual parser modules // Individual parser modules
pub mod deepseek_parser;
pub mod json_parser; pub mod json_parser;
pub mod llama_parser; pub mod llama_parser;
pub mod mistral_parser; pub mod mistral_parser;
pub mod pythonic_parser; pub mod pythonic_parser;
pub mod qwen_parser; pub mod qwen_parser;
// Re-export parser types for convenience
pub use deepseek_parser::DeepSeekParser;
pub use json_parser::JsonParser; pub use json_parser::JsonParser;
pub use llama_parser::LlamaParser; pub use llama_parser::LlamaParser;
pub use mistral_parser::MistralParser; pub use mistral_parser::MistralParser;
......
use crate::tool_parser::parsers::{ use crate::tool_parser::parsers::{
JsonParser, LlamaParser, MistralParser, PythonicParser, QwenParser, DeepSeekParser, JsonParser, LlamaParser, MistralParser, PythonicParser, QwenParser,
}; };
use crate::tool_parser::traits::ToolParser; use crate::tool_parser::traits::ToolParser;
use std::collections::HashMap; use std::collections::HashMap;
...@@ -110,6 +110,9 @@ impl ParserRegistry { ...@@ -110,6 +110,9 @@ impl ParserRegistry {
// Llama parser - <|python_tag|>{...} or plain JSON format // Llama parser - <|python_tag|>{...} or plain JSON format
self.register_parser("llama", Arc::new(LlamaParser::new())); self.register_parser("llama", Arc::new(LlamaParser::new()));
// DeepSeek V3 parser - Unicode tokens with JSON blocks
self.register_parser("deepseek", Arc::new(DeepSeekParser::new()));
} }
/// Register default model mappings /// Register default model mappings
...@@ -141,7 +144,11 @@ impl ParserRegistry { ...@@ -141,7 +144,11 @@ impl ParserRegistry {
self.map_model("llama-*", "json"); self.map_model("llama-*", "json");
self.map_model("meta-llama-*", "json"); self.map_model("meta-llama-*", "json");
// DeepSeek models - DeepSeek v3 would need custom parser, v2 uses pythonic // DeepSeek models
// DeepSeek V3 uses custom Unicode token format
self.map_model("deepseek-v3*", "deepseek");
self.map_model("deepseek-ai/DeepSeek-V3*", "deepseek");
// DeepSeek V2 uses pythonic format
self.map_model("deepseek-*", "pythonic"); self.map_model("deepseek-*", "pythonic");
// Other models default to JSON // Other models default to JSON
......
//! DeepSeek V3 Parser Integration Tests
use sglang_router_rs::tool_parser::{DeepSeekParser, ParseState, StreamResult, ToolParser};
#[tokio::test]
async fn test_deepseek_complete_parsing() {
let parser = DeepSeekParser::new();
// Test single tool call
let input = r#"Let me help you with that.
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
```json
{"location": "Tokyo", "units": "celsius"}
```<|tool▁call▁end|><|tool▁calls▁end|>
The weather in Tokyo is..."#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
// Verify arguments
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["location"], "Tokyo");
assert_eq!(args["units"], "celsius");
}
#[tokio::test]
async fn test_deepseek_multiple_tools() {
let parser = DeepSeekParser::new();
let input = r#"<|tool▁calls▁begin|>
<|tool▁call▁begin|>function<|tool▁sep|>search
```json
{"query": "rust programming"}
```<|tool▁call▁end|>
<|tool▁call▁begin|>function<|tool▁sep|>translate
```json
{"text": "Hello World", "to": "ja"}
```<|tool▁call▁end|>
<|tool▁calls▁end|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "search");
assert_eq!(result[1].function.name, "translate");
}
#[tokio::test]
async fn test_deepseek_streaming() {
let parser = DeepSeekParser::new();
let mut state = ParseState::new();
// Simulate streaming chunks
let chunks = vec![
"<|tool▁calls▁begin|><|tool▁call▁begin|>",
"function<|tool▁sep|>get_weather\n",
"```json\n",
r#"{"location": "#,
r#""Beijing", "#,
r#""units": "metric"}"#,
"\n```<|tool▁call▁end|><|tool▁calls▁end|>",
];
let mut found_name = false;
let mut found_complete = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
match result {
StreamResult::ToolName { name, .. } => {
assert_eq!(name, "get_weather");
found_name = true;
}
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather");
found_complete = true;
}
_ => {}
}
}
assert!(found_name || found_complete);
}
#[tokio::test]
async fn test_deepseek_nested_json() {
let parser = DeepSeekParser::new();
let input = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>process
```json
{
"data": {
"nested": {
"deep": [1, 2, 3]
}
}
}
```<|tool▁call▁end|><|tool▁calls▁end|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "process");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert!(args["data"]["nested"]["deep"].is_array());
}
#[test]
fn test_deepseek_format_detection() {
let parser = DeepSeekParser::new();
// Should detect DeepSeek format
assert!(parser.detect_format("<|tool▁calls▁begin|>"));
assert!(parser.detect_format("text with <|tool▁calls▁begin|> marker"));
// Should not detect other formats
assert!(!parser.detect_format("[TOOL_CALLS]"));
assert!(!parser.detect_format("<tool_call>"));
assert!(!parser.detect_format("plain text"));
}
#[tokio::test]
async fn test_deepseek_malformed_json_handling() {
let parser = DeepSeekParser::new();
// Malformed JSON should be skipped
let input = r#"<|tool▁calls▁begin|>
<|tool▁call▁begin|>function<|tool▁sep|>broken
```json
{invalid json}
```<|tool▁call▁end|>
<|tool▁call▁begin|>function<|tool▁sep|>valid
```json
{"key": "value"}
```<|tool▁call▁end|>
<|tool▁calls▁end|>"#;
let result = parser.parse_complete(input).await.unwrap();
// Only the valid tool call should be parsed
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "valid");
}
#[tokio::test]
async fn test_normal_text_extraction() {
let parser = DeepSeekParser::new();
// Python extracts text before tool calls as normal_text
let input = r#"Let me help you with that.
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
```json
{"location": "Tokyo"}
```<|tool▁call▁end|><|tool▁calls▁end|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
// TODO: Verify normal text extraction when parser returns it
// In Python: normal_text = "Let me help you with that."
}
#[tokio::test]
async fn test_multiple_tool_calls() {
let parser = DeepSeekParser::new();
let input = r#"<|tool▁calls▁begin|>
<|tool▁call▁begin|>function<|tool▁sep|>get_weather
```json
{"location": "Tokyo"}
```<|tool▁call▁end|>
<|tool▁call▁begin|>function<|tool▁sep|>get_weather
```json
{"location": "Paris"}
```<|tool▁call▁end|>
<|tool▁calls▁end|><|end▁of▁sentence|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "get_weather");
assert_eq!(result[1].function.name, "get_weather");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment