Unverified Commit 2a61e29e authored by Elyas Mehtabuddin's avatar Elyas Mehtabuddin Committed by GitHub
Browse files

feat: Tool calling support stream=True (#2932)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
Signed-off-by: default avatarBiswa Panda <biswa.panda@gmail.com>
Co-authored-by: default avatarayushag <ayushag@nvidia.com>
Co-authored-by: default avatarBiswa Panda <biswa.panda@gmail.com>
parent 55659eae
...@@ -598,12 +598,6 @@ pub fn validate_chat_completion_unsupported_fields( ...@@ -598,12 +598,6 @@ pub fn validate_chat_completion_unsupported_fields(
)); ));
} }
if inner.stream == Some(true) && inner.tools.is_some() {
return Err(ErrorMessage::not_implemented_error(
"`stream: true` is not supported when `tools` are provided.",
));
}
if inner.function_call.is_some() { if inner.function_call.is_some() {
return Err(ErrorMessage::not_implemented_error( return Err(ErrorMessage::not_implemented_error(
"`function_call` is deprecated. Please migrate to use `tool_choice` instead.", "`function_call` is deprecated. Please migrate to use `tool_choice` instead.",
......
...@@ -22,6 +22,10 @@ use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; ...@@ -22,6 +22,10 @@ use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use tracing; use tracing;
use dynamo_parsers::tool_calling::{
parsers::detect_tool_call_start, try_tool_call_parse_aggregate,
};
use crate::model_card::{ModelDeploymentCard, ModelInfo}; use crate::model_card::{ModelDeploymentCard, ModelInfo};
use crate::preprocessor::prompt::OAIChatLikeRequest; use crate::preprocessor::prompt::OAIChatLikeRequest;
use crate::protocols::common::preprocessor::PreprocessedRequestBuilder; use crate::protocols::common::preprocessor::PreprocessedRequestBuilder;
...@@ -55,6 +59,7 @@ use crate::protocols::common::llm_backend::EmbeddingsEngineOutput; ...@@ -55,6 +59,7 @@ use crate::protocols::common::llm_backend::EmbeddingsEngineOutput;
pub const ANNOTATION_FORMATTED_PROMPT: &str = "formatted_prompt"; pub const ANNOTATION_FORMATTED_PROMPT: &str = "formatted_prompt";
pub const ANNOTATION_TOKEN_IDS: &str = "token_ids"; pub const ANNOTATION_TOKEN_IDS: &str = "token_ids";
pub const ANNOTATION_LLM_METRICS: &str = "llm_metrics"; pub const ANNOTATION_LLM_METRICS: &str = "llm_metrics";
pub const ANNOTATION_POSSIBLE_TOOL_CALL: &str = "possible_tool_call";
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LLMMetricAnnotation { pub struct LLMMetricAnnotation {
pub input_tokens: usize, pub input_tokens: usize,
...@@ -62,6 +67,16 @@ pub struct LLMMetricAnnotation { ...@@ -62,6 +67,16 @@ pub struct LLMMetricAnnotation {
pub chunk_tokens: usize, pub chunk_tokens: usize,
} }
#[derive(Debug)]
pub struct JailState {
stream: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
is_jailed: bool,
tool_call_parser: Option<String>,
accumulated_content: HashMap<u32, String>, // choice index -> accumulated content
last_response_metadata: Option<NvCreateChatCompletionStreamResponse>, // for response structure
finished: bool, // Add this flag to track if stream is finished
}
impl LLMMetricAnnotation { impl LLMMetricAnnotation {
/// Convert this metrics struct to an Annotated event /// Convert this metrics struct to an Annotated event
pub fn to_annotation<T>(&self) -> Result<Annotated<T>, serde_json::Error> { pub fn to_annotation<T>(&self) -> Result<Annotated<T>, serde_json::Error> {
...@@ -90,11 +105,47 @@ impl LLMMetricAnnotation { ...@@ -90,11 +105,47 @@ impl LLMMetricAnnotation {
} }
} }
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct PossibleToolCallAnnotation {
pub possible_tokens: usize,
pub possible_content: String,
pub parser_used: Option<String>,
}
impl PossibleToolCallAnnotation {
/// Convert this possible tool call annotation to an Annotated event
pub fn to_annotation<T>(&self) -> Result<Annotated<T>, serde_json::Error> {
Annotated::from_annotation(ANNOTATION_POSSIBLE_TOOL_CALL, self)
}
/// Extract possible tool call info from an Annotated event, if present
pub fn from_annotation<T>(
annotation: &Annotated<T>,
) -> Result<Option<PossibleToolCallAnnotation>, Box<dyn std::error::Error>> {
if annotation.event.is_none() {
return Ok(None);
}
if annotation.event.as_ref().unwrap() != ANNOTATION_POSSIBLE_TOOL_CALL {
return Ok(None);
}
let comments = annotation
.comment
.as_ref()
.ok_or("missing comments block")?;
if comments.len() != 1 {
return Err("malformed comments block - expected exactly 1 comment".into());
}
let possible_info: PossibleToolCallAnnotation = serde_json::from_str(&comments[0])?;
Ok(Some(possible_info))
}
}
pub struct OpenAIPreprocessor { pub struct OpenAIPreprocessor {
mdcsum: String, mdcsum: String,
formatter: Arc<dyn OAIPromptFormatter>, formatter: Arc<dyn OAIPromptFormatter>,
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
model_info: Arc<dyn ModelInfo>, model_info: Arc<dyn ModelInfo>,
tool_call_parser: Option<String>,
} }
impl OpenAIPreprocessor { impl OpenAIPreprocessor {
...@@ -119,12 +170,14 @@ impl OpenAIPreprocessor { ...@@ -119,12 +170,14 @@ impl OpenAIPreprocessor {
); );
}; };
let model_info = model_info.get_model_info()?; let model_info = model_info.get_model_info()?;
let tool_call_parser = mdc.runtime_config.tool_call_parser.clone();
Ok(Arc::new(Self { Ok(Arc::new(Self {
formatter, formatter,
tokenizer, tokenizer,
model_info, model_info,
mdcsum, mdcsum,
tool_call_parser,
})) }))
} }
/// Encode a string to it's tokens /// Encode a string to it's tokens
...@@ -420,6 +473,7 @@ impl OpenAIPreprocessor { ...@@ -420,6 +473,7 @@ impl OpenAIPreprocessor {
context: Arc<dyn AsyncEngineContext>, context: Arc<dyn AsyncEngineContext>,
cancelled: bool, cancelled: bool,
cumulative_output_tokens: usize, cumulative_output_tokens: usize,
finished: bool, // Add this flag to track if stream is finished
} }
let state = State { let state = State {
...@@ -428,17 +482,24 @@ impl OpenAIPreprocessor { ...@@ -428,17 +482,24 @@ impl OpenAIPreprocessor {
context: context.clone(), context: context.clone(),
cancelled: false, cancelled: false,
cumulative_output_tokens: 0, cumulative_output_tokens: 0,
finished: false, // Initialize as not finished
}; };
// transform the common response stream into a chat response stream // transform the common response stream into a chat response stream
let stream = stream::unfold(state, |mut inner| { let stream = stream::unfold(state, |mut inner| {
async move { async move {
// If already finished, return None immediately
if inner.finished {
return None;
}
if let Some(response) = inner.response_stream.next().await { if let Some(response) = inner.response_stream.next().await {
if inner.cancelled { if inner.cancelled {
tracing::debug!( tracing::debug!(
request_id = inner.context.id(), request_id = inner.context.id(),
"Cancellation issued last message; closing stream" "Cancellation issued last message; closing stream"
); );
inner.finished = true; // Mark as finished
return None; return None;
} }
...@@ -502,7 +563,7 @@ impl OpenAIPreprocessor { ...@@ -502,7 +563,7 @@ impl OpenAIPreprocessor {
} else { } else {
// stream closed with out graceful closure // stream closed with out graceful closure
// we did not detect an is_finished/completed message // we did not detect an is_finished/completed message
// Ok(None) inner.finished = true; // Mark as finished
None None
} }
} }
...@@ -550,6 +611,255 @@ impl OpenAIPreprocessor { ...@@ -550,6 +611,255 @@ impl OpenAIPreprocessor {
ResponseStream::new(Box::pin(transformed_stream), context) ResponseStream::new(Box::pin(transformed_stream), context)
} }
/// Apply tool calling jail to the stream using the preprocessor's tool call parser
pub fn apply_tool_calling_jail_with_parser(
&self,
stream: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
) -> ManyOut<Annotated<NvCreateChatCompletionStreamResponse>> {
apply_tool_calling_jail_internal(stream, self.tool_call_parser.clone())
}
}
/// Apply tool calling jail to the stream - stops/jails the stream under certain conditions
/// When jailed, the stream will be unjailed when the input stream ends
pub fn apply_tool_calling_jail_internal(
stream: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
tool_call_parser: Option<String>,
) -> ManyOut<Annotated<NvCreateChatCompletionStreamResponse>> {
let context = stream.context();
let jail_state = JailState {
stream,
is_jailed: false,
tool_call_parser,
accumulated_content: HashMap::new(),
last_response_metadata: None,
finished: false,
};
// Transform the stream using unfold to maintain state
// Input: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>
// Returns None if the stream is finished
// Returns Some((Annotated<NvCreateChatCompletionStreamResponse>, JailState)) if the stream is not finished
// End output: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>
let jailed_stream = stream::unfold(jail_state, |mut state| async move {
// If already finished, return None immediately
if state.finished {
return None;
}
if let Some(response) = state.stream.next().await {
// Check if we should jail the stream
if !state.is_jailed {
// Handle the case where response.data is Option<T>
if let Some(ref chat_response) = response.data {
// Store metadata for potential tool call parsing later
state.last_response_metadata = Some(chat_response.clone());
// Extract text content from the response
if let Some(choice) = chat_response.choices.first()
&& let Some(ref content) = choice.delta.content
{
// Check for tool call start
match detect_tool_call_start(content, state.tool_call_parser.as_deref()) {
Ok(should_jail) => {
if should_jail {
tracing::debug!("Tool call detected, jailing stream");
state.is_jailed = true;
// Start accumulating content for this choice
state
.accumulated_content
.insert(choice.index, content.clone());
// Create possible tool call annotation with token information
let possible_annotation = PossibleToolCallAnnotation {
possible_tokens: 1, // This chunk contains tokens being processed
possible_content: content.clone(),
parser_used: state.tool_call_parser.clone(),
};
// Create annotated response instead of empty response
let mut annotated_response = response.clone();
if let Ok(possible_annotated) =
possible_annotation
.to_annotation::<NvCreateChatCompletionStreamResponse>()
{
// Set annotation event and comment
annotated_response.event = possible_annotated.event;
annotated_response.comment = possible_annotated.comment;
}
// Modify the response to have empty content but keep metadata
annotated_response =
annotated_response.map_data(|mut chat_response| {
// Clear the content but keep choice structure for ITL measurement
for choice in &mut chat_response.choices {
choice.delta.content = Some(String::new()); // Empty content
}
Ok(chat_response)
});
return Some((annotated_response, state));
}
}
Err(e) => {
tracing::warn!("Error detecting tool call start: {}", e);
}
}
}
}
} else if state.is_jailed {
// If already jailed, continue to jail but with annotations and accumulate content
if let Some(ref chat_response) = response.data {
// Extract content for annotation and accumulation
for choice in &chat_response.choices {
if let Some(ref content) = choice.delta.content
&& !content.is_empty()
{
// Accumulate content for this choice
state
.accumulated_content
.entry(choice.index)
.or_default()
.push_str(content);
// Create possible tool call annotation
let possible_annotation = PossibleToolCallAnnotation {
possible_tokens: 1,
possible_content: content.clone(),
parser_used: state.tool_call_parser.clone(),
};
// Create annotated response
let mut annotated_response = response.clone();
if let Ok(possible_annotated) = possible_annotation
.to_annotation::<NvCreateChatCompletionStreamResponse>(
) {
annotated_response.event = possible_annotated.event;
annotated_response.comment = possible_annotated.comment;
}
// Clear content but keep structure
annotated_response =
annotated_response.map_data(|mut chat_response| {
for choice in &mut chat_response.choices {
choice.delta.content = Some(String::new());
}
Ok(chat_response)
});
return Some((annotated_response, state));
}
}
}
}
// If not jailed or jailing condition not met, return the response as-is
Some((response, state))
} else {
// Stream ended - if we were jailed, we should unjail now and parse tool calls
if state.is_jailed {
tracing::debug!("Stream ended, unjailing and parsing accumulated content");
state.is_jailed = false;
// Parse accumulated content for tool calls
if !state.accumulated_content.is_empty()
&& let Some(base_response) = state.last_response_metadata.take()
{
// Try to parse tool calls from accumulated content for each choice
let mut final_response = base_response.clone();
for (choice_index, accumulated_text) in &state.accumulated_content {
if let Ok((tool_calls, normal_text)) = try_tool_call_parse_aggregate(
accumulated_text,
state.tool_call_parser.as_deref(),
) {
// Found tool calls, create a final response with them
tracing::debug!(
"Parsed {} tool calls from accumulated content",
tool_calls.len()
);
for tool_call in &tool_calls {
tracing::debug!(
tool_call_id = %tool_call.id,
function_name = %tool_call.function.name,
arguments = %tool_call.function.arguments,
"Parsed structured tool call from accumulated content in jail"
);
}
// Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallChunk for streaming
let tool_call_chunks: Vec<
dynamo_async_openai::types::ChatCompletionMessageToolCallChunk,
> = tool_calls
.into_iter()
.enumerate()
.map(|(idx, tool_call)| {
dynamo_async_openai::types::ChatCompletionMessageToolCallChunk {
index: idx as u32,
id: Some(tool_call.id),
r#type: Some(tool_call.r#type),
function: Some(
dynamo_async_openai::types::FunctionCallStream {
name: Some(tool_call.function.name),
arguments: Some(tool_call.function.arguments),
},
),
}
})
.collect();
// Create a choice with tool calls
#[allow(deprecated)]
let final_choice = dynamo_async_openai::types::ChatChoiceStream {
index: *choice_index,
delta:
dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
role: Some(dynamo_async_openai::types::Role::Assistant),
content: normal_text.filter(|t| !t.is_empty()),
tool_calls: Some(tool_call_chunks.clone()),
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: Some(
dynamo_async_openai::types::FinishReason::ToolCalls,
),
logprobs: None,
};
// Update the response choices
final_response.choices = vec![final_choice];
// Create final annotated response
let final_annotated = Annotated {
data: Some(final_response),
id: None,
event: None,
comment: None,
};
state.finished = true; // Mark as finished before returning
return Some((final_annotated, state));
}
}
}
}
state.finished = true; // Mark as finished
None
}
});
// Jailed Stream contains empty content chunks with annotation event "possible_tool_call" whenever the stream is jailed
// This is a bad UX for the user, as they have to see a lot of empty content chunks
// Filter out the empty content chunks with annotation event "possible_tool_call"
let filtered_stream = jailed_stream.filter(|annotated| {
let keep = annotated.event.as_deref() != Some(ANNOTATION_POSSIBLE_TOOL_CALL);
async move { keep }
});
ResponseStream::new(Box::pin(filtered_stream), context)
} }
// for pals, we do not want to add the generation prompt to the formatted prompt // for pals, we do not want to add the generation prompt to the formatted prompt
...@@ -601,8 +911,9 @@ impl ...@@ -601,8 +911,9 @@ impl
// transform the postprocessor stream // transform the postprocessor stream
let stream = Self::transform_postprocessor_stream(response_stream, response_generator); let stream = Self::transform_postprocessor_stream(response_stream, response_generator);
let context = stream.context();
let stream = self.apply_tool_calling_jail_with_parser(stream);
let context = stream.context();
// prepend the annotations to the response stream // prepend the annotations to the response stream
let stream = annotations_stream.chain(stream); let stream = annotations_stream.chain(stream);
......
...@@ -12,7 +12,6 @@ use crate::protocols::{ ...@@ -12,7 +12,6 @@ use crate::protocols::{
openai::ParsingOptions, openai::ParsingOptions,
}; };
use dynamo_parsers::tool_calling::try_tool_call_parse_aggregate;
use dynamo_runtime::engine::DataStream; use dynamo_runtime::engine::DataStream;
/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single /// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single
...@@ -38,6 +37,7 @@ pub struct DeltaAggregator { ...@@ -38,6 +37,7 @@ pub struct DeltaAggregator {
} }
/// Represents the accumulated state of a single chat choice during streaming aggregation. /// Represents the accumulated state of a single chat choice during streaming aggregation.
#[derive(Debug)]
struct DeltaChoice { struct DeltaChoice {
/// The index of the choice in the completion. /// The index of the choice in the completion.
index: u32, index: u32,
...@@ -63,6 +63,28 @@ impl Default for DeltaAggregator { ...@@ -63,6 +63,28 @@ impl Default for DeltaAggregator {
} }
} }
fn convert_tool_chunk_to_message_tool_call(
chunk: &dynamo_async_openai::types::ChatCompletionMessageToolCallChunk,
) -> Option<dynamo_async_openai::types::ChatCompletionMessageToolCall> {
// Convert ChatCompletionMessageToolCallChunk to ChatCompletionMessageToolCall
if let (Some(id), Some(r#type), Some(function)) = (&chunk.id, &chunk.r#type, &chunk.function) {
if let (Some(name), Some(arguments)) = (&function.name, &function.arguments) {
Some(dynamo_async_openai::types::ChatCompletionMessageToolCall {
id: id.clone(),
r#type: r#type.clone(),
function: dynamo_async_openai::types::FunctionCall {
name: name.clone(),
arguments: arguments.clone(),
},
})
} else {
None
}
} else {
None
}
}
impl DeltaAggregator { impl DeltaAggregator {
/// Creates a new, empty [`DeltaAggregator`] instance. /// Creates a new, empty [`DeltaAggregator`] instance.
pub fn new() -> Self { pub fn new() -> Self {
...@@ -89,7 +111,7 @@ impl DeltaAggregator { ...@@ -89,7 +111,7 @@ impl DeltaAggregator {
/// * `Err(String)` if an error occurs during processing. /// * `Err(String)` if an error occurs during processing.
pub async fn apply( pub async fn apply(
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>, stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
parsing_options: ParsingOptions, _parsing_options: ParsingOptions,
) -> Result<NvCreateChatCompletionResponse, String> { ) -> Result<NvCreateChatCompletionResponse, String> {
let aggregator = stream let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move { .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
...@@ -133,10 +155,9 @@ impl DeltaAggregator { ...@@ -133,10 +155,9 @@ impl DeltaAggregator {
tool_calls: None, tool_calls: None,
reasoning_content: None, reasoning_content: None,
}); });
// Append content if available. // Append content if available.
if let Some(content) = &choice.delta.content { if let Some(content) = &choice.delta.content {
state_choice.text.push_str(content); state_choice.text.push_str(content.trim_end());
} }
if let Some(reasoning_content) = &choice.delta.reasoning_content { if let Some(reasoning_content) = &choice.delta.reasoning_content {
...@@ -146,6 +167,27 @@ impl DeltaAggregator { ...@@ -146,6 +167,27 @@ impl DeltaAggregator {
.push_str(reasoning_content); .push_str(reasoning_content);
} }
// Since one tool call is one chunk, we don't need to aggregate them
// We just need to convert the ChatCompletionMessageToolCallChunk to ChatCompletionMessageToolCall and append to the state_choice.tool_calls
if let Some(tool_calls) = &choice.delta.tool_calls
&& !tool_calls.is_empty()
{
// Convert ChatCompletionMessageToolCallChunk to ChatCompletionMessageToolCall
let converted_tool_calls: Vec<
dynamo_async_openai::types::ChatCompletionMessageToolCall,
> = tool_calls
.iter()
.filter_map(convert_tool_chunk_to_message_tool_call)
.collect();
// Initialize and push the converted tool calls to state_choice.tool_calls
if let Some(existing_tool_calls) = &mut state_choice.tool_calls {
existing_tool_calls.extend(converted_tool_calls);
} else {
state_choice.tool_calls = Some(converted_tool_calls);
}
}
// Update finish reason if provided. // Update finish reason if provided.
if let Some(finish_reason) = choice.finish_reason { if let Some(finish_reason) = choice.finish_reason {
state_choice.finish_reason = Some(finish_reason); state_choice.finish_reason = Some(finish_reason);
...@@ -179,39 +221,8 @@ impl DeltaAggregator { ...@@ -179,39 +221,8 @@ impl DeltaAggregator {
.await; .await;
// Return early if an error was encountered. // Return early if an error was encountered.
let mut aggregator = if let Some(error) = aggregator.error { if let Some(error) = aggregator.error {
return Err(error); return Err(error);
} else {
aggregator
};
// After aggregation, inspect each choice's text for tool call syntax
for choice in aggregator.choices.values_mut() {
if choice.tool_calls.is_none()
&& let Ok((tool_calls, normal_text)) = try_tool_call_parse_aggregate(
&choice.text,
parsing_options.tool_call_parser.as_deref(),
)
{
if tool_calls.is_empty() {
continue;
}
for tool_call in &tool_calls {
tracing::debug!(
tool_call_id = %tool_call.id,
function_name = %tool_call.function.name,
arguments = %tool_call.function.arguments,
"Parsed structured tool call from aggregated content"
);
}
choice.tool_calls = Some(tool_calls);
choice.text.clear();
// If normal text is not empty, update the choice text
if let Some(normal_text) = normal_text.filter(|text| !text.is_empty()) {
choice.text = normal_text;
}
choice.finish_reason = Some(dynamo_async_openai::types::FinishReason::ToolCalls);
}
} }
// Extract aggregated choices and sort them by index. // Extract aggregated choices and sort them by index.
...@@ -328,12 +339,40 @@ mod tests { ...@@ -328,12 +339,40 @@ mod tests {
role: Option<dynamo_async_openai::types::Role>, role: Option<dynamo_async_openai::types::Role>,
finish_reason: Option<dynamo_async_openai::types::FinishReason>, finish_reason: Option<dynamo_async_openai::types::FinishReason>,
logprob: Option<f32>, logprob: Option<f32>,
tool_calls: Option<&str>,
) -> Annotated<NvCreateChatCompletionStreamResponse> { ) -> Annotated<NvCreateChatCompletionStreamResponse> {
// ALLOW: function_call is deprecated // ALLOW: function_call is deprecated
let tool_calls: Option<serde_json::Value> =
tool_calls.map(|tool_calls| serde_json::from_str(tool_calls).unwrap());
let tool_call_chunks = if let Some(tool_calls) = tool_calls {
vec![
dynamo_async_openai::types::ChatCompletionMessageToolCallChunk {
index: 0,
id: Some("test_id".to_string()),
r#type: Some(dynamo_async_openai::types::ChatCompletionToolType::Function),
function: Some(dynamo_async_openai::types::FunctionCallStream {
name: tool_calls["name"].as_str().map(|s| s.to_string()),
arguments: Some(serde_json::to_string(&tool_calls["arguments"]).unwrap()),
}),
},
]
} else {
vec![
dynamo_async_openai::types::ChatCompletionMessageToolCallChunk {
index: 0,
id: None,
r#type: None,
function: None,
},
]
};
let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta { let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
content: Some(text.to_string()), content: Some(text.to_string()),
function_call: None, function_call: None,
tool_calls: None, tool_calls: Some(tool_call_chunks),
role, role,
refusal: None, refusal: None,
reasoning_content: None, reasoning_content: None,
...@@ -407,6 +446,7 @@ mod tests { ...@@ -407,6 +446,7 @@ mod tests {
Some(dynamo_async_openai::types::Role::User), Some(dynamo_async_openai::types::Role::User),
None, None,
None, None,
None,
); );
// Create a stream // Create a stream
...@@ -445,6 +485,7 @@ mod tests { ...@@ -445,6 +485,7 @@ mod tests {
Some(dynamo_async_openai::types::Role::User), Some(dynamo_async_openai::types::Role::User),
None, None,
Some(-0.1), Some(-0.1),
None,
); );
let annotated_delta2 = create_test_delta( let annotated_delta2 = create_test_delta(
0, 0,
...@@ -452,6 +493,7 @@ mod tests { ...@@ -452,6 +493,7 @@ mod tests {
None, None,
Some(dynamo_async_openai::types::FinishReason::Stop), Some(dynamo_async_openai::types::FinishReason::Stop),
Some(-0.2), Some(-0.2),
None,
); );
// Create a stream // Create a stream
...@@ -591,71 +633,11 @@ mod tests { ...@@ -591,71 +633,11 @@ mod tests {
// Use create_test_delta to generate the annotated delta, then extract the inner delta for the test // Use create_test_delta to generate the annotated delta, then extract the inner delta for the test
let annotated_delta = create_test_delta( let annotated_delta = create_test_delta(
0, 0,
tool_call_json, "Hey Dude ! What's the weather in San Francisco in Fahrenheit?",
Some(dynamo_async_openai::types::Role::Assistant),
Some(dynamo_async_openai::types::FinishReason::ToolCalls),
None,
);
let data = annotated_delta.data.unwrap();
// Wrap it in Annotated and create a stream
let annotated_delta = Annotated {
data: Some(data),
id: Some("test_id".to_string()),
event: None,
comment: None,
};
let stream = Box::pin(stream::iter(vec![annotated_delta]));
// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result
assert!(result.is_ok());
let response = result.unwrap();
// There should be one choice
assert_eq!(response.choices.len(), 1);
let choice = &response.choices[0];
// The tool_calls field should be present and parsed
assert!(choice.message.tool_calls.is_some());
let tool_calls = choice.message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
let tool_call = &tool_calls[0];
assert_eq!(tool_call.function.name, "get_weather");
// The arguments should be a JSON string containing the expected keys
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments).unwrap();
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
// The content should be cleared (None) after tool call parsing
assert!(choice.message.content.is_none());
// The finish_reason should be ToolCalls
assert_eq!(
choice.finish_reason,
Some(dynamo_async_openai::types::FinishReason::ToolCalls)
);
assert_eq!(
choice.message.role,
dynamo_async_openai::types::Role::Assistant
);
}
#[tokio::test]
async fn test_tool_calling_output_with_normal_text() {
// Simulate a delta with a tool call in the content
let tool_call_json = r#"Hey, I'm a normal text! {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}"#;
// Use create_test_delta to generate the annotated delta, then extract the inner delta for the test
let annotated_delta = create_test_delta(
0,
tool_call_json,
Some(dynamo_async_openai::types::Role::Assistant), Some(dynamo_async_openai::types::Role::Assistant),
Some(dynamo_async_openai::types::FinishReason::ToolCalls), Some(dynamo_async_openai::types::FinishReason::ToolCalls),
None, None,
Some(tool_call_json),
); );
let data = annotated_delta.data.unwrap(); let data = annotated_delta.data.unwrap();
...@@ -691,11 +673,9 @@ mod tests { ...@@ -691,11 +673,9 @@ mod tests {
assert_eq!(args["location"], "San Francisco, CA"); assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
// The content should be the normal text
assert!(choice.message.content.is_some());
assert_eq!( assert_eq!(
choice.message.content.as_ref().unwrap(), choice.message.content.as_ref().unwrap(),
"Hey, I'm a normal text!" "Hey Dude ! What's the weather in San Francisco in Fahrenheit?"
); );
// The finish_reason should be ToolCalls // The finish_reason should be ToolCalls
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use async_trait::async_trait;
use dynamo_async_openai::types::{
ChatChoiceStream, ChatCompletionStreamResponseDelta, FinishReason as OAIFinishReason, Role,
};
use dynamo_llm::preprocessor::{
ANNOTATION_POSSIBLE_TOOL_CALL, PossibleToolCallAnnotation, apply_tool_calling_jail_internal,
};
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use dynamo_parsers::tool_calling::parsers::detect_tool_call_start;
use dynamo_runtime::pipeline::ResponseStream;
use dynamo_runtime::protocols::annotated::Annotated;
use futures::stream::{self, StreamExt};
use std::sync::Arc;
#[allow(deprecated)]
// Helper function to create a mock chat response chunk
fn create_mock_response_chunk(
content: String,
index: u32,
) -> Annotated<NvCreateChatCompletionStreamResponse> {
let choice = ChatChoiceStream {
index,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: Some(content),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: None,
};
let response = NvCreateChatCompletionStreamResponse {
id: "test-id".to_string(),
choices: vec![choice],
created: 1234567890,
model: "test-model".to_string(),
system_fingerprint: Some("test-fingerprint".to_string()),
object: "chat.completion.chunk".to_string(),
usage: None,
service_tier: None,
};
Annotated {
data: Some(response),
id: None,
event: None,
comment: None,
}
}
#[allow(deprecated)]
// Helper function to create a final response chunk with finish reason
fn create_final_response_chunk(index: u32) -> Annotated<NvCreateChatCompletionStreamResponse> {
let choice = ChatChoiceStream {
index,
delta: ChatCompletionStreamResponseDelta {
role: None,
content: None,
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: Some(OAIFinishReason::Stop),
logprobs: None,
};
let response = NvCreateChatCompletionStreamResponse {
id: "test-id".to_string(),
choices: vec![choice],
created: 1234567890,
model: "test-model".to_string(),
system_fingerprint: Some("test-fingerprint".to_string()),
object: "chat.completion.chunk".to_string(),
usage: None,
service_tier: None,
};
Annotated {
data: Some(response),
id: None,
event: None,
comment: None,
}
}
// Mock async engine context for testing
#[derive(Debug)]
struct MockAsyncEngineContext {
id: String,
stopped: std::sync::atomic::AtomicBool,
}
impl MockAsyncEngineContext {
fn new(id: String) -> Self {
Self {
id,
stopped: std::sync::atomic::AtomicBool::new(false),
}
}
}
#[async_trait]
impl dynamo_runtime::pipeline::AsyncEngineContext for MockAsyncEngineContext {
fn id(&self) -> &str {
&self.id
}
fn stop(&self) {
self.stopped
.store(true, std::sync::atomic::Ordering::Relaxed);
}
fn stop_generating(&self) {
self.stopped
.store(true, std::sync::atomic::Ordering::Relaxed);
}
fn kill(&self) {
self.stopped
.store(true, std::sync::atomic::Ordering::Relaxed);
}
fn is_stopped(&self) -> bool {
self.stopped.load(std::sync::atomic::Ordering::Relaxed)
}
fn is_killed(&self) -> bool {
self.stopped.load(std::sync::atomic::Ordering::Relaxed)
}
async fn stopped(&self) {
// No-op for testing
}
async fn killed(&self) {
// No-op for testing
}
fn link_child(&self, _: Arc<dyn dynamo_runtime::pipeline::AsyncEngineContext>) {
// No-op for testing
}
}
#[tokio::test]
async fn test_apply_tool_calling_jail_internal_with_tool_call_detection() {
// Create a stream with tool call content that SHOULD trigger jailing
let mock_context = Arc::new(MockAsyncEngineContext::new("test-request-id".to_string()));
// Create chunks that represent a tool call being generated
let chunks = vec![
create_mock_response_chunk("<TOOLCALL>".to_string(), 0),
create_mock_response_chunk("[{\"name\": \"get_weather\", ".to_string(), 0),
create_mock_response_chunk(
"\"arguments\": {\"location\": \"San Francisco\"}}]".to_string(),
0,
),
create_mock_response_chunk("</TOOLCALL>".to_string(), 0),
];
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
// Apply the jail with nemotron_deci parser - should trigger jailing on first chunk
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("nemotron_deci".to_string()));
// Collect all results
let results: Vec<_> = jailed_stream.collect().await;
// Verify that jailing was triggered
assert!(!results.is_empty(), "Should have some results");
// Results should be of length 1
// First Stream: [{"name": "get_weather", "arguments":"{"location": "San Francisco"}}]"
assert_eq!(results.len(), 1);
assert!(
results[0].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.is_some()
);
let tools = results[0].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tools.len(), 1);
let name = tools[0].function.as_ref().unwrap().name.as_ref().unwrap();
let arguments = serde_json::from_str::<serde_json::Value>(
tools[0]
.function
.as_ref()
.unwrap()
.arguments
.as_ref()
.unwrap(),
)
.unwrap();
assert_eq!(name, "get_weather");
assert_eq!(arguments["location"], "San Francisco");
}
#[tokio::test]
async fn test_apply_tool_calling_jail_internal_no_tool_calls() {
// Create a stream with regular content that should NOT trigger jailing
let mock_context = Arc::new(MockAsyncEngineContext::new("test-request-id-2".to_string()));
let chunks = vec![
create_mock_response_chunk("Hello, ".to_string(), 0),
create_mock_response_chunk("how can I ".to_string(), 0),
create_mock_response_chunk("help you today?".to_string(), 0),
create_final_response_chunk(0),
];
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
// Apply the jail with nemotron_deci parser - regular text should NOT be jailed
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("nemotron_deci".to_string()));
// Collect all results
let results: Vec<_> = jailed_stream.collect().await;
// Should have results and they should NOT be jailed (content should be preserved)
assert!(!results.is_empty(), "Should have results");
assert_eq!(results.len(), 4, "Should have all 4 chunks");
// Verify that content is NOT jailed - first few chunks should have their original content
for (i, result) in results.iter().take(3).enumerate() {
if let Some(ref response_data) = result.data {
let expected_content = match i {
0 => "Hello, ",
1 => "how can I ",
2 => "help you today?",
_ => unreachable!(),
};
assert_eq!(
response_data.choices[0].delta.content.as_deref(),
Some(expected_content),
"Chunk {} should have original content, not be jailed",
i
);
// Should NOT have annotation events for regular content
assert!(
result.event.is_none(),
"Regular content should not have annotation events"
);
}
}
// Last chunk should be the final response with finish reason
if let Some(last_result) = results.last()
&& let Some(ref response_data) = last_result.data
{
assert_eq!(
response_data.choices[0].finish_reason,
Some(OAIFinishReason::Stop)
);
}
}
#[tokio::test]
async fn test_apply_tool_calling_jail_internal_with_empty_stream() {
let mock_context = Arc::new(MockAsyncEngineContext::new("test-request-id-3".to_string()));
let chunks: Vec<Annotated<NvCreateChatCompletionStreamResponse>> = vec![];
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream = apply_tool_calling_jail_internal(response_stream, None);
let results: Vec<_> = jailed_stream.collect().await;
assert!(results.is_empty(), "Empty stream should produce no results");
}
#[tokio::test]
async fn test_apply_tool_calling_jail_internal_with_different_parsers() {
let mock_context = Arc::new(MockAsyncEngineContext::new("test-request-id-4".to_string()));
// Test with hermes parser format
let chunks = vec![
create_mock_response_chunk("<tool_call>".to_string(), 0),
create_mock_response_chunk(
"{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Tokyo\"}}".to_string(),
0,
),
create_mock_response_chunk("</tool_call>".to_string(), 0),
];
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("hermes".to_string()));
let results: Vec<_> = jailed_stream.collect().await;
assert!(!results.is_empty(), "Should have results for hermes parser");
}
#[tokio::test]
async fn test_detect_tool_call_start_different_parsers() {
// Test nemotron_deci parser
assert!(detect_tool_call_start("<TOOLCALL>", Some("nemotron_deci")).unwrap());
assert!(!detect_tool_call_start("Hello world", Some("nemotron_deci")).unwrap());
assert!(!detect_tool_call_start("<tool_call>", Some("nemotron_deci")).unwrap()); // Wrong format
// Test hermes parser - now also detects JSON patterns
assert!(detect_tool_call_start("<tool_call>", Some("hermes")).unwrap());
assert!(detect_tool_call_start("{\"name\": \"test\"}", Some("hermes")).unwrap()); // JSON detection
assert!(!detect_tool_call_start("Hello world", Some("hermes")).unwrap());
assert!(!detect_tool_call_start("<TOOLCALL>", Some("hermes")).unwrap()); // Wrong format
// Test phi4 parser
assert!(detect_tool_call_start("functools[", Some("phi4")).unwrap());
assert!(detect_tool_call_start("{\"name\": \"test\"}", Some("phi4")).unwrap()); // JSON detection
assert!(!detect_tool_call_start("Hello world", Some("phi4")).unwrap());
// Test mistral parser
assert!(detect_tool_call_start("[{", Some("mistral")).unwrap());
assert!(detect_tool_call_start("[TOOL_CALLS]", Some("mistral")).unwrap());
assert!(!detect_tool_call_start("Hello world", Some("mistral")).unwrap());
// Test llama3_json parser
assert!(detect_tool_call_start("<|python_tag|>", Some("llama3_json")).unwrap());
assert!(detect_tool_call_start("{\"name\": \"test\"}", Some("llama3_json")).unwrap()); // JSON detection
// Test default parser (should behave like nemotron_deci)
assert!(detect_tool_call_start("<TOOLCALL>", None).unwrap());
assert!(detect_tool_call_start("{\"name\": \"test\"}", None).unwrap()); // JSON detection
assert!(!detect_tool_call_start("Hello world", None).unwrap());
}
#[tokio::test]
async fn test_apply_tool_calling_jail_internal_hermes_parser() {
// Test with hermes parser format
let mock_context = Arc::new(MockAsyncEngineContext::new(
"test-request-id-hermes".to_string(),
));
let chunks = vec![
create_mock_response_chunk("I'll help you with that. ".to_string(), 0),
create_mock_response_chunk("<tool_call>".to_string(), 0), // This should trigger jailing
create_mock_response_chunk(
"{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Tokyo\"}}".to_string(),
0,
),
create_mock_response_chunk("</tool_call>".to_string(), 0),
];
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("hermes".to_string()));
let results: Vec<_> = jailed_stream.collect().await;
assert!(!results.is_empty(), "Should have results for hermes parser");
// Results should be of length 2
// First Stream : I'll help you with that.
// Second Stream : [{"name": "get_weather", "arguments":"{"location": "Tokyo"}}]" (jailed)
assert_eq!(results.len(), 2);
assert_eq!(
results[0].data.as_ref().unwrap().choices[0].delta.content,
Some("I'll help you with that. ".to_string())
);
assert!(
results[1].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.is_some()
);
let tools = results[1].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tools.len(), 1);
let name = tools[0].function.as_ref().unwrap().name.as_ref().unwrap();
let arguments = serde_json::from_str::<serde_json::Value>(
tools[0]
.function
.as_ref()
.unwrap()
.arguments
.as_ref()
.unwrap(),
)
.unwrap();
assert_eq!(name, "get_weather");
assert_eq!(arguments["location"], "Tokyo");
}
#[tokio::test]
async fn test_possible_tool_call_annotation_serialization() {
let annotation = PossibleToolCallAnnotation {
possible_tokens: 5,
possible_content: "test content".to_string(),
parser_used: Some("nemotron_deci".to_string()),
};
let annotated_result = annotation.to_annotation::<NvCreateChatCompletionStreamResponse>();
assert!(
annotated_result.is_ok(),
"Should be able to create annotation"
);
let annotated = annotated_result.unwrap();
assert_eq!(
annotated.event,
Some(ANNOTATION_POSSIBLE_TOOL_CALL.to_string())
);
assert!(annotated.comment.is_some(), "Should have comment");
// Test deserialization
let parsed_annotation = PossibleToolCallAnnotation::from_annotation(&annotated);
assert!(
parsed_annotation.is_ok(),
"Should be able to parse annotation"
);
let parsed = parsed_annotation.unwrap();
assert!(parsed.is_some(), "Should have parsed annotation");
let parsed = parsed.unwrap();
assert_eq!(parsed.possible_tokens, 5);
assert_eq!(parsed.possible_content, "test content");
assert_eq!(parsed.parser_used, Some("nemotron_deci".to_string()));
}
#[tokio::test]
async fn test_apply_tool_calling_jail_internal_mistral_parser_with_no_tool_call_start_token() {
let mock_context = Arc::new(MockAsyncEngineContext::new(
"test-request-id-mistral".to_string(),
));
let chunks = vec![
create_mock_response_chunk("Hey How".to_string(), 0),
create_mock_response_chunk("are you? ".to_string(), 0),
create_mock_response_chunk(r#"[{"name": "get_weather", "arguments":"#.to_string(), 0),
create_mock_response_chunk(
r#"{"location": "San Francisco", "unit": "fahrenheit"}}]"#.to_string(),
0,
),
create_final_response_chunk(0),
];
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("mistral".to_string()));
let results: Vec<_> = jailed_stream.collect().await;
assert!(
!results.is_empty(),
"Should have results for mistral parser"
);
// Results should be of length 4
// First Stream : Hey How
// Second Stream : are you?
// Third Stream : None (final response chunk)
// Fourth Stream : [{"name": "get_weather", "arguments":"{"location": "San Francisco", "unit": "fahrenheit"}}]" (jailed)
assert_eq!(results.len(), 4);
// First two normal text
assert_eq!(
results[0].data.as_ref().unwrap().choices[0].delta.content,
Some("Hey How".to_string())
);
assert_eq!(
results[1].data.as_ref().unwrap().choices[0].delta.content,
Some("are you? ".to_string())
);
assert_eq!(
results[2].data.as_ref().unwrap().choices[0].delta.content,
None
);
// Final tool call
assert!(
results[3].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.is_some()
);
let tools = results[3].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tools.len(), 1);
let name = tools[0].function.as_ref().unwrap().name.as_ref().unwrap();
let arguments = serde_json::from_str::<serde_json::Value>(
tools[0]
.function
.as_ref()
.unwrap()
.arguments
.as_ref()
.unwrap(),
)
.unwrap();
assert_eq!(name, "get_weather");
assert_eq!(arguments["location"], "San Francisco");
assert_eq!(arguments["unit"], "fahrenheit");
}
#[tokio::test]
async fn test_apply_tool_calling_jail_internal_mistral_parser_with_false_positive_tool_start() {
let mock_context = Arc::new(MockAsyncEngineContext::new(
"test-request-id-mistral".to_string(),
));
let chunks = vec![
create_mock_response_chunk("Hey How".to_string(), 0),
create_mock_response_chunk("are { you? ".to_string(), 0),
create_final_response_chunk(0),
];
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("mistral".to_string()));
let results: Vec<_> = jailed_stream.collect().await;
assert!(
!results.is_empty(),
"Should have results for mistral parser"
);
// Results should be of length 3
// First Stream : Hey How
// Second Stream : None (final response chunk)
// Third Stream : are { you? (normal text field from tool-call-parse-aggregate)
assert_eq!(results.len(), 3);
assert_eq!(
results[0].data.as_ref().unwrap().choices[0].delta.content,
Some("Hey How".to_string())
);
assert_eq!(
results[1].data.as_ref().unwrap().choices[0].delta.content,
None
);
assert_eq!(
results[2].data.as_ref().unwrap().choices[0].delta.content,
Some("are { you?".to_string())
);
}
#[tokio::test]
async fn test_apply_tool_calling_jail_internal_mistral_parser_with_false_positive_tool_start_and_tool_call_token()
{
let mock_context = Arc::new(MockAsyncEngineContext::new(
"test-request-id-mistral".to_string(),
));
let chunks = vec![
create_mock_response_chunk("Hey How".to_string(), 0),
create_mock_response_chunk("are { you? ".to_string(), 0),
create_mock_response_chunk(
r#"[TOOL_CALLS][{"name": "get_weather", "arguments":"#.to_string(),
0,
),
create_mock_response_chunk(
r#"{"location": "San Francisco", "unit": "fahrenheit"}}]"#.to_string(),
0,
),
create_final_response_chunk(0),
];
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("mistral".to_string()));
let results: Vec<_> = jailed_stream.collect().await;
assert!(
!results.is_empty(),
"Should have results for mistral parser"
);
// Results should be of length 3
// First Stream : Hey How
// Second Stream : None (final response chunk)
// Third Stream : Content: are { you? , Tool Calls: [{"name": "get_weather", "arguments":"{"location": "San Francisco", "unit": "fahrenheit"}}]"
assert_eq!(results.len(), 3);
assert_eq!(
results[0].data.as_ref().unwrap().choices[0].delta.content,
Some("Hey How".to_string())
);
assert_eq!(
results[1].data.as_ref().unwrap().choices[0].delta.content,
None
);
assert_eq!(
results[2].data.as_ref().unwrap().choices[0].delta.content,
Some("are { you?".to_string())
);
assert!(
results[2].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.is_some()
);
let tools = results[2].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tools.len(), 1);
let name = tools[0].function.as_ref().unwrap().name.as_ref().unwrap();
let arguments = serde_json::from_str::<serde_json::Value>(
tools[0]
.function
.as_ref()
.unwrap()
.arguments
.as_ref()
.unwrap(),
)
.unwrap();
assert_eq!(name, "get_weather");
assert_eq!(arguments["location"], "San Francisco");
assert_eq!(arguments["unit"], "fahrenheit");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment