Unverified Commit 8e8152a1 authored by nachiketb-nvidia's avatar nachiketb-nvidia Committed by GitHub
Browse files

feat: enable basic reasoning parsing of <think> </think> tokens (#2555)

parent ae4fb58f
...@@ -183,7 +183,7 @@ impl ...@@ -183,7 +183,7 @@ impl
incoming_request: SingleIn<NvCreateChatCompletionRequest>, incoming_request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> { ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
let (request, context) = incoming_request.transfer(()); let (request, context) = incoming_request.transfer(());
let deltas = request.response_generator(); let mut deltas = request.response_generator();
let ctx = context.context(); let ctx = context.context();
let req = request.inner.messages.into_iter().next_back().unwrap(); let req = request.inner.messages.into_iter().next_back().unwrap();
......
...@@ -61,6 +61,9 @@ struct DeltaChoice { ...@@ -61,6 +61,9 @@ struct DeltaChoice {
logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>, logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
// Optional tool calls for the chat choice. // Optional tool calls for the chat choice.
tool_calls: Option<Vec<dynamo_async_openai::types::ChatCompletionMessageToolCall>>, tool_calls: Option<Vec<dynamo_async_openai::types::ChatCompletionMessageToolCall>>,
/// Optional reasoning content for the chat choice.
reasoning_content: Option<String>,
} }
impl Default for DeltaAggregator { impl Default for DeltaAggregator {
...@@ -137,6 +140,7 @@ impl DeltaAggregator { ...@@ -137,6 +140,7 @@ impl DeltaAggregator {
finish_reason: None, finish_reason: None,
logprobs: choice.logprobs, logprobs: choice.logprobs,
tool_calls: None, tool_calls: None,
reasoning_content: None,
}); });
// Append content if available. // Append content if available.
...@@ -144,6 +148,13 @@ impl DeltaAggregator { ...@@ -144,6 +148,13 @@ impl DeltaAggregator {
state_choice.text.push_str(content); state_choice.text.push_str(content);
} }
if let Some(reasoning_content) = &choice.delta.reasoning_content {
state_choice
.reasoning_content
.get_or_insert_with(String::new)
.push_str(reasoning_content);
}
// Update finish reason if provided. // Update finish reason if provided.
if let Some(finish_reason) = choice.finish_reason { if let Some(finish_reason) = choice.finish_reason {
state_choice.finish_reason = Some(finish_reason); state_choice.finish_reason = Some(finish_reason);
...@@ -228,7 +239,7 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice { ...@@ -228,7 +239,7 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice {
refusal: None, refusal: None,
function_call: None, function_call: None,
audio: None, audio: None,
reasoning_content: None, reasoning_content: delta.reasoning_content,
}, },
index: delta.index, index: delta.index,
finish_reason: delta.finish_reason, finish_reason: delta.finish_reason,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use dynamo_parsers::{ParserResult, ReasoningParser, ReasoningParserType, ReasoningParserWrapper};
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}; use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
use crate::{ use crate::{
protocols::common::{self}, protocols::common::{self},
...@@ -42,7 +44,6 @@ pub struct DeltaGenerator { ...@@ -42,7 +44,6 @@ pub struct DeltaGenerator {
object: String, object: String,
/// Timestamp (Unix epoch) when the response was created. /// Timestamp (Unix epoch) when the response was created.
created: u32, created: u32,
/// Model name used for generating responses.
model: String, model: String,
/// Optional system fingerprint for version tracking. /// Optional system fingerprint for version tracking.
system_fingerprint: Option<String>, system_fingerprint: Option<String>,
...@@ -54,6 +55,10 @@ pub struct DeltaGenerator { ...@@ -54,6 +55,10 @@ pub struct DeltaGenerator {
msg_counter: u64, msg_counter: u64,
/// Configuration options for response generation. /// Configuration options for response generation.
options: DeltaGeneratorOptions, options: DeltaGeneratorOptions,
/// Reasoning Parser object
/// This is used to parse reasoning content in the response.
reasoning_parser: ReasoningParserWrapper,
} }
impl DeltaGenerator { impl DeltaGenerator {
...@@ -83,6 +88,14 @@ impl DeltaGenerator { ...@@ -83,6 +88,14 @@ impl DeltaGenerator {
completion_tokens_details: None, completion_tokens_details: None,
}; };
// Reasoning parser type
// This is hardcoded for now, but can be made configurable later.
// TODO: Make parser type configurable once front-end integration is determined
let reasoning_parser_type = ReasoningParserType::Basic;
// Reasoning parser wrapper
let reasoning_parser = reasoning_parser_type.get_reasoning_parser();
Self { Self {
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()), id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
object: "chat.completion.chunk".to_string(), object: "chat.completion.chunk".to_string(),
...@@ -93,6 +106,7 @@ impl DeltaGenerator { ...@@ -93,6 +106,7 @@ impl DeltaGenerator {
usage, usage,
msg_counter: 0, msg_counter: 0,
options, options,
reasoning_parser,
} }
} }
...@@ -169,6 +183,15 @@ impl DeltaGenerator { ...@@ -169,6 +183,15 @@ impl DeltaGenerator {
}) })
} }
fn create_reasoning_content(&mut self, text: Option<String>) -> Option<ParserResult> {
let text = text?;
let parser_result = self
.reasoning_parser
.parse_reasoning_streaming_incremental(&text);
Some(parser_result)
}
/// Creates a choice within a chat completion response. /// Creates a choice within a chat completion response.
/// ///
/// # Arguments /// # Arguments
...@@ -181,14 +204,20 @@ impl DeltaGenerator { ...@@ -181,14 +204,20 @@ impl DeltaGenerator {
/// * An [`dynamo_async_openai::types::CreateChatCompletionStreamResponse`] instance representing the choice. /// * An [`dynamo_async_openai::types::CreateChatCompletionStreamResponse`] instance representing the choice.
#[allow(deprecated)] #[allow(deprecated)]
pub fn create_choice( pub fn create_choice(
&self, &mut self,
index: u32, index: u32,
text: Option<String>, text: Option<String>,
finish_reason: Option<dynamo_async_openai::types::FinishReason>, finish_reason: Option<dynamo_async_openai::types::FinishReason>,
logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>, logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
) -> dynamo_async_openai::types::CreateChatCompletionStreamResponse { ) -> NvCreateChatCompletionStreamResponse {
let reasoning_parser_result = self.create_reasoning_content(text).unwrap_or_default();
let (normal_text, reasoning_content) = (
reasoning_parser_result.get_some_normal_text(),
reasoning_parser_result.get_some_reasoning(),
);
let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta { let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
content: text, content: normal_text,
function_call: None, function_call: None,
tool_calls: None, tool_calls: None,
role: if self.msg_counter == 0 { role: if self.msg_counter == 0 {
...@@ -197,7 +226,7 @@ impl DeltaGenerator { ...@@ -197,7 +226,7 @@ impl DeltaGenerator {
None None
}, },
refusal: None, refusal: None,
reasoning_content: None, reasoning_content,
}; };
let choice = dynamo_async_openai::types::ChatChoiceStream { let choice = dynamo_async_openai::types::ChatChoiceStream {
......
...@@ -95,7 +95,7 @@ impl ...@@ -95,7 +95,7 @@ impl
let max_tokens = request.inner.max_tokens.unwrap_or(0) as u64; let max_tokens = request.inner.max_tokens.unwrap_or(0) as u64;
// let generator = NvCreateChatCompletionStreamResponse::generator(request.model.clone()); // let generator = NvCreateChatCompletionStreamResponse::generator(request.model.clone());
let generator = request.response_generator(); let mut generator = request.response_generator();
let stream = stream! { let stream = stream! {
tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await; tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await;
......
...@@ -2,24 +2,10 @@ ...@@ -2,24 +2,10 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use tracing as log; use tracing as log;
pub struct ParserResult { use crate::{ParserResult, ReasoningParser};
/// The normal text outside of reasoning blocks.
pub normal_text: String,
/// The extracted reasoning text from within reasoning blocks. #[derive(Default, Debug, Clone)]
pub reasoning_text: String, pub struct BasicReasoningParser {
}
pub trait ReasoningParser {
/// Detects and parses reasoning from the input text.
fn detect_and_parse_reasoning(&mut self, text: &str) -> ParserResult;
/// Parses reasoning incrementally from streaming input.
fn parse_reasoning_streaming_incremental(&mut self, text: &str) -> ParserResult;
}
#[derive(Default)]
pub struct BaseReasoningParser {
think_start_token: String, think_start_token: String,
think_end_token: String, think_end_token: String,
_in_reasoning: bool, _in_reasoning: bool,
...@@ -28,7 +14,7 @@ pub struct BaseReasoningParser { ...@@ -28,7 +14,7 @@ pub struct BaseReasoningParser {
stripped_think_start: bool, stripped_think_start: bool,
} }
impl BaseReasoningParser { impl BasicReasoningParser {
pub fn new( pub fn new(
think_start_token: String, think_start_token: String,
think_end_token: String, think_end_token: String,
...@@ -46,8 +32,8 @@ impl BaseReasoningParser { ...@@ -46,8 +32,8 @@ impl BaseReasoningParser {
} }
} }
impl ReasoningParser for BaseReasoningParser { impl ReasoningParser for BasicReasoningParser {
fn detect_and_parse_reasoning(&mut self, text: &str) -> ParserResult { fn detect_and_parse_reasoning(&self, text: &str) -> ParserResult {
log::debug!("detect_and_parse_reasoning called with text: {:?}", text); log::debug!("detect_and_parse_reasoning called with text: {:?}", text);
let in_reasoning = self._in_reasoning || text.contains(&self.think_start_token); let in_reasoning = self._in_reasoning || text.contains(&self.think_start_token);
...@@ -194,8 +180,8 @@ mod tests { ...@@ -194,8 +180,8 @@ mod tests {
#[test] #[test]
fn test_detect_and_parse_reasoning_reasoning() { fn test_detect_and_parse_reasoning_reasoning() {
let mut parser = let parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = let result =
parser.detect_and_parse_reasoning("<think>with reasoning</think> and more text."); parser.detect_and_parse_reasoning("<think>with reasoning</think> and more text.");
assert_eq!(result.normal_text, "and more text."); assert_eq!(result.normal_text, "and more text.");
...@@ -203,16 +189,16 @@ mod tests { ...@@ -203,16 +189,16 @@ mod tests {
} }
#[test] #[test]
fn test_detect_and_parse_reasoning_reasoning_no_reasoning() { fn test_detect_and_parse_reasoning_reasoning_no_reasoning() {
let mut parser = let parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning("This is a test without reasoning."); let result = parser.detect_and_parse_reasoning("This is a test without reasoning.");
assert_eq!(result.normal_text, "This is a test without reasoning."); assert_eq!(result.normal_text, "This is a test without reasoning.");
assert_eq!(result.reasoning_text, ""); assert_eq!(result.reasoning_text, "");
} }
#[test] #[test]
fn test_detect_and_parse_reasoning_reasoning_truncated_reasoning() { fn test_detect_and_parse_reasoning_reasoning_truncated_reasoning() {
let mut parser = let parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning("<think>with truncated reasoning"); let result = parser.detect_and_parse_reasoning("<think>with truncated reasoning");
assert_eq!(result.normal_text, ""); assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "with truncated reasoning"); assert_eq!(result.reasoning_text, "with truncated reasoning");
...@@ -221,7 +207,7 @@ mod tests { ...@@ -221,7 +207,7 @@ mod tests {
#[test] #[test]
fn test_parse_reasoning_streaming_incremental() { fn test_parse_reasoning_streaming_incremental() {
let mut parser = let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.parse_reasoning_streaming_incremental("<thi"); let result = parser.parse_reasoning_streaming_incremental("<thi");
assert_eq!(result.normal_text, ""); assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, ""); assert_eq!(result.reasoning_text, "");
...@@ -230,7 +216,7 @@ mod tests { ...@@ -230,7 +216,7 @@ mod tests {
#[test] #[test]
fn test_parse_reasoning_streaming_incremental_complete() { fn test_parse_reasoning_streaming_incremental_complete() {
let mut parser = let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser let result = parser
.parse_reasoning_streaming_incremental("<think>with reasoning</think> and more text."); .parse_reasoning_streaming_incremental("<think>with reasoning</think> and more text.");
assert_eq!(result.normal_text, " and more text."); assert_eq!(result.normal_text, " and more text.");
...@@ -240,7 +226,7 @@ mod tests { ...@@ -240,7 +226,7 @@ mod tests {
#[test] #[test]
fn test_parse_reasoning_streaming_incremental_no_end_token() { fn test_parse_reasoning_streaming_incremental_no_end_token() {
let mut parser = let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true);
let result = parser.parse_reasoning_streaming_incremental("<think>with reasoning"); let result = parser.parse_reasoning_streaming_incremental("<think>with reasoning");
assert_eq!(result.normal_text, ""); assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "with reasoning"); assert_eq!(result.reasoning_text, "with reasoning");
...@@ -248,8 +234,8 @@ mod tests { ...@@ -248,8 +234,8 @@ mod tests {
#[test] #[test]
fn test_detect_and_parse_reasoning_multiple_reasoning_blocks() { fn test_detect_and_parse_reasoning_multiple_reasoning_blocks() {
let mut parser = let parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning( let result = parser.detect_and_parse_reasoning(
"<think>first reasoning</think> middle <think>second reasoning</think> end", "<think>first reasoning</think> middle <think>second reasoning</think> end",
); );
...@@ -261,7 +247,7 @@ mod tests { ...@@ -261,7 +247,7 @@ mod tests {
#[test] #[test]
fn test_streaming_multiple_reasoning_blocks() { fn test_streaming_multiple_reasoning_blocks() {
let mut parser = let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
let result1 = let result1 =
parser.parse_reasoning_streaming_incremental("<think>first reasoning</think> middle"); parser.parse_reasoning_streaming_incremental("<think>first reasoning</think> middle");
assert_eq!(result1.normal_text, " middle"); assert_eq!(result1.normal_text, " middle");
...@@ -277,7 +263,7 @@ mod tests { ...@@ -277,7 +263,7 @@ mod tests {
#[test] #[test]
fn test_partial_token_matching_opening_tag() { fn test_partial_token_matching_opening_tag() {
let mut parser = let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
// Feed partial opening tag // Feed partial opening tag
let result1 = parser.parse_reasoning_streaming_incremental("<th"); let result1 = parser.parse_reasoning_streaming_incremental("<th");
...@@ -294,7 +280,7 @@ mod tests { ...@@ -294,7 +280,7 @@ mod tests {
#[test] #[test]
fn test_partial_token_matching_closing_tag() { fn test_partial_token_matching_closing_tag() {
let mut parser = let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
// Start with complete opening and partial content // Start with complete opening and partial content
let result1 = parser.parse_reasoning_streaming_incremental("<think>reasoning content</th"); let result1 = parser.parse_reasoning_streaming_incremental("<think>reasoning content</th");
...@@ -310,7 +296,7 @@ mod tests { ...@@ -310,7 +296,7 @@ mod tests {
#[test] #[test]
fn test_buffer_state_persistence_across_calls() { fn test_buffer_state_persistence_across_calls() {
let mut parser = let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
// First call - partial opening tag // First call - partial opening tag
let result1 = parser.parse_reasoning_streaming_incremental("<th"); let result1 = parser.parse_reasoning_streaming_incremental("<th");
...@@ -336,7 +322,7 @@ mod tests { ...@@ -336,7 +322,7 @@ mod tests {
#[test] #[test]
fn test_streaming_with_stream_reasoning_enabled() { fn test_streaming_with_stream_reasoning_enabled() {
let mut parser = let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
// Start reasoning block // Start reasoning block
let result1 = parser.parse_reasoning_streaming_incremental("<think>reasoning "); let result1 = parser.parse_reasoning_streaming_incremental("<think>reasoning ");
...@@ -356,8 +342,8 @@ mod tests { ...@@ -356,8 +342,8 @@ mod tests {
#[test] #[test]
fn test_nested_reasoning_blocks() { fn test_nested_reasoning_blocks() {
let mut parser = let parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning( let result = parser.detect_and_parse_reasoning(
"<think>outer <think>inner</think> reasoning</think> normal", "<think>outer <think>inner</think> reasoning</think> normal",
); );
...@@ -369,8 +355,8 @@ mod tests { ...@@ -369,8 +355,8 @@ mod tests {
#[test] #[test]
fn test_malformed_missing_closing_tag() { fn test_malformed_missing_closing_tag() {
let mut parser = let parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning("<think>reasoning without closing tag"); let result = parser.detect_and_parse_reasoning("<think>reasoning without closing tag");
assert_eq!(result.normal_text, ""); assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "reasoning without closing tag"); assert_eq!(result.reasoning_text, "reasoning without closing tag");
...@@ -378,8 +364,8 @@ mod tests { ...@@ -378,8 +364,8 @@ mod tests {
#[test] #[test]
fn test_malformed_stray_closing_tag() { fn test_malformed_stray_closing_tag() {
let mut parser = let parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning("normal text</think> more normal"); let result = parser.detect_and_parse_reasoning("normal text</think> more normal");
assert_eq!(result.normal_text, "normal text</think> more normal"); assert_eq!(result.normal_text, "normal text</think> more normal");
assert_eq!(result.reasoning_text, ""); assert_eq!(result.reasoning_text, "");
...@@ -387,8 +373,8 @@ mod tests { ...@@ -387,8 +373,8 @@ mod tests {
#[test] #[test]
fn test_malformed_multiple_opening_tags() { fn test_malformed_multiple_opening_tags() {
let mut parser = let parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser let result = parser
.detect_and_parse_reasoning("<think>first <think>second reasoning</think> normal"); .detect_and_parse_reasoning("<think>first <think>second reasoning</think> normal");
// Should handle by replacing all opening tags and using first closing tag // Should handle by replacing all opening tags and using first closing tag
...@@ -398,8 +384,8 @@ mod tests { ...@@ -398,8 +384,8 @@ mod tests {
#[test] #[test]
fn test_empty_reasoning_block() { fn test_empty_reasoning_block() {
let mut parser = let parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning("<think></think> normal text"); let result = parser.detect_and_parse_reasoning("<think></think> normal text");
assert_eq!(result.normal_text, "normal text"); assert_eq!(result.normal_text, "normal text");
assert_eq!(result.reasoning_text, ""); assert_eq!(result.reasoning_text, "");
...@@ -407,8 +393,8 @@ mod tests { ...@@ -407,8 +393,8 @@ mod tests {
#[test] #[test]
fn test_whitespace_only_reasoning_block() { fn test_whitespace_only_reasoning_block() {
let mut parser = let parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning("<think> \n\t </think> normal text"); let result = parser.detect_and_parse_reasoning("<think> \n\t </think> normal text");
assert_eq!(result.normal_text, "normal text"); assert_eq!(result.normal_text, "normal text");
assert_eq!(result.reasoning_text, ""); // Should be empty after trim assert_eq!(result.reasoning_text, ""); // Should be empty after trim
...@@ -416,8 +402,8 @@ mod tests { ...@@ -416,8 +402,8 @@ mod tests {
#[test] #[test]
fn test_force_reasoning_mode() { fn test_force_reasoning_mode() {
let mut parser = let parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true);
let result = parser.detect_and_parse_reasoning("no think tags here"); let result = parser.detect_and_parse_reasoning("no think tags here");
assert_eq!(result.normal_text, ""); assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "no think tags here"); assert_eq!(result.reasoning_text, "no think tags here");
...@@ -426,7 +412,7 @@ mod tests { ...@@ -426,7 +412,7 @@ mod tests {
#[test] #[test]
fn test_streaming_reset_state_after_complete_block() { fn test_streaming_reset_state_after_complete_block() {
let mut parser = let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false); BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
// Process complete reasoning block // Process complete reasoning block
let result1 = let result1 =
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::base_parser::BaseReasoningParser; use super::base_parser::BasicReasoningParser;
use super::base_parser::ParserResult; use crate::ParserResult;
use super::base_parser::ReasoningParser; use crate::ReasoningParser;
#[derive(Default)] #[derive(Default, Debug, Clone)]
pub struct DeepseekR1ReasoningParser { pub struct DeepseekR1ReasoningParser {
base: BaseReasoningParser, base: BasicReasoningParser,
} }
impl DeepseekR1ReasoningParser { impl DeepseekR1ReasoningParser {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
base: BaseReasoningParser::new( base: BasicReasoningParser::new(
"<think>".to_string(), "<think>".to_string(),
"</think>".to_string(), "</think>".to_string(),
true, true,
...@@ -28,7 +28,7 @@ impl ReasoningParser for DeepseekR1ReasoningParser { ...@@ -28,7 +28,7 @@ impl ReasoningParser for DeepseekR1ReasoningParser {
self.base.parse_reasoning_streaming_incremental(text) self.base.parse_reasoning_streaming_incremental(text)
} }
fn detect_and_parse_reasoning(&mut self, text: &str) -> ParserResult { fn detect_and_parse_reasoning(&self, text: &str) -> ParserResult {
self.base.detect_and_parse_reasoning(text) self.base.detect_and_parse_reasoning(text)
} }
} }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
pub mod base_parser; mod base_parser;
pub mod deepseek_r1_parser; mod deepseek_r1_parser;
// Re-export main types and functions for convenience // Re-export main types and functions for convenience
pub use base_parser::ReasoningParser; pub use base_parser::BasicReasoningParser;
pub use deepseek_r1_parser::DeepseekR1ReasoningParser; pub use deepseek_r1_parser::DeepseekR1ReasoningParser;
#[derive(Debug, Clone, Default)]
pub struct ParserResult {
/// The normal text outside of reasoning blocks.
pub normal_text: String,
/// The extracted reasoning text from within reasoning blocks.
pub reasoning_text: String,
}
impl ParserResult {
pub fn get_some_reasoning(&self) -> Option<String> {
if self.reasoning_text.is_empty() {
None
} else {
Some(self.reasoning_text.clone())
}
}
pub fn get_some_normal_text(&self) -> Option<String> {
if self.normal_text.is_empty() {
None
} else {
Some(self.normal_text.clone())
}
}
}
pub trait ReasoningParser: Send + std::fmt::Debug {
/// Parses a standalone, non-streaming input chunk. Implementations may reset or ignore
/// internal streaming state and should return the split of normal vs reasoning text for
/// this complete input. Marker tokens must not be included in either output.
fn detect_and_parse_reasoning(&self, text: &str) -> ParserResult;
/// Parses a streaming chunk and updates internal state. The return value should be the
/// delta: only the newly discovered normal and reasoning text attributable to this chunk
/// (not the cumulative totals). Marker tokens must not be included in either output.
fn parse_reasoning_streaming_incremental(&mut self, text: &str) -> ParserResult;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum ReasoningParserType {
DeepseekR1,
Basic,
}
#[derive(std::fmt::Debug)]
pub struct ReasoningParserWrapper {
parser: Box<dyn ReasoningParser>,
}
impl ReasoningParser for ReasoningParserWrapper {
fn detect_and_parse_reasoning(&self, text: &str) -> ParserResult {
self.parser.detect_and_parse_reasoning(text)
}
fn parse_reasoning_streaming_incremental(&mut self, text: &str) -> ParserResult {
self.parser.parse_reasoning_streaming_incremental(text)
}
}
impl ReasoningParserType {
pub fn get_reasoning_parser(self) -> ReasoningParserWrapper {
match self {
ReasoningParserType::DeepseekR1 => ReasoningParserWrapper {
parser: Box::new(DeepseekR1ReasoningParser::new()),
},
ReasoningParserType::Basic => ReasoningParserWrapper {
parser: Box::new(BasicReasoningParser::new(
"<think>".into(),
"</think>".into(),
false,
true,
)),
},
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment