Unverified Commit 37322011 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: set default tool calling behaviour to be disabled (#3096)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent d05795f2
......@@ -15,6 +15,7 @@ pub mod prompt;
pub mod tools;
use anyhow::Result;
use dynamo_async_openai::types::ChatCompletionToolChoiceOption;
use dynamo_async_openai::types::EncodingFormat;
use futures::stream::{self, StreamExt};
use prompt::OAIPromptFormatter;
......@@ -77,6 +78,20 @@ pub struct JailState {
finished: bool, // Add this flag to track if stream is finished
}
pub fn maybe_enable_tool_call(
parser_str: Option<&str>,
request: &NvCreateChatCompletionRequest,
) -> bool {
// Enable tool call if the below two conditions are satisfied
// 1. parser_str is not None
// 2. tool_choice is not None
parser_str.is_some()
&& !matches!(
request.inner.tool_choice,
Some(ChatCompletionToolChoiceOption::None)
)
}
impl LLMMetricAnnotation {
/// Convert this metrics struct to an Annotated event
pub fn to_annotation<T>(&self) -> Result<Annotated<T>, serde_json::Error> {
......@@ -933,6 +948,8 @@ impl
let response_generator = request.response_generator(context.id().to_string());
let mut response_generator = Box::new(response_generator);
let enable_tool_calling =
maybe_enable_tool_call(self.tool_call_parser.as_deref(), &request);
// convert the chat completion request to a common completion request
let (common_request, annotations) = self.preprocess_request(&request)?;
......@@ -955,7 +972,13 @@ impl
// transform the postprocessor stream
let stream = Self::transform_postprocessor_stream(response_stream, response_generator);
let stream = self.apply_tool_calling_jail_with_parser(stream).await;
// Apply tool calling jail to the stream if tool call parser is present
let stream = if enable_tool_calling {
self.apply_tool_calling_jail_with_parser(stream).await
} else {
stream
};
let context = stream.context();
// prepend the annotations to the response stream
let stream = annotations_stream.chain(stream);
......
......@@ -2,13 +2,18 @@
// SPDX-License-Identifier: Apache-2.0
use async_trait::async_trait;
use dynamo_async_openai::types::ChatCompletionToolChoiceOption;
use dynamo_async_openai::types::CreateChatCompletionRequest;
use dynamo_async_openai::types::{
ChatChoiceStream, ChatCompletionStreamResponseDelta, FinishReason as OAIFinishReason, Role,
};
use dynamo_llm::preprocessor::{
ANNOTATION_POSSIBLE_TOOL_CALL, PossibleToolCallAnnotation, apply_tool_calling_jail_internal,
maybe_enable_tool_call,
};
use dynamo_llm::protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
};
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use dynamo_parsers::tool_calling::parsers::detect_tool_call_start;
use dynamo_runtime::pipeline::ResponseStream;
use dynamo_runtime::protocols::annotated::Annotated;
......@@ -709,3 +714,50 @@ async fn test_tool_calling_jail_internal_with_harmony_parser() {
assert_eq!(name, "get_current_weather");
assert_eq!(arguments["location"], "San Francisco");
}
#[test]
fn test_enable_tool_call() {
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
tool_choice: Some(ChatCompletionToolChoiceOption::Auto),
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
};
assert!(maybe_enable_tool_call(Some("nemotron_deci"), &request));
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
tool_choice: Some(ChatCompletionToolChoiceOption::None),
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
};
assert!(!maybe_enable_tool_call(Some("nemotron_deci"), &request));
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
tool_choice: Some(ChatCompletionToolChoiceOption::Required),
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
};
assert!(maybe_enable_tool_call(Some("nemotron_deci"), &request));
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
tool_choice: Some(ChatCompletionToolChoiceOption::Auto),
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
};
assert!(!maybe_enable_tool_call(None, &request));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment