Commit 9162f3ad authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: use async-openai CompletionRequest (#310)

parent 057f8f47
...@@ -140,17 +140,19 @@ async fn completions( ...@@ -140,17 +140,19 @@ async fn completions(
let request_id = uuid::Uuid::new_v4().to_string(); let request_id = uuid::Uuid::new_v4().to_string();
// todo - decide on default // todo - decide on default
let streaming = request.stream.unwrap_or(false); let streaming = request.inner.stream.unwrap_or(false);
// update the request to always stream // update the request to always stream
let request = CompletionRequest { let inner = async_openai::types::CreateCompletionRequest {
stream: Some(true), stream: Some(true),
..request ..request.inner
}; };
let request = CompletionRequest { inner, nvext: None };
// todo - make the protocols be optional for model name // todo - make the protocols be optional for model name
// todo - when optional, if none, apply a default // todo - when optional, if none, apply a default
let model = &request.model; let model = &request.inner.model;
// todo - error handling should be more robust // todo - error handling should be more robust
let engine = state let engine = state
......
...@@ -60,7 +60,7 @@ impl OAIChatLikeRequest for CompletionRequest { ...@@ -60,7 +60,7 @@ impl OAIChatLikeRequest for CompletionRequest {
let message = async_openai::types::ChatCompletionRequestMessage::User( let message = async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage { async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text( content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
self.prompt.clone(), crate::protocols::openai::completions::prompt_to_string(&self.inner.prompt),
), ),
name: None, name: None,
}, },
......
...@@ -22,11 +22,9 @@ pub mod nvext; ...@@ -22,11 +22,9 @@ pub mod nvext;
use anyhow::Result; use anyhow::Result;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{ use std::{
collections::HashMap,
fmt::Display, fmt::Display,
ops::{Add, Div, Mul, Sub}, ops::{Add, Div, Mul, Sub},
}; };
use validator::ValidationError;
use super::{ use super::{
common::{self, SamplingOptionsProvider, StopConditionsProvider}, common::{self, SamplingOptionsProvider, StopConditionsProvider},
...@@ -263,17 +261,6 @@ pub struct GenericCompletionResponse<C> ...@@ -263,17 +261,6 @@ pub struct GenericCompletionResponse<C>
// TODO() - add NvResponseExtention // TODO() - add NvResponseExtention
} }
fn validate_logit_bias(logit_bias: &HashMap<String, i32>) -> Result<(), ValidationError> {
for key in logit_bias.keys() {
if key.parse::<i32>().is_err() {
return Err(
ValidationError::new("logit_bias").with_message("Keys must be integers".into())
);
}
}
Ok(())
}
// todo - move to common location // todo - move to common location
fn validate_range<T>(value: Option<T>, range: &(T, T)) -> Result<Option<T>> fn validate_range<T>(value: Option<T>, range: &(T, T)) -> Result<Option<T>>
where where
......
...@@ -22,285 +22,26 @@ use validator::Validate; ...@@ -22,285 +22,26 @@ use validator::Validate;
mod aggregator; mod aggregator;
mod delta; mod delta;
// pub use aggregator::DeltaAggregator; pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator;
use super::{ use super::{
common::{self, SamplingOptionsProvider, StopConditionsProvider}, common::{self, SamplingOptionsProvider, StopConditionsProvider},
nvext::{NvExt, NvExtProvider}, nvext::{NvExt, NvExtProvider},
validate_logit_bias, CompletionUsage, ContentProvider, OpenAISamplingOptionsProvider, CompletionUsage, ContentProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
OpenAIStopConditionsProvider, MAX_FREQUENCY_PENALTY, MAX_PRESENCE_PENALTY, MAX_TEMPERATURE,
MAX_TOP_P, MIN_FREQUENCY_PENALTY, MIN_PRESENCE_PENALTY, MIN_TEMPERATURE, MIN_TOP_P,
}; };
use triton_distributed_runtime::protocols::annotated::AnnotationsProvider; use triton_distributed_runtime::protocols::annotated::AnnotationsProvider;
/// Legacy OpenAI CompletionRequest #[derive(Serialize, Deserialize, Validate, Debug, Clone)]
///
/// Reference: <https://platform.openai.com/docs/api-reference/completions>
#[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone)]
#[builder(build_fn(private, name = "build_internal", validate = "Self::validate"))]
pub struct CompletionRequest { pub struct CompletionRequest {
/// ID of the model to use. #[serde(flatten)]
#[builder(setter(into))] pub inner: async_openai::types::CreateCompletionRequest,
pub model: String,
/// The prompt(s) to generate completions for, encoded as a string, array of
/// strings, array of tokens, or array of token arrays.
///
/// NIM Compatibility:
/// The NIM LLM API only supports a single prompt as a string at this time.
#[builder(setter(into))]
pub prompt: String,
/// The maximum number of tokens that can be generated in the completion.
/// The token count of your prompt plus max_tokens cannot exceed the model's context length.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(into, strip_option))]
pub max_tokens: Option<u32>,
/// The minimum number of tokens to generate. We ignore stop tokens until we see this many
/// tokens. Leave this None unless you are working on the pre-processor.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(into, strip_option))]
pub min_tokens: Option<u32>,
/// If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only
/// server-sent events as they become available, with the stream terminated by a data: \[DONE\]
///
/// If this is set to true, but the response cannot be streamed an error will be returned.
///
/// NIM Compatibility:
/// The NIM SDK can send extra meta data in the SSE stream using the `:` comment, `event:`,
/// or `id:` fields. See the `enable_sse_metadata` field in the NvExt object.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub stream: Option<bool>,
/// How many completions to generate for each prompt.
///
/// Note: Because this parameter generates many completions, it can quickly consume your token quota.
/// Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.
///
/// NIM Compatibility:
/// At this time, the NIM LLM API does not support `n` completions.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(into, strip_option))]
pub n: Option<i32>,
/// Generates `best_of` completions server-side and returns the "best" (the one with the
/// highest log probability per token). Results cannot be streamed.
///
/// When used with `n`, best_of controls the number of candidate completions and `n` specifies
/// how many to return – `best_of` must be greater than `n`.
///
/// NIM Compatibility:
/// At this time, the NIM LLM API does not support `best_of` completions.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(into, strip_option))]
pub best_of: Option<i32>,
/// What sampling `temperature` to use, between 0 and 2. Higher values like 0.8 will make the
/// output more random, while lower values like 0.2 will make it more focused and deterministic.
///
/// We generally recommend altering this or `top_p` but not both.
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = "MIN_TEMPERATURE", max = "MAX_TEMPERATURE"))]
#[builder(default, setter(into, strip_option))]
pub temperature: Option<f32>,
/// An alternative to sampling with `temperature`, called nucleus sampling, where the model
/// considers the results of the tokens with `top_p` probability mass. So 0.1 means only the tokens
/// comprising the top 10% probability mass are considered.
///
/// We generally recommend altering this or `temperature` but not both.
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = "MIN_TOP_P", max = "MAX_TOP_P"))]
#[builder(default, setter(into, strip_option))]
pub top_p: Option<f32>,
/// Include the log probabilities on the logprobs most likely output tokens, as well the chosen tokens.
/// For example, if logprobs is 5, the API will return a list of the 5 most likely tokens. The API will
/// always return the logprob of the sampled token, so there may be up to logprobs+1 elements in the
/// response.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(into, strip_option))]
pub logprobs: Option<i32>,
/// Echo back the prompt in addition to the completion
///
/// NIM Compatibility:
/// At this time, the NIM LLM API does not support `echo` completions.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub echo: Option<bool>,
/// Up to 4 sequences where the API will stop generating further tokens. The returned text will not
/// contain the stop sequence.
#[serde(skip_serializing_if = "Option::is_none")]
// #[builder(default, setter(into, strip_option))]
#[builder(default, setter(strip_option))]
pub stop: Option<Vec<String>>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency
/// in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = "MIN_FREQUENCY_PENALTY", max = "MAX_FREQUENCY_PENALTY"))]
#[builder(default, setter(into, strip_option))]
pub frequency_penalty: Option<f32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in
/// the text so far, increasing the model's likelihood to talk about new topics.
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = "MIN_PRESENCE_PENALTY", max = "MAX_PRESENCE_PENALTY"))]
#[builder(default, setter(into, strip_option))]
pub presence_penalty: Option<f32>,
/// Modify the likelihood of specified tokens appearing in the completion.
///
/// Accepts a JSON object that maps tokens (specified by their token ID in the GPT tokenizer) to an
/// associated bias value from -100 to 100. You can use this tokenizer tool to convert text to token IDs.
/// Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact
/// effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of
/// selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
///
/// As specified in the OpenAI examples, this is a map of tokens_ids as strings to a bias value that
/// is an integer.
///
/// However, the OpenAI blog using the SDK shows that it can also be specified more accurately as a
/// map of token_ids as ints to a bias value that is also an int.
///
/// NIM Compatibility:
/// In the conversion of the OpenAI request to the internal NIM format, the keys of this map will be
/// validated to ensure they are integers. Since different models may have different tokenizers, the
/// range and values will again be validated on the compute backend to ensure they map to valid tokens
/// in the vocabulary of the model.
///
/// ```rust
/// use triton_distributed_llm::protocols::openai::completions::CompletionRequest;
///
/// let request = CompletionRequest::builder()
/// .prompt("What is the meaning of life?")
/// .model("gpt-3.5-turbo")
/// .add_logit_bias(1337, -100) // using an int as a key is ok
/// .add_logit_bias("42", 100) // using a string as a key is also ok
/// .build()
/// .expect("Should not fail");
///
/// assert!(CompletionRequest::builder()
/// .prompt("What is the meaning of life?")
/// .model("gpt-3.5-turbo")
/// .add_logit_bias("some non int", -100)
/// .build()
/// .is_err());
/// ```
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(custom(function = "validate_logit_bias"))]
#[builder(default)]
pub logit_bias: Option<HashMap<String, i32>>,
/// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
///
/// NIM Compatibility:
/// If provided, then the value of this field will be included in the trace metadata and the accounting
/// data (if enabled).
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(into, strip_option))]
pub user: Option<String>,
/// OpenAI specific API parameter; this is not supported by NIM models; however,
/// is preserved as part of the API for compatibility.
///
/// OpenAI API Reference:
/// <https://platform.openai.com/docs/api-reference/completions/create>
///
/// A validation error will be thrown if this field is set when executing against
/// any NIM model.
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(into, strip_option))]
pub suffix: Option<String>,
/// NVIDIA extension to OpenAI's legacy v1::completion::CompletionRequest
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub nvext: Option<NvExt>, pub nvext: Option<NvExt>,
} }
impl CompletionRequest {
/// Create a new CompletionRequestBuilder
pub fn builder() -> CompletionRequestBuilder {
CompletionRequestBuilder::default()
}
}
impl CompletionRequestBuilder {
// This is a pre-build validate function
// This is called before the generated build method, in this case build_internal, is called
// This has access to the internal state of the builder
fn validate(&self) -> Result<(), String> {
Ok(())
}
/// Builds and validates the CompletionRequest
///
/// ```rust
/// use triton_distributed_llm::protocols::openai::completions::CompletionRequest;
///
/// let request = CompletionRequest::builder()
/// .model("mixtral-8x7b-instruct-v0.1")
/// .prompt("Hello")
/// .max_tokens(16_u32)
/// .build()
/// .expect("Failed to build CompletionRequest");
/// ```
pub fn build(&self) -> anyhow::Result<CompletionRequest> {
// Calls the build_internal, validates the result, then performs addition
// post build validation. This is where we might handle any mutually exclusive fields
// and ensure there are no collisions.
let request = self
.build_internal()
.map_err(|e| anyhow::anyhow!("Failed to build CompletionRequest: {}", e))?;
request
.validate()
.map_err(|e| anyhow::anyhow!("Failed to validate CompletionRequest: {}", e))?;
Ok(request)
}
/// Add a stop condition to the `Vec<String>` in the ChatCompletionRequest
/// This will either create or append to the `Vec<String>`
pub fn add_stop(&mut self, stop: impl Into<String>) -> &mut Self {
if self.stop.is_none() {
self.stop = Some(Some(vec![]));
}
self.stop
.as_mut()
.unwrap()
.as_mut()
.unwrap()
.push(stop.into());
self
}
/// Add a tool to the `HashMap<String, i32>` in the ChatCompletionRequest
/// This will either create or update the `HashMap<String, i32>`
pub fn add_logit_bias<T>(&mut self, key: T, value: i32) -> &mut Self
where
T: std::fmt::Display,
{
if self.logit_bias.is_none() {
self.logit_bias = Some(Some(HashMap::new()));
}
self.logit_bias
.as_mut()
.unwrap()
.as_mut()
.unwrap()
.insert(key.to_string(), value);
self
}
}
/// Legacy OpenAI CompletionResponse /// Legacy OpenAI CompletionResponse
/// Represents a completion response from the API. /// Represents a completion response from the API.
/// Note: both the streamed and non-streamed response objects share the same /// Note: both the streamed and non-streamed response objects share the same
...@@ -377,6 +118,29 @@ pub struct LogprobResult { ...@@ -377,6 +118,29 @@ pub struct LogprobResult {
pub text_offset: Vec<i32>, pub text_offset: Vec<i32>,
} }
pub fn prompt_to_string(prompt: &async_openai::types::Prompt) -> String {
match prompt {
async_openai::types::Prompt::String(s) => s.clone(),
async_openai::types::Prompt::StringArray(arr) => arr.join(" "), // Join strings with spaces
async_openai::types::Prompt::IntegerArray(arr) => arr
.iter()
.map(|&num| num.to_string())
.collect::<Vec<_>>()
.join(" "),
async_openai::types::Prompt::ArrayOfIntegerArray(arr) => arr
.iter()
.map(|inner| {
inner
.iter()
.map(|&num| num.to_string())
.collect::<Vec<_>>()
.join(" ")
})
.collect::<Vec<_>>()
.join(" | "), // Separate arrays with a delimiter
}
}
impl NvExtProvider for CompletionRequest { impl NvExtProvider for CompletionRequest {
fn nvext(&self) -> Option<&NvExt> { fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref() self.nvext.as_ref()
...@@ -386,7 +150,7 @@ impl NvExtProvider for CompletionRequest { ...@@ -386,7 +150,7 @@ impl NvExtProvider for CompletionRequest {
if let Some(nvext) = self.nvext.as_ref() { if let Some(nvext) = self.nvext.as_ref() {
if let Some(use_raw_prompt) = nvext.use_raw_prompt { if let Some(use_raw_prompt) = nvext.use_raw_prompt {
if use_raw_prompt { if use_raw_prompt {
return Some(self.prompt.clone()); return Some(prompt_to_string(&self.inner.prompt));
} }
} }
} }
...@@ -412,19 +176,19 @@ impl AnnotationsProvider for CompletionRequest { ...@@ -412,19 +176,19 @@ impl AnnotationsProvider for CompletionRequest {
impl OpenAISamplingOptionsProvider for CompletionRequest { impl OpenAISamplingOptionsProvider for CompletionRequest {
fn get_temperature(&self) -> Option<f32> { fn get_temperature(&self) -> Option<f32> {
self.temperature self.inner.temperature
} }
fn get_top_p(&self) -> Option<f32> { fn get_top_p(&self) -> Option<f32> {
self.top_p self.inner.top_p
} }
fn get_frequency_penalty(&self) -> Option<f32> { fn get_frequency_penalty(&self) -> Option<f32> {
self.frequency_penalty self.inner.frequency_penalty
} }
fn get_presence_penalty(&self) -> Option<f32> { fn get_presence_penalty(&self) -> Option<f32> {
self.presence_penalty self.inner.presence_penalty
} }
fn nvext(&self) -> Option<&NvExt> { fn nvext(&self) -> Option<&NvExt> {
...@@ -434,15 +198,15 @@ impl OpenAISamplingOptionsProvider for CompletionRequest { ...@@ -434,15 +198,15 @@ impl OpenAISamplingOptionsProvider for CompletionRequest {
impl OpenAIStopConditionsProvider for CompletionRequest { impl OpenAIStopConditionsProvider for CompletionRequest {
fn get_max_tokens(&self) -> Option<u32> { fn get_max_tokens(&self) -> Option<u32> {
self.max_tokens self.inner.max_tokens
} }
fn get_min_tokens(&self) -> Option<u32> { fn get_min_tokens(&self) -> Option<u32> {
self.min_tokens None
} }
fn get_stop(&self) -> Option<Vec<String>> { fn get_stop(&self) -> Option<Vec<String>> {
self.stop.clone() None
} }
fn nvext(&self) -> Option<&NvExt> { fn nvext(&self) -> Option<&NvExt> {
...@@ -516,7 +280,7 @@ impl TryFrom<CompletionRequest> for common::CompletionRequest { ...@@ -516,7 +280,7 @@ impl TryFrom<CompletionRequest> for common::CompletionRequest {
// //
// ** no supported // ** no supported
if request.suffix.is_some() { if request.inner.suffix.is_some() {
return Err(anyhow::anyhow!("suffix is not supported")); return Err(anyhow::anyhow!("suffix is not supported"));
} }
...@@ -529,7 +293,7 @@ impl TryFrom<CompletionRequest> for common::CompletionRequest { ...@@ -529,7 +293,7 @@ impl TryFrom<CompletionRequest> for common::CompletionRequest {
.map_err(|e| anyhow::anyhow!("Failed to extract sampling options: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to extract sampling options: {}", e))?;
let prompt = common::PromptType::Completion(common::CompletionContext { let prompt = common::PromptType::Completion(common::CompletionContext {
prompt: request.prompt, prompt: prompt_to_string(&request.inner.prompt),
system_prompt: None, system_prompt: None,
}); });
......
...@@ -26,7 +26,7 @@ impl CompletionRequest { ...@@ -26,7 +26,7 @@ impl CompletionRequest {
enable_logprobs: false, enable_logprobs: false,
}; };
DeltaGenerator::new(self.model.clone(), options) DeltaGenerator::new(self.inner.model.clone(), options)
} }
} }
......
...@@ -387,7 +387,7 @@ async fn test_http_service() { ...@@ -387,7 +387,7 @@ async fn test_http_service() {
// ==== ChatCompletions / Unary / Error ==== // ==== ChatCompletions / Unary / Error ====
// ==== Completions / Unary / Error ==== // ==== Completions / Unary / Error ====
let mut request = CompletionRequest::builder() let mut request = async_openai::types::CreateCompletionRequestArgs::default()
.model("bar") .model("bar")
.prompt("hi") .prompt("hi")
.build() .build()
......
...@@ -13,15 +13,9 @@ ...@@ -13,15 +13,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use async_openai::types::CreateCompletionRequestArgs;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use triton_distributed_llm::protocols::{ use triton_distributed_llm::protocols::openai::{self, completions::CompletionRequest};
common,
openai::{
self,
completions::{CompletionRequest, CompletionRequestBuilder},
nvext::NvExt,
},
};
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
struct CompletionSample { struct CompletionSample {
...@@ -32,15 +26,20 @@ struct CompletionSample { ...@@ -32,15 +26,20 @@ struct CompletionSample {
impl CompletionSample { impl CompletionSample {
fn new<F>(description: impl Into<String>, configure: F) -> Result<Self, String> fn new<F>(description: impl Into<String>, configure: F) -> Result<Self, String>
where where
F: FnOnce(&mut CompletionRequestBuilder) -> &mut CompletionRequestBuilder, F: FnOnce(&mut CreateCompletionRequestArgs) -> &mut CreateCompletionRequestArgs,
{ {
let mut builder = CompletionRequestBuilder::default(); let mut builder = CreateCompletionRequestArgs::default();
builder builder
.model("gpt-3.5-turbo") .model("gpt-3.5-turbo")
.prompt("What is the meaning of life?"); .prompt("What is the meaning of life?");
configure(&mut builder); configure(&mut builder);
let inner = builder.build().unwrap();
let request = CompletionRequest { inner, nvext: None };
Ok(Self { Ok(Self {
request: builder.build().unwrap(), request,
description: description.into(), description: description.into(),
}) })
} }
...@@ -48,7 +47,7 @@ impl CompletionSample { ...@@ -48,7 +47,7 @@ impl CompletionSample {
#[test] #[test]
fn minimum_viable_request() { fn minimum_viable_request() {
let request = CompletionRequest::builder() let request = CreateCompletionRequestArgs::default()
.prompt("What is the meaning of life?") .prompt("What is the meaning of life?")
.model("gpt-3.5-turbo") .model("gpt-3.5-turbo")
.build() .build()
...@@ -57,57 +56,6 @@ fn minimum_viable_request() { ...@@ -57,57 +56,6 @@ fn minimum_viable_request() {
insta::assert_json_snapshot!(request); insta::assert_json_snapshot!(request);
} }
#[test]
fn missing_model() {
let request = CompletionRequest::builder()
.prompt("What is the meaning of life?")
.build();
assert!(request.is_err());
}
#[test]
fn missing_prompt() {
let request = CompletionRequest::builder().model("gpt-3.5-turbo").build();
assert!(request.is_err());
}
#[test]
fn out_of_range() {
let request = CompletionRequest::builder()
.prompt("What is the meaning of life?")
.model("gpt-3.5-turbo")
.temperature(openai::MAX_TEMPERATURE + 1.0)
.build();
assert!(request.is_err());
let request = CompletionRequest::builder()
.prompt("What is the meaning of life?")
.model("gpt-3.5-turbo")
.temperature(openai::MIN_TEMPERATURE - 1.0)
.build();
assert!(request.is_err());
}
#[test]
fn ignore_eos() {
let request = CompletionRequest::builder()
.prompt("What is the meaning of life?")
.model("gpt-3.5-turbo")
.nvext(
NvExt::builder()
.ignore_eos(true)
.build()
.expect("error building nvext"),
)
.build()
.expect("error building request");
let request = common::CompletionRequest::try_from(request).expect("error converting request");
let ignore_eos = request.stop_conditions.ignore_eos.unwrap();
assert!(ignore_eos);
}
#[test] #[test]
fn valid_samples() { fn valid_samples() {
let mut settings = insta::Settings::clone_current(); let mut settings = insta::Settings::clone_current();
...@@ -174,10 +122,5 @@ fn build_samples() -> Result<Vec<CompletionSample>, String> { ...@@ -174,10 +122,5 @@ fn build_samples() -> Result<Vec<CompletionSample>, String> {
|builder| builder.stream(true), |builder| builder.stream(true),
)?); )?);
samples.push(CompletionSample::new(
"should have prompt, model, and logit_bias fields with the logits_bias having two key/value pairs",
|builder| builder.add_logit_bias(1337, -100).add_logit_bias("42", 100),
)?);
Ok(samples) Ok(samples)
} }
--- ---
source: triton-llm/tests/openai_completions.rs source: tests/openai_completions.rs
description: "should have prompt, model, and max_tokens fields" description: "should have prompt, model, and max_tokens fields"
expression: sample.request expression: sample.request
--- ---
{ {
"max_tokens": 10,
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"prompt": "What is the meaning of life?", "prompt": "What is the meaning of life?"
"max_tokens": 10
} }
--- ---
source: triton-llm/tests/openai_completions.rs source: tests/openai_completions.rs
description: "should have prompt, model, and frequency_penalty fields" description: "should have prompt, model, and frequency_penalty fields"
expression: sample.request expression: sample.request
--- ---
{ {
"frequency_penalty": -2.0,
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"prompt": "What is the meaning of life?", "prompt": "What is the meaning of life?"
"frequency_penalty": -2.0
} }
--- ---
source: triton-llm/tests/openai_completions.rs source: tests/openai_completions.rs
description: "should have prompt, model, and presence_penalty fields" description: "should have prompt, model, and presence_penalty fields"
expression: sample.request expression: sample.request
--- ---
{ {
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"prompt": "What is the meaning of life?", "presence_penalty": -2.0,
"presence_penalty": -2.0 "prompt": "What is the meaning of life?"
} }
--- ---
source: triton-llm/tests/openai_completions.rs source: tests/openai_completions.rs
description: "should have prompt, model, and echo fields" description: "should have prompt, model, and echo fields"
expression: sample.request expression: sample.request
--- ---
{ {
"echo": true,
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"prompt": "What is the meaning of life?", "prompt": "What is the meaning of life?"
"echo": true
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment