Unverified Commit 31f5ed3c authored by Elyas Mehtabuddin's avatar Elyas Mehtabuddin Committed by GitHub
Browse files

feat: add finish reason = tool_calls for stream=False and phi-4 detect token start fix (#3087)


Signed-off-by: default avatarElyas Mehtabuddin <emehtabuddin@nvidia.com>
parent ef6734d0
...@@ -1469,6 +1469,7 @@ dependencies = [ ...@@ -1469,6 +1469,7 @@ dependencies = [
"rustpython-parser", "rustpython-parser",
"serde", "serde",
"serde_json", "serde_json",
"tokio",
"tracing", "tracing",
"uuid", "uuid",
] ]
......
...@@ -181,10 +181,13 @@ impl DeltaAggregator { ...@@ -181,10 +181,13 @@ impl DeltaAggregator {
.collect(); .collect();
// Initialize and push the converted tool calls to state_choice.tool_calls // Initialize and push the converted tool calls to state_choice.tool_calls
if let Some(existing_tool_calls) = &mut state_choice.tool_calls { // Only set tool_calls to Some if there are actual tool calls
existing_tool_calls.extend(converted_tool_calls); if !converted_tool_calls.is_empty() {
} else { if let Some(existing_tool_calls) = &mut state_choice.tool_calls {
state_choice.tool_calls = Some(converted_tool_calls); existing_tool_calls.extend(converted_tool_calls);
} else {
state_choice.tool_calls = Some(converted_tool_calls);
}
} }
} }
...@@ -257,6 +260,17 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice { ...@@ -257,6 +260,17 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice {
/// # Note /// # Note
/// The `function_call` field is deprecated. /// The `function_call` field is deprecated.
fn from(delta: DeltaChoice) -> Self { fn from(delta: DeltaChoice) -> Self {
// If tool calls are present and non-empty, finish reason should be ToolCalls
let finish_reason = if delta
.tool_calls
.as_ref()
.is_some_and(|calls| !calls.is_empty())
{
Some(dynamo_async_openai::types::FinishReason::ToolCalls)
} else {
delta.finish_reason
};
dynamo_async_openai::types::ChatChoice { dynamo_async_openai::types::ChatChoice {
message: dynamo_async_openai::types::ChatCompletionResponseMessage { message: dynamo_async_openai::types::ChatCompletionResponseMessage {
role: delta.role.expect("delta should have a Role"), role: delta.role.expect("delta should have a Role"),
...@@ -272,7 +286,7 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice { ...@@ -272,7 +286,7 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice {
reasoning_content: delta.reasoning_content, reasoning_content: delta.reasoning_content,
}, },
index: delta.index, index: delta.index,
finish_reason: delta.finish_reason, finish_reason,
logprobs: delta.logprobs, logprobs: delta.logprobs,
} }
} }
...@@ -347,7 +361,7 @@ mod tests { ...@@ -347,7 +361,7 @@ mod tests {
tool_calls.map(|tool_calls| serde_json::from_str(tool_calls).unwrap()); tool_calls.map(|tool_calls| serde_json::from_str(tool_calls).unwrap());
let tool_call_chunks = if let Some(tool_calls) = tool_calls { let tool_call_chunks = if let Some(tool_calls) = tool_calls {
vec![ Some(vec![
dynamo_async_openai::types::ChatCompletionMessageToolCallChunk { dynamo_async_openai::types::ChatCompletionMessageToolCallChunk {
index: 0, index: 0,
id: Some("test_id".to_string()), id: Some("test_id".to_string()),
...@@ -357,22 +371,15 @@ mod tests { ...@@ -357,22 +371,15 @@ mod tests {
arguments: Some(serde_json::to_string(&tool_calls["arguments"]).unwrap()), arguments: Some(serde_json::to_string(&tool_calls["arguments"]).unwrap()),
}), }),
}, },
] ])
} else { } else {
vec![ None
dynamo_async_openai::types::ChatCompletionMessageToolCallChunk {
index: 0,
id: None,
r#type: None,
function: None,
},
]
}; };
let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta { let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
content: Some(text.to_string()), content: Some(text.to_string()),
function_call: None, function_call: None,
tool_calls: Some(tool_call_chunks), tool_calls: tool_call_chunks,
role, role,
refusal: None, refusal: None,
reasoning_content: None, reasoning_content: None,
...@@ -625,6 +632,215 @@ mod tests { ...@@ -625,6 +632,215 @@ mod tests {
); );
} }
#[tokio::test]
async fn test_tool_calling_finish_reason_override_from_stop() {
// Test that when tool calls are present but finish reason is Stop, it gets overridden to ToolCalls
let tool_call_json =
r#"{"name": "get_weather", "arguments": {"location": "New York", "unit": "celsius"}}"#;
let annotated_delta = create_test_delta(
0,
"I'll check the weather for you.",
Some(dynamo_async_openai::types::Role::Assistant),
Some(dynamo_async_openai::types::FinishReason::Stop), // Original finish reason is Stop
None,
Some(tool_call_json),
);
let data = annotated_delta.data.unwrap();
let annotated_delta = Annotated {
data: Some(data),
id: Some("test_id".to_string()),
event: None,
comment: None,
};
let stream = Box::pin(stream::iter(vec![annotated_delta]));
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.choices.len(), 1);
let choice = &response.choices[0];
// Verify tool calls are present
assert!(choice.message.tool_calls.is_some());
let tool_calls = choice.message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
// Most importantly, verify that finish reason was overridden to ToolCalls despite original being Stop
assert_eq!(
choice.finish_reason,
Some(dynamo_async_openai::types::FinishReason::ToolCalls)
);
}
#[tokio::test]
async fn test_tool_calling_finish_reason_override_from_length() {
// Test that when tool calls are present but finish reason is Length, it gets overridden to ToolCalls
let tool_call_json = r#"{"name": "search", "arguments": {"query": "rust programming"}}"#;
let annotated_delta = create_test_delta(
0,
"Let me search for that.",
Some(dynamo_async_openai::types::Role::Assistant),
Some(dynamo_async_openai::types::FinishReason::Length), // Original finish reason is Length
None,
Some(tool_call_json),
);
let data = annotated_delta.data.unwrap();
let annotated_delta = Annotated {
data: Some(data),
id: Some("test_id".to_string()),
event: None,
comment: None,
};
let stream = Box::pin(stream::iter(vec![annotated_delta]));
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.choices.len(), 1);
let choice = &response.choices[0];
// Verify tool calls are present
assert!(choice.message.tool_calls.is_some());
let tool_calls = choice.message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
// Verify that finish reason was overridden to ToolCalls despite original being Length
assert_eq!(
choice.finish_reason,
Some(dynamo_async_openai::types::FinishReason::ToolCalls)
);
}
#[tokio::test]
async fn test_tool_calling_finish_reason_override_from_none() {
// Test that when tool calls are present but finish reason is None, it gets set to ToolCalls
let tool_call_json = r#"{"name": "calculate", "arguments": {"expression": "2+2"}}"#;
let annotated_delta = create_test_delta(
0,
"I'll calculate that for you.",
Some(dynamo_async_openai::types::Role::Assistant),
None, // Original finish reason is None
None,
Some(tool_call_json),
);
let data = annotated_delta.data.unwrap();
let annotated_delta = Annotated {
data: Some(data),
id: Some("test_id".to_string()),
event: None,
comment: None,
};
let stream = Box::pin(stream::iter(vec![annotated_delta]));
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.choices.len(), 1);
let choice = &response.choices[0];
// Verify tool calls are present
assert!(choice.message.tool_calls.is_some());
let tool_calls = choice.message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
// Verify that finish reason was set to ToolCalls despite original being None
assert_eq!(
choice.finish_reason,
Some(dynamo_async_openai::types::FinishReason::ToolCalls)
);
}
#[tokio::test]
async fn test_no_tool_calling_preserves_original_finish_reason() {
// Test that when no tool calls are present, the original finish reason is preserved
let annotated_delta = create_test_delta(
0,
"This is a regular response without tool calls.",
Some(dynamo_async_openai::types::Role::Assistant),
Some(dynamo_async_openai::types::FinishReason::Stop),
None,
None, // No tool calls
);
let data = annotated_delta.data.unwrap();
let annotated_delta = Annotated {
data: Some(data),
id: Some("test_id".to_string()),
event: None,
comment: None,
};
let stream = Box::pin(stream::iter(vec![annotated_delta]));
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.choices.len(), 1);
let choice = &response.choices[0];
// Verify no tool calls are present
assert!(choice.message.tool_calls.is_none());
// Verify that original finish reason (Stop) is preserved
assert_eq!(
choice.finish_reason,
Some(dynamo_async_openai::types::FinishReason::Stop)
);
}
#[tokio::test]
async fn test_empty_tool_calls_preserves_original_finish_reason() {
// Test that when tool calls array is empty, the original finish reason is preserved
// Create a delta with empty tool calls by modifying the create_test_delta output
let mut annotated_delta = create_test_delta(
0,
"Response with empty tool calls array.",
Some(dynamo_async_openai::types::Role::Assistant),
Some(dynamo_async_openai::types::FinishReason::Length),
None,
None,
);
// Manually set empty tool calls array
if let Some(ref mut data) = annotated_delta.data {
data.choices[0].delta.tool_calls = Some(vec![]); // Empty tool calls array
}
let data = annotated_delta.data.unwrap();
let annotated_delta = Annotated {
data: Some(data),
id: Some("test_id".to_string()),
event: None,
comment: None,
};
let stream = Box::pin(stream::iter(vec![annotated_delta]));
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.choices.len(), 1);
let choice = &response.choices[0];
// Verify tool calls array is empty
assert!(choice.message.tool_calls.is_none());
// Verify that original finish reason (Length) is preserved since tool calls are empty
assert_eq!(
choice.finish_reason,
Some(dynamo_async_openai::types::FinishReason::Length)
);
}
#[tokio::test] #[tokio::test]
async fn test_tool_calling_output() { async fn test_tool_calling_output() {
// Simulate a delta with a tool call in the content // Simulate a delta with a tool call in the content
...@@ -688,4 +904,45 @@ mod tests { ...@@ -688,4 +904,45 @@ mod tests {
dynamo_async_openai::types::Role::Assistant dynamo_async_openai::types::Role::Assistant
); );
} }
#[tokio::test]
async fn test_tool_calling_finish_reason_override_from_stop_alternative() {
// Test that when tool calls are present but finish reason is Stop, it gets overridden to ToolCalls
let tool_call_json =
r#"{"name": "get_weather", "arguments": {"location": "New York", "unit": "celsius"}}"#;
let annotated_delta = create_test_delta(
0,
"Getting weather for New York",
Some(dynamo_async_openai::types::Role::Assistant),
Some(dynamo_async_openai::types::FinishReason::Stop), // This should be overridden
None,
Some(tool_call_json),
);
let stream = Box::pin(stream::iter(vec![annotated_delta]));
// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result
assert!(result.is_ok());
let response = result.unwrap();
// There should be one choice
assert_eq!(response.choices.len(), 1);
let choice = &response.choices[0];
// The finish_reason should be ToolCalls, not Stop, because tool calls are present
assert_eq!(
choice.finish_reason,
Some(dynamo_async_openai::types::FinishReason::ToolCalls)
);
// Verify tool calls are present
assert!(choice.message.tool_calls.is_some());
let tool_calls = choice.message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].function.name, "get_weather");
}
} }
...@@ -173,16 +173,64 @@ pub fn detect_tool_call_start_harmony( ...@@ -173,16 +173,64 @@ pub fn detect_tool_call_start_harmony(
} }
if strict { if strict {
config // Check for complete start tokens first
let has_complete_token = config
.tool_call_start_tokens .tool_call_start_tokens
.iter() .iter()
.any(|token| trimmed.contains(token)) .any(|token| !token.is_empty() && trimmed.contains(token));
if has_complete_token {
return true;
}
// Check for partial start tokens (streaming scenario)
// This handles cases where start tokens are split across multiple chunks
config.tool_call_start_tokens.iter().any(|token| {
if token.is_empty() {
return false;
}
// Check if the chunk could be a prefix of this start token
// Handle Unicode character boundaries properly
for i in 1..=token.chars().count() {
if let Some(prefix) = token.chars().take(i).collect::<String>().get(..) {
let prefix_str = &prefix[..prefix.len()];
if trimmed == prefix_str || trimmed.ends_with(prefix_str) {
return true;
}
}
}
false
})
} else { } else {
config // Non-strict mode: check complete tokens and some heuristics
let has_complete_token = config
.tool_call_start_tokens .tool_call_start_tokens
.iter() .iter()
.any(|token| trimmed.contains(token)) .any(|token| !token.is_empty() && trimmed.contains(token));
|| trimmed.contains("<|channel|>")
if has_complete_token {
return true;
}
// Check for partial start tokens or known patterns
let has_partial_token = config.tool_call_start_tokens.iter().any(|token| {
if token.is_empty() {
return false;
}
// Check if the chunk could be a prefix of this start token
// Handle Unicode character boundaries properly
for i in 1..=token.chars().count() {
if let Some(prefix) = token.chars().take(i).collect::<String>().get(..) {
let prefix_str = &prefix[..prefix.len()];
if trimmed == prefix_str || trimmed.ends_with(prefix_str) {
return true;
}
}
}
false
});
has_partial_token || trimmed.contains("<|channel|>")
} }
} }
...@@ -328,4 +376,42 @@ mod detect_parser_tests { ...@@ -328,4 +376,42 @@ mod detect_parser_tests {
let result = detect_tool_call_start_harmony(text, &config, false); let result = detect_tool_call_start_harmony(text, &config, false);
assert!(result); assert!(result);
} }
#[test]
fn test_detect_tool_call_start_harmony_partial_tokens() {
// Test partial token detection for streaming scenarios
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default()
};
// Test various partial prefixes in strict mode
assert!(
detect_tool_call_start_harmony("<", &config, true),
"'<' should be detected as potential start"
);
assert!(
detect_tool_call_start_harmony("<|", &config, true),
"'<|' should be detected as potential start"
);
assert!(
detect_tool_call_start_harmony("<|start|>", &config, true),
"'<|start|>' should be detected as potential start"
);
assert!(
detect_tool_call_start_harmony("<|start|>assistant", &config, true),
"'<|start|>assistant' should be detected as potential start"
);
// Test that unrelated text is not detected in strict mode
assert!(
!detect_tool_call_start_harmony("hello world", &config, true),
"'hello world' should not be detected in strict mode"
);
assert!(
!detect_tool_call_start_harmony("xyz", &config, true),
"'xyz' should not be detected in strict mode"
);
}
} }
...@@ -73,9 +73,9 @@ fn handle_single_token_tool_calls(input: &str, start_token: &str) -> Option<Stri ...@@ -73,9 +73,9 @@ fn handle_single_token_tool_calls(input: &str, start_token: &str) -> Option<Stri
if s.is_empty() { if s.is_empty() {
continue; continue;
} }
// Only consider segments that start like JSON // Only consider segments that start like JSON (objects or arrays)
if s.starts_with('{') { if s.starts_with('{') {
// Trim trailing non-JSON by cutting at the last closing brace/bracket // Trim trailing non-JSON by cutting at the last closing brace
if let Some(pos) = s.rfind('}') { if let Some(pos) = s.rfind('}') {
let candidate = &s[..=pos].trim(); let candidate = &s[..=pos].trim();
// Keep only valid JSON candidates // Keep only valid JSON candidates
...@@ -83,17 +83,30 @@ fn handle_single_token_tool_calls(input: &str, start_token: &str) -> Option<Stri ...@@ -83,17 +83,30 @@ fn handle_single_token_tool_calls(input: &str, start_token: &str) -> Option<Stri
items.push(candidate.to_string()); items.push(candidate.to_string());
} }
} }
} else if s.starts_with('[') {
// Handle array format (like phi4: functools[{...}])
if let Some(pos) = s.rfind(']') {
let candidate = &s[..=pos].trim();
// Keep only valid JSON arrays
if serde_json::from_str::<serde_json::Value>(candidate).is_ok() {
// For arrays, we need to extract the individual objects
if let Ok(serde_json::Value::Array(arr)) =
serde_json::from_str::<serde_json::Value>(candidate)
{
for item in arr {
if let Ok(item_str) = serde_json::to_string(&item) {
items.push(item_str);
}
}
}
}
}
} }
} }
if items.is_empty() { if items.is_empty() {
// Remove everything up to and including the first occurrence of the start token // If we found the start token but no valid JSON after it, return empty string
if let Some(idx) = input.find(start_token) { // to avoid leaking the invalid content (important for phi4 and similar models)
let rest = &input[idx + start_token.len()..]; return Some(String::new());
return Some(rest.trim_start().to_string());
} else {
// Shouldn't happen because we checked contains() above, but be defensive
return None;
}
} }
Some(format!("[{}]", items.join(","))) Some(format!("[{}]", items.join(",")))
} }
...@@ -174,6 +187,7 @@ pub fn try_tool_call_parse_basic_json( ...@@ -174,6 +187,7 @@ pub fn try_tool_call_parse_basic_json(
// Assumption : One message will not contain different tags for tool calls. Iteration over tags is to support different tags by default for multiple models // Assumption : One message will not contain different tags for tool calls. Iteration over tags is to support different tags by default for multiple models
let mut json = trimmed.to_string(); let mut json = trimmed.to_string();
let mut normal_text = trimmed.to_string(); let mut normal_text = trimmed.to_string();
let mut found_start_token_with_no_valid_json = false;
// First, check if ANY start token exists in the input // First, check if ANY start token exists in the input
let has_start_token = tool_call_start_tokens let has_start_token = tool_call_start_tokens
...@@ -204,9 +218,16 @@ pub fn try_tool_call_parse_basic_json( ...@@ -204,9 +218,16 @@ pub fn try_tool_call_parse_basic_json(
// Single token case // Single token case
let result = handle_single_token_tool_calls(&json, start_token); let result = handle_single_token_tool_calls(&json, start_token);
if let Some(content) = result { if let Some(content) = result {
// Check if we found a start token but got empty JSON back
// This indicates the token was found but no valid JSON followed
if content.is_empty() {
found_start_token_with_no_valid_json = true;
}
json = content; json = content;
// For single token case, use the normal text we extracted earlier // For single token case, use the normal text we extracted earlier
normal_text = new_normal_text; normal_text = new_normal_text;
break; // Found content, exit early break; // Found content, exit early
} }
} }
...@@ -214,8 +235,15 @@ pub fn try_tool_call_parse_basic_json( ...@@ -214,8 +235,15 @@ pub fn try_tool_call_parse_basic_json(
// Start and end token case // Start and end token case
let result = extract_tool_call_content(&json, start_token, end_token); let result = extract_tool_call_content(&json, start_token, end_token);
if let Some(content) = result { if let Some(content) = result {
// Check if we found a start token but got empty JSON back
// This indicates the token was found but no valid JSON followed
if content.is_empty() {
found_start_token_with_no_valid_json = true;
}
json = content; json = content;
normal_text = new_normal_text; normal_text = new_normal_text;
break; // Found content, exit early break; // Found content, exit early
} }
} }
...@@ -304,7 +332,13 @@ pub fn try_tool_call_parse_basic_json( ...@@ -304,7 +332,13 @@ pub fn try_tool_call_parse_basic_json(
return Ok((results, Some(normal_text))); return Ok((results, Some(normal_text)));
} }
Ok((vec![], Some(trimmed.to_string()))) // If we found a start token but no valid JSON, return empty content
// to avoid leaking the token and invalid JSON content
if found_start_token_with_no_valid_json {
Ok((vec![], Some(String::new())))
} else {
Ok((vec![], Some(trimmed.to_string())))
}
} }
pub fn detect_tool_call_start_basic_json(chunk: &str, config: &JsonParserConfig) -> bool { pub fn detect_tool_call_start_basic_json(chunk: &str, config: &JsonParserConfig) -> bool {
...@@ -312,12 +346,48 @@ pub fn detect_tool_call_start_basic_json(chunk: &str, config: &JsonParserConfig) ...@@ -312,12 +346,48 @@ pub fn detect_tool_call_start_basic_json(chunk: &str, config: &JsonParserConfig)
if trimmed.is_empty() { if trimmed.is_empty() {
return false; return false;
} }
config
// Check if chunk contains any complete start token
let contains_complete_token = config
.tool_call_start_tokens .tool_call_start_tokens
.iter() .iter()
.any(|token| trimmed.contains(token)) .any(|token| !token.is_empty() && trimmed.contains(token));
|| trimmed.contains('{')
|| trimmed.contains('[') if contains_complete_token {
return true;
}
// Check for partial start tokens (streaming scenario)
// This handles cases where start tokens are split across multiple chunks
let has_partial_token = config.tool_call_start_tokens.iter().any(|token| {
if token.is_empty() {
return false;
}
// Check if the chunk could be a prefix of this start token
// Handle Unicode character boundaries properly
for i in 1..=token.chars().count() {
if let Some(prefix) = token.chars().take(i).collect::<String>().get(..) {
let prefix_str = &prefix[..prefix.len()];
// Check for exact prefix match
if trimmed == prefix_str {
return true;
}
// For longer prefixes (3+ chars), allow them anywhere in the input
// This allows "funny joke" to match "functools" via "fun"
// but prevents "<tool_call>" from matching "<TOOLCALL>" via single char "<"
if prefix_str.len() >= 3 && trimmed.contains(prefix_str) {
return true;
}
// For shorter prefixes, only match if they're at the end (streaming scenario)
if prefix_str.len() < 3 && trimmed.ends_with(prefix_str) {
return true;
}
}
}
false
});
has_partial_token || trimmed.contains('{') || trimmed.contains('[')
} }
#[cfg(test)] #[cfg(test)]
...@@ -435,4 +505,97 @@ mod detect_parser_tests { ...@@ -435,4 +505,97 @@ mod detect_parser_tests {
let result = detect_tool_call_start_basic_json(text, &config); let result = detect_tool_call_start_basic_json(text, &config);
assert!(result); assert!(result);
} }
#[test]
fn detect_tool_call_start_basic_json_chunk_phi4_partial_token_fun() {
// Test the streaming scenario where "fun" arrives first
let text = r#"fun"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["functools".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
};
let result = detect_tool_call_start_basic_json(text, &config);
assert!(
result,
"Should detect 'fun' as potential start of 'functools'"
);
}
#[test]
fn detect_tool_call_start_basic_json_chunk_phi4_partial_token_func() {
let text = r#"func"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["functools".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
};
let result = detect_tool_call_start_basic_json(text, &config);
assert!(
result,
"Should detect 'func' as potential start of 'functools'"
);
}
#[test]
fn detect_tool_call_start_basic_json_chunk_phi4_partial_token_f() {
let text = r#"f"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["functools".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
};
let result = detect_tool_call_start_basic_json(text, &config);
assert!(
result,
"Should detect 'f' as potential start of 'functools'"
);
}
#[test]
fn detect_tool_call_start_basic_json_chunk_phi4_partial_with_prefix() {
// Test case where text ends with a partial token (more realistic streaming scenario)
let text = r#"Hello fun"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["functools".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
};
let result = detect_tool_call_start_basic_json(text, &config);
assert!(
result,
"Should detect text ending with 'fun' as potential tool call start"
);
}
#[test]
fn detect_tool_call_start_basic_json_chunk_phi4_avoid_false_positive() {
// Test to ensure we don't get false positives for unrelated text
let text = r#"funny joke"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["functools".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
};
let result = detect_tool_call_start_basic_json(text, &config);
// This should still return true because "fun" is a prefix, but that's expected behavior
// The key is that we detect potential starts, and false positives are acceptable
// in streaming scenarios to avoid missing real tool calls
assert!(result);
}
#[test]
fn detect_tool_call_start_basic_json_chunk_phi4_no_match() {
let text = r#"hello world"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["functools".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
};
let result = detect_tool_call_start_basic_json(text, &config);
assert!(
!result,
"Should not detect unrelated text as tool call start"
);
}
} }
...@@ -103,11 +103,38 @@ pub fn parse_tool_calls_deepseek_v3_1( ...@@ -103,11 +103,38 @@ pub fn parse_tool_calls_deepseek_v3_1(
pub fn detect_tool_call_start_deepseek_v3_1(chunk: &str, config: &JsonParserConfig) -> bool { pub fn detect_tool_call_start_deepseek_v3_1(chunk: &str, config: &JsonParserConfig) -> bool {
let trimmed = chunk.trim(); let trimmed = chunk.trim();
!trimmed.is_empty() if trimmed.is_empty() {
&& config return false;
.tool_call_start_tokens }
.iter()
.any(|token| trimmed.contains(token)) // Check for complete start tokens first
let has_complete_token = config
.tool_call_start_tokens
.iter()
.any(|token| !token.is_empty() && trimmed.contains(token));
if has_complete_token {
return true;
}
// Check for partial start tokens (streaming scenario)
// This handles cases where start tokens are split across multiple chunks
config.tool_call_start_tokens.iter().any(|token| {
if token.is_empty() {
return false;
}
// Check if the chunk could be a prefix of this start token
// Handle Unicode character boundaries properly
for i in 1..=token.chars().count() {
if let Some(prefix) = token.chars().take(i).collect::<String>().get(..) {
let prefix_str = &prefix[..prefix.len()];
if trimmed == prefix_str || trimmed.ends_with(prefix_str) {
return true;
}
}
}
false
})
} }
#[cfg(test)] #[cfg(test)]
...@@ -263,4 +290,42 @@ mod detect_parser_tests { ...@@ -263,4 +290,42 @@ mod detect_parser_tests {
let result = detect_tool_call_start_deepseek_v3_1(text, &config); let result = detect_tool_call_start_deepseek_v3_1(text, &config);
assert!(result); assert!(result);
} }
#[test]
fn test_detect_tool_call_start_deepseek_v3_1_partial_tokens() {
// Test partial token detection for streaming scenarios with unicode characters
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()],
tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()],
..Default::default()
};
// Test various partial prefixes
assert!(
detect_tool_call_start_deepseek_v3_1("<", &config),
"'<' should be detected as potential start"
);
assert!(
detect_tool_call_start_deepseek_v3_1("<|", &config),
"'<|' should be detected as potential start"
);
assert!(
detect_tool_call_start_deepseek_v3_1("<|tool", &config),
"'<|tool' should be detected as potential start"
);
assert!(
detect_tool_call_start_deepseek_v3_1("<|tool▁calls", &config),
"'<|tool▁calls' should be detected as potential start"
);
// Test that unrelated text is not detected
assert!(
!detect_tool_call_start_deepseek_v3_1("hello world", &config),
"'hello world' should not be detected"
);
assert!(
!detect_tool_call_start_deepseek_v3_1("xyz", &config),
"'xyz' should not be detected"
);
}
} }
...@@ -1218,6 +1218,197 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it ...@@ -1218,6 +1218,197 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
assert_eq!(args["to"], "Los Angeles"); assert_eq!(args["to"], "Los Angeles");
} }
#[tokio::test]
async fn test_phi4_token_leak_reproduction() {
// Reproduce the issue where "functools" appears in content field
// This might happen when there's malformed JSON or parsing issues
let input = r#"functools{"name": "get_weather","arguments":{"location":"San Francisco"}}"#;
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"))
.await
.unwrap();
// Content should be empty, not contain "functools"
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco");
}
#[tokio::test]
async fn test_phi4_token_leak_edge_case() {
// Test the case where only the token appears without JSON
// This case is less critical but shouldn't leak the full token
let input = r#"functools"#;
let (result, _content) = detect_and_parse_tool_call(input, Some("phi4"))
.await
.unwrap();
// Content may contain the token if no valid JSON follows, but shouldn't crash
// The important thing is that no tool calls are returned
assert_eq!(result.len(), 0); // No tool calls found
// Content behavior is less critical for this edge case
}
#[tokio::test]
async fn test_phi4_token_with_invalid_json() {
// Test the case where token is followed by invalid JSON
let input = r#"functools{invalid json}"#;
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"))
.await
.unwrap();
// Content should be empty, not contain "functools" or leak the token
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 0); // No tool calls found due to invalid JSON
}
#[tokio::test]
async fn test_phi4_streaming_partial_tokens() {
// Test that our fix handles the actual streaming scenario described by the user
// Where "fun", "ct", "ools" arrive as separate chunks
// Test that "fun" is detected as a potential tool call start (for streaming jailing)
let config = super::get_tool_parser_map().get("phi4").unwrap();
// Test detection of partial tokens
use super::super::json::detect_tool_call_start_json;
assert!(
detect_tool_call_start_json("fun", &config.json),
"'fun' should be detected as potential start"
);
assert!(
detect_tool_call_start_json("f", &config.json),
"'f' should be detected as potential start"
);
assert!(
detect_tool_call_start_json("func", &config.json),
"'func' should be detected as potential start"
);
assert!(
detect_tool_call_start_json("functo", &config.json),
"'functo' should be detected as potential start"
);
// Test that unrelated text is not detected
assert!(
!detect_tool_call_start_json("hello", &config.json),
"'hello' should not be detected"
);
assert!(
!detect_tool_call_start_json("xyz", &config.json),
"'xyz' should not be detected"
);
}
#[tokio::test]
async fn test_phi4_false_positive_words() {
// Test that words like "funk" or text starting with "func" but not "functools"
// are correctly treated as normal content, not tool calls
let input = r#"funk music is great"#;
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"))
.await
.unwrap();
// Should be treated as normal content, not tool call
assert_eq!(
result.len(),
0,
"No tool calls should be found in 'funk music is great'"
);
assert_eq!(
content,
Some("funk music is great".to_string()),
"Content should contain the original text"
);
}
#[tokio::test]
async fn test_phi4_partial_but_complete_words() {
// Test words that start with "func" but are not "functools"
let input = r#"The function works well"#;
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"))
.await
.unwrap();
assert_eq!(
result.len(),
0,
"No tool calls should be found in 'The function works well'"
);
assert_eq!(content, Some("The function works well".to_string()));
let input = r#"functional programming"#;
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"))
.await
.unwrap();
assert_eq!(
result.len(),
0,
"No tool calls should be found in 'functional programming'"
);
assert_eq!(content, Some("functional programming".to_string()));
}
#[tokio::test]
async fn test_phi4_funk_variations() {
// Test various "funk" related words to ensure they're not treated as tool calls
let test_cases = vec![
"funk",
"funky",
"funktion", // German word for function
"funked",
"I love funk music",
"This is funky stuff",
];
for test_input in test_cases {
let (result, content) = detect_and_parse_tool_call(test_input, Some("phi4"))
.await
.unwrap();
assert_eq!(
result.len(),
0,
"No tool calls should be found in '{}'",
test_input
);
assert_eq!(
content,
Some(test_input.to_string()),
"Content should match input for '{}'",
test_input
);
}
}
#[tokio::test]
async fn test_phi4_func_but_not_functools() {
// Test words starting with "func" that are complete words, not partial "functools"
let test_cases = vec![
"func()", // Programming syntax
"funcdef", // Python keyword variant
"functions are useful",
"functionally speaking",
];
for test_input in test_cases {
let (result, content) = detect_and_parse_tool_call(test_input, Some("phi4"))
.await
.unwrap();
assert_eq!(
result.len(),
0,
"No tool calls should be found in '{}'",
test_input
);
assert_eq!(
content,
Some(test_input.to_string()),
"Content should match input for '{}'",
test_input
);
}
}
#[tokio::test] #[tokio::test]
async fn test_pythonic_parser_basic_with_constants() { async fn test_pythonic_parser_basic_with_constants() {
let input = r#"[get_weather(location="San Francisco", unit="fahrenheit"), get_weather(location="New York", unit="fahrenheit")]"#; let input = r#"[get_weather(location="San Francisco", unit="fahrenheit"), get_weather(location="New York", unit="fahrenheit")]"#;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment