Unverified Commit 45e881d3 authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

feat: Support for field include_stop_str_in_output (#4924)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
parent ef3027bd
...@@ -1006,6 +1006,13 @@ pub struct ChatChoiceLogprobs { ...@@ -1006,6 +1006,13 @@ pub struct ChatChoiceLogprobs {
pub refusal: Option<Vec<ChatCompletionTokenLogprob>>, pub refusal: Option<Vec<ChatCompletionTokenLogprob>>,
} }
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(untagged)]
pub enum StopReason {
String(String), // matched user-provided stop sequence
Int(i64), // matched stop token id (requires stop_token_id support)
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct ChatChoice { pub struct ChatChoice {
/// The index of the choice in the list of choices. /// The index of the choice in the list of choices.
...@@ -1017,6 +1024,10 @@ pub struct ChatChoice { ...@@ -1017,6 +1024,10 @@ pub struct ChatChoice {
/// `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function. /// `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<FinishReason>, pub finish_reason: Option<FinishReason>,
/// Which stop string matched (if any).
/// This is only set when `finish_reason` is `"stop"` because a user-provided stop sequence was hit.
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<StopReason>,
/// Log probability information for the choice. /// Log probability information for the choice.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatChoiceLogprobs>, pub logprobs: Option<ChatChoiceLogprobs>,
...@@ -1112,6 +1123,10 @@ pub struct ChatChoiceStream { ...@@ -1112,6 +1123,10 @@ pub struct ChatChoiceStream {
/// (deprecated) if the model called a function. /// (deprecated) if the model called a function.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<FinishReason>, pub finish_reason: Option<FinishReason>,
/// Which stop string matched (if any).
/// This is only set when `finish_reason` is `"stop"` because a user-provided stop sequence was hit.
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<StopReason>,
/// Log probability information for the choice. /// Log probability information for the choice.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatChoiceLogprobs>, pub logprobs: Option<ChatChoiceLogprobs>,
......
...@@ -413,6 +413,7 @@ impl ...@@ -413,6 +413,7 @@ impl
}, },
logprobs: None, logprobs: None,
finish_reason, finish_reason,
stop_reason: None,
}], }],
model: c.model, model: c.model,
created: c.created as u32, created: c.created as u32,
......
...@@ -54,7 +54,7 @@ pub fn decode(c: &mut Criterion) { ...@@ -54,7 +54,7 @@ pub fn decode(c: &mut Criterion) {
let tokenizer: Arc<dyn Tokenizer> = let tokenizer: Arc<dyn Tokenizer> =
Arc::new(HuggingFaceTokenizer::from_file(TEST_TOKENIZER).unwrap()); Arc::new(HuggingFaceTokenizer::from_file(TEST_TOKENIZER).unwrap());
let ds = DecodeStream::new(tokenizer, &[], false); let ds = DecodeStream::new(tokenizer, &[], false);
Decoder::new(ds, StopConditions::default()) Decoder::new(ds, StopConditions::default(), false)
}, },
|mut decoder| { |mut decoder| {
for tok in black_box(TEST_TOKS) { for tok in black_box(TEST_TOKS) {
...@@ -78,7 +78,7 @@ pub fn decode_big(c: &mut Criterion) { ...@@ -78,7 +78,7 @@ pub fn decode_big(c: &mut Criterion) {
let tokenizer: Arc<dyn Tokenizer> = let tokenizer: Arc<dyn Tokenizer> =
Arc::new(HuggingFaceTokenizer::from_file(TEST_TOKENIZER).unwrap()); Arc::new(HuggingFaceTokenizer::from_file(TEST_TOKENIZER).unwrap());
let ds = DecodeStream::new(tokenizer, &[], false); let ds = DecodeStream::new(tokenizer, &[], false);
Decoder::new(ds, StopConditions::default()) Decoder::new(ds, StopConditions::default(), false)
}, },
|mut decoder| { |mut decoder| {
for tok in black_box(&BIG_TEST_TOKS) { for tok in black_box(&BIG_TEST_TOKS) {
......
...@@ -215,6 +215,7 @@ pub fn final_response_to_one_chunk_stream( ...@@ -215,6 +215,7 @@ pub fn final_response_to_one_chunk_stream(
index: idx as u32, index: idx as u32,
delta, delta,
finish_reason: ch.finish_reason, finish_reason: ch.finish_reason,
stop_reason: ch.stop_reason.clone(),
logprobs: ch.logprobs.clone(), logprobs: ch.logprobs.clone(),
}; };
choices.push(choice); choices.push(choice);
...@@ -267,6 +268,7 @@ mod tests { ...@@ -267,6 +268,7 @@ mod tests {
reasoning_content: None, reasoning_content: None,
}, },
finish_reason: None, finish_reason: None,
stop_reason: None,
logprobs: None, logprobs: None,
}; };
...@@ -304,6 +306,7 @@ mod tests { ...@@ -304,6 +306,7 @@ mod tests {
reasoning_content: None, reasoning_content: None,
}, },
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: None, logprobs: None,
}; };
...@@ -427,6 +430,7 @@ mod tests { ...@@ -427,6 +430,7 @@ mod tests {
reasoning_content: None, reasoning_content: None,
}, },
finish_reason: None, finish_reason: None,
stop_reason: None,
logprobs: None, logprobs: None,
} }
}], }],
......
...@@ -42,6 +42,7 @@ use crate::protocols::{ ...@@ -42,6 +42,7 @@ use crate::protocols::{
}, },
}; };
use crate::tokenizers::{DecodeStream, HuggingFaceTokenizer, Tokenizer}; use crate::tokenizers::{DecodeStream, HuggingFaceTokenizer, Tokenizer};
use dynamo_async_openai::types::StopReason;
use tokenizers::Tokenizer as HfTokenizer; use tokenizers::Tokenizer as HfTokenizer;
/// Represents the output stream from the execution engine /// Represents the output stream from the execution engine
...@@ -63,6 +64,8 @@ struct DecoderUnfoldState { ...@@ -63,6 +64,8 @@ struct DecoderUnfoldState {
stream: ManyOut<ExecutionOutputStream>, stream: ManyOut<ExecutionOutputStream>,
decoder: Decoder, decoder: Decoder,
validate_engine_decode: bool, validate_engine_decode: bool,
/// Set to true when a local stop condition is detected, causing the stream to end
finished: bool,
} }
impl Backend { impl Backend {
...@@ -95,6 +98,7 @@ impl Backend { ...@@ -95,6 +98,7 @@ impl Backend {
prompt_token_ids: &[TokenIdType], prompt_token_ids: &[TokenIdType],
stop_conditions: StopConditions, stop_conditions: StopConditions,
skip_special_tokens: bool, skip_special_tokens: bool,
include_stop_str_in_output: bool,
) -> anyhow::Result<DecoderUnfoldState> { ) -> anyhow::Result<DecoderUnfoldState> {
let Some(tokenizer) = self.tokenizer.as_ref() else { let Some(tokenizer) = self.tokenizer.as_ref() else {
anyhow::bail!("Backend built from blank ModelDeploymentCard, no tokenizer"); anyhow::bail!("Backend built from blank ModelDeploymentCard, no tokenizer");
...@@ -102,12 +106,14 @@ impl Backend { ...@@ -102,12 +106,14 @@ impl Backend {
let decoder = Decoder::new( let decoder = Decoder::new(
tokenizer.decode_stream(prompt_token_ids, skip_special_tokens), tokenizer.decode_stream(prompt_token_ids, skip_special_tokens),
stop_conditions, stop_conditions,
include_stop_str_in_output,
); );
Ok(DecoderUnfoldState { Ok(DecoderUnfoldState {
stream, stream,
decoder, decoder,
validate_engine_decode: self.validate_engine_decode, validate_engine_decode: self.validate_engine_decode,
finished: false,
}) })
} }
} }
...@@ -133,6 +139,12 @@ impl ...@@ -133,6 +139,12 @@ impl
// TODO: Consider updating default to true to match behavior of other frameworks // TODO: Consider updating default to true to match behavior of other frameworks
let skip_special_tokens = request.output_options.skip_special_tokens.unwrap_or(false); let skip_special_tokens = request.output_options.skip_special_tokens.unwrap_or(false);
// Extract include_stop_str_in_output from sampling_options (defaults to false)
let include_stop_str_in_output = request
.sampling_options
.include_stop_str_in_output
.unwrap_or(false);
let next_stream = next.generate(request).await?; let next_stream = next.generate(request).await?;
let context = next_stream.context(); let context = next_stream.context();
...@@ -141,9 +153,15 @@ impl ...@@ -141,9 +153,15 @@ impl
&prompt_token_ids, &prompt_token_ids,
stop_conditions, stop_conditions,
skip_special_tokens, skip_special_tokens,
include_stop_str_in_output,
)?; )?;
let processed_stream = stream::unfold(state, |mut state| async move { let processed_stream = stream::unfold(state, |mut state| async move {
// If we've already detected a local stop condition, end the stream
if state.finished {
return None;
}
match state.stream.next().await { match state.stream.next().await {
Some(output) => { Some(output) => {
// move to state.process_output // move to state.process_output
...@@ -169,21 +187,38 @@ impl ...@@ -169,21 +187,38 @@ impl
// NOTE: the `finish_reason` is computed from the generated `token_ids` alone. // NOTE: the `finish_reason` is computed from the generated `token_ids` alone.
// The `data` field can have a `finish_reason` set, coming from the underlying // The `data` field can have a `finish_reason` set, coming from the underlying
// LLM inference `Engine`, and empty `token_ids`. See comment below for more details. // LLM inference `Engine`, and empty `token_ids`. See comment below for more details.
let finish_reason = match &result.stop_trigger { //
Some(StopTrigger::MaxTokensLimit) => Some(FinishReason::Length), // stop_reason is only set for user-provided stop sequences, not for system
Some(StopTrigger::HiddenStopTokenDetected(_)) => Some(FinishReason::Stop), // EOS tokens (HiddenStopTokenDetected). This matches OpenAI API behavior where
Some(StopTrigger::HiddenStopSequenceDetected(_)) => { // stop_reason is only present when a user-specified stop sequence is matched.
Some(FinishReason::Stop) let (finish_reason, stop_reason) = match &result.stop_trigger {
Some(StopTrigger::MaxTokensLimit) => (Some(FinishReason::Length), None),
Some(StopTrigger::HiddenStopTokenDetected(_)) => {
// System EOS token - no stop_reason (user didn't request this stop)
(Some(FinishReason::Stop), None)
}
Some(StopTrigger::HiddenStopSequenceDetected(seq)) => {
// User-provided stop sequence (hidden from output)
(
Some(FinishReason::Stop),
Some(StopReason::String(seq.clone())),
)
}
Some(StopTrigger::VisibleStopSequenceDetected(seq)) => {
// User-provided stop sequence (included in output)
(
Some(FinishReason::Stop),
Some(StopReason::String(seq.clone())),
)
} }
None => None, None => (None, None),
}; };
if data.finish_reason.is_none() && finish_reason.is_some() { // If we detected a local stop condition, mark stream as finished
tracing::debug!( // so we stop iterating (upstream may keep generating, but we ignore it)
?result.stop_trigger, if finish_reason.is_some() && data.finish_reason.is_none() {
"upstream did not provide a finish reason; issuing a stop_generation request to free resources",
);
state.stream.context().stop_generating(); state.stream.context().stop_generating();
state.finished = true;
} }
let text = result.text; let text = result.text;
...@@ -215,6 +250,7 @@ impl ...@@ -215,6 +250,7 @@ impl
// which we don't want to propagate to `data.finish_reason`. // which we don't want to propagate to `data.finish_reason`.
if finish_reason.is_some() { if finish_reason.is_some() {
data.finish_reason = finish_reason; data.finish_reason = finish_reason;
data.stop_reason = stop_reason;
} }
data.text = text; data.text = text;
data.tokens = Some(tokens); data.tokens = Some(tokens);
...@@ -240,6 +276,7 @@ impl ...@@ -240,6 +276,7 @@ impl
log_probs: data.log_probs, log_probs: data.log_probs,
top_logprobs: data.top_logprobs, top_logprobs: data.top_logprobs,
finish_reason: data.finish_reason, finish_reason: data.finish_reason,
stop_reason: data.stop_reason,
//mdcsum: mdcsum.clone(), //mdcsum: mdcsum.clone(),
index: data.index, index: data.index,
completion_usage: data.completion_usage, completion_usage: data.completion_usage,
...@@ -282,10 +319,6 @@ impl ...@@ -282,10 +319,6 @@ impl
} }
} }
// todo - add visible stop conditions
// visible_stop_ids: HashSet<TokenIdType>,
// visible_stop_sequences: Vec<String>,
/// The [`Decoder`] object could be a member of either the internal LLM engine or part of the /// The [`Decoder`] object could be a member of either the internal LLM engine or part of the
/// postprocessor. If in the postprocessor, should be minimally in the same process or at very minimum /// postprocessor. If in the postprocessor, should be minimally in the same process or at very minimum
/// on the same physical machine connected by an IPC. /// on the same physical machine connected by an IPC.
...@@ -301,9 +334,13 @@ pub struct Decoder { ...@@ -301,9 +334,13 @@ pub struct Decoder {
hidden_stop_ids: HashSet<TokenIdType>, hidden_stop_ids: HashSet<TokenIdType>,
// text sequences that if found in the response will trigger a stop condition after the // text sequences that if found in the response will trigger a stop condition after the
// minimum number of tokens have been generated // minimum number of tokens have been generated (excluded from output)
hidden_stop_sequences: Vec<String>, hidden_stop_sequences: Vec<String>,
// text sequences that if found in the response will trigger a stop condition after the
// minimum number of tokens have been generated (included in output)
visible_stop_sequences: Vec<String>,
// number of generated tokens // number of generated tokens
generated_tokens: u32, generated_tokens: u32,
...@@ -315,8 +352,6 @@ pub struct Decoder { ...@@ -315,8 +352,6 @@ pub struct Decoder {
// the number of bytes currently jailed // the number of bytes currently jailed
jailed_bytes: usize, jailed_bytes: usize,
// mdcsum
//mdcsum: String,
} }
#[allow(dead_code)] #[allow(dead_code)]
...@@ -325,16 +360,7 @@ pub enum StopTrigger { ...@@ -325,16 +360,7 @@ pub enum StopTrigger {
MaxTokensLimit, MaxTokensLimit,
HiddenStopTokenDetected(TokenIdType), HiddenStopTokenDetected(TokenIdType),
HiddenStopSequenceDetected(String), HiddenStopSequenceDetected(String),
} VisibleStopSequenceDetected(String),
impl StopTrigger {
pub fn should_hide_text(&self) -> bool {
match self {
StopTrigger::MaxTokensLimit => false,
StopTrigger::HiddenStopTokenDetected(_) => true,
StopTrigger::HiddenStopSequenceDetected(_) => true,
}
}
} }
pub struct StepResult { pub struct StepResult {
...@@ -370,7 +396,7 @@ impl Decoder { ...@@ -370,7 +396,7 @@ impl Decoder {
pub fn new( pub fn new(
decode_stream: DecodeStream, decode_stream: DecodeStream,
stop_condition: StopConditions, stop_condition: StopConditions,
//mdcsum: String, include_stop_str_in_output: bool,
) -> Self { ) -> Self {
let hidden_stop_ids: HashSet<TokenIdType> = stop_condition let hidden_stop_ids: HashSet<TokenIdType> = stop_condition
.stop_token_ids_hidden .stop_token_ids_hidden
...@@ -379,15 +405,19 @@ impl Decoder { ...@@ -379,15 +405,19 @@ impl Decoder {
.copied() .copied()
.collect(); .collect();
let hidden_stop_sequences: Vec<String> = stop_condition // Categorize stop sequences based on include_stop_str_in_output:
.stop // - When true: user-provided stop sequences go to visible (included in output)
.unwrap_or_default() // - When false: user-provided stop sequences go to hidden (excluded from output)
.iter() let (hidden_stop_sequences, visible_stop_sequences) = if include_stop_str_in_output {
.map(|x| x.to_string()) (Vec::new(), stop_condition.stop.unwrap_or_default())
.collect(); } else {
(stop_condition.stop.unwrap_or_default(), Vec::new())
};
// Calculate jail_max_bytes considering both hidden and visible stop sequences
let jail_max_bytes = hidden_stop_sequences let jail_max_bytes = hidden_stop_sequences
.iter() .iter()
.chain(visible_stop_sequences.iter())
.map(|x| x.len()) .map(|x| x.len())
.max() .max()
.unwrap_or(0); .unwrap_or(0);
...@@ -396,8 +426,7 @@ impl Decoder { ...@@ -396,8 +426,7 @@ impl Decoder {
decode_stream, decode_stream,
hidden_stop_ids, hidden_stop_ids,
hidden_stop_sequences, hidden_stop_sequences,
//visible_stop_ids: HashSet::new(), visible_stop_sequences,
//visible_stop_sequences: Vec::new(),
min_tokens: stop_condition.min_tokens.unwrap_or(0), min_tokens: stop_condition.min_tokens.unwrap_or(0),
generated_tokens: 0, generated_tokens: 0,
jail: String::new(), jail: String::new(),
...@@ -444,12 +473,13 @@ impl Decoder { ...@@ -444,12 +473,13 @@ impl Decoder {
log::debug!("post_append: {}", self.jail.len()); log::debug!("post_append: {}", self.jail.len());
log::debug!("jail: {}", self.jail); log::debug!("jail: {}", self.jail);
// Check hidden stop sequences first (excluded from output)
for seq in &self.hidden_stop_sequences { for seq in &self.hidden_stop_sequences {
log::debug!("stop seq: {}", seq); log::debug!("stop seq: {}", seq);
if let Some(offset) = galil_seiferas::gs_find(self.jail.as_bytes(), seq.as_bytes()) if let Some(offset) = galil_seiferas::gs_find(self.jail.as_bytes(), seq.as_bytes())
{ {
log::debug!("offset: {}", offset); log::debug!("offset: {}", offset);
// return only new bytes after pre_append .. offset+seq.len() // return only new bytes after pre_append .. offset (excluding stop sequence)
// example: seq = "ox", token = "boxes", return "b" // example: seq = "ox", token = "boxes", return "b"
// note: this changes when we start jailing tokens for partial matches // note: this changes when we start jailing tokens for partial matches
// on the suffix of the jail with prefixes of the stop sequences // on the suffix of the jail with prefixes of the stop sequences
...@@ -468,6 +498,26 @@ impl Decoder { ...@@ -468,6 +498,26 @@ impl Decoder {
} }
} }
// Check visible stop sequences (included in output)
for seq in &self.visible_stop_sequences {
if let Some(offset) = galil_seiferas::gs_find(self.jail.as_bytes(), seq.as_bytes())
{
// For visible stop sequences, include the stop string in the output
// Return all text from pre_append up to and including the stop sequence
let stop_end = offset + seq.len();
let token_with_stop = if stop_end > pre_append {
self.jail[pre_append..stop_end].to_string()
} else {
// Stop sequence was entirely in previously returned text
"".to_string()
};
return Ok(StepResult::with_stop_trigger(
Some(token_with_stop),
StopTrigger::VisibleStopSequenceDetected(seq.to_string()),
));
}
}
Self::maybe_drain_to_max_bytes(&mut self.jail, self.jail_max_bytes); Self::maybe_drain_to_max_bytes(&mut self.jail, self.jail_max_bytes);
} }
...@@ -484,12 +534,8 @@ impl Decoder { ...@@ -484,12 +534,8 @@ impl Decoder {
stop_trigger, stop_trigger,
} = self.step(*token_id)?; } = self.step(*token_id)?;
let hide_text = stop_trigger // Always include token text (for visible stops, the stop string is already in the token)
.as_ref() if let Some(token) = &token {
.map(|x| x.should_hide_text())
.unwrap_or(false);
if !hide_text && let Some(token) = &token {
text.get_or_insert_with(|| String::with_capacity(token_ids.len())) text.get_or_insert_with(|| String::with_capacity(token_ids.len()))
.push_str(token); .push_str(token);
} }
...@@ -511,24 +557,6 @@ impl Decoder { ...@@ -511,24 +557,6 @@ impl Decoder {
}) })
} }
fn return_token(&self, token: Option<String>) -> StepResult {
StepResult {
token,
stop_trigger: None,
}
}
fn return_with_stop_trigger(
&self,
token: Option<String>,
stop_trigger: StopTrigger,
) -> StepResult {
StepResult {
token,
stop_trigger: Some(stop_trigger),
}
}
fn jailed_string(&self) -> Option<String> { fn jailed_string(&self) -> Option<String> {
if self.jailed_bytes > 0 { if self.jailed_bytes > 0 {
// get the last jailed_bytes from the jail // get the last jailed_bytes from the jail
......
...@@ -159,12 +159,12 @@ impl ...@@ -159,12 +159,12 @@ impl
for c in prompt.chars() { for c in prompt.chars() {
// we are returning characters not tokens, so there will be some postprocessing overhead // we are returning characters not tokens, so there will be some postprocessing overhead
tokio::time::sleep(*TOKEN_ECHO_DELAY).await; tokio::time::sleep(*TOKEN_ECHO_DELAY).await;
let response = deltas.create_choice(0, Some(c.to_string()), None, None); let response = deltas.create_choice(0, Some(c.to_string()), None, None, None);
yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None }; yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
id += 1; id += 1;
} }
let response = deltas.create_choice(0, None, Some(dynamo_async_openai::types::FinishReason::Stop), None); let response = deltas.create_choice(0, None, Some(dynamo_async_openai::types::FinishReason::Stop), None, None);
yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None }; yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
}; };
......
...@@ -216,6 +216,7 @@ mod tests { ...@@ -216,6 +216,7 @@ mod tests {
log_probs: None, log_probs: None,
top_logprobs: None, top_logprobs: None,
finish_reason: None, finish_reason: None,
stop_reason: None,
index: None, index: None,
disaggregated_params: None, disaggregated_params: None,
completion_usage: None, completion_usage: None,
......
...@@ -309,6 +309,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -309,6 +309,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
log_probs: None, log_probs: None,
top_logprobs: None, top_logprobs: None,
finish_reason: None, finish_reason: None,
stop_reason: None,
index: None, index: None,
// Add dummy disaggregated_params for prefill workers // Add dummy disaggregated_params for prefill workers
disaggregated_params: if is_prefill { disaggregated_params: if is_prefill {
......
...@@ -961,6 +961,7 @@ mod tests { ...@@ -961,6 +961,7 @@ mod tests {
reasoning_content: None, reasoning_content: None,
}, },
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: Some(ChatChoiceLogprobs { logprobs: Some(ChatChoiceLogprobs {
content: Some(token_logprobs), content: Some(token_logprobs),
refusal: None, refusal: None,
...@@ -994,6 +995,7 @@ mod tests { ...@@ -994,6 +995,7 @@ mod tests {
reasoning_content: None, reasoning_content: None,
}, },
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: Some(ChatChoiceLogprobs { logprobs: Some(ChatChoiceLogprobs {
content: Some(token_logprobs), content: Some(token_logprobs),
refusal: None, refusal: None,
...@@ -1343,6 +1345,7 @@ mod tests { ...@@ -1343,6 +1345,7 @@ mod tests {
reasoning_content: None, reasoning_content: None,
}, },
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: None, // No logprobs logprobs: None, // No logprobs
}], }],
created: 1234567890, created: 1234567890,
......
...@@ -7,6 +7,7 @@ pub use super::FinishReason; ...@@ -7,6 +7,7 @@ pub use super::FinishReason;
pub use super::preprocessor::PreprocessedRequest; pub use super::preprocessor::PreprocessedRequest;
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
use dynamo_async_openai::types::CompletionUsage; use dynamo_async_openai::types::CompletionUsage;
use dynamo_async_openai::types::StopReason;
use dynamo_runtime::protocols::maybe_error::MaybeError; use dynamo_runtime::protocols::maybe_error::MaybeError;
pub type TokenType = Option<String>; pub type TokenType = Option<String>;
...@@ -44,6 +45,12 @@ pub struct BackendOutput { ...@@ -44,6 +45,12 @@ pub struct BackendOutput {
// TODO: Enrich this with more information as can apply our first-level postprocessing // TODO: Enrich this with more information as can apply our first-level postprocessing
// logic and return more detailed information // logic and return more detailed information
pub finish_reason: Option<FinishReason>, pub finish_reason: Option<FinishReason>,
/// The stop string or token that triggered the stop condition.
/// This is set when finish_reason is Stop and identifies what triggered it.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<StopReason>,
// Model Deployment Card checksum // Model Deployment Card checksum
//pub mdcsum: String, //pub mdcsum: String,
...@@ -91,6 +98,11 @@ pub struct LLMEngineOutput { ...@@ -91,6 +98,11 @@ pub struct LLMEngineOutput {
// logic and return more detailed information // logic and return more detailed information
pub finish_reason: Option<FinishReason>, pub finish_reason: Option<FinishReason>,
/// The stop string or token that triggered the stop condition.
/// This is set when finish_reason is Stop and identifies what triggered it.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<StopReason>,
// Index field for batch requests to match OpenAI format // Index field for batch requests to match OpenAI format
pub index: Option<u32>, pub index: Option<u32>,
...@@ -117,6 +129,7 @@ impl LLMEngineOutput { ...@@ -117,6 +129,7 @@ impl LLMEngineOutput {
log_probs: None, log_probs: None,
top_logprobs: None, top_logprobs: None,
finish_reason: Some(FinishReason::Cancelled), finish_reason: Some(FinishReason::Cancelled),
stop_reason: None,
index: None, index: None,
disaggregated_params: None, disaggregated_params: None,
extra_args: None, extra_args: None,
...@@ -132,6 +145,7 @@ impl LLMEngineOutput { ...@@ -132,6 +145,7 @@ impl LLMEngineOutput {
cum_log_probs: None, cum_log_probs: None,
log_probs: None, log_probs: None,
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
stop_reason: None,
top_logprobs: None, top_logprobs: None,
index: None, index: None,
disaggregated_params: None, disaggregated_params: None,
...@@ -149,6 +163,7 @@ impl LLMEngineOutput { ...@@ -149,6 +163,7 @@ impl LLMEngineOutput {
log_probs: None, log_probs: None,
top_logprobs: None, top_logprobs: None,
finish_reason: Some(FinishReason::Length), finish_reason: Some(FinishReason::Length),
stop_reason: None,
index: None, index: None,
disaggregated_params: None, disaggregated_params: None,
extra_args: None, extra_args: None,
...@@ -165,6 +180,7 @@ impl LLMEngineOutput { ...@@ -165,6 +180,7 @@ impl LLMEngineOutput {
log_probs: None, log_probs: None,
top_logprobs: None, top_logprobs: None,
finish_reason: Some(FinishReason::Error(err_msg)), finish_reason: Some(FinishReason::Error(err_msg)),
stop_reason: None,
index: None, index: None,
disaggregated_params: None, disaggregated_params: None,
extra_args: None, extra_args: None,
......
...@@ -12,6 +12,7 @@ use crate::protocols::{ ...@@ -12,6 +12,7 @@ use crate::protocols::{
openai::ParsingOptions, openai::ParsingOptions,
}; };
use dynamo_async_openai::types::StopReason;
use dynamo_runtime::engine::DataStream; use dynamo_runtime::engine::DataStream;
/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single /// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single
...@@ -49,6 +50,8 @@ struct DeltaChoice { ...@@ -49,6 +50,8 @@ struct DeltaChoice {
role: Option<dynamo_async_openai::types::Role>, role: Option<dynamo_async_openai::types::Role>,
/// The reason the completion was finished (if applicable). /// The reason the completion was finished (if applicable).
finish_reason: Option<dynamo_async_openai::types::FinishReason>, finish_reason: Option<dynamo_async_openai::types::FinishReason>,
/// The stop string or token that triggered the stop condition.
stop_reason: Option<StopReason>,
/// Optional log probabilities for the chat choice. /// Optional log probabilities for the chat choice.
logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>, logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
// Optional tool calls for the chat choice. // Optional tool calls for the chat choice.
...@@ -159,6 +162,7 @@ impl DeltaAggregator { ...@@ -159,6 +162,7 @@ impl DeltaAggregator {
text: "".to_string(), text: "".to_string(),
role: choice.delta.role, role: choice.delta.role,
finish_reason: None, finish_reason: None,
stop_reason: None,
logprobs: None, logprobs: None,
tool_calls: None, tool_calls: None,
reasoning_content: None, reasoning_content: None,
...@@ -204,6 +208,11 @@ impl DeltaAggregator { ...@@ -204,6 +208,11 @@ impl DeltaAggregator {
state_choice.finish_reason = Some(finish_reason); state_choice.finish_reason = Some(finish_reason);
} }
// Update stop reason if provided.
if let Some(stop_reason) = choice.stop_reason {
state_choice.stop_reason = Some(stop_reason);
}
// Update logprobs // Update logprobs
if let Some(logprobs) = &choice.logprobs { if let Some(logprobs) = &choice.logprobs {
let state_lps = state_choice.logprobs.get_or_insert( let state_lps = state_choice.logprobs.get_or_insert(
...@@ -296,6 +305,7 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice { ...@@ -296,6 +305,7 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice {
}, },
index: delta.index, index: delta.index,
finish_reason, finish_reason,
stop_reason: delta.stop_reason,
logprobs: delta.logprobs, logprobs: delta.logprobs,
} }
} }
...@@ -408,6 +418,7 @@ mod tests { ...@@ -408,6 +418,7 @@ mod tests {
index, index,
delta, delta,
finish_reason, finish_reason,
stop_reason: None,
logprobs, logprobs,
}; };
...@@ -626,6 +637,7 @@ mod tests { ...@@ -626,6 +637,7 @@ mod tests {
reasoning_content: None, reasoning_content: None,
}, },
finish_reason: Some(dynamo_async_openai::types::FinishReason::Stop), finish_reason: Some(dynamo_async_openai::types::FinishReason::Stop),
stop_reason: None,
logprobs: None, logprobs: None,
}, },
dynamo_async_openai::types::ChatChoiceStream { dynamo_async_openai::types::ChatChoiceStream {
...@@ -639,6 +651,7 @@ mod tests { ...@@ -639,6 +651,7 @@ mod tests {
reasoning_content: None, reasoning_content: None,
}, },
finish_reason: Some(dynamo_async_openai::types::FinishReason::Stop), finish_reason: Some(dynamo_async_openai::types::FinishReason::Stop),
stop_reason: None,
logprobs: None, logprobs: None,
}, },
], ],
......
...@@ -234,6 +234,7 @@ impl DeltaGenerator { ...@@ -234,6 +234,7 @@ impl DeltaGenerator {
/// * `text` - The text content for the response. /// * `text` - The text content for the response.
/// * `finish_reason` - The reason why the response finished (e.g., stop, length, etc.). /// * `finish_reason` - The reason why the response finished (e.g., stop, length, etc.).
/// * `logprobs` - Optional log probabilities of the generated tokens. /// * `logprobs` - Optional log probabilities of the generated tokens.
/// * `stop_reason` - Optional stop string or token that triggered the stop.
/// ///
/// # Returns /// # Returns
/// * An [`dynamo_async_openai::types::CreateChatCompletionStreamResponse`] instance representing the choice. /// * An [`dynamo_async_openai::types::CreateChatCompletionStreamResponse`] instance representing the choice.
...@@ -244,6 +245,7 @@ impl DeltaGenerator { ...@@ -244,6 +245,7 @@ impl DeltaGenerator {
text: Option<String>, text: Option<String>,
finish_reason: Option<dynamo_async_openai::types::FinishReason>, finish_reason: Option<dynamo_async_openai::types::FinishReason>,
logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>, logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
stop_reason: Option<dynamo_async_openai::types::StopReason>,
) -> NvCreateChatCompletionStreamResponse { ) -> NvCreateChatCompletionStreamResponse {
let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta { let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
content: text, content: text,
...@@ -262,6 +264,7 @@ impl DeltaGenerator { ...@@ -262,6 +264,7 @@ impl DeltaGenerator {
index, index,
delta, delta,
finish_reason, finish_reason,
stop_reason,
logprobs, logprobs,
}; };
...@@ -384,7 +387,13 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -384,7 +387,13 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
// Create the streaming response. // Create the streaming response.
let index = 0; let index = 0;
let mut stream_response = self.create_choice(index, delta.text, finish_reason, logprobs); let mut stream_response = self.create_choice(
index,
delta.text,
finish_reason,
logprobs,
delta.stop_reason,
);
// Record first token time (only succeeds on first call due to OnceLock) // Record first token time (only succeeds on first call due to OnceLock)
if let Some(ref tracker) = self.timing_tracker { if let Some(ref tracker) = self.timing_tracker {
......
...@@ -104,6 +104,7 @@ fn create_choice_stream( ...@@ -104,6 +104,7 @@ fn create_choice_stream(
content: &str, content: &str,
tool_calls: Option<Vec<ChatCompletionMessageToolCallChunk>>, tool_calls: Option<Vec<ChatCompletionMessageToolCallChunk>>,
finish_reason: Option<FinishReason>, finish_reason: Option<FinishReason>,
stop_reason: Option<dynamo_async_openai::types::StopReason>,
logprobs: Option<ChatChoiceLogprobs>, logprobs: Option<ChatChoiceLogprobs>,
) -> ChatChoiceStream { ) -> ChatChoiceStream {
#[allow(deprecated)] #[allow(deprecated)]
...@@ -118,6 +119,7 @@ fn create_choice_stream( ...@@ -118,6 +119,7 @@ fn create_choice_stream(
reasoning_content: None, reasoning_content: None,
}, },
finish_reason, finish_reason,
stop_reason,
logprobs, logprobs,
} }
} }
...@@ -178,6 +180,7 @@ impl ChoiceJailState { ...@@ -178,6 +180,7 @@ impl ChoiceJailState {
&prefix, &prefix,
None, None,
choice.finish_reason, choice.finish_reason,
None,
choice.logprobs.clone(), choice.logprobs.clone(),
); );
emissions.push(ChoiceEmission::PassThrough(prefix_choice)); emissions.push(ChoiceEmission::PassThrough(prefix_choice));
...@@ -226,6 +229,7 @@ impl ChoiceJailState { ...@@ -226,6 +229,7 @@ impl ChoiceJailState {
trailing_part, trailing_part,
None, None,
choice.finish_reason, choice.finish_reason,
None,
choice.logprobs.clone(), choice.logprobs.clone(),
); );
emissions.push(ChoiceEmission::Trailing(trailing_choice)); emissions.push(ChoiceEmission::Trailing(trailing_choice));
...@@ -258,6 +262,7 @@ impl ChoiceJailState { ...@@ -258,6 +262,7 @@ impl ChoiceJailState {
&prefix, &prefix,
None, None,
choice.finish_reason, choice.finish_reason,
None,
choice.logprobs.clone(), choice.logprobs.clone(),
); );
emissions.push(ChoiceEmission::PassThrough(prefix_choice)); emissions.push(ChoiceEmission::PassThrough(prefix_choice));
...@@ -301,6 +306,7 @@ impl ChoiceJailState { ...@@ -301,6 +306,7 @@ impl ChoiceJailState {
&content, &content,
None, None,
choice.finish_reason, choice.finish_reason,
None,
choice.logprobs.clone(), choice.logprobs.clone(),
); );
emissions.push(ChoiceEmission::PassThrough(pass_through_choice)); emissions.push(ChoiceEmission::PassThrough(pass_through_choice));
...@@ -354,6 +360,7 @@ impl ChoiceJailState { ...@@ -354,6 +360,7 @@ impl ChoiceJailState {
trailing_part, trailing_part,
None, None,
choice.finish_reason, choice.finish_reason,
None,
choice.logprobs.clone(), choice.logprobs.clone(),
); );
emissions.push(ChoiceEmission::Trailing(trailing_choice)); emissions.push(ChoiceEmission::Trailing(trailing_choice));
...@@ -384,6 +391,7 @@ impl ChoiceJailState { ...@@ -384,6 +391,7 @@ impl ChoiceJailState {
None, None,
self.stream_finish_reason, // For the accumulated content, assign the original stream finish reason, otherwise it will get lost self.stream_finish_reason, // For the accumulated content, assign the original stream finish reason, otherwise it will get lost
None, None,
None,
); );
let final_choice = jail_stream let final_choice = jail_stream
...@@ -557,6 +565,7 @@ impl JailedStream { ...@@ -557,6 +565,7 @@ impl JailedStream {
index: choice.index, index: choice.index,
delta: choice.delta.clone(), delta: choice.delta.clone(),
finish_reason: choice.finish_reason, finish_reason: choice.finish_reason,
stop_reason: choice.stop_reason.clone(),
logprobs: choice.logprobs.clone(), logprobs: choice.logprobs.clone(),
}; };
all_emissions.push(ChoiceEmission::PassThrough(pass_through_choice)); all_emissions.push(ChoiceEmission::PassThrough(pass_through_choice));
...@@ -834,6 +843,7 @@ impl JailedStream { ...@@ -834,6 +843,7 @@ impl JailedStream {
Some(tool_call_chunks), Some(tool_call_chunks),
None, None,
None, None,
None,
); );
return choice; return choice;
} }
...@@ -845,6 +855,7 @@ impl JailedStream { ...@@ -845,6 +855,7 @@ impl JailedStream {
accumulated_content, accumulated_content,
None, None,
base_choice.finish_reason, base_choice.finish_reason,
base_choice.stop_reason.clone(),
base_choice.logprobs.clone(), base_choice.logprobs.clone(),
) )
} }
...@@ -857,6 +868,7 @@ impl JailedStream { ...@@ -857,6 +868,7 @@ impl JailedStream {
"", "",
Some(tool_call_chunks), Some(tool_call_chunks),
base_choice.finish_reason, base_choice.finish_reason,
None,
base_choice.logprobs.clone(), base_choice.logprobs.clone(),
), ),
Ok(_) | Err(_) => { Ok(_) | Err(_) => {
...@@ -867,6 +879,7 @@ impl JailedStream { ...@@ -867,6 +879,7 @@ impl JailedStream {
accumulated_content, accumulated_content,
None, None,
base_choice.finish_reason, base_choice.finish_reason,
base_choice.stop_reason.clone(),
base_choice.logprobs.clone(), base_choice.logprobs.clone(),
) )
} }
......
...@@ -368,6 +368,7 @@ mod tests { ...@@ -368,6 +368,7 @@ mod tests {
reasoning_content: None, reasoning_content: None,
}, },
finish_reason: None, finish_reason: None,
stop_reason: None,
logprobs: None, logprobs: None,
}], }],
created: now, created: now,
......
...@@ -93,7 +93,7 @@ impl ...@@ -93,7 +93,7 @@ impl
let stream = stream! { let stream = stream! {
tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await; tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await;
for i in 0..10 { for i in 0..10 {
let output = generator.create_choice(i,Some(format!("choice {i}")), None, None); let output = generator.create_choice(i, Some(format!("choice {i}")), None, None, None);
yield Annotated::from_data(output); yield Annotated::from_data(output);
} }
......
...@@ -52,7 +52,7 @@ impl ...@@ -52,7 +52,7 @@ impl
// Generate 5 response chunks // Generate 5 response chunks
for i in 0..5 { for i in 0..5 {
let output = generator.create_choice(i, Some(format!("Mock response {i}")), None, None); let output = generator.create_choice(i, Some(format!("Mock response {i}")), None, None, None);
yield Annotated::from_data(output); yield Annotated::from_data(output);
} }
}; };
......
...@@ -389,6 +389,7 @@ fn create_response_with_linear_probs( ...@@ -389,6 +389,7 @@ fn create_response_with_linear_probs(
reasoning_content: None, reasoning_content: None,
}, },
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: Some(ChatChoiceLogprobs { logprobs: Some(ChatChoiceLogprobs {
content: Some(token_logprobs), content: Some(token_logprobs),
refusal: None, refusal: None,
...@@ -468,6 +469,7 @@ fn create_multi_choice_response( ...@@ -468,6 +469,7 @@ fn create_multi_choice_response(
reasoning_content: None, reasoning_content: None,
}, },
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: Some(ChatChoiceLogprobs { logprobs: Some(ChatChoiceLogprobs {
content: Some(token_logprobs), content: Some(token_logprobs),
refusal: None, refusal: None,
......
...@@ -34,6 +34,7 @@ mod tests { ...@@ -34,6 +34,7 @@ mod tests {
reasoning_content: None, reasoning_content: None,
}, },
finish_reason: None, finish_reason: None,
stop_reason: None,
logprobs: None, logprobs: None,
}; };
...@@ -73,6 +74,7 @@ mod tests { ...@@ -73,6 +74,7 @@ mod tests {
reasoning_content: None, reasoning_content: None,
}, },
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: None, logprobs: None,
}; };
...@@ -116,6 +118,7 @@ mod tests { ...@@ -116,6 +118,7 @@ mod tests {
reasoning_content: None, reasoning_content: None,
}, },
finish_reason: None, finish_reason: None,
stop_reason: None,
logprobs: None, logprobs: None,
}; };
...@@ -158,6 +161,7 @@ mod tests { ...@@ -158,6 +161,7 @@ mod tests {
reasoning_content: None, reasoning_content: None,
}, },
finish_reason: None, finish_reason: None,
stop_reason: None,
logprobs: None, logprobs: None,
} }
}) })
...@@ -202,6 +206,7 @@ mod tests { ...@@ -202,6 +206,7 @@ mod tests {
reasoning_content: None, reasoning_content: None,
}, },
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: None, logprobs: None,
} }
}) })
...@@ -2323,6 +2328,7 @@ mod parallel_jail_tests { ...@@ -2323,6 +2328,7 @@ mod parallel_jail_tests {
reasoning_content: None, reasoning_content: None,
}, },
finish_reason: None, finish_reason: None,
stop_reason: None,
logprobs: None, logprobs: None,
} }
}) })
......
...@@ -24,6 +24,7 @@ fn create_mock_response_chunk( ...@@ -24,6 +24,7 @@ fn create_mock_response_chunk(
reasoning_content, reasoning_content,
}, },
finish_reason: None, finish_reason: None,
stop_reason: None,
logprobs: None, logprobs: None,
}; };
......
...@@ -106,6 +106,7 @@ fn build_backend_outputs_with_cached_tokens(cached_tokens: Option<u32>) -> Vec<B ...@@ -106,6 +106,7 @@ fn build_backend_outputs_with_cached_tokens(cached_tokens: Option<u32>) -> Vec<B
log_probs: None, log_probs: None,
top_logprobs: None, top_logprobs: None,
finish_reason: None, finish_reason: None,
stop_reason: None,
index: Some(0), index: Some(0),
completion_usage: None, completion_usage: None,
disaggregated_params: None, disaggregated_params: None,
...@@ -118,6 +119,7 @@ fn build_backend_outputs_with_cached_tokens(cached_tokens: Option<u32>) -> Vec<B ...@@ -118,6 +119,7 @@ fn build_backend_outputs_with_cached_tokens(cached_tokens: Option<u32>) -> Vec<B
log_probs: None, log_probs: None,
top_logprobs: None, top_logprobs: None,
finish_reason: None, finish_reason: None,
stop_reason: None,
index: Some(0), index: Some(0),
completion_usage: None, completion_usage: None,
disaggregated_params: None, disaggregated_params: None,
...@@ -130,6 +132,7 @@ fn build_backend_outputs_with_cached_tokens(cached_tokens: Option<u32>) -> Vec<B ...@@ -130,6 +132,7 @@ fn build_backend_outputs_with_cached_tokens(cached_tokens: Option<u32>) -> Vec<B
log_probs: None, log_probs: None,
top_logprobs: None, top_logprobs: None,
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
stop_reason: None,
index: Some(0), index: Some(0),
completion_usage: cached_tokens.map(|ct| AoaiCompletionUsage { completion_usage: cached_tokens.map(|ct| AoaiCompletionUsage {
prompt_tokens: 0, prompt_tokens: 0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment