Unverified Commit 37322011 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: set default tool calling behaviour to be disabled (#3096)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent d05795f2
...@@ -15,6 +15,7 @@ pub mod prompt; ...@@ -15,6 +15,7 @@ pub mod prompt;
pub mod tools; pub mod tools;
use anyhow::Result; use anyhow::Result;
use dynamo_async_openai::types::ChatCompletionToolChoiceOption;
use dynamo_async_openai::types::EncodingFormat; use dynamo_async_openai::types::EncodingFormat;
use futures::stream::{self, StreamExt}; use futures::stream::{self, StreamExt};
use prompt::OAIPromptFormatter; use prompt::OAIPromptFormatter;
...@@ -77,6 +78,20 @@ pub struct JailState { ...@@ -77,6 +78,20 @@ pub struct JailState {
finished: bool, // Add this flag to track if stream is finished finished: bool, // Add this flag to track if stream is finished
} }
pub fn maybe_enable_tool_call(
parser_str: Option<&str>,
request: &NvCreateChatCompletionRequest,
) -> bool {
// Enable tool call if the below two conditions are satisfied
// 1. parser_str is not None
// 2. tool_choice is not None
parser_str.is_some()
&& !matches!(
request.inner.tool_choice,
Some(ChatCompletionToolChoiceOption::None)
)
}
impl LLMMetricAnnotation { impl LLMMetricAnnotation {
/// Convert this metrics struct to an Annotated event /// Convert this metrics struct to an Annotated event
pub fn to_annotation<T>(&self) -> Result<Annotated<T>, serde_json::Error> { pub fn to_annotation<T>(&self) -> Result<Annotated<T>, serde_json::Error> {
...@@ -933,6 +948,8 @@ impl ...@@ -933,6 +948,8 @@ impl
let response_generator = request.response_generator(context.id().to_string()); let response_generator = request.response_generator(context.id().to_string());
let mut response_generator = Box::new(response_generator); let mut response_generator = Box::new(response_generator);
let enable_tool_calling =
maybe_enable_tool_call(self.tool_call_parser.as_deref(), &request);
// convert the chat completion request to a common completion request // convert the chat completion request to a common completion request
let (common_request, annotations) = self.preprocess_request(&request)?; let (common_request, annotations) = self.preprocess_request(&request)?;
...@@ -955,7 +972,13 @@ impl ...@@ -955,7 +972,13 @@ impl
// transform the postprocessor stream // transform the postprocessor stream
let stream = Self::transform_postprocessor_stream(response_stream, response_generator); let stream = Self::transform_postprocessor_stream(response_stream, response_generator);
let stream = self.apply_tool_calling_jail_with_parser(stream).await; // Apply tool calling jail to the stream if tool call parser is present
let stream = if enable_tool_calling {
self.apply_tool_calling_jail_with_parser(stream).await
} else {
stream
};
let context = stream.context(); let context = stream.context();
// prepend the annotations to the response stream // prepend the annotations to the response stream
let stream = annotations_stream.chain(stream); let stream = annotations_stream.chain(stream);
......
...@@ -2,13 +2,18 @@ ...@@ -2,13 +2,18 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use async_trait::async_trait; use async_trait::async_trait;
use dynamo_async_openai::types::ChatCompletionToolChoiceOption;
use dynamo_async_openai::types::CreateChatCompletionRequest;
use dynamo_async_openai::types::{ use dynamo_async_openai::types::{
ChatChoiceStream, ChatCompletionStreamResponseDelta, FinishReason as OAIFinishReason, Role, ChatChoiceStream, ChatCompletionStreamResponseDelta, FinishReason as OAIFinishReason, Role,
}; };
use dynamo_llm::preprocessor::{ use dynamo_llm::preprocessor::{
ANNOTATION_POSSIBLE_TOOL_CALL, PossibleToolCallAnnotation, apply_tool_calling_jail_internal, ANNOTATION_POSSIBLE_TOOL_CALL, PossibleToolCallAnnotation, apply_tool_calling_jail_internal,
maybe_enable_tool_call,
};
use dynamo_llm::protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
}; };
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use dynamo_parsers::tool_calling::parsers::detect_tool_call_start; use dynamo_parsers::tool_calling::parsers::detect_tool_call_start;
use dynamo_runtime::pipeline::ResponseStream; use dynamo_runtime::pipeline::ResponseStream;
use dynamo_runtime::protocols::annotated::Annotated; use dynamo_runtime::protocols::annotated::Annotated;
...@@ -709,3 +714,50 @@ async fn test_tool_calling_jail_internal_with_harmony_parser() { ...@@ -709,3 +714,50 @@ async fn test_tool_calling_jail_internal_with_harmony_parser() {
assert_eq!(name, "get_current_weather"); assert_eq!(name, "get_current_weather");
assert_eq!(arguments["location"], "San Francisco"); assert_eq!(arguments["location"], "San Francisco");
} }
#[test]
fn test_enable_tool_call() {
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
tool_choice: Some(ChatCompletionToolChoiceOption::Auto),
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
};
assert!(maybe_enable_tool_call(Some("nemotron_deci"), &request));
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
tool_choice: Some(ChatCompletionToolChoiceOption::None),
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
};
assert!(!maybe_enable_tool_call(Some("nemotron_deci"), &request));
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
tool_choice: Some(ChatCompletionToolChoiceOption::Required),
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
};
assert!(maybe_enable_tool_call(Some("nemotron_deci"), &request));
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
tool_choice: Some(ChatCompletionToolChoiceOption::Auto),
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
};
assert!(!maybe_enable_tool_call(None, &request));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment