Unverified Commit 2a2bf58a authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: reasoning parser transformation (#3295)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 713e9e48
...@@ -159,12 +159,12 @@ impl ...@@ -159,12 +159,12 @@ impl
for c in prompt.chars() { for c in prompt.chars() {
// we are returning characters not tokens, so there will be some postprocessing overhead // we are returning characters not tokens, so there will be some postprocessing overhead
tokio::time::sleep(*TOKEN_ECHO_DELAY).await; tokio::time::sleep(*TOKEN_ECHO_DELAY).await;
let response = deltas.create_choice(0, Some(c.to_string()), None, None, None); let response = deltas.create_choice(0, Some(c.to_string()), None, None);
yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None }; yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
id += 1; id += 1;
} }
let response = deltas.create_choice(0, None, None, Some(dynamo_async_openai::types::FinishReason::Stop), None); let response = deltas.create_choice(0, None, Some(dynamo_async_openai::types::FinishReason::Stop), None);
yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None }; yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
}; };
......
...@@ -28,6 +28,7 @@ use crate::preprocessor::prompt::OAIChatLikeRequest; ...@@ -28,6 +28,7 @@ use crate::preprocessor::prompt::OAIChatLikeRequest;
use crate::protocols::common::preprocessor::PreprocessedRequestBuilder; use crate::protocols::common::preprocessor::PreprocessedRequestBuilder;
use crate::tokenizers::Encoding; use crate::tokenizers::Encoding;
use dynamo_parsers::{ReasoningParser, ReasoningParserType};
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream}; use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::{ use dynamo_runtime::pipeline::{
AsyncEngineContext, Error, ManyOut, Operator, SingleIn, async_trait, AsyncEngineContext, Error, ManyOut, Operator, SingleIn, async_trait,
...@@ -93,6 +94,12 @@ impl LLMMetricAnnotation { ...@@ -93,6 +94,12 @@ impl LLMMetricAnnotation {
} }
} }
// Reasoning State for reasoning parsing transformation step
struct ReasoningState {
stream: Pin<Box<dyn Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send>>,
reasoning_parser: Option<Box<dyn ReasoningParser>>,
}
pub struct OpenAIPreprocessor { pub struct OpenAIPreprocessor {
mdcsum: String, mdcsum: String,
formatter: Arc<dyn OAIPromptFormatter>, formatter: Arc<dyn OAIPromptFormatter>,
...@@ -668,6 +675,56 @@ impl OpenAIPreprocessor { ...@@ -668,6 +675,56 @@ impl OpenAIPreprocessor {
.build(); .build();
jail.apply(stream) jail.apply(stream)
} }
// Motivation: Each transformation on the stream should be a separate step to allow for more flexibility
// Earlier reasoning parser logic was nested under delta generation logic in choice_from_postprocessor
// Since we have tool calling parsing as separate step, it makes sense to have reasoning parser as separate step as well
pub fn parse_reasoning_content_from_stream<S>(
stream: S,
parser_name: String,
) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
where
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
{
// Initialize reasoning parser from parser_name
let reasoning_parser = Box::new(ReasoningParserType::get_reasoning_parser_from_name(
parser_name.as_ref(),
)) as Box<dyn ReasoningParser>;
let state = ReasoningState {
stream: Box::pin(stream),
reasoning_parser: Some(reasoning_parser),
};
stream::unfold(state, |mut state| async move {
if let Some(response) = state.stream.next().await {
// Process the response through reasoning parser if available
let processed_response = if let Some(ref mut parser) = state.reasoning_parser {
response.map_data(|mut data| {
// Process all choices, not just the first one
for choice in data.choices.iter_mut() {
if let Some(text) = choice.delta.content.as_ref() {
let parser_result =
parser.parse_reasoning_streaming_incremental(text, &[]);
// Update this specific choice with parsed content
choice.delta.content = parser_result.get_some_normal_text();
choice.delta.reasoning_content = parser_result.get_some_reasoning();
}
}
Ok(data)
})
} else {
// No reasoning parser configured, pass through unchanged
response
};
Some((processed_response, state))
} else {
None
}
})
}
} }
// for pals, we do not want to add the generation prompt to the formatted prompt // for pals, we do not want to add the generation prompt to the formatted prompt
...@@ -715,9 +772,6 @@ impl ...@@ -715,9 +772,6 @@ impl
let mut response_generator = Box::new(response_generator); let mut response_generator = Box::new(response_generator);
// set the runtime configuration
response_generator.set_reasoning_parser(self.runtime_config.clone());
// update isl // update isl
response_generator.update_isl(common_request.token_ids.len() as u32); response_generator.update_isl(common_request.token_ids.len() as u32);
...@@ -744,6 +798,25 @@ impl ...@@ -744,6 +798,25 @@ impl
context.clone(), context.clone(),
); );
// Try to parse reasoning content only if parser is configured
let should_parse_reasoning = self.runtime_config.reasoning_parser.is_some();
// Reasoning Content Parsing Transformation Step
// Current Solution:
// This step operates on Deltas created by the transform_postprocessor_stream function
// Only access to text and not token_ids - so can not support parsing based on token_ids for now
// Future Solution:
// To address the limitation if needed in future: move this step before transform_postprocessor_stream and add new field of reasoning_content to the backend output
// Use backend_output.reasoning_content field to fill out the deltas.
let stream: Pin<Box<dyn Stream<Item = _> + Send>> = if should_parse_reasoning {
Box::pin(Self::parse_reasoning_content_from_stream(
stream,
self.runtime_config.reasoning_parser.clone().unwrap(), // Safety: We already checked that parser is some, so gtg
))
} else {
Box::pin(stream)
};
// Check if tools are present and if we should apply jail // Check if tools are present and if we should apply jail
let has_tools = let has_tools =
request.inner.tools.is_some() && !request.inner.tools.as_ref().unwrap().is_empty(); request.inner.tools.is_some() && !request.inner.tools.as_ref().unwrap().is_empty();
......
...@@ -7,7 +7,6 @@ use crate::{ ...@@ -7,7 +7,6 @@ use crate::{
protocols::common::{self}, protocols::common::{self},
types::TokenIdType, types::TokenIdType,
}; };
use dynamo_parsers::{ParserResult, ReasoningParser, ReasoningParserType, ReasoningParserWrapper};
/// Provides a method for generating a [`DeltaGenerator`] from a chat completion request. /// Provides a method for generating a [`DeltaGenerator`] from a chat completion request.
impl NvCreateChatCompletionRequest { impl NvCreateChatCompletionRequest {
...@@ -66,11 +65,6 @@ pub struct DeltaGenerator { ...@@ -66,11 +65,6 @@ pub struct DeltaGenerator {
msg_counter: u64, msg_counter: u64,
/// Configuration options for response generation. /// Configuration options for response generation.
options: DeltaGeneratorOptions, options: DeltaGeneratorOptions,
/// Reasoning Parser object
/// This is used to parse reasoning content in the response.
/// None means no reasoning parsing will be performed.
reasoning_parser: Option<ReasoningParserWrapper>,
} }
impl DeltaGenerator { impl DeltaGenerator {
...@@ -101,14 +95,6 @@ impl DeltaGenerator { ...@@ -101,14 +95,6 @@ impl DeltaGenerator {
completion_tokens_details: None, completion_tokens_details: None,
}; };
// Reasoning parser type
// If no parser is specified (None), no reasoning parsing will be performed
let reasoning_parser = options
.runtime_config
.reasoning_parser
.as_deref()
.map(ReasoningParserType::get_reasoning_parser_from_name);
let chatcmpl_id = format!("chatcmpl-{request_id}"); let chatcmpl_id = format!("chatcmpl-{request_id}");
Self { Self {
...@@ -121,21 +107,6 @@ impl DeltaGenerator { ...@@ -121,21 +107,6 @@ impl DeltaGenerator {
usage, usage,
msg_counter: 0, msg_counter: 0,
options, options,
reasoning_parser,
}
}
/// Update runtime configuration and reconfigure the reasoning parser accordingly.
pub fn set_reasoning_parser(&mut self, runtime_config: ModelRuntimeConfig) {
self.options.runtime_config = runtime_config.clone();
match self.options.runtime_config.reasoning_parser.as_deref() {
Some(name) => {
self.reasoning_parser =
Some(ReasoningParserType::get_reasoning_parser_from_name(name));
}
None => {
self.reasoning_parser = None;
}
} }
} }
...@@ -212,24 +183,6 @@ impl DeltaGenerator { ...@@ -212,24 +183,6 @@ impl DeltaGenerator {
}) })
} }
fn create_reasoning_content(
&mut self,
text: &Option<String>,
token_ids: &[u32],
) -> Option<ParserResult> {
// If no reasoning parser is configured, return None
let reasoning_parser = self.reasoning_parser.as_mut()?;
let text_ref = text.as_deref().unwrap_or("");
if text_ref.is_empty() && token_ids.is_empty() {
return None;
}
let parser_result =
reasoning_parser.parse_reasoning_streaming_incremental(text_ref, token_ids);
Some(parser_result)
}
/// Creates a choice within a chat completion response. /// Creates a choice within a chat completion response.
/// ///
/// # Arguments /// # Arguments
...@@ -245,7 +198,6 @@ impl DeltaGenerator { ...@@ -245,7 +198,6 @@ impl DeltaGenerator {
&mut self, &mut self,
index: u32, index: u32,
text: Option<String>, text: Option<String>,
reasoning_content: Option<String>,
finish_reason: Option<dynamo_async_openai::types::FinishReason>, finish_reason: Option<dynamo_async_openai::types::FinishReason>,
logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>, logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
) -> NvCreateChatCompletionStreamResponse { ) -> NvCreateChatCompletionStreamResponse {
...@@ -259,7 +211,7 @@ impl DeltaGenerator { ...@@ -259,7 +211,7 @@ impl DeltaGenerator {
None None
}, },
refusal: None, refusal: None,
reasoning_content, reasoning_content: None,
}; };
let choice = dynamo_async_openai::types::ChatChoiceStream { let choice = dynamo_async_openai::types::ChatChoiceStream {
...@@ -371,25 +323,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -371,25 +323,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
None => None, None => None,
}; };
// Handle reasoning parsing if enabled, otherwise treat all text as normal
let (normal_text, reasoning_content) =
match self.create_reasoning_content(&delta.text, &delta.token_ids) {
Some(reasoning_parser_result) => (
reasoning_parser_result.get_some_normal_text(),
reasoning_parser_result.get_some_reasoning(),
),
None => (delta.text, None),
};
// Create the streaming response. // Create the streaming response.
let index = 0; let index = 0;
let stream_response = self.create_choice( let stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
index,
normal_text,
reasoning_content,
finish_reason,
logprobs,
);
Ok(stream_response) Ok(stream_response)
} }
......
...@@ -89,7 +89,7 @@ impl ...@@ -89,7 +89,7 @@ impl
let stream = stream! { let stream = stream! {
tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await; tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await;
for i in 0..10 { for i in 0..10 {
let output = generator.create_choice(i,Some(format!("choice {i}")), None, None, None); let output = generator.create_choice(i,Some(format!("choice {i}")), None, None);
yield Annotated::from_data(output); yield Annotated::from_data(output);
} }
......
...@@ -52,7 +52,7 @@ impl ...@@ -52,7 +52,7 @@ impl
// Generate 5 response chunks // Generate 5 response chunks
for i in 0..5 { for i in 0..5 {
let output = generator.create_choice(i, Some(format!("Mock response {i}")), None, None, None); let output = generator.create_choice(i, Some(format!("Mock response {i}")), None, None);
yield Annotated::from_data(output); yield Annotated::from_data(output);
} }
}; };
......
This diff is collapsed.
...@@ -119,7 +119,7 @@ impl ReasoningParser for BasicReasoningParser { ...@@ -119,7 +119,7 @@ impl ReasoningParser for BasicReasoningParser {
}; };
return ParserResult { return ParserResult {
normal_text: normal_text.to_string(), normal_text: normal_text.to_string(),
reasoning_text: reasoning_text.trim().to_string(), reasoning_text: reasoning_text.to_string(),
}; };
} }
// Continue with reasoning content // Continue with reasoning content
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment