"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "f9839161ed78e5e11395e30385d2e8403533fa61"
Unverified Commit 2422b83d authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix: Do not apply chat template to completions (#2718)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 2c30e41f
...@@ -24,6 +24,7 @@ use tracing; ...@@ -24,6 +24,7 @@ use tracing;
use crate::model_card::{ModelDeploymentCard, ModelInfo, TokenizerKind}; use crate::model_card::{ModelDeploymentCard, ModelInfo, TokenizerKind};
use crate::preprocessor::prompt::OAIChatLikeRequest; use crate::preprocessor::prompt::OAIChatLikeRequest;
use crate::protocols::common::preprocessor::PreprocessedRequestBuilder;
use crate::tokenizers::Encoding; use crate::tokenizers::Encoding;
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream}; use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
...@@ -151,10 +152,108 @@ impl OpenAIPreprocessor { ...@@ -151,10 +152,108 @@ impl OpenAIPreprocessor {
&self, &self,
request: &R, request: &R,
) -> Result<(PreprocessedRequest, HashMap<String, String>)> { ) -> Result<(PreprocessedRequest, HashMap<String, String>)> {
let mut annotations = HashMap::new(); let mut builder = self.builder(request)?;
let formatted_prompt = self.apply_template(request)?;
let annotations = self.gather_tokens(request, &mut builder, formatted_prompt)?;
Ok((builder.build()?, annotations))
}
pub fn builder<
R: OAIChatLikeRequest
+ AnnotationsProvider
+ SamplingOptionsProvider
+ StopConditionsProvider
+ OutputOptionsProvider
+ NvExtProvider,
>(
&self,
request: &R,
) -> Result<PreprocessedRequestBuilder> {
let mut builder = PreprocessedRequest::builder(); let mut builder = PreprocessedRequest::builder();
builder.model(request.model()); builder.model(request.model());
let mut stop_conditions = request.extract_stop_conditions()?;
if let Some(stop_tokens) = &mut stop_conditions.stop_token_ids_hidden {
for eos_token in self.model_info.eos_token_ids() {
if !stop_tokens.contains(&eos_token) {
stop_tokens.push(eos_token);
}
}
} else {
stop_conditions.stop_token_ids_hidden = Some(self.model_info.eos_token_ids());
}
// apply ignore eos if not already set
stop_conditions.apply_ignore_eos();
if !stop_conditions.ignore_eos.unwrap_or(false) {
builder.eos_token_ids(self.model_info.eos_token_ids());
}
builder.stop_conditions(stop_conditions);
builder.sampling_options(request.extract_sampling_options()?);
builder.output_options(request.extract_output_options()?);
builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone()));
builder.estimated_prefix_hit_num_blocks(None);
// Extract backend_instance_id from nvext if present
if let Some(nvext) = request.nvext() {
builder.backend_instance_id(nvext.backend_instance_id);
}
Ok(builder)
}
pub fn apply_template<
R: OAIChatLikeRequest
+ AnnotationsProvider
+ SamplingOptionsProvider
+ StopConditionsProvider
+ OutputOptionsProvider
+ NvExtProvider,
>(
&self,
request: &R,
) -> Result<Option<String>> {
if let PromptInput::Text(_) = request.prompt_input_type()
&& let Some(TextInput::Single(_)) = request.extract_text()
{
let use_raw_prompt = request
.nvext()
.is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false));
let formatted_prompt = if use_raw_prompt {
match request.raw_prompt() {
Some(prompt) => prompt,
None => {
tracing::warn!("Raw prompt requested but not available");
self.formatter.render(request)?
}
}
} else {
self.formatter.render(request)?
};
Ok(Some(formatted_prompt))
} else {
Ok(None)
}
}
pub fn gather_tokens<
R: OAIChatLikeRequest
+ AnnotationsProvider
+ SamplingOptionsProvider
+ StopConditionsProvider
+ OutputOptionsProvider
+ NvExtProvider,
>(
&self,
request: &R,
builder: &mut PreprocessedRequestBuilder,
formatted_prompt: Option<String>,
) -> Result<HashMap<String, String>> {
let mut annotations = HashMap::new();
// match request type before any conversion/processing // match request type before any conversion/processing
match request.prompt_input_type() { match request.prompt_input_type() {
PromptInput::Tokens(_) => { PromptInput::Tokens(_) => {
...@@ -177,22 +276,16 @@ impl OpenAIPreprocessor { ...@@ -177,22 +276,16 @@ impl OpenAIPreprocessor {
PromptInput::Text(_) => { PromptInput::Text(_) => {
if let Some(text_input) = request.extract_text() { if let Some(text_input) = request.extract_text() {
match text_input { match text_input {
TextInput::Single(_) => { TextInput::Single(raw_prompt) => {
let use_raw_prompt = request if let Some(f) = formatted_prompt.as_ref()
.nvext() && request.has_annotation(ANNOTATION_FORMATTED_PROMPT)
.is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false)); {
annotations
let formatted_prompt = if use_raw_prompt { .insert(ANNOTATION_FORMATTED_PROMPT.to_string(), f.to_string());
match request.raw_prompt() { }
Some(prompt) => prompt,
None => { // Completions will use raw_prompt, no template
tracing::warn!("Raw prompt requested but not available"); let prompt = formatted_prompt.unwrap_or(raw_prompt);
self.formatter.render(request)?
}
}
} else {
self.formatter.render(request)?
};
// Check if backend_instance_id is present and token_data is provided // Check if backend_instance_id is present and token_data is provided
let has_backend_instance_id = request let has_backend_instance_id = request
...@@ -215,22 +308,15 @@ impl OpenAIPreprocessor { ...@@ -215,22 +308,15 @@ impl OpenAIPreprocessor {
tracing::warn!( tracing::warn!(
"backend_instance_id provided but no token_data; tokenizing prompt" "backend_instance_id provided but no token_data; tokenizing prompt"
); );
let encoding = self.tokenizer.encode(&formatted_prompt)?; let encoding = self.tokenizer.encode(&prompt)?;
(encoding.token_ids().to_vec(), false) (encoding.token_ids().to_vec(), false)
} }
} else { } else {
// No backend_instance_id provided, continue the normal flow. // No backend_instance_id provided, continue the normal flow.
let encoding = self.tokenizer.encode(&formatted_prompt)?; let encoding = self.tokenizer.encode(&prompt)?;
(encoding.token_ids().to_vec(), false) (encoding.token_ids().to_vec(), false)
}; };
if request.has_annotation(ANNOTATION_FORMATTED_PROMPT) {
annotations.insert(
ANNOTATION_FORMATTED_PROMPT.to_string(),
formatted_prompt,
);
}
if request.has_annotation(ANNOTATION_TOKEN_IDS) if request.has_annotation(ANNOTATION_TOKEN_IDS)
&& !skip_token_annotation && !skip_token_annotation
{ {
...@@ -258,37 +344,7 @@ impl OpenAIPreprocessor { ...@@ -258,37 +344,7 @@ impl OpenAIPreprocessor {
} }
} }
} }
Ok(annotations)
let mut stop_conditions = request.extract_stop_conditions()?;
if let Some(stop_tokens) = &mut stop_conditions.stop_token_ids_hidden {
for eos_token in self.model_info.eos_token_ids() {
if !stop_tokens.contains(&eos_token) {
stop_tokens.push(eos_token);
}
}
} else {
stop_conditions.stop_token_ids_hidden = Some(self.model_info.eos_token_ids());
}
// apply ignore eos if not already set
stop_conditions.apply_ignore_eos();
if !stop_conditions.ignore_eos.unwrap_or(false) {
builder.eos_token_ids(self.model_info.eos_token_ids());
}
builder.stop_conditions(stop_conditions);
builder.sampling_options(request.extract_sampling_options()?);
builder.output_options(request.extract_output_options()?);
builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone()));
builder.estimated_prefix_hit_num_blocks(None);
// Extract backend_instance_id from nvext if present
if let Some(nvext) = request.nvext() {
builder.backend_instance_id(nvext.backend_instance_id);
}
Ok((builder.build()?, annotations))
} }
/// Preprocess an embedding request, handling both text and token ID inputs. /// Preprocess an embedding request, handling both text and token ID inputs.
...@@ -581,7 +637,9 @@ impl ...@@ -581,7 +637,9 @@ impl
let response_generator = request.response_generator(context.id().to_string()); let response_generator = request.response_generator(context.id().to_string());
let mut response_generator = Box::new(response_generator); let mut response_generator = Box::new(response_generator);
// convert the chat completion request to a common completion request // convert the chat completion request to a common completion request
let (common_request, annotations) = self.preprocess_request(&request)?; let mut builder = self.builder(&request)?;
let annotations = self.gather_tokens(&request, &mut builder, None)?;
let common_request = builder.build()?;
// update isl // update isl
response_generator.update_isl(common_request.token_ids.len() as u32); response_generator.update_isl(common_request.token_ids.len() as u32);
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use tracing as log;
use crate::{ParserResult, ReasoningParser}; use crate::{ParserResult, ReasoningParser};
...@@ -34,13 +33,8 @@ impl BasicReasoningParser { ...@@ -34,13 +33,8 @@ impl BasicReasoningParser {
impl ReasoningParser for BasicReasoningParser { impl ReasoningParser for BasicReasoningParser {
fn detect_and_parse_reasoning(&mut self, text: &str, _token_ids: &[u32]) -> ParserResult { fn detect_and_parse_reasoning(&mut self, text: &str, _token_ids: &[u32]) -> ParserResult {
log::debug!("detect_and_parse_reasoning called with text: {:?}", text);
let in_reasoning = self._in_reasoning || text.contains(&self.think_start_token); let in_reasoning = self._in_reasoning || text.contains(&self.think_start_token);
log::debug!("in_reasoning: {}", in_reasoning);
if !in_reasoning { if !in_reasoning {
log::debug!("No reasoning detected, returning normal text.");
return ParserResult { return ParserResult {
normal_text: text.to_string(), normal_text: text.to_string(),
reasoning_text: String::new(), reasoning_text: String::new(),
...@@ -49,15 +43,8 @@ impl ReasoningParser for BasicReasoningParser { ...@@ -49,15 +43,8 @@ impl ReasoningParser for BasicReasoningParser {
// The text is considered to be in a reasoning block. // The text is considered to be in a reasoning block.
let processed_text = text.replace(&self.think_start_token, "").trim().to_string(); let processed_text = text.replace(&self.think_start_token, "").trim().to_string();
log::debug!(
"Processed text after removing think_start_token: {:?}",
processed_text
);
if !processed_text.contains(&self.think_end_token) { if !processed_text.contains(&self.think_end_token) {
log::debug!(
"Reasoning truncated, think_end_token not found. Returning reasoning text."
);
// Assume reasoning was truncated before `think_end_token` // Assume reasoning was truncated before `think_end_token`
return ParserResult { return ParserResult {
normal_text: String::new(), normal_text: String::new(),
...@@ -73,9 +60,6 @@ impl ReasoningParser for BasicReasoningParser { ...@@ -73,9 +60,6 @@ impl ReasoningParser for BasicReasoningParser {
.map(|s| s.trim().to_string()) .map(|s| s.trim().to_string())
.unwrap_or_default(); .unwrap_or_default();
log::debug!("Extracted reasoning_text: {:?}", reasoning_text);
log::debug!("Extracted normal_text: {:?}", normal_text);
ParserResult { ParserResult {
normal_text, normal_text,
reasoning_text, reasoning_text,
...@@ -92,19 +76,6 @@ impl ReasoningParser for BasicReasoningParser { ...@@ -92,19 +76,6 @@ impl ReasoningParser for BasicReasoningParser {
let mut current_text = self._buffer.to_string(); let mut current_text = self._buffer.to_string();
// If the current text is a prefix of the think token, keep buffering // If the current text is a prefix of the think token, keep buffering
log::debug!(
"parse_reasoning_streaming_incremental called with text: {:?}",
text
);
log::debug!("current buffer: {:?}", self._buffer);
log::debug!("current_text: {:?}", current_text);
log::debug!(
"in_reasoning: {}, stripped_think_start: {}, stream_reasoning: {}",
self._in_reasoning,
self.stripped_think_start,
self.stream_reasoning
);
if self.think_start_token.starts_with(&current_text) if self.think_start_token.starts_with(&current_text)
&& self.think_start_token.as_str() != current_text.as_str() && self.think_start_token.as_str() != current_text.as_str()
{ {
......
...@@ -144,7 +144,7 @@ impl ReasoningParserType { ...@@ -144,7 +144,7 @@ impl ReasoningParserType {
} }
pub fn get_reasoning_parser_from_name(name: &str) -> ReasoningParserWrapper { pub fn get_reasoning_parser_from_name(name: &str) -> ReasoningParserWrapper {
tracing::debug!("Selected reasoning parser: {}", name); tracing::debug!(parser_name = name, "Selected reasoning parser");
match name.to_lowercase().as_str() { match name.to_lowercase().as_str() {
"deepseek_r1" => Self::DeepseekR1.get_reasoning_parser(), "deepseek_r1" => Self::DeepseekR1.get_reasoning_parser(),
"basic" => Self::Basic.get_reasoning_parser(), "basic" => Self::Basic.get_reasoning_parser(),
...@@ -156,8 +156,8 @@ impl ReasoningParserType { ...@@ -156,8 +156,8 @@ impl ReasoningParserType {
"mistral" => Self::Mistral.get_reasoning_parser(), "mistral" => Self::Mistral.get_reasoning_parser(),
_ => { _ => {
tracing::warn!( tracing::warn!(
"Unknown reasoning parser type '{}', falling back to Basic Reasoning Parser", parser_name = name,
name "Unknown reasoning parser type, falling back to Basic Reasoning Parser",
); );
Self::Basic.get_reasoning_parser() Self::Basic.get_reasoning_parser()
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment