Unverified Commit 2a2bf58a authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: reasoning parser transformation (#3295)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 713e9e48
...@@ -159,12 +159,12 @@ impl ...@@ -159,12 +159,12 @@ impl
for c in prompt.chars() { for c in prompt.chars() {
// we are returning characters not tokens, so there will be some postprocessing overhead // we are returning characters not tokens, so there will be some postprocessing overhead
tokio::time::sleep(*TOKEN_ECHO_DELAY).await; tokio::time::sleep(*TOKEN_ECHO_DELAY).await;
let response = deltas.create_choice(0, Some(c.to_string()), None, None, None); let response = deltas.create_choice(0, Some(c.to_string()), None, None);
yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None }; yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
id += 1; id += 1;
} }
let response = deltas.create_choice(0, None, None, Some(dynamo_async_openai::types::FinishReason::Stop), None); let response = deltas.create_choice(0, None, Some(dynamo_async_openai::types::FinishReason::Stop), None);
yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None }; yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
}; };
......
...@@ -28,6 +28,7 @@ use crate::preprocessor::prompt::OAIChatLikeRequest; ...@@ -28,6 +28,7 @@ use crate::preprocessor::prompt::OAIChatLikeRequest;
use crate::protocols::common::preprocessor::PreprocessedRequestBuilder; use crate::protocols::common::preprocessor::PreprocessedRequestBuilder;
use crate::tokenizers::Encoding; use crate::tokenizers::Encoding;
use dynamo_parsers::{ReasoningParser, ReasoningParserType};
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream}; use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::{ use dynamo_runtime::pipeline::{
AsyncEngineContext, Error, ManyOut, Operator, SingleIn, async_trait, AsyncEngineContext, Error, ManyOut, Operator, SingleIn, async_trait,
...@@ -93,6 +94,12 @@ impl LLMMetricAnnotation { ...@@ -93,6 +94,12 @@ impl LLMMetricAnnotation {
} }
} }
// Reasoning State for reasoning parsing transformation step
struct ReasoningState {
stream: Pin<Box<dyn Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send>>,
reasoning_parser: Option<Box<dyn ReasoningParser>>,
}
pub struct OpenAIPreprocessor { pub struct OpenAIPreprocessor {
mdcsum: String, mdcsum: String,
formatter: Arc<dyn OAIPromptFormatter>, formatter: Arc<dyn OAIPromptFormatter>,
...@@ -668,6 +675,56 @@ impl OpenAIPreprocessor { ...@@ -668,6 +675,56 @@ impl OpenAIPreprocessor {
.build(); .build();
jail.apply(stream) jail.apply(stream)
} }
// Motivation: Each transformation on the stream should be a separate step to allow for more flexibility
// Earlier reasoning parser logic was nested under delta generation logic in choice_from_postprocessor
// Since we have tool calling parsing as separate step, it makes sense to have reasoning parser as separate step as well
pub fn parse_reasoning_content_from_stream<S>(
stream: S,
parser_name: String,
) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
where
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
{
// Initialize reasoning parser from parser_name
let reasoning_parser = Box::new(ReasoningParserType::get_reasoning_parser_from_name(
parser_name.as_ref(),
)) as Box<dyn ReasoningParser>;
let state = ReasoningState {
stream: Box::pin(stream),
reasoning_parser: Some(reasoning_parser),
};
stream::unfold(state, |mut state| async move {
if let Some(response) = state.stream.next().await {
// Process the response through reasoning parser if available
let processed_response = if let Some(ref mut parser) = state.reasoning_parser {
response.map_data(|mut data| {
// Process all choices, not just the first one
for choice in data.choices.iter_mut() {
if let Some(text) = choice.delta.content.as_ref() {
let parser_result =
parser.parse_reasoning_streaming_incremental(text, &[]);
// Update this specific choice with parsed content
choice.delta.content = parser_result.get_some_normal_text();
choice.delta.reasoning_content = parser_result.get_some_reasoning();
}
}
Ok(data)
})
} else {
// No reasoning parser configured, pass through unchanged
response
};
Some((processed_response, state))
} else {
None
}
})
}
} }
// for pals, we do not want to add the generation prompt to the formatted prompt // for pals, we do not want to add the generation prompt to the formatted prompt
...@@ -715,9 +772,6 @@ impl ...@@ -715,9 +772,6 @@ impl
let mut response_generator = Box::new(response_generator); let mut response_generator = Box::new(response_generator);
// set the runtime configuration
response_generator.set_reasoning_parser(self.runtime_config.clone());
// update isl // update isl
response_generator.update_isl(common_request.token_ids.len() as u32); response_generator.update_isl(common_request.token_ids.len() as u32);
...@@ -744,6 +798,25 @@ impl ...@@ -744,6 +798,25 @@ impl
context.clone(), context.clone(),
); );
// Try to parse reasoning content only if parser is configured
let should_parse_reasoning = self.runtime_config.reasoning_parser.is_some();
// Reasoning Content Parsing Transformation Step
// Current Solution:
// This step operates on Deltas created by the transform_postprocessor_stream function
// Only access to text and not token_ids - so can not support parsing based on token_ids for now
// Future Solution:
// To address the limitation if needed in future: move this step before transform_postprocessor_stream and add new field of reasoning_content to the backend output
// Use backend_output.reasoning_content field to fill out the deltas.
let stream: Pin<Box<dyn Stream<Item = _> + Send>> = if should_parse_reasoning {
Box::pin(Self::parse_reasoning_content_from_stream(
stream,
self.runtime_config.reasoning_parser.clone().unwrap(), // Safety: We already checked that parser is some, so gtg
))
} else {
Box::pin(stream)
};
// Check if tools are present and if we should apply jail // Check if tools are present and if we should apply jail
let has_tools = let has_tools =
request.inner.tools.is_some() && !request.inner.tools.as_ref().unwrap().is_empty(); request.inner.tools.is_some() && !request.inner.tools.as_ref().unwrap().is_empty();
......
...@@ -7,7 +7,6 @@ use crate::{ ...@@ -7,7 +7,6 @@ use crate::{
protocols::common::{self}, protocols::common::{self},
types::TokenIdType, types::TokenIdType,
}; };
use dynamo_parsers::{ParserResult, ReasoningParser, ReasoningParserType, ReasoningParserWrapper};
/// Provides a method for generating a [`DeltaGenerator`] from a chat completion request. /// Provides a method for generating a [`DeltaGenerator`] from a chat completion request.
impl NvCreateChatCompletionRequest { impl NvCreateChatCompletionRequest {
...@@ -66,11 +65,6 @@ pub struct DeltaGenerator { ...@@ -66,11 +65,6 @@ pub struct DeltaGenerator {
msg_counter: u64, msg_counter: u64,
/// Configuration options for response generation. /// Configuration options for response generation.
options: DeltaGeneratorOptions, options: DeltaGeneratorOptions,
/// Reasoning Parser object
/// This is used to parse reasoning content in the response.
/// None means no reasoning parsing will be performed.
reasoning_parser: Option<ReasoningParserWrapper>,
} }
impl DeltaGenerator { impl DeltaGenerator {
...@@ -101,14 +95,6 @@ impl DeltaGenerator { ...@@ -101,14 +95,6 @@ impl DeltaGenerator {
completion_tokens_details: None, completion_tokens_details: None,
}; };
// Reasoning parser type
// If no parser is specified (None), no reasoning parsing will be performed
let reasoning_parser = options
.runtime_config
.reasoning_parser
.as_deref()
.map(ReasoningParserType::get_reasoning_parser_from_name);
let chatcmpl_id = format!("chatcmpl-{request_id}"); let chatcmpl_id = format!("chatcmpl-{request_id}");
Self { Self {
...@@ -121,21 +107,6 @@ impl DeltaGenerator { ...@@ -121,21 +107,6 @@ impl DeltaGenerator {
usage, usage,
msg_counter: 0, msg_counter: 0,
options, options,
reasoning_parser,
}
}
/// Update runtime configuration and reconfigure the reasoning parser accordingly.
pub fn set_reasoning_parser(&mut self, runtime_config: ModelRuntimeConfig) {
self.options.runtime_config = runtime_config.clone();
match self.options.runtime_config.reasoning_parser.as_deref() {
Some(name) => {
self.reasoning_parser =
Some(ReasoningParserType::get_reasoning_parser_from_name(name));
}
None => {
self.reasoning_parser = None;
}
} }
} }
...@@ -212,24 +183,6 @@ impl DeltaGenerator { ...@@ -212,24 +183,6 @@ impl DeltaGenerator {
}) })
} }
fn create_reasoning_content(
&mut self,
text: &Option<String>,
token_ids: &[u32],
) -> Option<ParserResult> {
// If no reasoning parser is configured, return None
let reasoning_parser = self.reasoning_parser.as_mut()?;
let text_ref = text.as_deref().unwrap_or("");
if text_ref.is_empty() && token_ids.is_empty() {
return None;
}
let parser_result =
reasoning_parser.parse_reasoning_streaming_incremental(text_ref, token_ids);
Some(parser_result)
}
/// Creates a choice within a chat completion response. /// Creates a choice within a chat completion response.
/// ///
/// # Arguments /// # Arguments
...@@ -245,7 +198,6 @@ impl DeltaGenerator { ...@@ -245,7 +198,6 @@ impl DeltaGenerator {
&mut self, &mut self,
index: u32, index: u32,
text: Option<String>, text: Option<String>,
reasoning_content: Option<String>,
finish_reason: Option<dynamo_async_openai::types::FinishReason>, finish_reason: Option<dynamo_async_openai::types::FinishReason>,
logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>, logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
) -> NvCreateChatCompletionStreamResponse { ) -> NvCreateChatCompletionStreamResponse {
...@@ -259,7 +211,7 @@ impl DeltaGenerator { ...@@ -259,7 +211,7 @@ impl DeltaGenerator {
None None
}, },
refusal: None, refusal: None,
reasoning_content, reasoning_content: None,
}; };
let choice = dynamo_async_openai::types::ChatChoiceStream { let choice = dynamo_async_openai::types::ChatChoiceStream {
...@@ -371,25 +323,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -371,25 +323,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
None => None, None => None,
}; };
// Handle reasoning parsing if enabled, otherwise treat all text as normal
let (normal_text, reasoning_content) =
match self.create_reasoning_content(&delta.text, &delta.token_ids) {
Some(reasoning_parser_result) => (
reasoning_parser_result.get_some_normal_text(),
reasoning_parser_result.get_some_reasoning(),
),
None => (delta.text, None),
};
// Create the streaming response. // Create the streaming response.
let index = 0; let index = 0;
let stream_response = self.create_choice( let stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
index,
normal_text,
reasoning_content,
finish_reason,
logprobs,
);
Ok(stream_response) Ok(stream_response)
} }
......
...@@ -89,7 +89,7 @@ impl ...@@ -89,7 +89,7 @@ impl
let stream = stream! { let stream = stream! {
tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await; tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await;
for i in 0..10 { for i in 0..10 {
let output = generator.create_choice(i,Some(format!("choice {i}")), None, None, None); let output = generator.create_choice(i,Some(format!("choice {i}")), None, None);
yield Annotated::from_data(output); yield Annotated::from_data(output);
} }
......
...@@ -52,7 +52,7 @@ impl ...@@ -52,7 +52,7 @@ impl
// Generate 5 response chunks // Generate 5 response chunks
for i in 0..5 { for i in 0..5 {
let output = generator.create_choice(i, Some(format!("Mock response {i}")), None, None, None); let output = generator.create_choice(i, Some(format!("Mock response {i}")), None, None);
yield Annotated::from_data(output); yield Annotated::from_data(output);
} }
}; };
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_async_openai::types::{ChatChoiceStream, ChatCompletionStreamResponseDelta, Role};
use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use dynamo_runtime::protocols::annotated::Annotated;
use futures::{StreamExt, stream};
/// Helper function to create a mock chat response chunk
fn create_mock_response_chunk(
content: String,
reasoning_content: Option<String>,
) -> Annotated<NvCreateChatCompletionStreamResponse> {
#[allow(deprecated)]
let choice = ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: Some(content),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content,
},
finish_reason: None,
logprobs: None,
};
let response = NvCreateChatCompletionStreamResponse {
id: "test-id".to_string(),
choices: vec![choice],
created: 1234567890,
model: "test-model".to_string(),
system_fingerprint: Some("test-fingerprint".to_string()),
object: "chat.completion.chunk".to_string(),
usage: None,
service_tier: None,
};
Annotated {
id: Some("test-id".to_string()),
data: Some(response),
event: None,
comment: None,
}
}
#[cfg(test)]
mod tests {
use super::*;
/// Helper function to assert choice content and reasoning content
fn assert_choice(
choice: &ChatChoiceStream,
expected_content: Option<&str>,
expected_reasoning_content: Option<&str>,
) {
match expected_content {
Some(expected) => {
assert_eq!(
choice.delta.content.as_deref(),
Some(expected),
"Content mismatch"
);
}
None => {
assert!(
choice.delta.content.is_none()
|| choice.delta.content.as_ref().unwrap().is_empty(),
"Expected content to be None or empty, got: {:?}",
choice.delta.content
);
}
}
match expected_reasoning_content {
Some(expected) => {
assert_eq!(
choice.delta.reasoning_content.as_deref(),
Some(expected),
"Reasoning content mismatch"
);
}
None => {
assert!(
choice.delta.reasoning_content.is_none(),
"Expected reasoning content to be None, got: {:?}",
choice.delta.reasoning_content
);
}
}
}
#[tokio::test]
async fn test_reasoning_parser_with_basic_parser() {
// Basic Parser test <think> </think> tags
// <think> This is reasoning content </think> Here's my answer.
// content: Here's my answer.
// reasoning_content: This is reasoning content
// Create a mock runtime config with basic reasoning parser
let runtime_config = dynamo_llm::local_model::runtime_config::ModelRuntimeConfig {
reasoning_parser: Some("basic".to_string()),
..Default::default()
};
// Create test input stream with reasoning content
let input_chunks = vec![
create_mock_response_chunk("<think>This".to_string(), None),
create_mock_response_chunk(" is reasoning content".to_string(), None),
create_mock_response_chunk("</think> Here's my answer.".to_string(), None),
];
let input_stream = stream::iter(input_chunks);
// Apply the reasoning parser transformation
let output_stream = OpenAIPreprocessor::parse_reasoning_content_from_stream(
input_stream,
runtime_config.reasoning_parser.unwrap(),
);
// Pin the stream and collect all output chunks
let mut output_stream = std::pin::pin!(output_stream);
let mut output_chunks = Vec::new();
while let Some(chunk) = output_stream.next().await {
output_chunks.push(chunk);
}
// Verify that reasoning content was parsed correctly
assert_eq!(output_chunks.len(), 3);
// Chunk 0: "<think>This"
let output_choice_0 = &output_chunks[0].data.as_ref().unwrap().choices[0];
assert_choice(output_choice_0, None, Some("This"));
// Chunk 1: " is reasoning content"
let output_choice_1 = &output_chunks[1].data.as_ref().unwrap().choices[0];
assert_choice(output_choice_1, None, Some(" is reasoning content"));
// Chunk 2: "</think> Here's my answer."
let output_choice_2 = &output_chunks[2].data.as_ref().unwrap().choices[0];
assert_choice(output_choice_2, Some(" Here's my answer."), None);
}
#[tokio::test]
async fn test_reasoning_parser_with_only_reasoning_content() {
// Create a mock runtime config with basic reasoning parser
let runtime_config = dynamo_llm::local_model::runtime_config::ModelRuntimeConfig {
reasoning_parser: Some("basic".to_string()),
..Default::default()
};
// Create test input stream with only reasoning content
let input_chunks = vec![
create_mock_response_chunk("<think>Only".to_string(), None),
create_mock_response_chunk(" reasoning".to_string(), None),
create_mock_response_chunk(" here</think>".to_string(), None),
];
let input_stream = stream::iter(input_chunks);
// Apply the reasoning parser transformation
let output_stream = OpenAIPreprocessor::parse_reasoning_content_from_stream(
input_stream,
runtime_config.reasoning_parser.unwrap(),
);
// Pin the stream and collect all output chunks
let mut output_stream = std::pin::pin!(output_stream);
let mut output_chunks = Vec::new();
while let Some(chunk) = output_stream.next().await {
output_chunks.push(chunk);
}
// Verify that reasoning content was parsed correctly across three chunks
assert_eq!(output_chunks.len(), 3);
// Chunk 0: "<think>Only"
let output_choice_0 = &output_chunks[0].data.as_ref().unwrap().choices[0];
assert_choice(output_choice_0, None, Some("Only"));
// Chunk 1: " reasoning"
let output_choice_1 = &output_chunks[1].data.as_ref().unwrap().choices[0];
assert_choice(output_choice_1, None, Some(" reasoning"));
// Chunk 2: " here</think>"
let output_choice_2 = &output_chunks[2].data.as_ref().unwrap().choices[0];
assert_choice(output_choice_2, None, Some(" here"));
}
#[tokio::test]
async fn test_reasoning_parser_with_only_normal_content() {
// Create a mock runtime config with basic reasoning parser
let runtime_config = dynamo_llm::local_model::runtime_config::ModelRuntimeConfig {
reasoning_parser: Some("basic".to_string()),
..Default::default()
};
// Create test input stream with only normal content (no reasoning tags)
let input_chunks = vec![create_mock_response_chunk(
"Just normal text without reasoning tags.".to_string(),
None,
)];
let input_stream = stream::iter(input_chunks);
// Apply the reasoning parser transformation
let output_stream = OpenAIPreprocessor::parse_reasoning_content_from_stream(
input_stream,
runtime_config.reasoning_parser.unwrap(),
);
// Pin the stream and collect all output chunks
let mut output_stream = std::pin::pin!(output_stream);
let mut output_chunks = Vec::new();
while let Some(chunk) = output_stream.next().await {
output_chunks.push(chunk);
}
// Verify that only normal content is present
assert_eq!(output_chunks.len(), 1);
let output_choice = &output_chunks[0].data.as_ref().unwrap().choices[0];
assert_choice(
output_choice,
Some("Just normal text without reasoning tags."),
None,
);
}
#[tokio::test]
async fn test_reasoning_parser_with_invalid_parser_name() {
// Create a mock runtime config with invalid reasoning parser
let runtime_config = dynamo_llm::local_model::runtime_config::ModelRuntimeConfig {
reasoning_parser: Some("invalid_parser_name".to_string()),
..Default::default()
};
// Create test input stream
let input_chunks = vec![create_mock_response_chunk("Hello world!".to_string(), None)];
let input_stream = stream::iter(input_chunks.clone());
// Apply the reasoning parser transformation
let output_stream = OpenAIPreprocessor::parse_reasoning_content_from_stream(
input_stream,
runtime_config.reasoning_parser.unwrap(),
);
// Pin the stream and collect all output chunks
let mut output_stream = std::pin::pin!(output_stream);
let mut output_chunks = Vec::new();
while let Some(chunk) = output_stream.next().await {
output_chunks.push(chunk);
}
// Verify that invalid parser name results in passthrough behavior
assert_eq!(output_chunks.len(), input_chunks.len());
for (input, output) in input_chunks.iter().zip(output_chunks.iter()) {
let input_choice = &input.data.as_ref().unwrap().choices[0];
let output_choice = &output.data.as_ref().unwrap().choices[0];
assert_choice(
output_choice,
input_choice.delta.content.as_deref(),
input_choice.delta.reasoning_content.as_deref(),
);
}
}
#[tokio::test]
async fn test_reasoning_parser_with_mistral_parser() {
// Create a mock runtime config with mistral reasoning parser
let runtime_config = dynamo_llm::local_model::runtime_config::ModelRuntimeConfig {
reasoning_parser: Some("mistral".to_string()),
..Default::default()
};
// Create test input stream with Mistral-style reasoning tags
let input_chunks = vec![create_mock_response_chunk(
"Let me think. [THINK]This is Mistral reasoning[/THINK] Here's my answer.".to_string(),
None,
)];
let input_stream = stream::iter(input_chunks);
// Apply the reasoning parser transformation
let output_stream = OpenAIPreprocessor::parse_reasoning_content_from_stream(
input_stream,
runtime_config.reasoning_parser.unwrap(),
);
// Pin the stream and collect all output chunks
let mut output_stream = std::pin::pin!(output_stream);
let mut output_chunks = Vec::new();
while let Some(chunk) = output_stream.next().await {
output_chunks.push(chunk);
}
// Verify that Mistral-style reasoning is parsed correctly
assert_eq!(output_chunks.len(), 1);
let output_choice = &output_chunks[0].data.as_ref().unwrap().choices[0];
assert!(
output_choice.delta.reasoning_content.is_some(),
"Should extract Mistral reasoning content"
);
assert!(
output_choice.delta.content.is_some(),
"Should have normal content"
);
let reasoning_content = output_choice.delta.reasoning_content.as_ref().unwrap();
let normal_content = output_choice.delta.content.as_ref().unwrap();
// Verify the content was parsed with Mistral tags
assert!(
reasoning_content.contains("Mistral reasoning"),
"Should contain Mistral reasoning content"
);
assert!(
normal_content.contains("Let me think") || normal_content.contains("Here's my answer"),
"Should contain normal content"
);
}
#[tokio::test]
async fn test_reasoning_parser_with_gpt_oss_parser() {
let input_chunks = vec![
// Chunk 1: Start of analysis channel
create_mock_response_chunk("<|channel|>".to_string(), None),
// Chunk 2: Analysis channel with reasoning content
create_mock_response_chunk(
"analysis<|message|>Let me analyze this question carefully.".to_string(),
None,
),
// Chunk 3: Continue reasoning content
create_mock_response_chunk(
" The user is asking about weather in San Francisco.".to_string(),
None,
),
// Chunk 4: End analysis and start assistant final channel
create_mock_response_chunk(
"<|end|><|start|>assistant<|channel|>final<|message|>".to_string(),
None,
),
// Chunk 5: Normal content (final response)
create_mock_response_chunk(
"I can help you with the weather in San Francisco.".to_string(),
None,
),
];
let input_stream = stream::iter(input_chunks);
// Apply the reasoning parser transformation
let output_stream = OpenAIPreprocessor::parse_reasoning_content_from_stream(
input_stream,
"gpt_oss".to_string(),
);
// Pin the stream and collect all output chunks
let mut output_stream = std::pin::pin!(output_stream);
let mut output_chunks = Vec::new();
while let Some(chunk) = output_stream.next().await {
output_chunks.push(chunk);
}
// Verify we got output chunks
assert!(!output_chunks.is_empty(), "Should have output chunks");
// Collect all reasoning content and normal content across all chunks
let mut all_reasoning = String::new();
let mut all_normal_content = String::new();
for chunk in output_chunks.iter() {
if let Some(ref response_data) = chunk.data {
for choice in &response_data.choices {
// Collect reasoning content
if let Some(ref reasoning) = choice.delta.reasoning_content {
all_reasoning.push_str(reasoning);
}
// Collect normal content
if let Some(ref content) = choice.delta.content {
all_normal_content.push_str(content);
}
}
}
}
// Assert reasoning content was parsed correctly
assert_eq!(
all_reasoning,
"Let me analyze this question carefully. The user is asking about weather in San Francisco.",
"Reasoning content should exactly match expected text. Got: {}",
all_reasoning
);
// Assert normal content was parsed correctly
assert_eq!(
all_normal_content, "I can help you with the weather in San Francisco.",
"Normal content should exactly match expected text. Got: {}",
all_normal_content
);
}
#[tokio::test]
async fn test_reasoning_parser_with_kimi_parser() {
// Create a mock runtime config with Kimi reasoning parser
let runtime_config = dynamo_llm::local_model::runtime_config::ModelRuntimeConfig {
reasoning_parser: Some("kimi".to_string()),
..Default::default()
};
// Create test input stream with Kimi-style reasoning tags
let input_chunks = vec![
create_mock_response_chunk("Let me analyze this. ◁think▷This is Kimi reasoning content◁/think▷ Here's my conclusion.".to_string(), None),
];
let input_stream = stream::iter(input_chunks);
// Apply the reasoning parser transformation
let output_stream = OpenAIPreprocessor::parse_reasoning_content_from_stream(
input_stream,
runtime_config.reasoning_parser.unwrap(),
);
// Pin the stream and collect all output chunks
let mut output_stream = std::pin::pin!(output_stream);
let mut output_chunks = Vec::new();
while let Some(chunk) = output_stream.next().await {
output_chunks.push(chunk);
}
// Verify that Kimi-style reasoning is parsed correctly
assert_eq!(output_chunks.len(), 1);
let output_choice = &output_chunks[0].data.as_ref().unwrap().choices[0];
assert!(
output_choice.delta.reasoning_content.is_some(),
"Should extract Kimi reasoning content"
);
assert!(
output_choice.delta.content.is_some(),
"Should have normal content"
);
let reasoning_content = output_choice.delta.reasoning_content.as_ref().unwrap();
let normal_content = output_choice.delta.content.as_ref().unwrap();
// Verify the content was parsed with Kimi tags
assert!(
reasoning_content.contains("Kimi reasoning"),
"Should contain Kimi reasoning content"
);
assert!(
normal_content.contains("Let me analyze")
|| normal_content.contains("Here's my conclusion"),
"Should contain normal content"
);
}
#[tokio::test]
async fn test_nemotron_with_reasoning_and_tool_calls() {
let input_chunks = vec![
// Chunk 1: Start of reasoning
create_mock_response_chunk("<think>I need to".to_string(), None),
// Chunk 2: Continue reasoning
create_mock_response_chunk(" check the weather first</think>".to_string(), None),
// Chunk 3: Normal text after reasoning
create_mock_response_chunk("Let me help you with that. ".to_string(), None),
// Chunk 4: Tool call start
create_mock_response_chunk("<TOOLCALL>[{\"name\": \"get_weather\",".to_string(), None),
// Chunk 5: Tool call arguments
create_mock_response_chunk(
" \"arguments\": {\"location\": \"San Francisco\"}}]".to_string(),
None,
),
// Chunk 6: Tool call end
create_mock_response_chunk("</TOOLCALL>".to_string(), None),
];
let input_stream = stream::iter(input_chunks);
// Step 1: Apply reasoning parser transformation
let reasoning_parsed_stream = OpenAIPreprocessor::parse_reasoning_content_from_stream(
input_stream,
"nemotron_deci".to_string(),
);
// Step 2: Apply tool calling jail transformation
let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail(
"nemotron_deci".to_string(),
reasoning_parsed_stream,
);
// Collect all output chunks
let mut tool_parsed_stream = std::pin::pin!(tool_parsed_stream);
let mut output_chunks = Vec::new();
while let Some(chunk) = tool_parsed_stream.next().await {
output_chunks.push(chunk);
}
// Verify we got output chunks
assert!(!output_chunks.is_empty(), "Should have output chunks");
// Collect all reasoning content, normal content, and check for tool calls
let mut all_reasoning = String::new();
let mut all_normal_content = String::new();
let mut found_tool_calls = false;
let mut tool_call_function_name: Option<String> = None;
let mut tool_call_arguments: Option<serde_json::Value> = None;
for chunk in output_chunks.iter() {
if let Some(ref response_data) = chunk.data {
for choice in &response_data.choices {
// Collect reasoning content
if let Some(ref reasoning) = choice.delta.reasoning_content {
all_reasoning.push_str(reasoning);
}
// Collect normal content
if let Some(ref content) = choice.delta.content {
all_normal_content.push_str(content);
}
// Check for tool calls
if let Some(ref tool_calls) = choice.delta.tool_calls
&& !tool_calls.is_empty()
{
found_tool_calls = true;
// Extract tool call details
for tool_call in tool_calls {
if let Some(ref function) = tool_call.function {
if let Some(ref name) = function.name {
tool_call_function_name = Some(name.clone());
}
if let Some(ref args) = function.arguments {
tool_call_arguments = Some(serde_json::from_str(args).unwrap());
}
}
}
}
}
}
}
// Assert reasoning content was parsed correctly
assert_eq!(
all_reasoning, "I need to check the weather first",
"Reasoning content should exactly match expected text. Got: {}",
all_reasoning
);
// Assert normal content was parsed correctly
assert_eq!(
all_normal_content, "Let me help you with that. ",
"Normal content should exactly match expected text. Got: {}",
all_normal_content
);
// Assert tool calls were parsed correctly
assert!(
found_tool_calls,
"Should have found tool calls in the output"
);
assert_eq!(
tool_call_function_name.as_deref(),
Some("get_weather"),
"Tool call function name should be 'get_weather'"
);
assert_eq!(
tool_call_arguments.as_ref(),
Some(&serde_json::json!({"location": "San Francisco"})),
"Tool call arguments should exactly match expected value"
);
}
#[tokio::test]
#[ignore]
// (TODO: Ayush) Fix this test
async fn test_gpt_oss_with_reasoning_and_tool_calls_full() {
let input_chunks = vec![
create_mock_response_chunk("<|channel|>analysis<|message|>Let me help you with that. I need to check the weather first.<|end|>".to_string(), None),
create_mock_response_chunk("<|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{\"location\":\"San Francisco\"}".to_string(), None),
create_mock_response_chunk("<|start|>assistant<|channel|>final<|message|>I'll check the weather for you.".to_string(), None),
];
let input_stream = stream::iter(input_chunks);
let reasoning_parsed_stream = OpenAIPreprocessor::parse_reasoning_content_from_stream(
input_stream,
"gpt_oss".to_string(),
);
let mut debug_stream = std::pin::pin!(reasoning_parsed_stream);
let mut debug_chunks = Vec::new();
while let Some(chunk) = debug_stream.next().await {
debug_chunks.push(chunk);
}
// Re-create a stream from the debug_chunks for further processing
let reasoning_parsed_stream = stream::iter(debug_chunks);
let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail(
"harmony".to_string(),
reasoning_parsed_stream,
);
let mut tool_parsed_stream = std::pin::pin!(tool_parsed_stream);
let mut output_chunks = Vec::new();
while let Some(chunk) = tool_parsed_stream.next().await {
output_chunks.push(chunk);
}
assert!(!output_chunks.is_empty(), "Should have output chunks");
let mut all_reasoning = String::new();
let mut all_normal_content = String::new();
let mut found_tool_calls = false;
for chunk in output_chunks.iter() {
if let Some(ref response_data) = chunk.data {
for choice in &response_data.choices {
if let Some(ref reasoning) = choice.delta.reasoning_content {
all_reasoning.push_str(reasoning);
}
if let Some(ref content) = choice.delta.content {
all_normal_content.push_str(content);
}
if let Some(ref tool_calls) = choice.delta.tool_calls
&& !tool_calls.is_empty()
{
found_tool_calls = true;
}
}
}
}
assert_eq!(
all_reasoning,
"Let me analyze this request. I need to get the current weather for San Francisco."
);
assert!(all_normal_content.contains("I'll check the weather for you"));
assert!(found_tool_calls, "Should have found tool calls");
}
}
...@@ -119,7 +119,7 @@ impl ReasoningParser for BasicReasoningParser { ...@@ -119,7 +119,7 @@ impl ReasoningParser for BasicReasoningParser {
}; };
return ParserResult { return ParserResult {
normal_text: normal_text.to_string(), normal_text: normal_text.to_string(),
reasoning_text: reasoning_text.trim().to_string(), reasoning_text: reasoning_text.to_string(),
}; };
} }
// Continue with reasoning content // Continue with reasoning content
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment