Unverified Commit f476fd74 authored by Greg Clark's avatar Greg Clark Committed by GitHub
Browse files

feat: logprob handling (#2426)


Signed-off-by: default avatarGreg Clark <grclark@nvidia.com>
parent 5816c082
...@@ -268,6 +268,7 @@ fn run_request( ...@@ -268,6 +268,7 @@ fn run_request(
//text: if output.text.is_empty() { None } else { Some(output.text) }, //text: if output.text.is_empty() { None } else { Some(output.text) },
cum_log_probs: None, // TODO output.cumulative_logprob.map(|v| v as f64), cum_log_probs: None, // TODO output.cumulative_logprob.map(|v| v as f64),
log_probs: None, // TODO output.logprobs log_probs: None, // TODO output.logprobs
top_logprobs: None,
finish_reason: None, finish_reason: None,
index: None, index: None,
}; };
......
...@@ -590,7 +590,7 @@ impl ...@@ -590,7 +590,7 @@ impl
None => None, None => None,
}; };
#[allow(deprecated)] #[allow(deprecated)]
let inner = response_generator.create_choice(0, Some(from_assistant), None); let inner = response_generator.create_choice(0, Some(from_assistant), None, None);
let ann = Annotated{ let ann = Annotated{
id: None, id: None,
data: Some(inner), data: Some(inner),
......
...@@ -231,6 +231,7 @@ impl ...@@ -231,6 +231,7 @@ impl
text: data.text, text: data.text,
cum_log_probs: data.cum_log_probs, cum_log_probs: data.cum_log_probs,
log_probs: data.log_probs, log_probs: data.log_probs,
top_logprobs: data.top_logprobs,
finish_reason: data.finish_reason, finish_reason: data.finish_reason,
//mdcsum: mdcsum.clone(), //mdcsum: mdcsum.clone(),
index: data.index, index: data.index,
......
...@@ -102,6 +102,7 @@ fn delta_core(tok: u32) -> Annotated<LLMEngineOutput> { ...@@ -102,6 +102,7 @@ fn delta_core(tok: u32) -> Annotated<LLMEngineOutput> {
text: None, text: None,
cum_log_probs: None, cum_log_probs: None,
log_probs: None, log_probs: None,
top_logprobs: None,
finish_reason: None, finish_reason: None,
index: None, index: None,
}; };
...@@ -242,11 +243,11 @@ impl ...@@ -242,11 +243,11 @@ impl
let mut id = 1; let mut id = 1;
for c in chars_string.chars() { for c in chars_string.chars() {
tokio::time::sleep(*TOKEN_ECHO_DELAY).await; tokio::time::sleep(*TOKEN_ECHO_DELAY).await;
let response = deltas.create_choice(0, Some(c.to_string()), None); let response = deltas.create_choice(0, Some(c.to_string()), None, None);
yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None }; yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
id += 1; id += 1;
} }
let response = deltas.create_choice(0, None, Some(async_openai::types::CompletionFinishReason::Stop)); let response = deltas.create_choice(0, None, Some(async_openai::types::CompletionFinishReason::Stop), None);
yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None }; yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
}; };
......
...@@ -166,7 +166,7 @@ impl RetryManager { ...@@ -166,7 +166,7 @@ impl RetryManager {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::protocols::common::{SamplingOptions, StopConditions}; use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use dynamo_runtime::pipeline::context::Controller; use dynamo_runtime::pipeline::context::Controller;
use dynamo_runtime::pipeline::AsyncEngine; use dynamo_runtime::pipeline::AsyncEngine;
use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::atomic::{AtomicU32, Ordering};
...@@ -183,6 +183,7 @@ mod tests { ...@@ -183,6 +183,7 @@ mod tests {
..Default::default() ..Default::default()
}, },
sampling_options: SamplingOptions::default(), sampling_options: SamplingOptions::default(),
output_options: OutputOptions::default(),
eos_token_ids: vec![], eos_token_ids: vec![],
mdc_sum: None, mdc_sum: None,
annotations: vec![], annotations: vec![],
...@@ -198,6 +199,7 @@ mod tests { ...@@ -198,6 +199,7 @@ mod tests {
text: Some(format!("token_{}", token_id)), text: Some(format!("token_{}", token_id)),
cum_log_probs: None, cum_log_probs: None,
log_probs: None, log_probs: None,
top_logprobs: None,
finish_reason: None, finish_reason: None,
index: None, index: None,
}) })
......
...@@ -405,6 +405,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error> ...@@ -405,6 +405,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
text: None, text: None,
cum_log_probs: None, cum_log_probs: None,
log_probs: None, log_probs: None,
top_logprobs: None,
finish_reason: None, finish_reason: None,
index: None, index: None,
}; };
...@@ -525,7 +526,7 @@ mod integration_tests { ...@@ -525,7 +526,7 @@ mod integration_tests {
use super::*; use super::*;
use crate::kv_router::indexer::RouterEvent; use crate::kv_router::indexer::RouterEvent;
use crate::kv_router::KV_EVENT_SUBJECT; use crate::kv_router::KV_EVENT_SUBJECT;
use crate::protocols::common::{SamplingOptions, StopConditions}; use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use dynamo_runtime::{ use dynamo_runtime::{
pipeline::Context, pipeline::Context,
pipeline::{network::Ingress, PushRouter}, pipeline::{network::Ingress, PushRouter},
...@@ -641,6 +642,7 @@ mod integration_tests { ...@@ -641,6 +642,7 @@ mod integration_tests {
..Default::default() ..Default::default()
}, },
sampling_options: SamplingOptions::default(), sampling_options: SamplingOptions::default(),
output_options: OutputOptions::default(),
eos_token_ids: vec![], eos_token_ids: vec![],
mdc_sum: None, mdc_sum: None,
annotations: vec![format!("dp_rank:{dp_rank}")], annotations: vec![format!("dp_rank:{dp_rank}")],
......
...@@ -33,7 +33,7 @@ use dynamo_runtime::pipeline::{ ...@@ -33,7 +33,7 @@ use dynamo_runtime::pipeline::{
use dynamo_runtime::protocols::annotated::{Annotated, AnnotationsProvider}; use dynamo_runtime::protocols::annotated::{Annotated, AnnotationsProvider};
use crate::protocols::{ use crate::protocols::{
common::{SamplingOptionsProvider, StopConditionsProvider}, common::{OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
openai::{ openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
...@@ -146,6 +146,7 @@ impl OpenAIPreprocessor { ...@@ -146,6 +146,7 @@ impl OpenAIPreprocessor {
+ AnnotationsProvider + AnnotationsProvider
+ SamplingOptionsProvider + SamplingOptionsProvider
+ StopConditionsProvider + StopConditionsProvider
+ OutputOptionsProvider
+ NvExtProvider, + NvExtProvider,
>( >(
&self, &self,
...@@ -249,6 +250,7 @@ impl OpenAIPreprocessor { ...@@ -249,6 +250,7 @@ impl OpenAIPreprocessor {
builder.stop_conditions(stop_conditions); builder.stop_conditions(stop_conditions);
builder.sampling_options(request.extract_sampling_options()?); builder.sampling_options(request.extract_sampling_options()?);
builder.output_options(request.extract_output_options()?);
builder.annotations(request.annotations().unwrap_or_default()); builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone())); builder.mdc_sum(Some(self.mdcsum.clone()));
builder.estimated_prefix_hit_num_blocks(None); builder.estimated_prefix_hit_num_blocks(None);
......
...@@ -45,6 +45,10 @@ pub trait StopConditionsProvider { ...@@ -45,6 +45,10 @@ pub trait StopConditionsProvider {
fn extract_stop_conditions(&self) -> Result<StopConditions>; fn extract_stop_conditions(&self) -> Result<StopConditions>;
} }
pub trait OutputOptionsProvider {
fn extract_output_options(&self) -> Result<OutputOptions>;
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub enum FinishReason { pub enum FinishReason {
#[serde(rename = "eos")] #[serde(rename = "eos")]
...@@ -179,6 +183,9 @@ pub struct CompletionRequest { ...@@ -179,6 +183,9 @@ pub struct CompletionRequest {
/// are needed. /// are needed.
pub sampling_options: SamplingOptions, pub sampling_options: SamplingOptions,
#[builder(default)]
pub output_options: OutputOptions,
/// The computed checksum of the Model Deployment Card (MDC). /// The computed checksum of the Model Deployment Card (MDC).
#[builder(default)] #[builder(default)]
pub mdc_sum: Option<String>, pub mdc_sum: Option<String>,
......
...@@ -23,6 +23,15 @@ use dynamo_runtime::protocols::maybe_error::MaybeError; ...@@ -23,6 +23,15 @@ use dynamo_runtime::protocols::maybe_error::MaybeError;
pub type TokenType = Option<String>; pub type TokenType = Option<String>;
pub type LogProbs = Vec<f64>; pub type LogProbs = Vec<f64>;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct TopLogprob {
pub rank: u32,
pub token_id: TokenIdType,
pub token: TokenType,
pub logprob: f64,
}
pub type TopLogprobs = Vec<Vec<TopLogprob>>; // num_tokens x top_logprobs
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct BackendOutput { pub struct BackendOutput {
/// New token_ids generated from the LLM Engine /// New token_ids generated from the LLM Engine
...@@ -41,6 +50,8 @@ pub struct BackendOutput { ...@@ -41,6 +50,8 @@ pub struct BackendOutput {
/// Optional log probabilities /// Optional log probabilities
pub log_probs: Option<LogProbs>, pub log_probs: Option<LogProbs>,
pub top_logprobs: Option<TopLogprobs>,
// TODO: Enrich this with more information as can apply our first-level postprocessing // TODO: Enrich this with more information as can apply our first-level postprocessing
// logic and return more detailed information // logic and return more detailed information
pub finish_reason: Option<FinishReason>, pub finish_reason: Option<FinishReason>,
...@@ -77,6 +88,8 @@ pub struct LLMEngineOutput { ...@@ -77,6 +88,8 @@ pub struct LLMEngineOutput {
/// Optional log probabilities /// Optional log probabilities
pub log_probs: Option<LogProbs>, pub log_probs: Option<LogProbs>,
pub top_logprobs: Option<TopLogprobs>,
// TODO: Enrich this with more information as can apply our first-level postprocessing // TODO: Enrich this with more information as can apply our first-level postprocessing
// logic and return more detailed information // logic and return more detailed information
pub finish_reason: Option<FinishReason>, pub finish_reason: Option<FinishReason>,
...@@ -93,6 +106,7 @@ impl LLMEngineOutput { ...@@ -93,6 +106,7 @@ impl LLMEngineOutput {
text: None, text: None,
cum_log_probs: None, cum_log_probs: None,
log_probs: None, log_probs: None,
top_logprobs: None,
finish_reason: Some(FinishReason::Cancelled), finish_reason: Some(FinishReason::Cancelled),
index: None, index: None,
} }
...@@ -106,6 +120,7 @@ impl LLMEngineOutput { ...@@ -106,6 +120,7 @@ impl LLMEngineOutput {
cum_log_probs: None, cum_log_probs: None,
log_probs: None, log_probs: None,
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
top_logprobs: None,
index: None, index: None,
} }
} }
...@@ -117,6 +132,7 @@ impl LLMEngineOutput { ...@@ -117,6 +132,7 @@ impl LLMEngineOutput {
text: None, text: None,
cum_log_probs: None, cum_log_probs: None,
log_probs: None, log_probs: None,
top_logprobs: None,
finish_reason: Some(FinishReason::Length), finish_reason: Some(FinishReason::Length),
index: None, index: None,
} }
...@@ -129,6 +145,7 @@ impl LLMEngineOutput { ...@@ -129,6 +145,7 @@ impl LLMEngineOutput {
text: None, text: None,
cum_log_probs: None, cum_log_probs: None,
log_probs: None, log_probs: None,
top_logprobs: None,
finish_reason: Some(FinishReason::Error(err_msg)), finish_reason: Some(FinishReason::Error(err_msg)),
index: None, index: None,
} }
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
use derive_builder::Builder; use derive_builder::Builder;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::{SamplingOptions, StopConditions}; use super::{OutputOptions, SamplingOptions, StopConditions};
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
/// [`PreprocessedRequest`] is the internal representation of an LLM request. The [`dynamo.llm-preprocessor`] /// [`PreprocessedRequest`] is the internal representation of an LLM request. The [`dynamo.llm-preprocessor`]
...@@ -29,6 +29,10 @@ pub struct PreprocessedRequest { ...@@ -29,6 +29,10 @@ pub struct PreprocessedRequest {
/// are needed. /// are needed.
pub sampling_options: SamplingOptions, pub sampling_options: SamplingOptions,
/// OutputOptions are options that control the output of the inference engine such as whether
/// to return log probabilities, or whether to skip special tokens in output.
pub output_options: OutputOptions,
/// The EOS token ID(s) for the Model /// The EOS token ID(s) for the Model
/// Not every backend needs this, but those that do can find it here. /// Not every backend needs this, but those that do can find it here.
/// TODO - refactor this to a better location /// TODO - refactor this to a better location
......
...@@ -17,7 +17,7 @@ use anyhow::Result; ...@@ -17,7 +17,7 @@ use anyhow::Result;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::{ use super::{
common::{self, SamplingOptionsProvider, StopConditionsProvider}, common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
ContentProvider, ContentProvider,
}; };
use crate::protocols::openai::common_ext::CommonExtProvider; use crate::protocols::openai::common_ext::CommonExtProvider;
...@@ -79,6 +79,16 @@ trait OpenAIStopConditionsProvider { ...@@ -79,6 +79,16 @@ trait OpenAIStopConditionsProvider {
} }
} }
trait OpenAIOutputOptionsProvider {
fn get_logprobs(&self) -> Option<u32>;
fn get_prompt_logprobs(&self) -> Option<u32>;
fn get_skip_special_tokens(&self) -> Option<bool>;
fn get_formatted_prompt(&self) -> Option<bool>;
}
impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvider for T { impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvider for T {
fn extract_sampling_options(&self) -> Result<common::SamplingOptions> { fn extract_sampling_options(&self) -> Result<common::SamplingOptions> {
// let result = self.validate(); // let result = self.validate();
...@@ -168,6 +178,22 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T { ...@@ -168,6 +178,22 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
} }
} }
impl<T: OpenAIOutputOptionsProvider> OutputOptionsProvider for T {
fn extract_output_options(&self) -> Result<common::OutputOptions> {
let logprobs = self.get_logprobs();
let prompt_logprobs = self.get_prompt_logprobs();
let skip_special_tokens = self.get_skip_special_tokens();
let formatted_prompt = self.get_formatted_prompt();
Ok(common::OutputOptions {
logprobs,
prompt_logprobs,
skip_special_tokens,
formatted_prompt,
})
}
}
pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debug>: pub trait DeltaGeneratorExt<ResponseType: Send + Sync + 'static + std::fmt::Debug>:
Send + Sync + 'static Send + Sync + 'static
{ {
......
...@@ -23,7 +23,8 @@ use super::{ ...@@ -23,7 +23,8 @@ use super::{
common_ext::{CommonExt, CommonExtProvider}, common_ext::{CommonExt, CommonExtProvider},
nvext::NvExt, nvext::NvExt,
nvext::NvExtProvider, nvext::NvExtProvider,
validate, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider, validate, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider,
OpenAIStopConditionsProvider,
}; };
mod aggregator; mod aggregator;
...@@ -232,6 +233,31 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest { ...@@ -232,6 +233,31 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
} }
} }
impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest {
fn get_logprobs(&self) -> Option<u32> {
match self.inner.logprobs {
Some(true) => match self.inner.top_logprobs {
Some(top_logprobs) => Some(top_logprobs as u32),
None => Some(1_u32),
},
Some(false) => None,
None => None,
}
}
fn get_prompt_logprobs(&self) -> Option<u32> {
None
}
fn get_skip_special_tokens(&self) -> Option<bool> {
None
}
fn get_formatted_prompt(&self) -> Option<bool> {
None
}
}
/// Implements `ValidateRequest` for `NvCreateChatCompletionRequest`, /// Implements `ValidateRequest` for `NvCreateChatCompletionRequest`,
/// allowing us to validate the data. /// allowing us to validate the data.
impl ValidateRequest for NvCreateChatCompletionRequest { impl ValidateRequest for NvCreateChatCompletionRequest {
......
...@@ -14,7 +14,10 @@ ...@@ -14,7 +14,10 @@
// limitations under the License. // limitations under the License.
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}; use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
use crate::protocols::common; use crate::{
protocols::common::{self},
types::TokenIdType,
};
/// Provides a method for generating a [`DeltaGenerator`] from a chat completion request. /// Provides a method for generating a [`DeltaGenerator`] from a chat completion request.
impl NvCreateChatCompletionRequest { impl NvCreateChatCompletionRequest {
...@@ -25,7 +28,8 @@ impl NvCreateChatCompletionRequest { ...@@ -25,7 +28,8 @@ impl NvCreateChatCompletionRequest {
pub fn response_generator(&self) -> DeltaGenerator { pub fn response_generator(&self) -> DeltaGenerator {
let options = DeltaGeneratorOptions { let options = DeltaGeneratorOptions {
enable_usage: true, enable_usage: true,
enable_logprobs: self.inner.logprobs.unwrap_or(false), enable_logprobs: self.inner.logprobs.unwrap_or(false)
|| self.inner.top_logprobs.unwrap_or(0) > 0,
}; };
DeltaGenerator::new(self.inner.model.clone(), options) DeltaGenerator::new(self.inner.model.clone(), options)
...@@ -112,6 +116,71 @@ impl DeltaGenerator { ...@@ -112,6 +116,71 @@ impl DeltaGenerator {
self.usage.prompt_tokens = isl; self.usage.prompt_tokens = isl;
} }
pub fn create_logprobs(
&self,
tokens: Vec<common::llm_backend::TokenType>,
token_ids: Vec<TokenIdType>,
logprobs: Option<common::llm_backend::LogProbs>,
top_logprobs: Option<common::llm_backend::TopLogprobs>,
) -> Option<async_openai::types::ChatChoiceLogprobs> {
if !self.options.enable_logprobs || logprobs.is_none() {
return None;
}
let toks = tokens
.into_iter()
.zip(token_ids)
.map(|(token, token_id)| (token.unwrap_or_default(), token_id))
.collect::<Vec<(String, TokenIdType)>>();
let tok_lps = toks
.iter()
.zip(logprobs.unwrap())
.map(|(_, lp)| lp as f32)
.collect::<Vec<f32>>();
let content = top_logprobs.map(|top_logprobs| {
toks.iter()
.zip(tok_lps)
.zip(top_logprobs)
.map(|(((t, tid), lp), top_lps)| {
let mut found_selected_token = false;
let mut converted_top_lps = top_lps
.iter()
.map(|top_lp| {
let top_t = top_lp.token.clone().unwrap_or_default();
let top_tid = top_lp.token_id;
found_selected_token = found_selected_token || top_tid == *tid;
async_openai::types::TopLogprobs {
token: top_t,
logprob: top_lp.logprob as f32,
bytes: None,
}
})
.collect::<Vec<async_openai::types::TopLogprobs>>();
if !found_selected_token {
// If the selected token is not in the top logprobs, add it
converted_top_lps.push(async_openai::types::TopLogprobs {
token: t.clone(),
logprob: lp,
bytes: None,
});
}
async_openai::types::ChatCompletionTokenLogprob {
token: t.clone(),
logprob: lp,
bytes: None,
top_logprobs: converted_top_lps,
}
})
.collect()
});
Some(async_openai::types::ChatChoiceLogprobs {
content,
refusal: None,
})
}
/// Creates a choice within a chat completion response. /// Creates a choice within a chat completion response.
/// ///
/// # Arguments /// # Arguments
...@@ -203,8 +272,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -203,8 +272,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
self.usage.completion_tokens += token_length; self.usage.completion_tokens += token_length;
} }
// TODO: Implement log probabilities aggregation. let logprobs = self.create_logprobs(
let logprobs = None; delta.tokens,
delta.token_ids,
delta.log_probs,
delta.top_logprobs,
);
// Map backend finish reasons to OpenAI's finish reasons. // Map backend finish reasons to OpenAI's finish reasons.
let finish_reason = match delta.finish_reason { let finish_reason = match delta.finish_reason {
......
...@@ -21,10 +21,11 @@ use validator::Validate; ...@@ -21,10 +21,11 @@ use validator::Validate;
use crate::engines::ValidateRequest; use crate::engines::ValidateRequest;
use super::{ use super::{
common::{self, SamplingOptionsProvider, StopConditionsProvider}, common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
common_ext::{CommonExt, CommonExtProvider}, common_ext::{CommonExt, CommonExtProvider},
nvext::{NvExt, NvExtProvider}, nvext::{NvExt, NvExtProvider},
validate, ContentProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider, validate, ContentProvider, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider,
OpenAIStopConditionsProvider,
}; };
mod aggregator; mod aggregator;
...@@ -279,6 +280,10 @@ impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest { ...@@ -279,6 +280,10 @@ impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
.extract_sampling_options() .extract_sampling_options()
.map_err(|e| anyhow::anyhow!("Failed to extract sampling options: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to extract sampling options: {}", e))?;
let output_options = request
.extract_output_options()
.map_err(|e| anyhow::anyhow!("Failed to extract output options: {}", e))?;
let prompt = common::PromptType::Completion(common::CompletionContext { let prompt = common::PromptType::Completion(common::CompletionContext {
prompt: prompt_to_string(&request.inner.prompt), prompt: prompt_to_string(&request.inner.prompt),
system_prompt: None, system_prompt: None,
...@@ -288,6 +293,7 @@ impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest { ...@@ -288,6 +293,7 @@ impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
prompt, prompt,
stop_conditions, stop_conditions,
sampling_options, sampling_options,
output_options,
mdc_sum: None, mdc_sum: None,
annotations: None, annotations: None,
}) })
...@@ -329,6 +335,26 @@ impl TryFrom<common::StreamingCompletionResponse> for async_openai::types::Choic ...@@ -329,6 +335,26 @@ impl TryFrom<common::StreamingCompletionResponse> for async_openai::types::Choic
} }
} }
impl OpenAIOutputOptionsProvider for NvCreateCompletionRequest {
fn get_logprobs(&self) -> Option<u32> {
self.inner.logprobs.map(|logprobs| logprobs as u32)
}
fn get_prompt_logprobs(&self) -> Option<u32> {
self.inner
.echo
.and_then(|echo| if echo { Some(1) } else { None })
}
fn get_skip_special_tokens(&self) -> Option<bool> {
None
}
fn get_formatted_prompt(&self) -> Option<bool> {
None
}
}
/// Implements `ValidateRequest` for `NvCreateCompletionRequest`, /// Implements `ValidateRequest` for `NvCreateCompletionRequest`,
/// allowing us to validate the data. /// allowing us to validate the data.
impl ValidateRequest for NvCreateCompletionRequest { impl ValidateRequest for NvCreateCompletionRequest {
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
// limitations under the License. // limitations under the License.
use super::{NvCreateCompletionRequest, NvCreateCompletionResponse}; use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
use crate::protocols::common; use crate::{protocols::common, types::TokenIdType};
impl NvCreateCompletionRequest { impl NvCreateCompletionRequest {
// put this method on the request // put this method on the request
...@@ -22,7 +22,7 @@ impl NvCreateCompletionRequest { ...@@ -22,7 +22,7 @@ impl NvCreateCompletionRequest {
pub fn response_generator(&self) -> DeltaGenerator { pub fn response_generator(&self) -> DeltaGenerator {
let options = DeltaGeneratorOptions { let options = DeltaGeneratorOptions {
enable_usage: true, enable_usage: true,
enable_logprobs: false, enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
}; };
DeltaGenerator::new(self.inner.model.clone(), options) DeltaGenerator::new(self.inner.model.clone(), options)
...@@ -82,11 +82,74 @@ impl DeltaGenerator { ...@@ -82,11 +82,74 @@ impl DeltaGenerator {
self.usage.prompt_tokens = isl; self.usage.prompt_tokens = isl;
} }
pub fn create_logprobs(
&self,
tokens: Vec<common::llm_backend::TokenType>,
token_ids: Vec<TokenIdType>,
logprobs: Option<common::llm_backend::LogProbs>,
top_logprobs: Option<common::llm_backend::TopLogprobs>,
) -> Option<async_openai::types::Logprobs> {
if !self.options.enable_logprobs || logprobs.is_none() {
return None;
}
let toks = tokens
.into_iter()
.zip(token_ids)
.map(|(token, token_id)| (token.unwrap_or_default(), token_id))
.collect::<Vec<(String, TokenIdType)>>();
let tok_lps = toks
.iter()
.zip(logprobs.unwrap())
.map(|(_, lp)| lp as f32)
.collect::<Vec<f32>>();
let top_lps = top_logprobs.map_or(vec![], |top_logprobs| {
toks.iter()
.zip(tok_lps.iter())
.zip(top_logprobs.iter())
.map(|(((t, tid), lp), top_lps)| {
let mut found_selected_token = false;
let mut converted_top_lps = top_lps
.iter()
.map(|top_lp| {
let top_t = top_lp.token.clone().unwrap_or_default();
let top_tid = top_lp.token_id;
found_selected_token = found_selected_token || top_tid == *tid;
async_openai::types::TopLogprobs {
token: top_t,
logprob: top_lp.logprob as f32,
bytes: None,
}
})
.collect::<Vec<async_openai::types::TopLogprobs>>();
if !found_selected_token {
// If the selected token is not in the top logprobs, add it
converted_top_lps.push(async_openai::types::TopLogprobs {
token: t.clone(),
logprob: *lp,
bytes: None,
});
}
serde_json::to_value(converted_top_lps).unwrap()
})
.collect()
});
Some(async_openai::types::Logprobs {
tokens: toks.iter().map(|(t, _)| t.clone()).collect(),
token_logprobs: tok_lps.into_iter().map(Some).collect(),
text_offset: vec![],
top_logprobs: top_lps,
})
}
pub fn create_choice( pub fn create_choice(
&self, &self,
index: u32, index: u32,
text: Option<String>, text: Option<String>,
finish_reason: Option<async_openai::types::CompletionFinishReason>, finish_reason: Option<async_openai::types::CompletionFinishReason>,
logprobs: Option<async_openai::types::Logprobs>,
) -> NvCreateCompletionResponse { ) -> NvCreateCompletionResponse {
// todo - update for tool calling // todo - update for tool calling
...@@ -105,7 +168,7 @@ impl DeltaGenerator { ...@@ -105,7 +168,7 @@ impl DeltaGenerator {
text: text.unwrap_or_default(), text: text.unwrap_or_default(),
index, index,
finish_reason, finish_reason,
logprobs: None, logprobs,
}], }],
usage: if self.options.enable_usage { usage: if self.options.enable_usage {
Some(usage) Some(usage)
...@@ -136,13 +199,18 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for ...@@ -136,13 +199,18 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
self.usage.completion_tokens += token_length; self.usage.completion_tokens += token_length;
} }
// TODO logprobs let logprobs = self.create_logprobs(
delta.tokens,
delta.token_ids,
delta.log_probs,
delta.top_logprobs,
);
let finish_reason = delta.finish_reason.map(Into::into); let finish_reason = delta.finish_reason.map(Into::into);
// create choice // create choice
let index = delta.index.unwrap_or(0); let index = delta.index.unwrap_or(0);
let response = self.create_choice(index, delta.text.clone(), finish_reason); let response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs);
Ok(response) Ok(response)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment