Unverified Commit 8e8152a1 authored by nachiketb-nvidia's avatar nachiketb-nvidia Committed by GitHub
Browse files

feat: enable basic reasoning parsing of <think> </think> tokens (#2555)

parent ae4fb58f
......@@ -183,7 +183,7 @@ impl
incoming_request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
let (request, context) = incoming_request.transfer(());
let deltas = request.response_generator();
let mut deltas = request.response_generator();
let ctx = context.context();
let req = request.inner.messages.into_iter().next_back().unwrap();
......
......@@ -61,6 +61,9 @@ struct DeltaChoice {
logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
// Optional tool calls for the chat choice.
tool_calls: Option<Vec<dynamo_async_openai::types::ChatCompletionMessageToolCall>>,
/// Optional reasoning content for the chat choice.
reasoning_content: Option<String>,
}
impl Default for DeltaAggregator {
......@@ -137,6 +140,7 @@ impl DeltaAggregator {
finish_reason: None,
logprobs: choice.logprobs,
tool_calls: None,
reasoning_content: None,
});
// Append content if available.
......@@ -144,6 +148,13 @@ impl DeltaAggregator {
state_choice.text.push_str(content);
}
if let Some(reasoning_content) = &choice.delta.reasoning_content {
state_choice
.reasoning_content
.get_or_insert_with(String::new)
.push_str(reasoning_content);
}
// Update finish reason if provided.
if let Some(finish_reason) = choice.finish_reason {
state_choice.finish_reason = Some(finish_reason);
......@@ -228,7 +239,7 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice {
refusal: None,
function_call: None,
audio: None,
reasoning_content: None,
reasoning_content: delta.reasoning_content,
},
index: delta.index,
finish_reason: delta.finish_reason,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_parsers::{ParserResult, ReasoningParser, ReasoningParserType, ReasoningParserWrapper};
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
use crate::{
protocols::common::{self},
......@@ -42,7 +44,6 @@ pub struct DeltaGenerator {
object: String,
/// Timestamp (Unix epoch) when the response was created.
created: u32,
/// Model name used for generating responses.
model: String,
/// Optional system fingerprint for version tracking.
system_fingerprint: Option<String>,
......@@ -54,6 +55,10 @@ pub struct DeltaGenerator {
msg_counter: u64,
/// Configuration options for response generation.
options: DeltaGeneratorOptions,
/// Reasoning Parser object
/// This is used to parse reasoning content in the response.
reasoning_parser: ReasoningParserWrapper,
}
impl DeltaGenerator {
......@@ -83,6 +88,14 @@ impl DeltaGenerator {
completion_tokens_details: None,
};
// Reasoning parser type
// This is hardcoded for now, but can be made configurable later.
// TODO: Make parser type configurable once front-end integration is determined
let reasoning_parser_type = ReasoningParserType::Basic;
// Reasoning parser wrapper
let reasoning_parser = reasoning_parser_type.get_reasoning_parser();
Self {
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
object: "chat.completion.chunk".to_string(),
......@@ -93,6 +106,7 @@ impl DeltaGenerator {
usage,
msg_counter: 0,
options,
reasoning_parser,
}
}
......@@ -169,6 +183,15 @@ impl DeltaGenerator {
})
}
fn create_reasoning_content(&mut self, text: Option<String>) -> Option<ParserResult> {
let text = text?;
let parser_result = self
.reasoning_parser
.parse_reasoning_streaming_incremental(&text);
Some(parser_result)
}
/// Creates a choice within a chat completion response.
///
/// # Arguments
......@@ -181,14 +204,20 @@ impl DeltaGenerator {
/// * An [`dynamo_async_openai::types::CreateChatCompletionStreamResponse`] instance representing the choice.
#[allow(deprecated)]
pub fn create_choice(
&self,
&mut self,
index: u32,
text: Option<String>,
finish_reason: Option<dynamo_async_openai::types::FinishReason>,
logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
) -> dynamo_async_openai::types::CreateChatCompletionStreamResponse {
) -> NvCreateChatCompletionStreamResponse {
let reasoning_parser_result = self.create_reasoning_content(text).unwrap_or_default();
let (normal_text, reasoning_content) = (
reasoning_parser_result.get_some_normal_text(),
reasoning_parser_result.get_some_reasoning(),
);
let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
content: text,
content: normal_text,
function_call: None,
tool_calls: None,
role: if self.msg_counter == 0 {
......@@ -197,7 +226,7 @@ impl DeltaGenerator {
None
},
refusal: None,
reasoning_content: None,
reasoning_content,
};
let choice = dynamo_async_openai::types::ChatChoiceStream {
......
......@@ -95,7 +95,7 @@ impl
let max_tokens = request.inner.max_tokens.unwrap_or(0) as u64;
// let generator = NvCreateChatCompletionStreamResponse::generator(request.model.clone());
let generator = request.response_generator();
let mut generator = request.response_generator();
let stream = stream! {
tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await;
......
......@@ -2,24 +2,10 @@
// SPDX-License-Identifier: Apache-2.0
use tracing as log;
pub struct ParserResult {
/// The normal text outside of reasoning blocks.
pub normal_text: String,
use crate::{ParserResult, ReasoningParser};
/// The extracted reasoning text from within reasoning blocks.
pub reasoning_text: String,
}
pub trait ReasoningParser {
/// Detects and parses reasoning from the input text.
fn detect_and_parse_reasoning(&mut self, text: &str) -> ParserResult;
/// Parses reasoning incrementally from streaming input.
fn parse_reasoning_streaming_incremental(&mut self, text: &str) -> ParserResult;
}
#[derive(Default)]
pub struct BaseReasoningParser {
#[derive(Default, Debug, Clone)]
pub struct BasicReasoningParser {
think_start_token: String,
think_end_token: String,
_in_reasoning: bool,
......@@ -28,7 +14,7 @@ pub struct BaseReasoningParser {
stripped_think_start: bool,
}
impl BaseReasoningParser {
impl BasicReasoningParser {
pub fn new(
think_start_token: String,
think_end_token: String,
......@@ -46,8 +32,8 @@ impl BaseReasoningParser {
}
}
impl ReasoningParser for BaseReasoningParser {
fn detect_and_parse_reasoning(&mut self, text: &str) -> ParserResult {
impl ReasoningParser for BasicReasoningParser {
fn detect_and_parse_reasoning(&self, text: &str) -> ParserResult {
log::debug!("detect_and_parse_reasoning called with text: {:?}", text);
let in_reasoning = self._in_reasoning || text.contains(&self.think_start_token);
......@@ -194,8 +180,8 @@ mod tests {
#[test]
fn test_detect_and_parse_reasoning_reasoning() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result =
parser.detect_and_parse_reasoning("<think>with reasoning</think> and more text.");
assert_eq!(result.normal_text, "and more text.");
......@@ -203,16 +189,16 @@ mod tests {
}
#[test]
fn test_detect_and_parse_reasoning_reasoning_no_reasoning() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning("This is a test without reasoning.");
assert_eq!(result.normal_text, "This is a test without reasoning.");
assert_eq!(result.reasoning_text, "");
}
#[test]
fn test_detect_and_parse_reasoning_reasoning_truncated_reasoning() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning("<think>with truncated reasoning");
assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "with truncated reasoning");
......@@ -221,7 +207,7 @@ mod tests {
#[test]
fn test_parse_reasoning_streaming_incremental() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.parse_reasoning_streaming_incremental("<thi");
assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "");
......@@ -230,7 +216,7 @@ mod tests {
#[test]
fn test_parse_reasoning_streaming_incremental_complete() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser
.parse_reasoning_streaming_incremental("<think>with reasoning</think> and more text.");
assert_eq!(result.normal_text, " and more text.");
......@@ -240,7 +226,7 @@ mod tests {
#[test]
fn test_parse_reasoning_streaming_incremental_no_end_token() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true);
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true);
let result = parser.parse_reasoning_streaming_incremental("<think>with reasoning");
assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "with reasoning");
......@@ -248,8 +234,8 @@ mod tests {
#[test]
fn test_detect_and_parse_reasoning_multiple_reasoning_blocks() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning(
"<think>first reasoning</think> middle <think>second reasoning</think> end",
);
......@@ -261,7 +247,7 @@ mod tests {
#[test]
fn test_streaming_multiple_reasoning_blocks() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
let result1 =
parser.parse_reasoning_streaming_incremental("<think>first reasoning</think> middle");
assert_eq!(result1.normal_text, " middle");
......@@ -277,7 +263,7 @@ mod tests {
#[test]
fn test_partial_token_matching_opening_tag() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
// Feed partial opening tag
let result1 = parser.parse_reasoning_streaming_incremental("<th");
......@@ -294,7 +280,7 @@ mod tests {
#[test]
fn test_partial_token_matching_closing_tag() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
// Start with complete opening and partial content
let result1 = parser.parse_reasoning_streaming_incremental("<think>reasoning content</th");
......@@ -310,7 +296,7 @@ mod tests {
#[test]
fn test_buffer_state_persistence_across_calls() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
// First call - partial opening tag
let result1 = parser.parse_reasoning_streaming_incremental("<th");
......@@ -336,7 +322,7 @@ mod tests {
#[test]
fn test_streaming_with_stream_reasoning_enabled() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
// Start reasoning block
let result1 = parser.parse_reasoning_streaming_incremental("<think>reasoning ");
......@@ -356,8 +342,8 @@ mod tests {
#[test]
fn test_nested_reasoning_blocks() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning(
"<think>outer <think>inner</think> reasoning</think> normal",
);
......@@ -369,8 +355,8 @@ mod tests {
#[test]
fn test_malformed_missing_closing_tag() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning("<think>reasoning without closing tag");
assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "reasoning without closing tag");
......@@ -378,8 +364,8 @@ mod tests {
#[test]
fn test_malformed_stray_closing_tag() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning("normal text</think> more normal");
assert_eq!(result.normal_text, "normal text</think> more normal");
assert_eq!(result.reasoning_text, "");
......@@ -387,8 +373,8 @@ mod tests {
#[test]
fn test_malformed_multiple_opening_tags() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser
.detect_and_parse_reasoning("<think>first <think>second reasoning</think> normal");
// Should handle by replacing all opening tags and using first closing tag
......@@ -398,8 +384,8 @@ mod tests {
#[test]
fn test_empty_reasoning_block() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning("<think></think> normal text");
assert_eq!(result.normal_text, "normal text");
assert_eq!(result.reasoning_text, "");
......@@ -407,8 +393,8 @@ mod tests {
#[test]
fn test_whitespace_only_reasoning_block() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, true);
let result = parser.detect_and_parse_reasoning("<think> \n\t </think> normal text");
assert_eq!(result.normal_text, "normal text");
assert_eq!(result.reasoning_text, ""); // Should be empty after trim
......@@ -416,8 +402,8 @@ mod tests {
#[test]
fn test_force_reasoning_mode() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true);
let parser =
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), true, true);
let result = parser.detect_and_parse_reasoning("no think tags here");
assert_eq!(result.normal_text, "");
assert_eq!(result.reasoning_text, "no think tags here");
......@@ -426,7 +412,7 @@ mod tests {
#[test]
fn test_streaming_reset_state_after_complete_block() {
let mut parser =
BaseReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
BasicReasoningParser::new("<think>".to_string(), "</think>".to_string(), false, false);
// Process complete reasoning block
let result1 =
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::base_parser::BaseReasoningParser;
use super::base_parser::ParserResult;
use super::base_parser::ReasoningParser;
use super::base_parser::BasicReasoningParser;
use crate::ParserResult;
use crate::ReasoningParser;
#[derive(Default)]
#[derive(Default, Debug, Clone)]
pub struct DeepseekR1ReasoningParser {
base: BaseReasoningParser,
base: BasicReasoningParser,
}
impl DeepseekR1ReasoningParser {
pub fn new() -> Self {
Self {
base: BaseReasoningParser::new(
base: BasicReasoningParser::new(
"<think>".to_string(),
"</think>".to_string(),
true,
......@@ -28,7 +28,7 @@ impl ReasoningParser for DeepseekR1ReasoningParser {
self.base.parse_reasoning_streaming_incremental(text)
}
fn detect_and_parse_reasoning(&mut self, text: &str) -> ParserResult {
fn detect_and_parse_reasoning(&self, text: &str) -> ParserResult {
self.base.detect_and_parse_reasoning(text)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod base_parser;
pub mod deepseek_r1_parser;
mod base_parser;
mod deepseek_r1_parser;
// Re-export main types and functions for convenience
pub use base_parser::ReasoningParser;
pub use base_parser::BasicReasoningParser;
pub use deepseek_r1_parser::DeepseekR1ReasoningParser;
#[derive(Debug, Clone, Default)]
pub struct ParserResult {
/// The normal text outside of reasoning blocks.
pub normal_text: String,
/// The extracted reasoning text from within reasoning blocks.
pub reasoning_text: String,
}
impl ParserResult {
pub fn get_some_reasoning(&self) -> Option<String> {
if self.reasoning_text.is_empty() {
None
} else {
Some(self.reasoning_text.clone())
}
}
pub fn get_some_normal_text(&self) -> Option<String> {
if self.normal_text.is_empty() {
None
} else {
Some(self.normal_text.clone())
}
}
}
pub trait ReasoningParser: Send + std::fmt::Debug {
/// Parses a standalone, non-streaming input chunk. Implementations may reset or ignore
/// internal streaming state and should return the split of normal vs reasoning text for
/// this complete input. Marker tokens must not be included in either output.
fn detect_and_parse_reasoning(&self, text: &str) -> ParserResult;
/// Parses a streaming chunk and updates internal state. The return value should be the
/// delta: only the newly discovered normal and reasoning text attributable to this chunk
/// (not the cumulative totals). Marker tokens must not be included in either output.
fn parse_reasoning_streaming_incremental(&mut self, text: &str) -> ParserResult;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum ReasoningParserType {
DeepseekR1,
Basic,
}
#[derive(std::fmt::Debug)]
pub struct ReasoningParserWrapper {
parser: Box<dyn ReasoningParser>,
}
impl ReasoningParser for ReasoningParserWrapper {
fn detect_and_parse_reasoning(&self, text: &str) -> ParserResult {
self.parser.detect_and_parse_reasoning(text)
}
fn parse_reasoning_streaming_incremental(&mut self, text: &str) -> ParserResult {
self.parser.parse_reasoning_streaming_incremental(text)
}
}
impl ReasoningParserType {
pub fn get_reasoning_parser(self) -> ReasoningParserWrapper {
match self {
ReasoningParserType::DeepseekR1 => ReasoningParserWrapper {
parser: Box::new(DeepseekR1ReasoningParser::new()),
},
ReasoningParserType::Basic => ReasoningParserWrapper {
parser: Box::new(BasicReasoningParser::new(
"<think>".into(),
"</think>".into(),
false,
true,
)),
},
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment