Unverified Commit 2a2bf58a authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: reasoning parser transformation (#3295)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 713e9e48
......@@ -159,12 +159,12 @@ impl
for c in prompt.chars() {
// we are returning characters not tokens, so there will be some postprocessing overhead
tokio::time::sleep(*TOKEN_ECHO_DELAY).await;
let response = deltas.create_choice(0, Some(c.to_string()), None, None, None);
let response = deltas.create_choice(0, Some(c.to_string()), None, None);
yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
id += 1;
}
let response = deltas.create_choice(0, None, None, Some(dynamo_async_openai::types::FinishReason::Stop), None);
let response = deltas.create_choice(0, None, Some(dynamo_async_openai::types::FinishReason::Stop), None);
yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
};
......
......@@ -28,6 +28,7 @@ use crate::preprocessor::prompt::OAIChatLikeRequest;
use crate::protocols::common::preprocessor::PreprocessedRequestBuilder;
use crate::tokenizers::Encoding;
use dynamo_parsers::{ReasoningParser, ReasoningParserType};
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use dynamo_runtime::pipeline::{
AsyncEngineContext, Error, ManyOut, Operator, SingleIn, async_trait,
......@@ -93,6 +94,12 @@ impl LLMMetricAnnotation {
}
}
// Reasoning State for reasoning parsing transformation step
struct ReasoningState {
stream: Pin<Box<dyn Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send>>,
reasoning_parser: Option<Box<dyn ReasoningParser>>,
}
pub struct OpenAIPreprocessor {
mdcsum: String,
formatter: Arc<dyn OAIPromptFormatter>,
......@@ -668,6 +675,56 @@ impl OpenAIPreprocessor {
.build();
jail.apply(stream)
}
// Motivation: Each transformation on the stream should be a separate step to allow for more flexibility
// Earlier reasoning parser logic was nested under delta generation logic in choice_from_postprocessor
// Since we have tool calling parsing as separate step, it makes sense to have reasoning parser as separate step as well
pub fn parse_reasoning_content_from_stream<S>(
stream: S,
parser_name: String,
) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
where
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
{
// Initialize reasoning parser from parser_name
let reasoning_parser = Box::new(ReasoningParserType::get_reasoning_parser_from_name(
parser_name.as_ref(),
)) as Box<dyn ReasoningParser>;
let state = ReasoningState {
stream: Box::pin(stream),
reasoning_parser: Some(reasoning_parser),
};
stream::unfold(state, |mut state| async move {
if let Some(response) = state.stream.next().await {
// Process the response through reasoning parser if available
let processed_response = if let Some(ref mut parser) = state.reasoning_parser {
response.map_data(|mut data| {
// Process all choices, not just the first one
for choice in data.choices.iter_mut() {
if let Some(text) = choice.delta.content.as_ref() {
let parser_result =
parser.parse_reasoning_streaming_incremental(text, &[]);
// Update this specific choice with parsed content
choice.delta.content = parser_result.get_some_normal_text();
choice.delta.reasoning_content = parser_result.get_some_reasoning();
}
}
Ok(data)
})
} else {
// No reasoning parser configured, pass through unchanged
response
};
Some((processed_response, state))
} else {
None
}
})
}
}
// for pals, we do not want to add the generation prompt to the formatted prompt
......@@ -715,9 +772,6 @@ impl
let mut response_generator = Box::new(response_generator);
// set the runtime configuration
response_generator.set_reasoning_parser(self.runtime_config.clone());
// update isl
response_generator.update_isl(common_request.token_ids.len() as u32);
......@@ -744,6 +798,25 @@ impl
context.clone(),
);
// Try to parse reasoning content only if parser is configured
let should_parse_reasoning = self.runtime_config.reasoning_parser.is_some();
// Reasoning Content Parsing Transformation Step
// Current Solution:
// This step operates on Deltas created by the transform_postprocessor_stream function
// Only access to text and not token_ids - so can not support parsing based on token_ids for now
// Future Solution:
// To address the limitation if needed in future: move this step before transform_postprocessor_stream and add new field of reasoning_content to the backend output
// Use backend_output.reasoning_content field to fill out the deltas.
let stream: Pin<Box<dyn Stream<Item = _> + Send>> = if should_parse_reasoning {
Box::pin(Self::parse_reasoning_content_from_stream(
stream,
self.runtime_config.reasoning_parser.clone().unwrap(), // Safety: We already checked that parser is some, so gtg
))
} else {
Box::pin(stream)
};
// Check if tools are present and if we should apply jail
let has_tools =
request.inner.tools.is_some() && !request.inner.tools.as_ref().unwrap().is_empty();
......
......@@ -7,7 +7,6 @@ use crate::{
protocols::common::{self},
types::TokenIdType,
};
use dynamo_parsers::{ParserResult, ReasoningParser, ReasoningParserType, ReasoningParserWrapper};
/// Provides a method for generating a [`DeltaGenerator`] from a chat completion request.
impl NvCreateChatCompletionRequest {
......@@ -66,11 +65,6 @@ pub struct DeltaGenerator {
msg_counter: u64,
/// Configuration options for response generation.
options: DeltaGeneratorOptions,
/// Reasoning Parser object
/// This is used to parse reasoning content in the response.
/// None means no reasoning parsing will be performed.
reasoning_parser: Option<ReasoningParserWrapper>,
}
impl DeltaGenerator {
......@@ -101,14 +95,6 @@ impl DeltaGenerator {
completion_tokens_details: None,
};
// Reasoning parser type
// If no parser is specified (None), no reasoning parsing will be performed
let reasoning_parser = options
.runtime_config
.reasoning_parser
.as_deref()
.map(ReasoningParserType::get_reasoning_parser_from_name);
let chatcmpl_id = format!("chatcmpl-{request_id}");
Self {
......@@ -121,21 +107,6 @@ impl DeltaGenerator {
usage,
msg_counter: 0,
options,
reasoning_parser,
}
}
/// Update runtime configuration and reconfigure the reasoning parser accordingly.
pub fn set_reasoning_parser(&mut self, runtime_config: ModelRuntimeConfig) {
self.options.runtime_config = runtime_config.clone();
match self.options.runtime_config.reasoning_parser.as_deref() {
Some(name) => {
self.reasoning_parser =
Some(ReasoningParserType::get_reasoning_parser_from_name(name));
}
None => {
self.reasoning_parser = None;
}
}
}
......@@ -212,24 +183,6 @@ impl DeltaGenerator {
})
}
fn create_reasoning_content(
&mut self,
text: &Option<String>,
token_ids: &[u32],
) -> Option<ParserResult> {
// If no reasoning parser is configured, return None
let reasoning_parser = self.reasoning_parser.as_mut()?;
let text_ref = text.as_deref().unwrap_or("");
if text_ref.is_empty() && token_ids.is_empty() {
return None;
}
let parser_result =
reasoning_parser.parse_reasoning_streaming_incremental(text_ref, token_ids);
Some(parser_result)
}
/// Creates a choice within a chat completion response.
///
/// # Arguments
......@@ -245,7 +198,6 @@ impl DeltaGenerator {
&mut self,
index: u32,
text: Option<String>,
reasoning_content: Option<String>,
finish_reason: Option<dynamo_async_openai::types::FinishReason>,
logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
) -> NvCreateChatCompletionStreamResponse {
......@@ -259,7 +211,7 @@ impl DeltaGenerator {
None
},
refusal: None,
reasoning_content,
reasoning_content: None,
};
let choice = dynamo_async_openai::types::ChatChoiceStream {
......@@ -371,25 +323,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
None => None,
};
// Handle reasoning parsing if enabled, otherwise treat all text as normal
let (normal_text, reasoning_content) =
match self.create_reasoning_content(&delta.text, &delta.token_ids) {
Some(reasoning_parser_result) => (
reasoning_parser_result.get_some_normal_text(),
reasoning_parser_result.get_some_reasoning(),
),
None => (delta.text, None),
};
// Create the streaming response.
let index = 0;
let stream_response = self.create_choice(
index,
normal_text,
reasoning_content,
finish_reason,
logprobs,
);
let stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
Ok(stream_response)
}
......
......@@ -89,7 +89,7 @@ impl
let stream = stream! {
tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await;
for i in 0..10 {
let output = generator.create_choice(i,Some(format!("choice {i}")), None, None, None);
let output = generator.create_choice(i,Some(format!("choice {i}")), None, None);
yield Annotated::from_data(output);
}
......
......@@ -52,7 +52,7 @@ impl
// Generate 5 response chunks
for i in 0..5 {
let output = generator.create_choice(i, Some(format!("Mock response {i}")), None, None, None);
let output = generator.create_choice(i, Some(format!("Mock response {i}")), None, None);
yield Annotated::from_data(output);
}
};
......
This diff is collapsed.
......@@ -119,7 +119,7 @@ impl ReasoningParser for BasicReasoningParser {
};
return ParserResult {
normal_text: normal_text.to_string(),
reasoning_text: reasoning_text.trim().to_string(),
reasoning_text: reasoning_text.to_string(),
};
}
// Continue with reasoning content
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment