Unverified Commit 45e881d3 authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

feat: Support for field include_stop_str_in_output (#4924)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
parent ef3027bd
......@@ -1006,6 +1006,13 @@ pub struct ChatChoiceLogprobs {
pub refusal: Option<Vec<ChatCompletionTokenLogprob>>,
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
#[serde(untagged)]
pub enum StopReason {
String(String), // matched user-provided stop sequence
Int(i64), // matched stop token id (requires stop_token_id support)
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
pub struct ChatChoice {
/// The index of the choice in the list of choices.
......@@ -1017,6 +1024,10 @@ pub struct ChatChoice {
/// `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<FinishReason>,
/// Which stop string matched (if any).
/// This is only set when `finish_reason` is `"stop"` because a user-provided stop sequence was hit.
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<StopReason>,
/// Log probability information for the choice.
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatChoiceLogprobs>,
......@@ -1112,6 +1123,10 @@ pub struct ChatChoiceStream {
/// (deprecated) if the model called a function.
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<FinishReason>,
/// Which stop string matched (if any).
/// This is only set when `finish_reason` is `"stop"` because a user-provided stop sequence was hit.
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<StopReason>,
/// Log probability information for the choice.
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatChoiceLogprobs>,
......
......@@ -413,6 +413,7 @@ impl
},
logprobs: None,
finish_reason,
stop_reason: None,
}],
model: c.model,
created: c.created as u32,
......
......@@ -54,7 +54,7 @@ pub fn decode(c: &mut Criterion) {
let tokenizer: Arc<dyn Tokenizer> =
Arc::new(HuggingFaceTokenizer::from_file(TEST_TOKENIZER).unwrap());
let ds = DecodeStream::new(tokenizer, &[], false);
Decoder::new(ds, StopConditions::default())
Decoder::new(ds, StopConditions::default(), false)
},
|mut decoder| {
for tok in black_box(TEST_TOKS) {
......@@ -78,7 +78,7 @@ pub fn decode_big(c: &mut Criterion) {
let tokenizer: Arc<dyn Tokenizer> =
Arc::new(HuggingFaceTokenizer::from_file(TEST_TOKENIZER).unwrap());
let ds = DecodeStream::new(tokenizer, &[], false);
Decoder::new(ds, StopConditions::default())
Decoder::new(ds, StopConditions::default(), false)
},
|mut decoder| {
for tok in black_box(&BIG_TEST_TOKS) {
......
......@@ -215,6 +215,7 @@ pub fn final_response_to_one_chunk_stream(
index: idx as u32,
delta,
finish_reason: ch.finish_reason,
stop_reason: ch.stop_reason.clone(),
logprobs: ch.logprobs.clone(),
};
choices.push(choice);
......@@ -267,6 +268,7 @@ mod tests {
reasoning_content: None,
},
finish_reason: None,
stop_reason: None,
logprobs: None,
};
......@@ -304,6 +306,7 @@ mod tests {
reasoning_content: None,
},
finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: None,
};
......@@ -427,6 +430,7 @@ mod tests {
reasoning_content: None,
},
finish_reason: None,
stop_reason: None,
logprobs: None,
}
}],
......
......@@ -42,6 +42,7 @@ use crate::protocols::{
},
};
use crate::tokenizers::{DecodeStream, HuggingFaceTokenizer, Tokenizer};
use dynamo_async_openai::types::StopReason;
use tokenizers::Tokenizer as HfTokenizer;
/// Represents the output stream from the execution engine
......@@ -63,6 +64,8 @@ struct DecoderUnfoldState {
stream: ManyOut<ExecutionOutputStream>,
decoder: Decoder,
validate_engine_decode: bool,
/// Set to true when a local stop condition is detected, causing the stream to end
finished: bool,
}
impl Backend {
......@@ -95,6 +98,7 @@ impl Backend {
prompt_token_ids: &[TokenIdType],
stop_conditions: StopConditions,
skip_special_tokens: bool,
include_stop_str_in_output: bool,
) -> anyhow::Result<DecoderUnfoldState> {
let Some(tokenizer) = self.tokenizer.as_ref() else {
anyhow::bail!("Backend built from blank ModelDeploymentCard, no tokenizer");
......@@ -102,12 +106,14 @@ impl Backend {
let decoder = Decoder::new(
tokenizer.decode_stream(prompt_token_ids, skip_special_tokens),
stop_conditions,
include_stop_str_in_output,
);
Ok(DecoderUnfoldState {
stream,
decoder,
validate_engine_decode: self.validate_engine_decode,
finished: false,
})
}
}
......@@ -133,6 +139,12 @@ impl
// TODO: Consider updating default to true to match behavior of other frameworks
let skip_special_tokens = request.output_options.skip_special_tokens.unwrap_or(false);
// Extract include_stop_str_in_output from sampling_options (defaults to false)
let include_stop_str_in_output = request
.sampling_options
.include_stop_str_in_output
.unwrap_or(false);
let next_stream = next.generate(request).await?;
let context = next_stream.context();
......@@ -141,9 +153,15 @@ impl
&prompt_token_ids,
stop_conditions,
skip_special_tokens,
include_stop_str_in_output,
)?;
let processed_stream = stream::unfold(state, |mut state| async move {
// If we've already detected a local stop condition, end the stream
if state.finished {
return None;
}
match state.stream.next().await {
Some(output) => {
// move to state.process_output
......@@ -169,21 +187,38 @@ impl
// NOTE: the `finish_reason` is computed from the generated `token_ids` alone.
// The `data` field can have a `finish_reason` set, coming from the underlying
// LLM inference `Engine`, and empty `token_ids`. See comment below for more details.
let finish_reason = match &result.stop_trigger {
Some(StopTrigger::MaxTokensLimit) => Some(FinishReason::Length),
Some(StopTrigger::HiddenStopTokenDetected(_)) => Some(FinishReason::Stop),
Some(StopTrigger::HiddenStopSequenceDetected(_)) => {
Some(FinishReason::Stop)
//
// stop_reason is only set for user-provided stop sequences, not for system
// EOS tokens (HiddenStopTokenDetected). This matches OpenAI API behavior where
// stop_reason is only present when a user-specified stop sequence is matched.
let (finish_reason, stop_reason) = match &result.stop_trigger {
Some(StopTrigger::MaxTokensLimit) => (Some(FinishReason::Length), None),
Some(StopTrigger::HiddenStopTokenDetected(_)) => {
// System EOS token - no stop_reason (user didn't request this stop)
(Some(FinishReason::Stop), None)
}
Some(StopTrigger::HiddenStopSequenceDetected(seq)) => {
// User-provided stop sequence (hidden from output)
(
Some(FinishReason::Stop),
Some(StopReason::String(seq.clone())),
)
}
Some(StopTrigger::VisibleStopSequenceDetected(seq)) => {
// User-provided stop sequence (included in output)
(
Some(FinishReason::Stop),
Some(StopReason::String(seq.clone())),
)
}
None => None,
None => (None, None),
};
if data.finish_reason.is_none() && finish_reason.is_some() {
tracing::debug!(
?result.stop_trigger,
"upstream did not provide a finish reason; issuing a stop_generation request to free resources",
);
// If we detected a local stop condition, mark stream as finished
// so we stop iterating (upstream may keep generating, but we ignore it)
if finish_reason.is_some() && data.finish_reason.is_none() {
state.stream.context().stop_generating();
state.finished = true;
}
let text = result.text;
......@@ -215,6 +250,7 @@ impl
// which we don't want to propagate to `data.finish_reason`.
if finish_reason.is_some() {
data.finish_reason = finish_reason;
data.stop_reason = stop_reason;
}
data.text = text;
data.tokens = Some(tokens);
......@@ -240,6 +276,7 @@ impl
log_probs: data.log_probs,
top_logprobs: data.top_logprobs,
finish_reason: data.finish_reason,
stop_reason: data.stop_reason,
//mdcsum: mdcsum.clone(),
index: data.index,
completion_usage: data.completion_usage,
......@@ -282,10 +319,6 @@ impl
}
}
// todo - add visible stop conditions
// visible_stop_ids: HashSet<TokenIdType>,
// visible_stop_sequences: Vec<String>,
/// The [`Decoder`] object could be a member of either the internal LLM engine or part of the
/// postprocessor. If in the postprocessor, should be minimally in the same process or at very minimum
/// on the same physical machine connected by an IPC.
......@@ -301,9 +334,13 @@ pub struct Decoder {
hidden_stop_ids: HashSet<TokenIdType>,
// text sequences that if found in the response will trigger a stop condition after the
// minimum number of tokens have been generated
// minimum number of tokens have been generated (excluded from output)
hidden_stop_sequences: Vec<String>,
// text sequences that if found in the response will trigger a stop condition after the
// minimum number of tokens have been generated (included in output)
visible_stop_sequences: Vec<String>,
// number of generated tokens
generated_tokens: u32,
......@@ -315,8 +352,6 @@ pub struct Decoder {
// the number of bytes currently jailed
jailed_bytes: usize,
// mdcsum
//mdcsum: String,
}
#[allow(dead_code)]
......@@ -325,16 +360,7 @@ pub enum StopTrigger {
MaxTokensLimit,
HiddenStopTokenDetected(TokenIdType),
HiddenStopSequenceDetected(String),
}
impl StopTrigger {
pub fn should_hide_text(&self) -> bool {
match self {
StopTrigger::MaxTokensLimit => false,
StopTrigger::HiddenStopTokenDetected(_) => true,
StopTrigger::HiddenStopSequenceDetected(_) => true,
}
}
VisibleStopSequenceDetected(String),
}
pub struct StepResult {
......@@ -370,7 +396,7 @@ impl Decoder {
pub fn new(
decode_stream: DecodeStream,
stop_condition: StopConditions,
//mdcsum: String,
include_stop_str_in_output: bool,
) -> Self {
let hidden_stop_ids: HashSet<TokenIdType> = stop_condition
.stop_token_ids_hidden
......@@ -379,15 +405,19 @@ impl Decoder {
.copied()
.collect();
let hidden_stop_sequences: Vec<String> = stop_condition
.stop
.unwrap_or_default()
.iter()
.map(|x| x.to_string())
.collect();
// Categorize stop sequences based on include_stop_str_in_output:
// - When true: user-provided stop sequences go to visible (included in output)
// - When false: user-provided stop sequences go to hidden (excluded from output)
let (hidden_stop_sequences, visible_stop_sequences) = if include_stop_str_in_output {
(Vec::new(), stop_condition.stop.unwrap_or_default())
} else {
(stop_condition.stop.unwrap_or_default(), Vec::new())
};
// Calculate jail_max_bytes considering both hidden and visible stop sequences
let jail_max_bytes = hidden_stop_sequences
.iter()
.chain(visible_stop_sequences.iter())
.map(|x| x.len())
.max()
.unwrap_or(0);
......@@ -396,8 +426,7 @@ impl Decoder {
decode_stream,
hidden_stop_ids,
hidden_stop_sequences,
//visible_stop_ids: HashSet::new(),
//visible_stop_sequences: Vec::new(),
visible_stop_sequences,
min_tokens: stop_condition.min_tokens.unwrap_or(0),
generated_tokens: 0,
jail: String::new(),
......@@ -444,12 +473,13 @@ impl Decoder {
log::debug!("post_append: {}", self.jail.len());
log::debug!("jail: {}", self.jail);
// Check hidden stop sequences first (excluded from output)
for seq in &self.hidden_stop_sequences {
log::debug!("stop seq: {}", seq);
if let Some(offset) = galil_seiferas::gs_find(self.jail.as_bytes(), seq.as_bytes())
{
log::debug!("offset: {}", offset);
// return only new bytes after pre_append .. offset+seq.len()
// return only new bytes after pre_append .. offset (excluding stop sequence)
// example: seq = "ox", token = "boxes", return "b"
// note: this changes when we start jailing tokens for partial matches
// on the suffix of the jail with prefixes of the stop sequences
......@@ -468,6 +498,26 @@ impl Decoder {
}
}
// Check visible stop sequences (included in output)
for seq in &self.visible_stop_sequences {
if let Some(offset) = galil_seiferas::gs_find(self.jail.as_bytes(), seq.as_bytes())
{
// For visible stop sequences, include the stop string in the output
// Return all text from pre_append up to and including the stop sequence
let stop_end = offset + seq.len();
let token_with_stop = if stop_end > pre_append {
self.jail[pre_append..stop_end].to_string()
} else {
// Stop sequence was entirely in previously returned text
"".to_string()
};
return Ok(StepResult::with_stop_trigger(
Some(token_with_stop),
StopTrigger::VisibleStopSequenceDetected(seq.to_string()),
));
}
}
Self::maybe_drain_to_max_bytes(&mut self.jail, self.jail_max_bytes);
}
......@@ -484,12 +534,8 @@ impl Decoder {
stop_trigger,
} = self.step(*token_id)?;
let hide_text = stop_trigger
.as_ref()
.map(|x| x.should_hide_text())
.unwrap_or(false);
if !hide_text && let Some(token) = &token {
// Always include token text (for visible stops, the stop string is already in the token)
if let Some(token) = &token {
text.get_or_insert_with(|| String::with_capacity(token_ids.len()))
.push_str(token);
}
......@@ -511,24 +557,6 @@ impl Decoder {
})
}
fn return_token(&self, token: Option<String>) -> StepResult {
StepResult {
token,
stop_trigger: None,
}
}
fn return_with_stop_trigger(
&self,
token: Option<String>,
stop_trigger: StopTrigger,
) -> StepResult {
StepResult {
token,
stop_trigger: Some(stop_trigger),
}
}
fn jailed_string(&self) -> Option<String> {
if self.jailed_bytes > 0 {
// get the last jailed_bytes from the jail
......
......@@ -159,12 +159,12 @@ impl
for c in prompt.chars() {
// we are returning characters not tokens, so there will be some postprocessing overhead
tokio::time::sleep(*TOKEN_ECHO_DELAY).await;
let response = deltas.create_choice(0, Some(c.to_string()), None, None);
let response = deltas.create_choice(0, Some(c.to_string()), None, None, None);
yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
id += 1;
}
let response = deltas.create_choice(0, None, Some(dynamo_async_openai::types::FinishReason::Stop), None);
let response = deltas.create_choice(0, None, Some(dynamo_async_openai::types::FinishReason::Stop), None, None);
yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
};
......
......@@ -216,6 +216,7 @@ mod tests {
log_probs: None,
top_logprobs: None,
finish_reason: None,
stop_reason: None,
index: None,
disaggregated_params: None,
completion_usage: None,
......
......@@ -309,6 +309,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
log_probs: None,
top_logprobs: None,
finish_reason: None,
stop_reason: None,
index: None,
// Add dummy disaggregated_params for prefill workers
disaggregated_params: if is_prefill {
......
......@@ -961,6 +961,7 @@ mod tests {
reasoning_content: None,
},
finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: Some(ChatChoiceLogprobs {
content: Some(token_logprobs),
refusal: None,
......@@ -994,6 +995,7 @@ mod tests {
reasoning_content: None,
},
finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: Some(ChatChoiceLogprobs {
content: Some(token_logprobs),
refusal: None,
......@@ -1343,6 +1345,7 @@ mod tests {
reasoning_content: None,
},
finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: None, // No logprobs
}],
created: 1234567890,
......
......@@ -7,6 +7,7 @@ pub use super::FinishReason;
pub use super::preprocessor::PreprocessedRequest;
use crate::protocols::TokenIdType;
use dynamo_async_openai::types::CompletionUsage;
use dynamo_async_openai::types::StopReason;
use dynamo_runtime::protocols::maybe_error::MaybeError;
pub type TokenType = Option<String>;
......@@ -44,6 +45,12 @@ pub struct BackendOutput {
// TODO: Enrich this with more information as can apply our first-level postprocessing
// logic and return more detailed information
pub finish_reason: Option<FinishReason>,
/// The stop string or token that triggered the stop condition.
/// This is set when finish_reason is Stop and identifies what triggered it.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<StopReason>,
// Model Deployment Card checksum
//pub mdcsum: String,
......@@ -91,6 +98,11 @@ pub struct LLMEngineOutput {
// logic and return more detailed information
pub finish_reason: Option<FinishReason>,
/// The stop string or token that triggered the stop condition.
/// This is set when finish_reason is Stop and identifies what triggered it.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<StopReason>,
// Index field for batch requests to match OpenAI format
pub index: Option<u32>,
......@@ -117,6 +129,7 @@ impl LLMEngineOutput {
log_probs: None,
top_logprobs: None,
finish_reason: Some(FinishReason::Cancelled),
stop_reason: None,
index: None,
disaggregated_params: None,
extra_args: None,
......@@ -132,6 +145,7 @@ impl LLMEngineOutput {
cum_log_probs: None,
log_probs: None,
finish_reason: Some(FinishReason::Stop),
stop_reason: None,
top_logprobs: None,
index: None,
disaggregated_params: None,
......@@ -149,6 +163,7 @@ impl LLMEngineOutput {
log_probs: None,
top_logprobs: None,
finish_reason: Some(FinishReason::Length),
stop_reason: None,
index: None,
disaggregated_params: None,
extra_args: None,
......@@ -165,6 +180,7 @@ impl LLMEngineOutput {
log_probs: None,
top_logprobs: None,
finish_reason: Some(FinishReason::Error(err_msg)),
stop_reason: None,
index: None,
disaggregated_params: None,
extra_args: None,
......
......@@ -12,6 +12,7 @@ use crate::protocols::{
openai::ParsingOptions,
};
use dynamo_async_openai::types::StopReason;
use dynamo_runtime::engine::DataStream;
/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single
......@@ -49,6 +50,8 @@ struct DeltaChoice {
role: Option<dynamo_async_openai::types::Role>,
/// The reason the completion was finished (if applicable).
finish_reason: Option<dynamo_async_openai::types::FinishReason>,
/// The stop string or token that triggered the stop condition.
stop_reason: Option<StopReason>,
/// Optional log probabilities for the chat choice.
logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
// Optional tool calls for the chat choice.
......@@ -159,6 +162,7 @@ impl DeltaAggregator {
text: "".to_string(),
role: choice.delta.role,
finish_reason: None,
stop_reason: None,
logprobs: None,
tool_calls: None,
reasoning_content: None,
......@@ -204,6 +208,11 @@ impl DeltaAggregator {
state_choice.finish_reason = Some(finish_reason);
}
// Update stop reason if provided.
if let Some(stop_reason) = choice.stop_reason {
state_choice.stop_reason = Some(stop_reason);
}
// Update logprobs
if let Some(logprobs) = &choice.logprobs {
let state_lps = state_choice.logprobs.get_or_insert(
......@@ -296,6 +305,7 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice {
},
index: delta.index,
finish_reason,
stop_reason: delta.stop_reason,
logprobs: delta.logprobs,
}
}
......@@ -408,6 +418,7 @@ mod tests {
index,
delta,
finish_reason,
stop_reason: None,
logprobs,
};
......@@ -626,6 +637,7 @@ mod tests {
reasoning_content: None,
},
finish_reason: Some(dynamo_async_openai::types::FinishReason::Stop),
stop_reason: None,
logprobs: None,
},
dynamo_async_openai::types::ChatChoiceStream {
......@@ -639,6 +651,7 @@ mod tests {
reasoning_content: None,
},
finish_reason: Some(dynamo_async_openai::types::FinishReason::Stop),
stop_reason: None,
logprobs: None,
},
],
......
......@@ -234,6 +234,7 @@ impl DeltaGenerator {
/// * `text` - The text content for the response.
/// * `finish_reason` - The reason why the response finished (e.g., stop, length, etc.).
/// * `logprobs` - Optional log probabilities of the generated tokens.
/// * `stop_reason` - Optional stop string or token that triggered the stop.
///
/// # Returns
/// * An [`dynamo_async_openai::types::CreateChatCompletionStreamResponse`] instance representing the choice.
......@@ -244,6 +245,7 @@ impl DeltaGenerator {
text: Option<String>,
finish_reason: Option<dynamo_async_openai::types::FinishReason>,
logprobs: Option<dynamo_async_openai::types::ChatChoiceLogprobs>,
stop_reason: Option<dynamo_async_openai::types::StopReason>,
) -> NvCreateChatCompletionStreamResponse {
let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
content: text,
......@@ -262,6 +264,7 @@ impl DeltaGenerator {
index,
delta,
finish_reason,
stop_reason,
logprobs,
};
......@@ -384,7 +387,13 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
// Create the streaming response.
let index = 0;
let mut stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
let mut stream_response = self.create_choice(
index,
delta.text,
finish_reason,
logprobs,
delta.stop_reason,
);
// Record first token time (only succeeds on first call due to OnceLock)
if let Some(ref tracker) = self.timing_tracker {
......
......@@ -104,6 +104,7 @@ fn create_choice_stream(
content: &str,
tool_calls: Option<Vec<ChatCompletionMessageToolCallChunk>>,
finish_reason: Option<FinishReason>,
stop_reason: Option<dynamo_async_openai::types::StopReason>,
logprobs: Option<ChatChoiceLogprobs>,
) -> ChatChoiceStream {
#[allow(deprecated)]
......@@ -118,6 +119,7 @@ fn create_choice_stream(
reasoning_content: None,
},
finish_reason,
stop_reason,
logprobs,
}
}
......@@ -178,6 +180,7 @@ impl ChoiceJailState {
&prefix,
None,
choice.finish_reason,
None,
choice.logprobs.clone(),
);
emissions.push(ChoiceEmission::PassThrough(prefix_choice));
......@@ -226,6 +229,7 @@ impl ChoiceJailState {
trailing_part,
None,
choice.finish_reason,
None,
choice.logprobs.clone(),
);
emissions.push(ChoiceEmission::Trailing(trailing_choice));
......@@ -258,6 +262,7 @@ impl ChoiceJailState {
&prefix,
None,
choice.finish_reason,
None,
choice.logprobs.clone(),
);
emissions.push(ChoiceEmission::PassThrough(prefix_choice));
......@@ -301,6 +306,7 @@ impl ChoiceJailState {
&content,
None,
choice.finish_reason,
None,
choice.logprobs.clone(),
);
emissions.push(ChoiceEmission::PassThrough(pass_through_choice));
......@@ -354,6 +360,7 @@ impl ChoiceJailState {
trailing_part,
None,
choice.finish_reason,
None,
choice.logprobs.clone(),
);
emissions.push(ChoiceEmission::Trailing(trailing_choice));
......@@ -384,6 +391,7 @@ impl ChoiceJailState {
None,
self.stream_finish_reason, // For the accumulated content, assign the original stream finish reason, otherwise it will get lost
None,
None,
);
let final_choice = jail_stream
......@@ -557,6 +565,7 @@ impl JailedStream {
index: choice.index,
delta: choice.delta.clone(),
finish_reason: choice.finish_reason,
stop_reason: choice.stop_reason.clone(),
logprobs: choice.logprobs.clone(),
};
all_emissions.push(ChoiceEmission::PassThrough(pass_through_choice));
......@@ -834,6 +843,7 @@ impl JailedStream {
Some(tool_call_chunks),
None,
None,
None,
);
return choice;
}
......@@ -845,6 +855,7 @@ impl JailedStream {
accumulated_content,
None,
base_choice.finish_reason,
base_choice.stop_reason.clone(),
base_choice.logprobs.clone(),
)
}
......@@ -857,6 +868,7 @@ impl JailedStream {
"",
Some(tool_call_chunks),
base_choice.finish_reason,
None,
base_choice.logprobs.clone(),
),
Ok(_) | Err(_) => {
......@@ -867,6 +879,7 @@ impl JailedStream {
accumulated_content,
None,
base_choice.finish_reason,
base_choice.stop_reason.clone(),
base_choice.logprobs.clone(),
)
}
......
......@@ -368,6 +368,7 @@ mod tests {
reasoning_content: None,
},
finish_reason: None,
stop_reason: None,
logprobs: None,
}],
created: now,
......
......@@ -93,7 +93,7 @@ impl
let stream = stream! {
tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await;
for i in 0..10 {
let output = generator.create_choice(i,Some(format!("choice {i}")), None, None);
let output = generator.create_choice(i, Some(format!("choice {i}")), None, None, None);
yield Annotated::from_data(output);
}
......
......@@ -52,7 +52,7 @@ impl
// Generate 5 response chunks
for i in 0..5 {
let output = generator.create_choice(i, Some(format!("Mock response {i}")), None, None);
let output = generator.create_choice(i, Some(format!("Mock response {i}")), None, None, None);
yield Annotated::from_data(output);
}
};
......
......@@ -389,6 +389,7 @@ fn create_response_with_linear_probs(
reasoning_content: None,
},
finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: Some(ChatChoiceLogprobs {
content: Some(token_logprobs),
refusal: None,
......@@ -468,6 +469,7 @@ fn create_multi_choice_response(
reasoning_content: None,
},
finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: Some(ChatChoiceLogprobs {
content: Some(token_logprobs),
refusal: None,
......
......@@ -34,6 +34,7 @@ mod tests {
reasoning_content: None,
},
finish_reason: None,
stop_reason: None,
logprobs: None,
};
......@@ -73,6 +74,7 @@ mod tests {
reasoning_content: None,
},
finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: None,
};
......@@ -116,6 +118,7 @@ mod tests {
reasoning_content: None,
},
finish_reason: None,
stop_reason: None,
logprobs: None,
};
......@@ -158,6 +161,7 @@ mod tests {
reasoning_content: None,
},
finish_reason: None,
stop_reason: None,
logprobs: None,
}
})
......@@ -202,6 +206,7 @@ mod tests {
reasoning_content: None,
},
finish_reason: Some(FinishReason::Stop),
stop_reason: None,
logprobs: None,
}
})
......@@ -2323,6 +2328,7 @@ mod parallel_jail_tests {
reasoning_content: None,
},
finish_reason: None,
stop_reason: None,
logprobs: None,
}
})
......
......@@ -24,6 +24,7 @@ fn create_mock_response_chunk(
reasoning_content,
},
finish_reason: None,
stop_reason: None,
logprobs: None,
};
......
......@@ -106,6 +106,7 @@ fn build_backend_outputs_with_cached_tokens(cached_tokens: Option<u32>) -> Vec<B
log_probs: None,
top_logprobs: None,
finish_reason: None,
stop_reason: None,
index: Some(0),
completion_usage: None,
disaggregated_params: None,
......@@ -118,6 +119,7 @@ fn build_backend_outputs_with_cached_tokens(cached_tokens: Option<u32>) -> Vec<B
log_probs: None,
top_logprobs: None,
finish_reason: None,
stop_reason: None,
index: Some(0),
completion_usage: None,
disaggregated_params: None,
......@@ -130,6 +132,7 @@ fn build_backend_outputs_with_cached_tokens(cached_tokens: Option<u32>) -> Vec<B
log_probs: None,
top_logprobs: None,
finish_reason: Some(FinishReason::Stop),
stop_reason: None,
index: Some(0),
completion_usage: cached_tokens.map(|ct| AoaiCompletionUsage {
prompt_tokens: 0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment