Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
9162f3ad
Commit
9162f3ad
authored
Feb 28, 2025
by
Paul Hendricks
Committed by
GitHub
Feb 28, 2025
Browse files
refactor: use async-openai CompletionRequest (#310)
parent
057f8f47
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
71 additions
and
375 deletions
+71
-375
lib/llm/src/http/service/openai.rs
lib/llm/src/http/service/openai.rs
+6
-4
lib/llm/src/preprocessor/prompt/template/oai.rs
lib/llm/src/preprocessor/prompt/template/oai.rs
+1
-1
lib/llm/src/protocols/openai.rs
lib/llm/src/protocols/openai.rs
+0
-13
lib/llm/src/protocols/openai/completions.rs
lib/llm/src/protocols/openai/completions.rs
+39
-275
lib/llm/src/protocols/openai/completions/delta.rs
lib/llm/src/protocols/openai/completions/delta.rs
+1
-1
lib/llm/tests/http-service.rs
lib/llm/tests/http-service.rs
+1
-1
lib/llm/tests/openai_completions.rs
lib/llm/tests/openai_completions.rs
+11
-68
lib/llm/tests/snapshots/openai_completions__valid_samples-2.snap
.../tests/snapshots/openai_completions__valid_samples-2.snap
+3
-3
lib/llm/tests/snapshots/openai_completions__valid_samples-5.snap
.../tests/snapshots/openai_completions__valid_samples-5.snap
+3
-3
lib/llm/tests/snapshots/openai_completions__valid_samples-6.snap
.../tests/snapshots/openai_completions__valid_samples-6.snap
+3
-3
lib/llm/tests/snapshots/openai_completions__valid_samples-8.snap
.../tests/snapshots/openai_completions__valid_samples-8.snap
+3
-3
No files found.
lib/llm/src/http/service/openai.rs
View file @
9162f3ad
...
...
@@ -140,17 +140,19 @@ async fn completions(
let
request_id
=
uuid
::
Uuid
::
new_v4
()
.to_string
();
// todo - decide on default
let
streaming
=
request
.stream
.unwrap_or
(
false
);
let
streaming
=
request
.
inner.
stream
.unwrap_or
(
false
);
// update the request to always stream
let
request
=
CompletionRequest
{
let
inner
=
async_openai
::
types
::
Create
CompletionRequest
{
stream
:
Some
(
true
),
..
request
..
request
.inner
};
let
request
=
CompletionRequest
{
inner
,
nvext
:
None
};
// todo - make the protocols be optional for model name
// todo - when optional, if none, apply a default
let
model
=
&
request
.model
;
let
model
=
&
request
.
inner.
model
;
// todo - error handling should be more robust
let
engine
=
state
...
...
lib/llm/src/preprocessor/prompt/template/oai.rs
View file @
9162f3ad
...
...
@@ -60,7 +60,7 @@ impl OAIChatLikeRequest for CompletionRequest {
let
message
=
async_openai
::
types
::
ChatCompletionRequestMessage
::
User
(
async_openai
::
types
::
ChatCompletionRequestUserMessage
{
content
:
async_openai
::
types
::
ChatCompletionRequestUserMessageContent
::
Text
(
self
.prompt
.clone
(
),
crate
::
protocols
::
openai
::
completions
::
prompt_to_string
(
&
self
.inner.prompt
),
),
name
:
None
,
},
...
...
lib/llm/src/protocols/openai.rs
View file @
9162f3ad
...
...
@@ -22,11 +22,9 @@ pub mod nvext;
use
anyhow
::
Result
;
use
serde
::{
Deserialize
,
Serialize
};
use
std
::{
collections
::
HashMap
,
fmt
::
Display
,
ops
::{
Add
,
Div
,
Mul
,
Sub
},
};
use
validator
::
ValidationError
;
use
super
::{
common
::{
self
,
SamplingOptionsProvider
,
StopConditionsProvider
},
...
...
@@ -263,17 +261,6 @@ pub struct GenericCompletionResponse<C>
// TODO() - add NvResponseExtention
}
fn
validate_logit_bias
(
logit_bias
:
&
HashMap
<
String
,
i32
>
)
->
Result
<
(),
ValidationError
>
{
for
key
in
logit_bias
.keys
()
{
if
key
.parse
::
<
i32
>
()
.is_err
()
{
return
Err
(
ValidationError
::
new
(
"logit_bias"
)
.with_message
(
"Keys must be integers"
.into
())
);
}
}
Ok
(())
}
// todo - move to common location
fn
validate_range
<
T
>
(
value
:
Option
<
T
>
,
range
:
&
(
T
,
T
))
->
Result
<
Option
<
T
>>
where
...
...
lib/llm/src/protocols/openai/completions.rs
View file @
9162f3ad
...
...
@@ -22,285 +22,26 @@ use validator::Validate;
mod
aggregator
;
mod
delta
;
// pub use aggregator::DeltaAggregator;
pub
use
aggregator
::
DeltaAggregator
;
pub
use
delta
::
DeltaGenerator
;
use
super
::{
common
::{
self
,
SamplingOptionsProvider
,
StopConditionsProvider
},
nvext
::{
NvExt
,
NvExtProvider
},
validate_logit_bias
,
CompletionUsage
,
ContentProvider
,
OpenAISamplingOptionsProvider
,
OpenAIStopConditionsProvider
,
MAX_FREQUENCY_PENALTY
,
MAX_PRESENCE_PENALTY
,
MAX_TEMPERATURE
,
MAX_TOP_P
,
MIN_FREQUENCY_PENALTY
,
MIN_PRESENCE_PENALTY
,
MIN_TEMPERATURE
,
MIN_TOP_P
,
CompletionUsage
,
ContentProvider
,
OpenAISamplingOptionsProvider
,
OpenAIStopConditionsProvider
,
};
use
triton_distributed_runtime
::
protocols
::
annotated
::
AnnotationsProvider
;
/// Legacy OpenAI CompletionRequest
///
/// Reference: <https://platform.openai.com/docs/api-reference/completions>
#[derive(Serialize,
Deserialize,
Builder,
Validate,
Debug,
Clone)]
#[builder(build_fn(private,
name
=
"build_internal"
,
validate
=
"Self::validate"
))]
#[derive(Serialize,
Deserialize,
Validate,
Debug,
Clone)]
pub
struct
CompletionRequest
{
/// ID of the model to use.
#[builder(setter(into))]
pub
model
:
String
,
/// The prompt(s) to generate completions for, encoded as a string, array of
/// strings, array of tokens, or array of token arrays.
///
/// NIM Compatibility:
/// The NIM LLM API only supports a single prompt as a string at this time.
#[builder(setter(into))]
pub
prompt
:
String
,
/// The maximum number of tokens that can be generated in the completion.
/// The token count of your prompt plus max_tokens cannot exceed the model's context length.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(into,
strip_option))]
pub
max_tokens
:
Option
<
u32
>
,
/// The minimum number of tokens to generate. We ignore stop tokens until we see this many
/// tokens. Leave this None unless you are working on the pre-processor.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(into,
strip_option))]
pub
min_tokens
:
Option
<
u32
>
,
/// If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only
/// server-sent events as they become available, with the stream terminated by a data: \[DONE\]
///
/// If this is set to true, but the response cannot be streamed an error will be returned.
///
/// NIM Compatibility:
/// The NIM SDK can send extra meta data in the SSE stream using the `:` comment, `event:`,
/// or `id:` fields. See the `enable_sse_metadata` field in the NvExt object.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(strip_option))]
pub
stream
:
Option
<
bool
>
,
/// How many completions to generate for each prompt.
///
/// Note: Because this parameter generates many completions, it can quickly consume your token quota.
/// Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.
///
/// NIM Compatibility:
/// At this time, the NIM LLM API does not support `n` completions.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(into,
strip_option))]
pub
n
:
Option
<
i32
>
,
/// Generates `best_of` completions server-side and returns the "best" (the one with the
/// highest log probability per token). Results cannot be streamed.
///
/// When used with `n`, best_of controls the number of candidate completions and `n` specifies
/// how many to return – `best_of` must be greater than `n`.
///
/// NIM Compatibility:
/// At this time, the NIM LLM API does not support `best_of` completions.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(into,
strip_option))]
pub
best_of
:
Option
<
i32
>
,
/// What sampling `temperature` to use, between 0 and 2. Higher values like 0.8 will make the
/// output more random, while lower values like 0.2 will make it more focused and deterministic.
///
/// We generally recommend altering this or `top_p` but not both.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[validate(range(min
=
"MIN_TEMPERATURE"
,
max
=
"MAX_TEMPERATURE"
))]
#[builder(default,
setter(into,
strip_option))]
pub
temperature
:
Option
<
f32
>
,
/// An alternative to sampling with `temperature`, called nucleus sampling, where the model
/// considers the results of the tokens with `top_p` probability mass. So 0.1 means only the tokens
/// comprising the top 10% probability mass are considered.
///
/// We generally recommend altering this or `temperature` but not both.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[validate(range(min
=
"MIN_TOP_P"
,
max
=
"MAX_TOP_P"
))]
#[builder(default,
setter(into,
strip_option))]
pub
top_p
:
Option
<
f32
>
,
/// Include the log probabilities on the logprobs most likely output tokens, as well the chosen tokens.
/// For example, if logprobs is 5, the API will return a list of the 5 most likely tokens. The API will
/// always return the logprob of the sampled token, so there may be up to logprobs+1 elements in the
/// response.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(into,
strip_option))]
pub
logprobs
:
Option
<
i32
>
,
/// Echo back the prompt in addition to the completion
///
/// NIM Compatibility:
/// At this time, the NIM LLM API does not support `echo` completions.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(strip_option))]
pub
echo
:
Option
<
bool
>
,
/// Up to 4 sequences where the API will stop generating further tokens. The returned text will not
/// contain the stop sequence.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
// #[builder(default, setter(into, strip_option))]
#[builder(default,
setter(strip_option))]
pub
stop
:
Option
<
Vec
<
String
>>
,
#[serde(flatten)]
pub
inner
:
async_openai
::
types
::
CreateCompletionRequest
,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency
/// in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[validate(range(min
=
"MIN_FREQUENCY_PENALTY"
,
max
=
"MAX_FREQUENCY_PENALTY"
))]
#[builder(default,
setter(into,
strip_option))]
pub
frequency_penalty
:
Option
<
f32
>
,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in
/// the text so far, increasing the model's likelihood to talk about new topics.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[validate(range(min
=
"MIN_PRESENCE_PENALTY"
,
max
=
"MAX_PRESENCE_PENALTY"
))]
#[builder(default,
setter(into,
strip_option))]
pub
presence_penalty
:
Option
<
f32
>
,
/// Modify the likelihood of specified tokens appearing in the completion.
///
/// Accepts a JSON object that maps tokens (specified by their token ID in the GPT tokenizer) to an
/// associated bias value from -100 to 100. You can use this tokenizer tool to convert text to token IDs.
/// Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact
/// effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of
/// selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
///
/// As specified in the OpenAI examples, this is a map of tokens_ids as strings to a bias value that
/// is an integer.
///
/// However, the OpenAI blog using the SDK shows that it can also be specified more accurately as a
/// map of token_ids as ints to a bias value that is also an int.
///
/// NIM Compatibility:
/// In the conversion of the OpenAI request to the internal NIM format, the keys of this map will be
/// validated to ensure they are integers. Since different models may have different tokenizers, the
/// range and values will again be validated on the compute backend to ensure they map to valid tokens
/// in the vocabulary of the model.
///
/// ```rust
/// use triton_distributed_llm::protocols::openai::completions::CompletionRequest;
///
/// let request = CompletionRequest::builder()
/// .prompt("What is the meaning of life?")
/// .model("gpt-3.5-turbo")
/// .add_logit_bias(1337, -100) // using an int as a key is ok
/// .add_logit_bias("42", 100) // using a string as a key is also ok
/// .build()
/// .expect("Should not fail");
///
/// assert!(CompletionRequest::builder()
/// .prompt("What is the meaning of life?")
/// .model("gpt-3.5-turbo")
/// .add_logit_bias("some non int", -100)
/// .build()
/// .is_err());
/// ```
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[validate(custom(function
=
"validate_logit_bias"
))]
#[builder(default)]
pub
logit_bias
:
Option
<
HashMap
<
String
,
i32
>>
,
/// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
///
/// NIM Compatibility:
/// If provided, then the value of this field will be included in the trace metadata and the accounting
/// data (if enabled).
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(into,
strip_option))]
pub
user
:
Option
<
String
>
,
/// OpenAI specific API parameter; this is not supported by NIM models; however,
/// is preserved as part of the API for compatibility.
///
/// OpenAI API Reference:
/// <https://platform.openai.com/docs/api-reference/completions/create>
///
/// A validation error will be thrown if this field is set when executing against
/// any NIM model.
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(into,
strip_option))]
pub
suffix
:
Option
<
String
>
,
/// NVIDIA extension to OpenAI's legacy v1::completion::CompletionRequest
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(strip_option))]
pub
nvext
:
Option
<
NvExt
>
,
}
impl
CompletionRequest
{
/// Create a new CompletionRequestBuilder
pub
fn
builder
()
->
CompletionRequestBuilder
{
CompletionRequestBuilder
::
default
()
}
}
impl
CompletionRequestBuilder
{
// This is a pre-build validate function
// This is called before the generated build method, in this case build_internal, is called
// This has access to the internal state of the builder
fn
validate
(
&
self
)
->
Result
<
(),
String
>
{
Ok
(())
}
/// Builds and validates the CompletionRequest
///
/// ```rust
/// use triton_distributed_llm::protocols::openai::completions::CompletionRequest;
///
/// let request = CompletionRequest::builder()
/// .model("mixtral-8x7b-instruct-v0.1")
/// .prompt("Hello")
/// .max_tokens(16_u32)
/// .build()
/// .expect("Failed to build CompletionRequest");
/// ```
pub
fn
build
(
&
self
)
->
anyhow
::
Result
<
CompletionRequest
>
{
// Calls the build_internal, validates the result, then performs addition
// post build validation. This is where we might handle any mutually exclusive fields
// and ensure there are no collisions.
let
request
=
self
.build_internal
()
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to build CompletionRequest: {}"
,
e
))
?
;
request
.validate
()
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to validate CompletionRequest: {}"
,
e
))
?
;
Ok
(
request
)
}
/// Add a stop condition to the `Vec<String>` in the ChatCompletionRequest
/// This will either create or append to the `Vec<String>`
pub
fn
add_stop
(
&
mut
self
,
stop
:
impl
Into
<
String
>
)
->
&
mut
Self
{
if
self
.stop
.is_none
()
{
self
.stop
=
Some
(
Some
(
vec!
[]));
}
self
.stop
.as_mut
()
.unwrap
()
.as_mut
()
.unwrap
()
.push
(
stop
.into
());
self
}
/// Add a tool to the `HashMap<String, i32>` in the ChatCompletionRequest
/// This will either create or update the `HashMap<String, i32>`
pub
fn
add_logit_bias
<
T
>
(
&
mut
self
,
key
:
T
,
value
:
i32
)
->
&
mut
Self
where
T
:
std
::
fmt
::
Display
,
{
if
self
.logit_bias
.is_none
()
{
self
.logit_bias
=
Some
(
Some
(
HashMap
::
new
()));
}
self
.logit_bias
.as_mut
()
.unwrap
()
.as_mut
()
.unwrap
()
.insert
(
key
.to_string
(),
value
);
self
}
}
/// Legacy OpenAI CompletionResponse
/// Represents a completion response from the API.
/// Note: both the streamed and non-streamed response objects share the same
...
...
@@ -377,6 +118,29 @@ pub struct LogprobResult {
pub
text_offset
:
Vec
<
i32
>
,
}
pub
fn
prompt_to_string
(
prompt
:
&
async_openai
::
types
::
Prompt
)
->
String
{
match
prompt
{
async_openai
::
types
::
Prompt
::
String
(
s
)
=>
s
.clone
(),
async_openai
::
types
::
Prompt
::
StringArray
(
arr
)
=>
arr
.join
(
" "
),
// Join strings with spaces
async_openai
::
types
::
Prompt
::
IntegerArray
(
arr
)
=>
arr
.iter
()
.map
(|
&
num
|
num
.to_string
())
.collect
::
<
Vec
<
_
>>
()
.join
(
" "
),
async_openai
::
types
::
Prompt
::
ArrayOfIntegerArray
(
arr
)
=>
arr
.iter
()
.map
(|
inner
|
{
inner
.iter
()
.map
(|
&
num
|
num
.to_string
())
.collect
::
<
Vec
<
_
>>
()
.join
(
" "
)
})
.collect
::
<
Vec
<
_
>>
()
.join
(
" | "
),
// Separate arrays with a delimiter
}
}
impl
NvExtProvider
for
CompletionRequest
{
fn
nvext
(
&
self
)
->
Option
<&
NvExt
>
{
self
.nvext
.as_ref
()
...
...
@@ -386,7 +150,7 @@ impl NvExtProvider for CompletionRequest {
if
let
Some
(
nvext
)
=
self
.nvext
.as_ref
()
{
if
let
Some
(
use_raw_prompt
)
=
nvext
.use_raw_prompt
{
if
use_raw_prompt
{
return
Some
(
self
.prompt
.clone
(
));
return
Some
(
prompt_to_string
(
&
self
.inner.prompt
));
}
}
}
...
...
@@ -412,19 +176,19 @@ impl AnnotationsProvider for CompletionRequest {
impl
OpenAISamplingOptionsProvider
for
CompletionRequest
{
fn
get_temperature
(
&
self
)
->
Option
<
f32
>
{
self
.temperature
self
.
inner.
temperature
}
fn
get_top_p
(
&
self
)
->
Option
<
f32
>
{
self
.top_p
self
.
inner.
top_p
}
fn
get_frequency_penalty
(
&
self
)
->
Option
<
f32
>
{
self
.frequency_penalty
self
.
inner.
frequency_penalty
}
fn
get_presence_penalty
(
&
self
)
->
Option
<
f32
>
{
self
.presence_penalty
self
.
inner.
presence_penalty
}
fn
nvext
(
&
self
)
->
Option
<&
NvExt
>
{
...
...
@@ -434,15 +198,15 @@ impl OpenAISamplingOptionsProvider for CompletionRequest {
impl
OpenAIStopConditionsProvider
for
CompletionRequest
{
fn
get_max_tokens
(
&
self
)
->
Option
<
u32
>
{
self
.max_tokens
self
.
inner.
max_tokens
}
fn
get_min_tokens
(
&
self
)
->
Option
<
u32
>
{
self
.min_tokens
None
}
fn
get_stop
(
&
self
)
->
Option
<
Vec
<
String
>>
{
self
.stop
.cl
one
()
N
one
}
fn
nvext
(
&
self
)
->
Option
<&
NvExt
>
{
...
...
@@ -516,7 +280,7 @@ impl TryFrom<CompletionRequest> for common::CompletionRequest {
//
// ** no supported
if
request
.suffix
.is_some
()
{
if
request
.
inner.
suffix
.is_some
()
{
return
Err
(
anyhow
::
anyhow!
(
"suffix is not supported"
));
}
...
...
@@ -529,7 +293,7 @@ impl TryFrom<CompletionRequest> for common::CompletionRequest {
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Failed to extract sampling options: {}"
,
e
))
?
;
let
prompt
=
common
::
PromptType
::
Completion
(
common
::
CompletionContext
{
prompt
:
request
.prompt
,
prompt
:
prompt_to_string
(
&
request
.
inner.
prompt
)
,
system_prompt
:
None
,
});
...
...
lib/llm/src/protocols/openai/completions/delta.rs
View file @
9162f3ad
...
...
@@ -26,7 +26,7 @@ impl CompletionRequest {
enable_logprobs
:
false
,
};
DeltaGenerator
::
new
(
self
.model
.clone
(),
options
)
DeltaGenerator
::
new
(
self
.
inner.
model
.clone
(),
options
)
}
}
...
...
lib/llm/tests/http-service.rs
View file @
9162f3ad
...
...
@@ -387,7 +387,7 @@ async fn test_http_service() {
// ==== ChatCompletions / Unary / Error ====
// ==== Completions / Unary / Error ====
let
mut
request
=
CompletionRequest
::
builder
()
let
mut
request
=
async_openai
::
types
::
Create
CompletionRequest
Args
::
default
()
.model
(
"bar"
)
.prompt
(
"hi"
)
.build
()
...
...
lib/llm/tests/openai_completions.rs
View file @
9162f3ad
...
...
@@ -13,15 +13,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use
async_openai
::
types
::
CreateCompletionRequestArgs
;
use
serde
::{
Deserialize
,
Serialize
};
use
triton_distributed_llm
::
protocols
::{
common
,
openai
::{
self
,
completions
::{
CompletionRequest
,
CompletionRequestBuilder
},
nvext
::
NvExt
,
},
};
use
triton_distributed_llm
::
protocols
::
openai
::{
self
,
completions
::
CompletionRequest
};
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
struct
CompletionSample
{
...
...
@@ -32,15 +26,20 @@ struct CompletionSample {
impl
CompletionSample
{
fn
new
<
F
>
(
description
:
impl
Into
<
String
>
,
configure
:
F
)
->
Result
<
Self
,
String
>
where
F
:
FnOnce
(
&
mut
CompletionRequest
Builder
)
->
&
mut
CompletionRequest
Builder
,
F
:
FnOnce
(
&
mut
Create
CompletionRequest
Args
)
->
&
mut
Create
CompletionRequest
Args
,
{
let
mut
builder
=
CompletionRequest
Builder
::
default
();
let
mut
builder
=
Create
CompletionRequest
Args
::
default
();
builder
.model
(
"gpt-3.5-turbo"
)
.prompt
(
"What is the meaning of life?"
);
configure
(
&
mut
builder
);
let
inner
=
builder
.build
()
.unwrap
();
let
request
=
CompletionRequest
{
inner
,
nvext
:
None
};
Ok
(
Self
{
request
:
builder
.build
()
.unwrap
()
,
request
,
description
:
description
.into
(),
})
}
...
...
@@ -48,7 +47,7 @@ impl CompletionSample {
#[test]
fn
minimum_viable_request
()
{
let
request
=
CompletionRequest
::
builder
()
let
request
=
Create
CompletionRequest
Args
::
default
()
.prompt
(
"What is the meaning of life?"
)
.model
(
"gpt-3.5-turbo"
)
.build
()
...
...
@@ -57,57 +56,6 @@ fn minimum_viable_request() {
insta
::
assert_json_snapshot!
(
request
);
}
#[test]
fn
missing_model
()
{
let
request
=
CompletionRequest
::
builder
()
.prompt
(
"What is the meaning of life?"
)
.build
();
assert
!
(
request
.is_err
());
}
#[test]
fn
missing_prompt
()
{
let
request
=
CompletionRequest
::
builder
()
.model
(
"gpt-3.5-turbo"
)
.build
();
assert
!
(
request
.is_err
());
}
#[test]
fn
out_of_range
()
{
let
request
=
CompletionRequest
::
builder
()
.prompt
(
"What is the meaning of life?"
)
.model
(
"gpt-3.5-turbo"
)
.temperature
(
openai
::
MAX_TEMPERATURE
+
1.0
)
.build
();
assert
!
(
request
.is_err
());
let
request
=
CompletionRequest
::
builder
()
.prompt
(
"What is the meaning of life?"
)
.model
(
"gpt-3.5-turbo"
)
.temperature
(
openai
::
MIN_TEMPERATURE
-
1.0
)
.build
();
assert
!
(
request
.is_err
());
}
#[test]
fn
ignore_eos
()
{
let
request
=
CompletionRequest
::
builder
()
.prompt
(
"What is the meaning of life?"
)
.model
(
"gpt-3.5-turbo"
)
.nvext
(
NvExt
::
builder
()
.ignore_eos
(
true
)
.build
()
.expect
(
"error building nvext"
),
)
.build
()
.expect
(
"error building request"
);
let
request
=
common
::
CompletionRequest
::
try_from
(
request
)
.expect
(
"error converting request"
);
let
ignore_eos
=
request
.stop_conditions.ignore_eos
.unwrap
();
assert
!
(
ignore_eos
);
}
#[test]
fn
valid_samples
()
{
let
mut
settings
=
insta
::
Settings
::
clone_current
();
...
...
@@ -174,10 +122,5 @@ fn build_samples() -> Result<Vec<CompletionSample>, String> {
|
builder
|
builder
.stream
(
true
),
)
?
);
samples
.push
(
CompletionSample
::
new
(
"should have prompt, model, and logit_bias fields with the logits_bias having two key/value pairs"
,
|
builder
|
builder
.add_logit_bias
(
1337
,
-
100
)
.add_logit_bias
(
"42"
,
100
),
)
?
);
Ok
(
samples
)
}
lib/llm/tests/snapshots/openai_completions__valid_samples-2.snap
View file @
9162f3ad
---
source:
triton-llm/
tests/openai_completions.rs
source: tests/openai_completions.rs
description: "should have prompt, model, and max_tokens fields"
expression: sample.request
---
{
"max_tokens": 10,
"model": "gpt-3.5-turbo",
"prompt": "What is the meaning of life?",
"max_tokens": 10
"prompt": "What is the meaning of life?"
}
lib/llm/tests/snapshots/openai_completions__valid_samples-5.snap
View file @
9162f3ad
---
source:
triton-llm/
tests/openai_completions.rs
source: tests/openai_completions.rs
description: "should have prompt, model, and frequency_penalty fields"
expression: sample.request
---
{
"frequency_penalty": -2.0,
"model": "gpt-3.5-turbo",
"prompt": "What is the meaning of life?",
"frequency_penalty": -2.0
"prompt": "What is the meaning of life?"
}
lib/llm/tests/snapshots/openai_completions__valid_samples-6.snap
View file @
9162f3ad
---
source:
triton-llm/
tests/openai_completions.rs
source: tests/openai_completions.rs
description: "should have prompt, model, and presence_penalty fields"
expression: sample.request
---
{
"model": "gpt-3.5-turbo",
"pr
ompt": "What is the meaning of life?"
,
"pr
esence_penalty": -2.0
"pr
esence_penalty": -2.0
,
"pr
ompt": "What is the meaning of life?"
}
lib/llm/tests/snapshots/openai_completions__valid_samples-8.snap
View file @
9162f3ad
---
source:
triton-llm/
tests/openai_completions.rs
source: tests/openai_completions.rs
description: "should have prompt, model, and echo fields"
expression: sample.request
---
{
"echo": true,
"model": "gpt-3.5-turbo",
"prompt": "What is the meaning of life?",
"echo": true
"prompt": "What is the meaning of life?"
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment