Unverified Commit 6675bfc8 authored by zhongdaor-nv's avatar zhongdaor-nv Committed by GitHub
Browse files

feat: enhance GPT OSS frontend with improved harmony tool calling parser and...


feat: enhance GPT OSS frontend with improved harmony tool calling parser and reasoning parser (#2999)
Signed-off-by: default avatarzhongdaor <zhongdaor@nvidia.com>
parent 20b7a8ae
......@@ -81,4 +81,4 @@ opt-level = 3
[profile.release]
# These make the build much slower but shrink the binary, and could help performance
codegen-units = 1
lto = true
lto = true
\ No newline at end of file
......@@ -263,7 +263,14 @@ impl ModelWatcher {
let client = component.endpoint(&endpoint_id.name).client().await?;
let model_slug = model_entry.slug();
let card = match ModelDeploymentCard::load_from_store(&model_slug, &self.drt).await {
Ok(Some(card)) => card,
Ok(Some(mut card)) => {
tracing::debug!(card.display_name, "adding model");
// Ensure runtime_config is populated
if let Some(rc) = model_entry.runtime_config.clone() {
card.runtime_config = rc;
}
card
}
Ok(None) => {
anyhow::bail!("Missing ModelDeploymentCard in storage under key {model_slug}");
}
......
......@@ -160,6 +160,8 @@ pub struct OpenAIPreprocessor {
formatter: Arc<dyn OAIPromptFormatter>,
tokenizer: Arc<dyn Tokenizer>,
model_info: Arc<dyn ModelInfo>,
/// Per-model runtime configuration propagated to response generator (e.g., reasoning/tool parser)
runtime_config: crate::local_model::runtime_config::ModelRuntimeConfig,
tool_call_parser: Option<String>,
}
......@@ -187,11 +189,15 @@ impl OpenAIPreprocessor {
let model_info = model_info.get_model_info()?;
let tool_call_parser = mdc.runtime_config.tool_call_parser.clone();
// // Initialize runtime config from the ModelDeploymentCard
let runtime_config = mdc.runtime_config.clone();
Ok(Arc::new(Self {
formatter,
tokenizer,
model_info,
mdcsum,
runtime_config,
tool_call_parser,
}))
}
......@@ -948,6 +954,8 @@ impl
let response_generator = request.response_generator(context.id().to_string());
let mut response_generator = Box::new(response_generator);
// set the runtime configuration
response_generator.set_reasoning_parser(self.runtime_config.clone());
let enable_tool_calling =
maybe_enable_tool_call(self.tool_call_parser.as_deref(), &request);
// convert the chat completion request to a common completion request
......
......@@ -125,6 +125,20 @@ impl DeltaGenerator {
}
}
/// Update runtime configuration and reconfigure the reasoning parser accordingly.
pub fn set_reasoning_parser(&mut self, runtime_config: ModelRuntimeConfig) {
self.options.runtime_config = runtime_config.clone();
match self.options.runtime_config.reasoning_parser.as_deref() {
Some(name) => {
self.reasoning_parser =
Some(ReasoningParserType::get_reasoning_parser_from_name(name));
}
None => {
self.reasoning_parser = None;
}
}
}
/// Updates the prompt token usage count.
///
/// # Arguments
......
......@@ -150,7 +150,7 @@ impl ReasoningParser for GptOssReasoningParser {
fn parse_reasoning_streaming_incremental(
&mut self,
_text: &str,
text: &str,
token_ids: &[u32],
) -> ParserResult {
tracing::debug!(
......@@ -173,9 +173,8 @@ impl ReasoningParser for GptOssReasoningParser {
}
if let Some(channel) = self.parser.current_channel() {
tracing::debug!("Current channel: {}", channel);
tracing::debug!("Current channel {}", channel);
if channel == "final" {
tracing::debug!("In final channel, processing normal text");
// If we're in the final channel, we should not parse reasoning
if let Some(current) = self.parser.last_content_delta().unwrap_or_default() {
tracing::debug!("Got normal text delta of {} chars", current.len());
......@@ -186,6 +185,64 @@ impl ReasoningParser for GptOssReasoningParser {
}
tracing::debug!("No content delta in final channel");
ParserResult::default()
} else if channel == "commentary" {
// If we're in the commentary channel, we should return raw token content and recover content that has been consumed by the parser
// so that the tool parser can process it properly
if let Ok(enc) = get_harmony_encoding() {
let current_content = self.parser.current_content().unwrap_or_default();
let mut final_text = text.to_string();
// Restore commentary metadata consumed by the parser so the tool-call parser can
// process it correctly.
//
// Example:
// Before parsing:
// "<|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{\"format\":\"celsius\",\"location\":\"San Francisco\"}<|call|>"
// After parsing, the header is stripped, so we must reconstruct it:
// "<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>"
//
// This ensures downstream tool-call parsing receives the channel, target, and
// constraint metadata together with the message payload.
// Recovery should only happen once, and only when `current_content` is empty.
if current_content.is_empty() {
let tokens = self.parser.tokens();
// Get the token id for " <|channel|>"
let channel_token_id = enc
.tokenizer()
.encode_with_special_tokens("<|channel|>")
.last()
.copied();
// Find the last occurrence of the <|channel|> token (id 20005) in the tokens vector
let last_channel_token_idx = channel_token_id
.and_then(|token_id| {
tokens.iter().rposition(|token| *token == token_id)
})
.unwrap_or(0);
// Then get the generated text from the last <|channel|> to the end of self.parser.tokens()
let end_token_idx = self.parser.tokens().len();
// Use Harmony's decode_utf8 to decode tokens into text
let generated_text = enc
.tokenizer()
.decode_utf8(
&self.parser.tokens()[last_channel_token_idx..end_token_idx],
)
.unwrap_or_default();
final_text = generated_text;
}
ParserResult {
normal_text: final_text,
reasoning_text: String::new(),
}
} else {
tracing::warn!("Failed to get harmony encoding for raw token decoding");
ParserResult::default()
}
} else {
tracing::debug!("In reasoning channel: {}", channel);
if let Some(current) = self.parser.last_content_delta().unwrap_or_default() {
......
......@@ -3,9 +3,10 @@
use super::config::JsonParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
use openai_harmony::StreamableParser;
use openai_harmony::chat::{Content::Text, Role};
use openai_harmony::{HarmonyEncoding, HarmonyEncodingName, load_harmony_encoding};
use openai_harmony::{
HarmonyEncoding, HarmonyEncodingName, StreamableParser, load_harmony_encoding,
};
use serde_json::Value;
static GLOBAL_HARMONY_GPTOSS_ENCODING: tokio::sync::OnceCell<
......@@ -162,6 +163,109 @@ pub async fn parse_tool_calls_harmony(
Ok((res, Some(normal_text.to_string())))
}
/// Parse tool calls from a complete Harmony Format text chunk using direct token parsing.
///
/// This function is optimized for parsing complete text chunks where the entire content
/// is available at once. It uses `parse_messages_from_completion_tokens` to directly
/// parse all tokens into Harmony Format messages, then extracts tool calls from messages
/// with the "commentary" channel and "functions.*" recipients.
///
/// Unlike `parse_tool_calls_harmony`, this function doesn't perform start token detection
/// or token-by-token streaming, making it more efficient for complete chunks.
///
/// # Arguments
/// * `text` - The full Harmony-format string to be parsed, excluding any trailing stop tokens.
/// Example:
/// `<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco"}`
/// * `_config` - Parser configuration (currently unused but kept for API consistency)
///
/// # Returns
/// * `Ok((tool_calls, normal_text))` - Tuple containing extracted tool calls and any normal text
/// * `Err(e)` - If parsing fails due to encoding or tokenization errors
pub async fn parse_tool_calls_harmony_complete(
text: &str,
_config: &JsonParserConfig,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
let enc = match get_harmony_encoding().await.as_ref() {
Ok(e) => e,
Err(e) => {
tracing::debug!("Failed to load harmony encoding: {e}. Tool calls will not be parsed.");
return Ok((vec![], Some(text.to_string())));
}
};
// // Encode the text into tokens using harmony encoding
let tokens: Vec<u32> = enc.tokenizer().encode_with_special_tokens(text);
let messages = match enc.parse_messages_from_completion_tokens(tokens, Some(Role::Assistant)) {
Ok(messages) => messages,
Err(e) => {
tracing::debug!(
"Failed to parse messages from completion tokens: {e}. Tool calls will not be parsed."
);
return Ok((vec![], Some(text.to_string())));
}
};
let mut normal_text = String::new();
let mut res = Vec::with_capacity(messages.len());
let mut call_idx = 0; // Index of the tool call
for message in messages.iter() {
if message.author.role != Role::Assistant {
continue;
}
let channel = message.channel.as_deref();
let recipient = message.recipient.as_deref().unwrap_or_default();
// Handle commentary channel
if channel == Some("commentary") && recipient.starts_with("functions.") {
let Some(fname) = message
.recipient
.as_ref()
.and_then(|r| r.split('.').nth(1))
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
else {
continue;
};
let args = match message.content.first() {
Some(Text(text)) => match serde_json::from_str::<Value>(text.text.trim()) {
Ok(value) => value,
Err(_) => {
Value::Null // Set args to null if it's not valid JSON
}
},
_ => {
Value::Null // Set args to null if it's not a text content
}
};
// Add tool call to result if args is valid JSON
if !args.is_null() {
call_idx += 1;
res.push(ToolCallResponse {
id: format!("call-{}", call_idx),
tp: ToolCallType::Function,
function: CalledFunction {
name: fname.to_string(),
// Safety: `Value::Object` is always valid JSON, so serialization cannot fail
arguments: serde_json::to_string(&args).unwrap(),
},
});
}
// Handle reasoning(analysis) channel
} else if channel == Some("analysis") {
normal_text.push_str(match &message.content[0] {
Text(t) => &t.text,
_ => "",
});
}
}
Ok((res, Some(normal_text.to_string())))
}
pub fn detect_tool_call_start_harmony(
chunk: &str,
config: &JsonParserConfig,
......@@ -266,6 +370,20 @@ mod tests {
assert_eq!(args["location"], "San Francisco");
}
#[tokio::test]
async fn test_parse_tool_calls_harmony_complete_basic() {
let text = r#"<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"format":"celsius","location":"San Francisco"}"#;
let (tool_calls, normal_content) =
parse_tool_calls_harmony_complete(text, &Default::default())
.await
.unwrap();
assert_eq!(normal_content, Some("".to_string()));
let (name, args) = extract_name_and_args(tool_calls[0].clone());
assert_eq!(name, "get_current_weather");
assert_eq!(args["location"], "San Francisco");
assert_eq!(args["format"], "celsius");
}
#[tokio::test]
async fn test_parse_tools_harmony_without_start_token() {
let text = r#"
......
......@@ -4,4 +4,6 @@
pub mod harmony_parser;
pub use super::{config, response};
pub use harmony_parser::{detect_tool_call_start_harmony, parse_tool_calls_harmony};
pub use harmony_parser::{
detect_tool_call_start_harmony, parse_tool_calls_harmony, parse_tool_calls_harmony_complete,
};
......@@ -11,7 +11,7 @@ pub mod tools;
// Re-export main types and functions for convenience
pub use config::{JsonParserConfig, ToolCallConfig, ToolCallParserType};
pub use harmony::parse_tool_calls_harmony;
pub use harmony::{parse_tool_calls_harmony, parse_tool_calls_harmony_complete};
pub use json::try_tool_call_parse_json;
pub use parsers::{detect_and_parse_tool_call, try_tool_call_parse};
pub use pythonic::try_tool_call_parse_pythonic;
......
......@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use super::config::{ToolCallConfig, ToolCallParserType};
use super::harmony::{detect_tool_call_start_harmony, parse_tool_calls_harmony};
use super::harmony::{detect_tool_call_start_harmony, parse_tool_calls_harmony_complete};
use super::json::{detect_tool_call_start_json, try_tool_call_parse_json};
use super::pythonic::{detect_tool_call_start_pythonic, try_tool_call_parse_pythonic};
use super::response::ToolCallResponse;
......@@ -43,7 +43,8 @@ pub async fn try_tool_call_parse(
Ok((results, normal_content))
}
ToolCallParserType::Harmony => {
let (results, normal_content) = parse_tool_calls_harmony(message, &config.json).await?;
let (results, normal_content) =
parse_tool_calls_harmony_complete(message, &config.json).await?;
Ok((results, normal_content))
}
ToolCallParserType::Pythonic => {
......@@ -1450,10 +1451,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
#[tokio::test]
async fn test_harmony_parser_basic() {
let input = r#"
<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|>
<|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json
<|message|>{"location":"San Francisco", "unit":"fahrenheit"}<|call|>
"#;
<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco", "unit":"fahrenheit"}"#;
let (result, content) = detect_and_parse_tool_call(input, Some("harmony"))
.await
.unwrap();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment