Unverified Commit 2ae20102 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: jail stream optimizations (v1) (#3195)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 6ba64c31
......@@ -3,12 +3,14 @@
use async_stream::stream;
use dynamo_async_openai::types::{
ChatChoiceStream, ChatCompletionMessageToolCallChunk, ChatCompletionStreamResponseDelta,
FinishReason, FunctionCallStream, Role,
ChatChoiceLogprobs, ChatChoiceStream, ChatCompletionMessageToolCallChunk,
ChatCompletionStreamResponseDelta, FinishReason, FunctionCallStream, Role,
};
use dynamo_parsers::tool_calling::parsers::get_tool_parser_map;
use dynamo_parsers::tool_calling::{detect_tool_call_start, try_tool_call_parse_aggregate};
use dynamo_parsers::tool_calling::{
detect_tool_call_start, find_tool_call_end_position, try_tool_call_parse_aggregate,
};
use dynamo_runtime::protocols::annotated::Annotated;
use futures::{Stream, StreamExt};
......@@ -72,6 +74,30 @@ struct ChoiceJailState {
partial_match_buffer: String,
}
fn create_choice_stream(
index: u32,
role: Option<Role>,
content: &str,
tool_calls: Option<Vec<ChatCompletionMessageToolCallChunk>>,
finish_reason: Option<FinishReason>,
logprobs: Option<ChatChoiceLogprobs>,
) -> ChatChoiceStream {
#[allow(deprecated)]
ChatChoiceStream {
index,
delta: ChatCompletionStreamResponseDelta {
role,
content: Some(content.to_string()),
tool_calls,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason,
logprobs,
}
}
impl ChoiceJailState {
/// Create a new jail state for a choice
fn new(index: u32) -> Self {
......@@ -120,19 +146,14 @@ impl ChoiceJailState {
// Emit prefix if any
if !prefix.is_empty() {
#[allow(deprecated)]
let prefix_choice = ChatChoiceStream {
index: choice.index,
delta: ChatCompletionStreamResponseDelta {
role: choice.delta.role,
content: Some(prefix),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: choice.logprobs.clone(),
};
let prefix_choice = create_choice_stream(
choice.index,
choice.delta.role,
&prefix,
None,
None,
choice.logprobs.clone(),
);
emissions.push(ChoiceEmission::PassThrough(prefix_choice));
}
......@@ -165,19 +186,14 @@ impl ChoiceJailState {
// Handle trailing content if any
if !trailing_part.is_empty() {
#[allow(deprecated)]
let trailing_choice = ChatChoiceStream {
index: choice.index,
delta: ChatCompletionStreamResponseDelta {
role: choice.delta.role,
content: Some(trailing_part.to_string()),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: choice.logprobs.clone(),
};
let trailing_choice = create_choice_stream(
choice.index,
choice.delta.role,
trailing_part,
None,
None,
choice.logprobs.clone(),
);
emissions.push(ChoiceEmission::Trailing(trailing_choice));
}
} else {
......@@ -202,19 +218,14 @@ impl ChoiceJailState {
// Emit the safe prefix
if !prefix.is_empty() {
#[allow(deprecated)]
let prefix_choice = ChatChoiceStream {
index: choice.index,
delta: ChatCompletionStreamResponseDelta {
role: choice.delta.role,
content: Some(prefix),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: choice.logprobs.clone(),
};
let prefix_choice = create_choice_stream(
choice.index,
choice.delta.role,
&prefix,
None,
None,
choice.logprobs.clone(),
);
emissions.push(ChoiceEmission::PassThrough(prefix_choice));
}
......@@ -250,19 +261,14 @@ impl ChoiceJailState {
// No markers - emit everything
if !content.is_empty() {
#[allow(deprecated)]
let pass_through_choice = ChatChoiceStream {
index: choice.index,
delta: ChatCompletionStreamResponseDelta {
role: choice.delta.role,
content: Some(content),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: choice.logprobs.clone(),
};
let pass_through_choice = create_choice_stream(
choice.index,
choice.delta.role,
&content,
None,
None,
choice.logprobs.clone(),
);
emissions.push(ChoiceEmission::PassThrough(pass_through_choice));
}
self.partial_match_buffer.clear();
......@@ -300,19 +306,14 @@ impl ChoiceJailState {
// Handle trailing content if any
if !trailing_part.is_empty() {
#[allow(deprecated)]
let trailing_choice = ChatChoiceStream {
index: choice.index,
delta: ChatCompletionStreamResponseDelta {
role: choice.delta.role,
content: Some(trailing_part.to_string()),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: choice.logprobs.clone(),
};
let trailing_choice = create_choice_stream(
choice.index,
choice.delta.role,
trailing_part,
None,
None,
choice.logprobs.clone(),
);
emissions.push(ChoiceEmission::Trailing(trailing_choice));
}
......@@ -335,19 +336,14 @@ impl ChoiceJailState {
// Create a dummy choice for the method call
#[allow(deprecated)]
let dummy_choice = ChatChoiceStream {
index: self.index,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: None,
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: None,
};
let dummy_choice = create_choice_stream(
self.index,
Some(Role::Assistant),
&self.accumulated_content,
None,
None,
None,
);
let final_choice = jail_stream
.create_tool_call_choice(self.index, &self.accumulated_content, &dummy_choice)
......@@ -663,7 +659,7 @@ impl JailedStream {
if let Ok((_, _)) =
try_tool_call_parse_aggregate(accumulated_content, Some(parser)).await
{
let split_pos = self.find_tool_call_end_position(accumulated_content, parser);
let split_pos = find_tool_call_end_position(accumulated_content, Some(parser));
(true, split_pos)
} else {
(false, accumulated_content.len())
......@@ -704,37 +700,25 @@ impl JailedStream {
.collect();
// Create choice with tool calls
#[allow(deprecated)]
return ChatChoiceStream {
index: choice_index,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: normal_text.filter(|t| !t.is_empty()),
tool_calls: Some(tool_call_chunks),
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: Some(FinishReason::ToolCalls),
logprobs: None,
};
return create_choice_stream(
choice_index,
Some(Role::Assistant),
normal_text.as_deref().unwrap_or(""),
Some(tool_call_chunks),
Some(FinishReason::ToolCalls),
None,
);
}
// No tool calls found or parsing failed, return content choice
#[allow(deprecated)]
ChatChoiceStream {
index: choice_index,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: Some(accumulated_content.to_string()),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: base_choice.logprobs.clone(),
}
create_choice_stream(
choice_index,
Some(Role::Assistant),
accumulated_content,
None,
None,
base_choice.logprobs.clone(),
)
}
/// Check if accumulated content contains complete tool calls that can be parsed
......@@ -750,63 +734,6 @@ impl JailedStream {
}
false
}
/// Find the exact position where the tool call ends for splitting content
/// This handles the early exit case where we have trailing content after the tool call
fn find_tool_call_end_position(&self, content: &str, parser: &str) -> usize {
match parser {
"hermes" => {
// For Hermes, look for </tool_call> marker
if let Some(pos) = content.find("</tool_call>") {
pos + "</tool_call>".len()
} else {
content.len()
}
}
"nemotron_deci" => {
// For Nemotron, look for </TOOLCALL> marker
if let Some(pos) = content.find("</TOOLCALL>") {
pos + "</TOOLCALL>".len()
} else {
content.len()
}
}
"mistral" => {
// For Mistral, look for [/TOOL_CALLS] marker or end of JSON array
if let Some(pos) = content.find("[/TOOL_CALLS]") {
pos + "[/TOOL_CALLS]".len()
} else if let Some(pos) = content.rfind(']') {
// Find the last ] which should be the end of the tool calls array
pos + 1
} else {
content.len()
}
}
"phi4" => {
// For Phi4, look for <|tool_call|> end marker
if let Some(pos) = content.rfind("<|tool_call|>") {
// Look for the next occurrence after this position
if let Some(end_pos) = content[pos..].find(">") {
pos + end_pos + 1
} else {
content.len()
}
} else {
content.len()
}
}
"llama3_json" => {
// For Llama3 JSON, there's no explicit end marker
// The end is determined by complete JSON parsing
// Return full content length to avoid early splitting
content.len()
}
_ => {
// Unknown parser, default to full content
content.len()
}
}
}
}
/// Builder for configuring a JailedStream
......
......@@ -3,7 +3,20 @@
pub mod harmony_parser;
pub use super::config::JsonParserConfig;
pub use super::{config, response};
pub use harmony_parser::{
detect_tool_call_start_harmony, parse_tool_calls_harmony, parse_tool_calls_harmony_complete,
};
pub fn find_tool_call_end_position_harmony(chunk: &str, config: &JsonParserConfig) -> usize {
let end_token = config
.tool_call_end_tokens
.first()
.map_or("<|call|>", |v| v);
if let Some(pos) = chunk.rfind(end_token) {
pos + end_token.len()
} else {
chunk.len()
}
}
......@@ -41,3 +41,31 @@ pub fn detect_tool_call_start_json(chunk: &str, config: &JsonParserConfig) -> bo
JsonParserType::DeepseekV31 => detect_tool_call_start_deepseek_v3_1(chunk, config),
}
}
pub fn find_tool_call_end_position_json(
chunk: &str,
parser: &str,
config: &JsonParserConfig,
) -> usize {
match parser {
"hermes" | "nemotron_deci" => {
if let Some(end_token) = config.tool_call_end_tokens.first() {
if let Some(pos) = chunk.find(end_token) {
pos + end_token.len()
} else {
chunk.len()
}
} else {
chunk.len()
}
}
"mistral" | "phi4" => {
if let Some(pos) = chunk.rfind(']') {
pos + 1
} else {
chunk.len()
}
}
_ => chunk.len(),
}
}
......@@ -13,7 +13,10 @@ pub mod tools;
pub use config::{JsonParserConfig, ToolCallConfig, ToolCallParserType};
pub use harmony::{parse_tool_calls_harmony, parse_tool_calls_harmony_complete};
pub use json::try_tool_call_parse_json;
pub use parsers::{detect_and_parse_tool_call, detect_tool_call_start, try_tool_call_parse};
pub use parsers::{
detect_and_parse_tool_call, detect_tool_call_start, find_tool_call_end_position,
try_tool_call_parse,
};
pub use pythonic::try_tool_call_parse_pythonic;
pub use response::{CalledFunction, ToolCallResponse, ToolCallType};
pub use tools::{try_tool_call_parse_aggregate, try_tool_call_parse_stream};
......@@ -2,9 +2,17 @@
// SPDX-License-Identifier: Apache-2.0
use super::config::{ToolCallConfig, ToolCallParserType};
use super::harmony::{detect_tool_call_start_harmony, parse_tool_calls_harmony_complete};
use super::json::{detect_tool_call_start_json, try_tool_call_parse_json};
use super::pythonic::{detect_tool_call_start_pythonic, try_tool_call_parse_pythonic};
use super::harmony::{
detect_tool_call_start_harmony, find_tool_call_end_position_harmony,
parse_tool_calls_harmony_complete,
};
use super::json::{
detect_tool_call_start_json, find_tool_call_end_position_json, try_tool_call_parse_json,
};
use super::pythonic::{
detect_tool_call_start_pythonic, find_tool_call_end_position_pythonic,
try_tool_call_parse_pythonic,
};
use super::response::ToolCallResponse;
use std::collections::HashMap;
use std::sync::OnceLock;
......@@ -116,6 +124,41 @@ pub fn detect_tool_call_start(chunk: &str, parser_str: Option<&str>) -> anyhow::
}
}
pub fn find_tool_call_end_position(chunk: &str, parser_str: Option<&str>) -> usize {
let parser_map = get_tool_parser_map();
let parser_key = match parser_str {
Some(s) if !s.is_empty() => s,
_ => "default",
};
match parser_map.get(parser_key) {
Some(config) => match config.format {
ToolCallParserType::Json => {
// For "default", use "nemotron_deci" as the effective parser; otherwise, use the provided parser_key
let effective_parser = if parser_key == "default" {
"nemotron_deci"
} else {
parser_key
};
find_tool_call_end_position_json(chunk, effective_parser, &config.json)
}
ToolCallParserType::Harmony => find_tool_call_end_position_harmony(chunk, &config.json),
ToolCallParserType::Pythonic => find_tool_call_end_position_pythonic(chunk),
ToolCallParserType::Typescript => {
// Typescript parser not implemented
chunk.len()
}
ToolCallParserType::Xml => {
// Xml parser not implemented
chunk.len()
}
},
None => {
// Unknown parser, return full content length
chunk.len()
}
}
}
// Tests
// cargo test postprocessor::tool_calling::parsers
#[cfg(test)]
......
......@@ -5,3 +5,7 @@ pub mod pythonic_parser;
pub use super::{config, response};
pub use pythonic_parser::{detect_tool_call_start_pythonic, try_tool_call_parse_pythonic};
pub fn find_tool_call_end_position_pythonic(chunk: &str) -> usize {
chunk.len()
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment