Unverified Commit 2422b83d authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix: Do not apply chat template to completions (#2718)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 2c30e41f
......@@ -24,6 +24,7 @@ use tracing;
use crate::model_card::{ModelDeploymentCard, ModelInfo, TokenizerKind};
use crate::preprocessor::prompt::OAIChatLikeRequest;
use crate::protocols::common::preprocessor::PreprocessedRequestBuilder;
use crate::tokenizers::Encoding;
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
......@@ -151,10 +152,108 @@ impl OpenAIPreprocessor {
&self,
request: &R,
) -> Result<(PreprocessedRequest, HashMap<String, String>)> {
let mut annotations = HashMap::new();
let mut builder = self.builder(request)?;
let formatted_prompt = self.apply_template(request)?;
let annotations = self.gather_tokens(request, &mut builder, formatted_prompt)?;
Ok((builder.build()?, annotations))
}
pub fn builder<
R: OAIChatLikeRequest
+ AnnotationsProvider
+ SamplingOptionsProvider
+ StopConditionsProvider
+ OutputOptionsProvider
+ NvExtProvider,
>(
&self,
request: &R,
) -> Result<PreprocessedRequestBuilder> {
let mut builder = PreprocessedRequest::builder();
builder.model(request.model());
let mut stop_conditions = request.extract_stop_conditions()?;
if let Some(stop_tokens) = &mut stop_conditions.stop_token_ids_hidden {
for eos_token in self.model_info.eos_token_ids() {
if !stop_tokens.contains(&eos_token) {
stop_tokens.push(eos_token);
}
}
} else {
stop_conditions.stop_token_ids_hidden = Some(self.model_info.eos_token_ids());
}
// apply ignore eos if not already set
stop_conditions.apply_ignore_eos();
if !stop_conditions.ignore_eos.unwrap_or(false) {
builder.eos_token_ids(self.model_info.eos_token_ids());
}
builder.stop_conditions(stop_conditions);
builder.sampling_options(request.extract_sampling_options()?);
builder.output_options(request.extract_output_options()?);
builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone()));
builder.estimated_prefix_hit_num_blocks(None);
// Extract backend_instance_id from nvext if present
if let Some(nvext) = request.nvext() {
builder.backend_instance_id(nvext.backend_instance_id);
}
Ok(builder)
}
pub fn apply_template<
R: OAIChatLikeRequest
+ AnnotationsProvider
+ SamplingOptionsProvider
+ StopConditionsProvider
+ OutputOptionsProvider
+ NvExtProvider,
>(
&self,
request: &R,
) -> Result<Option<String>> {
if let PromptInput::Text(_) = request.prompt_input_type()
&& let Some(TextInput::Single(_)) = request.extract_text()
{
let use_raw_prompt = request
.nvext()
.is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false));
let formatted_prompt = if use_raw_prompt {
match request.raw_prompt() {
Some(prompt) => prompt,
None => {
tracing::warn!("Raw prompt requested but not available");
self.formatter.render(request)?
}
}
} else {
self.formatter.render(request)?
};
Ok(Some(formatted_prompt))
} else {
Ok(None)
}
}
pub fn gather_tokens<
R: OAIChatLikeRequest
+ AnnotationsProvider
+ SamplingOptionsProvider
+ StopConditionsProvider
+ OutputOptionsProvider
+ NvExtProvider,
>(
&self,
request: &R,
builder: &mut PreprocessedRequestBuilder,
formatted_prompt: Option<String>,
) -> Result<HashMap<String, String>> {
let mut annotations = HashMap::new();
// match request type before any conversion/processing
match request.prompt_input_type() {
PromptInput::Tokens(_) => {
......@@ -177,22 +276,16 @@ impl OpenAIPreprocessor {
PromptInput::Text(_) => {
if let Some(text_input) = request.extract_text() {
match text_input {
TextInput::Single(_) => {
let use_raw_prompt = request
.nvext()
.is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false));
let formatted_prompt = if use_raw_prompt {
match request.raw_prompt() {
Some(prompt) => prompt,
None => {
tracing::warn!("Raw prompt requested but not available");
self.formatter.render(request)?
}
}
} else {
self.formatter.render(request)?
};
TextInput::Single(raw_prompt) => {
if let Some(f) = formatted_prompt.as_ref()
&& request.has_annotation(ANNOTATION_FORMATTED_PROMPT)
{
annotations
.insert(ANNOTATION_FORMATTED_PROMPT.to_string(), f.to_string());
}
// Completions will use raw_prompt, no template
let prompt = formatted_prompt.unwrap_or(raw_prompt);
// Check if backend_instance_id is present and token_data is provided
let has_backend_instance_id = request
......@@ -215,22 +308,15 @@ impl OpenAIPreprocessor {
tracing::warn!(
"backend_instance_id provided but no token_data; tokenizing prompt"
);
let encoding = self.tokenizer.encode(&formatted_prompt)?;
let encoding = self.tokenizer.encode(&prompt)?;
(encoding.token_ids().to_vec(), false)
}
} else {
// No backend_instance_id provided, continue the normal flow.
let encoding = self.tokenizer.encode(&formatted_prompt)?;
let encoding = self.tokenizer.encode(&prompt)?;
(encoding.token_ids().to_vec(), false)
};
if request.has_annotation(ANNOTATION_FORMATTED_PROMPT) {
annotations.insert(
ANNOTATION_FORMATTED_PROMPT.to_string(),
formatted_prompt,
);
}
if request.has_annotation(ANNOTATION_TOKEN_IDS)
&& !skip_token_annotation
{
......@@ -258,37 +344,7 @@ impl OpenAIPreprocessor {
}
}
}
let mut stop_conditions = request.extract_stop_conditions()?;
if let Some(stop_tokens) = &mut stop_conditions.stop_token_ids_hidden {
for eos_token in self.model_info.eos_token_ids() {
if !stop_tokens.contains(&eos_token) {
stop_tokens.push(eos_token);
}
}
} else {
stop_conditions.stop_token_ids_hidden = Some(self.model_info.eos_token_ids());
}
// apply ignore eos if not already set
stop_conditions.apply_ignore_eos();
if !stop_conditions.ignore_eos.unwrap_or(false) {
builder.eos_token_ids(self.model_info.eos_token_ids());
}
builder.stop_conditions(stop_conditions);
builder.sampling_options(request.extract_sampling_options()?);
builder.output_options(request.extract_output_options()?);
builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone()));
builder.estimated_prefix_hit_num_blocks(None);
// Extract backend_instance_id from nvext if present
if let Some(nvext) = request.nvext() {
builder.backend_instance_id(nvext.backend_instance_id);
}
Ok((builder.build()?, annotations))
Ok(annotations)
}
/// Preprocess an embedding request, handling both text and token ID inputs.
......@@ -581,7 +637,9 @@ impl
let response_generator = request.response_generator(context.id().to_string());
let mut response_generator = Box::new(response_generator);
// convert the chat completion request to a common completion request
let (common_request, annotations) = self.preprocess_request(&request)?;
let mut builder = self.builder(&request)?;
let annotations = self.gather_tokens(&request, &mut builder, None)?;
let common_request = builder.build()?;
// update isl
response_generator.update_isl(common_request.token_ids.len() as u32);
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use tracing as log;
use crate::{ParserResult, ReasoningParser};
......@@ -34,13 +33,8 @@ impl BasicReasoningParser {
impl ReasoningParser for BasicReasoningParser {
fn detect_and_parse_reasoning(&mut self, text: &str, _token_ids: &[u32]) -> ParserResult {
log::debug!("detect_and_parse_reasoning called with text: {:?}", text);
let in_reasoning = self._in_reasoning || text.contains(&self.think_start_token);
log::debug!("in_reasoning: {}", in_reasoning);
if !in_reasoning {
log::debug!("No reasoning detected, returning normal text.");
return ParserResult {
normal_text: text.to_string(),
reasoning_text: String::new(),
......@@ -49,15 +43,8 @@ impl ReasoningParser for BasicReasoningParser {
// The text is considered to be in a reasoning block.
let processed_text = text.replace(&self.think_start_token, "").trim().to_string();
log::debug!(
"Processed text after removing think_start_token: {:?}",
processed_text
);
if !processed_text.contains(&self.think_end_token) {
log::debug!(
"Reasoning truncated, think_end_token not found. Returning reasoning text."
);
// Assume reasoning was truncated before `think_end_token`
return ParserResult {
normal_text: String::new(),
......@@ -73,9 +60,6 @@ impl ReasoningParser for BasicReasoningParser {
.map(|s| s.trim().to_string())
.unwrap_or_default();
log::debug!("Extracted reasoning_text: {:?}", reasoning_text);
log::debug!("Extracted normal_text: {:?}", normal_text);
ParserResult {
normal_text,
reasoning_text,
......@@ -92,19 +76,6 @@ impl ReasoningParser for BasicReasoningParser {
let mut current_text = self._buffer.to_string();
// If the current text is a prefix of the think token, keep buffering
log::debug!(
"parse_reasoning_streaming_incremental called with text: {:?}",
text
);
log::debug!("current buffer: {:?}", self._buffer);
log::debug!("current_text: {:?}", current_text);
log::debug!(
"in_reasoning: {}, stripped_think_start: {}, stream_reasoning: {}",
self._in_reasoning,
self.stripped_think_start,
self.stream_reasoning
);
if self.think_start_token.starts_with(&current_text)
&& self.think_start_token.as_str() != current_text.as_str()
{
......
......@@ -144,7 +144,7 @@ impl ReasoningParserType {
}
pub fn get_reasoning_parser_from_name(name: &str) -> ReasoningParserWrapper {
tracing::debug!("Selected reasoning parser: {}", name);
tracing::debug!(parser_name = name, "Selected reasoning parser");
match name.to_lowercase().as_str() {
"deepseek_r1" => Self::DeepseekR1.get_reasoning_parser(),
"basic" => Self::Basic.get_reasoning_parser(),
......@@ -156,8 +156,8 @@ impl ReasoningParserType {
"mistral" => Self::Mistral.get_reasoning_parser(),
_ => {
tracing::warn!(
"Unknown reasoning parser type '{}', falling back to Basic Reasoning Parser",
name
parser_name = name,
"Unknown reasoning parser type, falling back to Basic Reasoning Parser",
);
Self::Basic.get_reasoning_parser()
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment