Unverified Commit f476fd74 authored by Greg Clark's avatar Greg Clark Committed by GitHub
Browse files

feat: logprob handling (#2426)


Signed-off-by: default avatarGreg Clark <grclark@nvidia.com>
parent 5816c082
......@@ -268,6 +268,7 @@ fn run_request(
//text: if output.text.is_empty() { None } else { Some(output.text) },
cum_log_probs: None, // TODO output.cumulative_logprob.map(|v| v as f64),
log_probs: None, // TODO output.logprobs
top_logprobs: None,
finish_reason: None,
index: None,
};
......
......@@ -590,7 +590,7 @@ impl
None => None,
};
#[allow(deprecated)]
let inner = response_generator.create_choice(0, Some(from_assistant), None);
let inner = response_generator.create_choice(0, Some(from_assistant), None, None);
let ann = Annotated{
id: None,
data: Some(inner),
......
......@@ -231,6 +231,7 @@ impl
text: data.text,
cum_log_probs: data.cum_log_probs,
log_probs: data.log_probs,
top_logprobs: data.top_logprobs,
finish_reason: data.finish_reason,
//mdcsum: mdcsum.clone(),
index: data.index,
......
......@@ -102,6 +102,7 @@ fn delta_core(tok: u32) -> Annotated<LLMEngineOutput> {
text: None,
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: None,
index: None,
};
......@@ -242,11 +243,11 @@ impl
let mut id = 1;
for c in chars_string.chars() {
tokio::time::sleep(*TOKEN_ECHO_DELAY).await;
let response = deltas.create_choice(0, Some(c.to_string()), None);
let response = deltas.create_choice(0, Some(c.to_string()), None, None);
yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
id += 1;
}
let response = deltas.create_choice(0, None, Some(async_openai::types::CompletionFinishReason::Stop));
let response = deltas.create_choice(0, None, Some(async_openai::types::CompletionFinishReason::Stop), None);
yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
};
......
......@@ -166,7 +166,7 @@ impl RetryManager {
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::common::{SamplingOptions, StopConditions};
use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use dynamo_runtime::pipeline::context::Controller;
use dynamo_runtime::pipeline::AsyncEngine;
use std::sync::atomic::{AtomicU32, Ordering};
......@@ -183,6 +183,7 @@ mod tests {
..Default::default()
},
sampling_options: SamplingOptions::default(),
output_options: OutputOptions::default(),
eos_token_ids: vec![],
mdc_sum: None,
annotations: vec![],
......@@ -198,6 +199,7 @@ mod tests {
text: Some(format!("token_{}", token_id)),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: None,
index: None,
})
......
......@@ -405,6 +405,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
text: None,
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: None,
index: None,
};
......@@ -525,7 +526,7 @@ mod integration_tests {
use super::*;
use crate::kv_router::indexer::RouterEvent;
use crate::kv_router::KV_EVENT_SUBJECT;
use crate::protocols::common::{SamplingOptions, StopConditions};
use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use dynamo_runtime::{
pipeline::Context,
pipeline::{network::Ingress, PushRouter},
......@@ -641,6 +642,7 @@ mod integration_tests {
..Default::default()
},
sampling_options: SamplingOptions::default(),
output_options: OutputOptions::default(),
eos_token_ids: vec![],
mdc_sum: None,
annotations: vec![format!("dp_rank:{dp_rank}")],
......
......@@ -33,7 +33,7 @@ use dynamo_runtime::pipeline::{
use dynamo_runtime::protocols::annotated::{Annotated, AnnotationsProvider};
use crate::protocols::{
common::{SamplingOptionsProvider, StopConditionsProvider},
common::{OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
......@@ -146,6 +146,7 @@ impl OpenAIPreprocessor {
+ AnnotationsProvider
+ SamplingOptionsProvider
+ StopConditionsProvider
+ OutputOptionsProvider
+ NvExtProvider,
>(
&self,
......@@ -249,6 +250,7 @@ impl OpenAIPreprocessor {
builder.stop_conditions(stop_conditions);
builder.sampling_options(request.extract_sampling_options()?);
builder.output_options(request.extract_output_options()?);
builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone()));
builder.estimated_prefix_hit_num_blocks(None);
......
......@@ -45,6 +45,10 @@ pub trait StopConditionsProvider {
fn extract_stop_conditions(&self) -> Result<StopConditions>;
}
pub trait OutputOptionsProvider {
fn extract_output_options(&self) -> Result<OutputOptions>;
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub enum FinishReason {
#[serde(rename = "eos")]
......@@ -179,6 +183,9 @@ pub struct CompletionRequest {
/// are needed.
pub sampling_options: SamplingOptions,
#[builder(default)]
pub output_options: OutputOptions,
/// The computed checksum of the Model Deployment Card (MDC).
#[builder(default)]
pub mdc_sum: Option<String>,
......
......@@ -23,6 +23,15 @@ use dynamo_runtime::protocols::maybe_error::MaybeError;
pub type TokenType = Option<String>;
pub type LogProbs = Vec<f64>;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct TopLogprob {
pub rank: u32,
pub token_id: TokenIdType,
pub token: TokenType,
pub logprob: f64,
}
pub type TopLogprobs = Vec<Vec<TopLogprob>>; // num_tokens x top_logprobs
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct BackendOutput {
/// New token_ids generated from the LLM Engine
......@@ -41,6 +50,8 @@ pub struct BackendOutput {
/// Optional log probabilities
pub log_probs: Option<LogProbs>,
pub top_logprobs: Option<TopLogprobs>,
// TODO: Enrich this with more information as can apply our first-level postprocessing
// logic and return more detailed information
pub finish_reason: Option<FinishReason>,
......@@ -77,6 +88,8 @@ pub struct LLMEngineOutput {
/// Optional log probabilities
pub log_probs: Option<LogProbs>,
pub top_logprobs: Option<TopLogprobs>,
// TODO: Enrich this with more information as can apply our first-level postprocessing
// logic and return more detailed information
pub finish_reason: Option<FinishReason>,
......@@ -93,6 +106,7 @@ impl LLMEngineOutput {
text: None,
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: Some(FinishReason::Cancelled),
index: None,
}
......@@ -106,6 +120,7 @@ impl LLMEngineOutput {
cum_log_probs: None,
log_probs: None,
finish_reason: Some(FinishReason::Stop),
top_logprobs: None,
index: None,
}
}
......@@ -117,6 +132,7 @@ impl LLMEngineOutput {
text: None,
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: Some(FinishReason::Length),
index: None,
}
......@@ -129,6 +145,7 @@ impl LLMEngineOutput {
text: None,
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: Some(FinishReason::Error(err_msg)),
index: None,
}
......
......@@ -4,7 +4,7 @@
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use super::{SamplingOptions, StopConditions};
use super::{OutputOptions, SamplingOptions, StopConditions};
use crate::protocols::TokenIdType;
/// [`PreprocessedRequest`] is the internal representation of an LLM request. The [`dynamo.llm-preprocessor`]
......@@ -29,6 +29,10 @@ pub struct PreprocessedRequest {
/// are needed.
pub sampling_options: SamplingOptions,
/// OutputOptions are options that control the output of the inference engine such as whether
/// to return log probabilities, or whether to skip special tokens in output.
pub output_options: OutputOptions,
/// The EOS token ID(s) for the Model
/// Not every backend needs this, but those that do can find it here.
/// TODO - refactor this to a better location
......
......@@ -17,7 +17,7 @@ use anyhow::Result;
use serde::{Deserialize, Serialize};
use super::{
common::{self, SamplingOptionsProvider, StopConditionsProvider},
common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
ContentProvider,
};
use crate::protocols::openai::common_ext::CommonExtProvider;
......@@ -79,6 +79,16 @@ trait OpenAIStopConditionsProvider {
}
}
trait OpenAIOutputOptionsProvider {
fn get_logprobs(&self) -> Option<u32>;
fn get_prompt_logprobs(&self) -> Option<u32>;
fn get_skip_special_tokens(&self) -> Option<bool>;
fn get_formatted_prompt(&self) -> Option<bool>;
}
impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvider for T {
fn extract_sampling_options(&self) -> Result<common::SamplingOptions> {
// let result = self.validate();
......@@ -168,6 +178,22 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
}
}
impl<T: OpenAIOutputOptionsProvider> OutputOptionsProvider for T {
fn extract_output_options(&self) -> Result<common::OutputOptions> {
let logprobs = self.get_logprobs();
let prompt_logprobs = self.get_prompt_logprobs();
let skip_special_tokens = self.get_skip_special_tokens();
let formatted_prompt = self.get_formatted_prompt();
Ok(common::OutputOptions {
logprobs,
prompt_logprobs,
skip_special_tokens,
formatted_prompt,
})
}
}
pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debug>:
Send + Sync + 'static
{
......
......@@ -23,7 +23,8 @@ use super::{
common_ext::{CommonExt, CommonExtProvider},
nvext::NvExt,
nvext::NvExtProvider,
validate, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
validate, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider,
OpenAIStopConditionsProvider,
};
mod aggregator;
......@@ -232,6 +233,31 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
}
}
impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest {
fn get_logprobs(&self) -> Option<u32> {
match self.inner.logprobs {
Some(true) => match self.inner.top_logprobs {
Some(top_logprobs) => Some(top_logprobs as u32),
None => Some(1_u32),
},
Some(false) => None,
None => None,
}
}
fn get_prompt_logprobs(&self) -> Option<u32> {
None
}
fn get_skip_special_tokens(&self) -> Option<bool> {
None
}
fn get_formatted_prompt(&self) -> Option<bool> {
None
}
}
/// Implements `ValidateRequest` for `NvCreateChatCompletionRequest`,
/// allowing us to validate the data.
impl ValidateRequest for NvCreateChatCompletionRequest {
......
......@@ -14,7 +14,10 @@
// limitations under the License.
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
use crate::protocols::common;
use crate::{
protocols::common::{self},
types::TokenIdType,
};
/// Provides a method for generating a [`DeltaGenerator`] from a chat completion request.
impl NvCreateChatCompletionRequest {
......@@ -25,7 +28,8 @@ impl NvCreateChatCompletionRequest {
pub fn response_generator(&self) -> DeltaGenerator {
let options = DeltaGeneratorOptions {
enable_usage: true,
enable_logprobs: self.inner.logprobs.unwrap_or(false),
enable_logprobs: self.inner.logprobs.unwrap_or(false)
|| self.inner.top_logprobs.unwrap_or(0) > 0,
};
DeltaGenerator::new(self.inner.model.clone(), options)
......@@ -112,6 +116,71 @@ impl DeltaGenerator {
self.usage.prompt_tokens = isl;
}
pub fn create_logprobs(
&self,
tokens: Vec<common::llm_backend::TokenType>,
token_ids: Vec<TokenIdType>,
logprobs: Option<common::llm_backend::LogProbs>,
top_logprobs: Option<common::llm_backend::TopLogprobs>,
) -> Option<async_openai::types::ChatChoiceLogprobs> {
if !self.options.enable_logprobs || logprobs.is_none() {
return None;
}
let toks = tokens
.into_iter()
.zip(token_ids)
.map(|(token, token_id)| (token.unwrap_or_default(), token_id))
.collect::<Vec<(String, TokenIdType)>>();
let tok_lps = toks
.iter()
.zip(logprobs.unwrap())
.map(|(_, lp)| lp as f32)
.collect::<Vec<f32>>();
let content = top_logprobs.map(|top_logprobs| {
toks.iter()
.zip(tok_lps)
.zip(top_logprobs)
.map(|(((t, tid), lp), top_lps)| {
let mut found_selected_token = false;
let mut converted_top_lps = top_lps
.iter()
.map(|top_lp| {
let top_t = top_lp.token.clone().unwrap_or_default();
let top_tid = top_lp.token_id;
found_selected_token = found_selected_token || top_tid == *tid;
async_openai::types::TopLogprobs {
token: top_t,
logprob: top_lp.logprob as f32,
bytes: None,
}
})
.collect::<Vec<async_openai::types::TopLogprobs>>();
if !found_selected_token {
// If the selected token is not in the top logprobs, add it
converted_top_lps.push(async_openai::types::TopLogprobs {
token: t.clone(),
logprob: lp,
bytes: None,
});
}
async_openai::types::ChatCompletionTokenLogprob {
token: t.clone(),
logprob: lp,
bytes: None,
top_logprobs: converted_top_lps,
}
})
.collect()
});
Some(async_openai::types::ChatChoiceLogprobs {
content,
refusal: None,
})
}
/// Creates a choice within a chat completion response.
///
/// # Arguments
......@@ -203,8 +272,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
self.usage.completion_tokens += token_length;
}
// TODO: Implement log probabilities aggregation.
let logprobs = None;
let logprobs = self.create_logprobs(
delta.tokens,
delta.token_ids,
delta.log_probs,
delta.top_logprobs,
);
// Map backend finish reasons to OpenAI's finish reasons.
let finish_reason = match delta.finish_reason {
......
......@@ -21,10 +21,11 @@ use validator::Validate;
use crate::engines::ValidateRequest;
use super::{
common::{self, SamplingOptionsProvider, StopConditionsProvider},
common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
common_ext::{CommonExt, CommonExtProvider},
nvext::{NvExt, NvExtProvider},
validate, ContentProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
validate, ContentProvider, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider,
OpenAIStopConditionsProvider,
};
mod aggregator;
......@@ -279,6 +280,10 @@ impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
.extract_sampling_options()
.map_err(|e| anyhow::anyhow!("Failed to extract sampling options: {}", e))?;
let output_options = request
.extract_output_options()
.map_err(|e| anyhow::anyhow!("Failed to extract output options: {}", e))?;
let prompt = common::PromptType::Completion(common::CompletionContext {
prompt: prompt_to_string(&request.inner.prompt),
system_prompt: None,
......@@ -288,6 +293,7 @@ impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
prompt,
stop_conditions,
sampling_options,
output_options,
mdc_sum: None,
annotations: None,
})
......@@ -329,6 +335,26 @@ impl TryFrom<common::StreamingCompletionResponse> for async_openai::types::Choic
}
}
impl OpenAIOutputOptionsProvider for NvCreateCompletionRequest {
fn get_logprobs(&self) -> Option<u32> {
self.inner.logprobs.map(|logprobs| logprobs as u32)
}
fn get_prompt_logprobs(&self) -> Option<u32> {
self.inner
.echo
.and_then(|echo| if echo { Some(1) } else { None })
}
fn get_skip_special_tokens(&self) -> Option<bool> {
None
}
fn get_formatted_prompt(&self) -> Option<bool> {
None
}
}
/// Implements `ValidateRequest` for `NvCreateCompletionRequest`,
/// allowing us to validate the data.
impl ValidateRequest for NvCreateCompletionRequest {
......
......@@ -14,7 +14,7 @@
// limitations under the License.
use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
use crate::protocols::common;
use crate::{protocols::common, types::TokenIdType};
impl NvCreateCompletionRequest {
// put this method on the request
......@@ -22,7 +22,7 @@ impl NvCreateCompletionRequest {
pub fn response_generator(&self) -> DeltaGenerator {
let options = DeltaGeneratorOptions {
enable_usage: true,
enable_logprobs: false,
enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
};
DeltaGenerator::new(self.inner.model.clone(), options)
......@@ -82,11 +82,74 @@ impl DeltaGenerator {
self.usage.prompt_tokens = isl;
}
pub fn create_logprobs(
&self,
tokens: Vec<common::llm_backend::TokenType>,
token_ids: Vec<TokenIdType>,
logprobs: Option<common::llm_backend::LogProbs>,
top_logprobs: Option<common::llm_backend::TopLogprobs>,
) -> Option<async_openai::types::Logprobs> {
if !self.options.enable_logprobs || logprobs.is_none() {
return None;
}
let toks = tokens
.into_iter()
.zip(token_ids)
.map(|(token, token_id)| (token.unwrap_or_default(), token_id))
.collect::<Vec<(String, TokenIdType)>>();
let tok_lps = toks
.iter()
.zip(logprobs.unwrap())
.map(|(_, lp)| lp as f32)
.collect::<Vec<f32>>();
let top_lps = top_logprobs.map_or(vec![], |top_logprobs| {
toks.iter()
.zip(tok_lps.iter())
.zip(top_logprobs.iter())
.map(|(((t, tid), lp), top_lps)| {
let mut found_selected_token = false;
let mut converted_top_lps = top_lps
.iter()
.map(|top_lp| {
let top_t = top_lp.token.clone().unwrap_or_default();
let top_tid = top_lp.token_id;
found_selected_token = found_selected_token || top_tid == *tid;
async_openai::types::TopLogprobs {
token: top_t,
logprob: top_lp.logprob as f32,
bytes: None,
}
})
.collect::<Vec<async_openai::types::TopLogprobs>>();
if !found_selected_token {
// If the selected token is not in the top logprobs, add it
converted_top_lps.push(async_openai::types::TopLogprobs {
token: t.clone(),
logprob: *lp,
bytes: None,
});
}
serde_json::to_value(converted_top_lps).unwrap()
})
.collect()
});
Some(async_openai::types::Logprobs {
tokens: toks.iter().map(|(t, _)| t.clone()).collect(),
token_logprobs: tok_lps.into_iter().map(Some).collect(),
text_offset: vec![],
top_logprobs: top_lps,
})
}
pub fn create_choice(
&self,
index: u32,
text: Option<String>,
finish_reason: Option<async_openai::types::CompletionFinishReason>,
logprobs: Option<async_openai::types::Logprobs>,
) -> NvCreateCompletionResponse {
// todo - update for tool calling
......@@ -105,7 +168,7 @@ impl DeltaGenerator {
text: text.unwrap_or_default(),
index,
finish_reason,
logprobs: None,
logprobs,
}],
usage: if self.options.enable_usage {
Some(usage)
......@@ -136,13 +199,18 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
self.usage.completion_tokens += token_length;
}
// TODO logprobs
let logprobs = self.create_logprobs(
delta.tokens,
delta.token_ids,
delta.log_probs,
delta.top_logprobs,
);
let finish_reason = delta.finish_reason.map(Into::into);
// create choice
let index = delta.index.unwrap_or(0);
let response = self.create_choice(index, delta.text.clone(), finish_reason);
let response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);
Ok(response)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment