Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
ee86bad3
Unverified
Commit
ee86bad3
authored
Jul 01, 2025
by
Nathan Barry
Committed by
GitHub
Jul 01, 2025
Browse files
feat: Validation engine for validating OpenAI api request data (#1674)
parent
f0652d89
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
663 additions
and
63 deletions
+663
-63
lib/llm/src/engines.rs
lib/llm/src/engines.rs
+41
-1
lib/llm/src/protocols/openai.rs
lib/llm/src/protocols/openai.rs
+4
-52
lib/llm/src/protocols/openai/chat_completions.rs
lib/llm/src/protocols/openai/chat_completions.rs
+45
-4
lib/llm/src/protocols/openai/completions.rs
lib/llm/src/protocols/openai/completions.rs
+30
-1
lib/llm/src/protocols/openai/validate.rs
lib/llm/src/protocols/openai/validate.rs
+529
-0
lib/llm/tests/openai_completions.rs
lib/llm/tests/openai_completions.rs
+5
-5
lib/runtime/src/pipeline/context.rs
lib/runtime/src/pipeline/context.rs
+9
-0
No files found.
lib/llm/src/engines.rs
View file @
ee86bad3
...
@@ -124,8 +124,19 @@ fn delta_core(tok: u32) -> Annotated<LLMEngineOutput> {
...
@@ -124,8 +124,19 @@ fn delta_core(tok: u32) -> Annotated<LLMEngineOutput> {
/// Useful for testing ingress such as service-http.
/// Useful for testing ingress such as service-http.
struct
EchoEngineFull
{}
struct
EchoEngineFull
{}
/// Validate Engine that verifies request data
pub
struct
ValidateEngine
<
E
>
{
inner
:
E
,
}
impl
<
E
>
ValidateEngine
<
E
>
{
pub
fn
new
(
inner
:
E
)
->
Self
{
Self
{
inner
}
}
}
/// Engine that dispatches requests to either OpenAICompletions
/// Engine that dispatches requests to either OpenAICompletions
//or OpenAIChatCompletions engine
//
/
or OpenAIChatCompletions engine
pub
struct
EngineDispatcher
<
E
>
{
pub
struct
EngineDispatcher
<
E
>
{
inner
:
E
,
inner
:
E
,
}
}
...
@@ -136,6 +147,11 @@ impl<E> EngineDispatcher<E> {
...
@@ -136,6 +147,11 @@ impl<E> EngineDispatcher<E> {
}
}
}
}
/// Trait on request types that allows us to validate the data
pub
trait
ValidateRequest
{
fn
validate
(
&
self
)
->
Result
<
(),
anyhow
::
Error
>
;
}
/// Trait that allows handling both completion and chat completions requests
/// Trait that allows handling both completion and chat completions requests
#[async_trait]
#[async_trait]
pub
trait
StreamingEngine
:
Send
+
Sync
{
pub
trait
StreamingEngine
:
Send
+
Sync
{
...
@@ -267,6 +283,30 @@ impl
...
@@ -267,6 +283,30 @@ impl
}
}
}
}
#[async_trait]
impl
<
E
,
Req
,
Resp
>
AsyncEngine
<
SingleIn
<
Req
>
,
ManyOut
<
Annotated
<
Resp
>>
,
Error
>
for
ValidateEngine
<
E
>
where
E
:
AsyncEngine
<
SingleIn
<
Req
>
,
ManyOut
<
Annotated
<
Resp
>>
,
Error
>
+
Send
+
Sync
,
Req
:
ValidateRequest
+
Send
+
Sync
+
'static
,
Resp
:
Send
+
Sync
+
'static
,
{
async
fn
generate
(
&
self
,
incoming_request
:
SingleIn
<
Req
>
,
)
->
Result
<
ManyOut
<
Annotated
<
Resp
>>
,
Error
>
{
let
(
request
,
context
)
=
incoming_request
.into_parts
();
// Validate the request first
if
let
Err
(
validation_error
)
=
request
.validate
()
{
return
Err
(
anyhow
::
anyhow!
(
"Validation failed: {}"
,
validation_error
));
}
// Forward to inner engine if validation passes
let
validated_request
=
SingleIn
::
rejoin
(
request
,
context
);
self
.inner
.generate
(
validated_request
)
.await
}
}
#[async_trait]
#[async_trait]
impl
<
E
>
StreamingEngine
for
EngineDispatcher
<
E
>
impl
<
E
>
StreamingEngine
for
EngineDispatcher
<
E
>
where
where
...
...
lib/llm/src/protocols/openai.rs
View file @
ee86bad3
...
@@ -13,8 +13,6 @@
...
@@ -13,8 +13,6 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
use
std
::
fmt
::
Display
;
use
anyhow
::
Result
;
use
anyhow
::
Result
;
use
serde
::{
Deserialize
,
Serialize
};
use
serde
::{
Deserialize
,
Serialize
};
...
@@ -29,42 +27,11 @@ pub mod embeddings;
...
@@ -29,42 +27,11 @@ pub mod embeddings;
pub
mod
models
;
pub
mod
models
;
pub
mod
nvext
;
pub
mod
nvext
;
pub
mod
responses
;
pub
mod
responses
;
pub
mod
validate
;
/// Minimum allowed value for OpenAI's `temperature` sampling option
use
validate
::{
pub
const
MIN_TEMPERATURE
:
f32
=
0.0
;
validate_range
,
FREQUENCY_PENALTY_RANGE
,
PRESENCE_PENALTY_RANGE
,
TEMPERATURE_RANGE
,
TOP_P_RANGE
,
};
/// Maximum allowed value for OpenAI's `temperature` sampling option
pub
const
MAX_TEMPERATURE
:
f32
=
2.0
;
/// Allowed range of values for OpenAI's `temperature`` sampling option
pub
const
TEMPERATURE_RANGE
:
(
f32
,
f32
)
=
(
MIN_TEMPERATURE
,
MAX_TEMPERATURE
);
/// Minimum allowed value for OpenAI's `top_p` sampling option
pub
const
MIN_TOP_P
:
f32
=
0.0
;
/// Maximum allowed value for OpenAI's `top_p` sampling option
pub
const
MAX_TOP_P
:
f32
=
1.0
;
/// Allowed range of values for OpenAI's `top_p` sampling option
pub
const
TOP_P_RANGE
:
(
f32
,
f32
)
=
(
MIN_TOP_P
,
MAX_TOP_P
);
/// Minimum allowed value for OpenAI's `frequency_penalty` sampling option
pub
const
MIN_FREQUENCY_PENALTY
:
f32
=
-
2.0
;
/// Maximum allowed value for OpenAI's `frequency_penalty` sampling option
pub
const
MAX_FREQUENCY_PENALTY
:
f32
=
2.0
;
/// Allowed range of values for OpenAI's `frequency_penalty` sampling option
pub
const
FREQUENCY_PENALTY_RANGE
:
(
f32
,
f32
)
=
(
MIN_FREQUENCY_PENALTY
,
MAX_FREQUENCY_PENALTY
);
/// Minimum allowed value for OpenAI's `presence_penalty` sampling option
pub
const
MIN_PRESENCE_PENALTY
:
f32
=
-
2.0
;
/// Maximum allowed value for OpenAI's `presence_penalty` sampling option
pub
const
MAX_PRESENCE_PENALTY
:
f32
=
2.0
;
/// Allowed range of values for OpenAI's `presence_penalty` sampling option
pub
const
PRESENCE_PENALTY_RANGE
:
(
f32
,
f32
)
=
(
MIN_PRESENCE_PENALTY
,
MAX_PRESENCE_PENALTY
);
#[derive(Serialize,
Deserialize,
Debug)]
#[derive(Serialize,
Deserialize,
Debug)]
pub
struct
AnnotatedDelta
<
R
>
{
pub
struct
AnnotatedDelta
<
R
>
{
...
@@ -166,21 +133,6 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
...
@@ -166,21 +133,6 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
}
}
}
}
// todo - move to common location
fn
validate_range
<
T
>
(
value
:
Option
<
T
>
,
range
:
&
(
T
,
T
))
->
Result
<
Option
<
T
>>
where
T
:
PartialOrd
+
Display
,
{
if
value
.is_none
()
{
return
Ok
(
None
);
}
let
value
=
value
.unwrap
();
if
value
<
range
.0
||
value
>
range
.1
{
anyhow
::
bail!
(
"Value {} is out of range [{}, {}]"
,
value
,
range
.0
,
range
.1
);
}
Ok
(
Some
(
value
))
}
pub
trait
DeltaGeneratorExt
<
ResponseType
:
Send
+
Sync
+
'static
+
std
::
fmt
::
Debug
>
:
pub
trait
DeltaGeneratorExt
<
ResponseType
:
Send
+
Sync
+
'static
+
std
::
fmt
::
Debug
>
:
Send
+
Sync
+
'static
Send
+
Sync
+
'static
{
{
...
...
lib/llm/src/protocols/openai/chat_completions.rs
View file @
ee86bad3
...
@@ -17,10 +17,12 @@ use dynamo_runtime::protocols::annotated::AnnotationsProvider;
...
@@ -17,10 +17,12 @@ use dynamo_runtime::protocols::annotated::AnnotationsProvider;
use
serde
::{
Deserialize
,
Serialize
};
use
serde
::{
Deserialize
,
Serialize
};
use
validator
::
Validate
;
use
validator
::
Validate
;
use
super
::
nvext
::
NvExt
;
use
crate
::
engines
::
ValidateRequest
;
use
super
::
nvext
::
NvExtProvider
;
use
super
::
OpenAISamplingOptionsProvider
;
use
super
::{
use
super
::
OpenAIStopConditionsProvider
;
nvext
::
NvExt
,
nvext
::
NvExtProvider
,
validate
,
OpenAISamplingOptionsProvider
,
OpenAIStopConditionsProvider
,
};
mod
aggregator
;
mod
aggregator
;
mod
delta
;
mod
delta
;
...
@@ -174,3 +176,42 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
...
@@ -174,3 +176,42 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
self
.nvext
.as_ref
()
self
.nvext
.as_ref
()
}
}
}
}
/// Implements `ValidateRequest` for `NvCreateChatCompletionRequest`,
/// allowing us to validate the data.
impl
ValidateRequest
for
NvCreateChatCompletionRequest
{
fn
validate
(
&
self
)
->
Result
<
(),
anyhow
::
Error
>
{
validate
::
validate_messages
(
&
self
.inner.messages
)
?
;
validate
::
validate_model
(
&
self
.inner.model
)
?
;
// none for store
validate
::
validate_reasoning_effort
(
&
self
.inner.reasoning_effort
)
?
;
validate
::
validate_metadata
(
&
self
.inner.metadata
)
?
;
validate
::
validate_frequency_penalty
(
self
.inner.frequency_penalty
)
?
;
validate
::
validate_logit_bias
(
&
self
.inner.logit_bias
)
?
;
// none for logprobs
validate
::
validate_top_logprobs
(
self
.inner.top_logprobs
)
?
;
// validate::validate_max_tokens(self.inner.max_tokens)?; // warning depricated field
validate
::
validate_max_completion_tokens
(
self
.inner.max_completion_tokens
)
?
;
validate
::
validate_n
(
self
.inner.n
)
?
;
// none for modalities
// none for prediction
// none for audio
validate
::
validate_presence_penalty
(
self
.inner.presence_penalty
)
?
;
// none for response_format
// none for seed
validate
::
validate_service_tier
(
&
self
.inner.service_tier
)
?
;
validate
::
validate_stop
(
&
self
.inner.stop
)
?
;
// none for stream
// none for stream_options
validate
::
validate_temperature
(
self
.inner.temperature
)
?
;
validate
::
validate_top_p
(
self
.inner.top_p
)
?
;
validate
::
validate_tools
(
&
self
.inner.tools
.as_deref
())
?
;
// none for tool_choice
// none for parallel_tool_calls
validate
::
validate_user
(
self
.inner.user
.as_deref
())
?
;
// none for function call
// none for functions
Ok
(())
}
}
lib/llm/src/protocols/openai/completions.rs
View file @
ee86bad3
...
@@ -18,10 +18,12 @@ use dynamo_runtime::protocols::annotated::AnnotationsProvider;
...
@@ -18,10 +18,12 @@ use dynamo_runtime::protocols::annotated::AnnotationsProvider;
use
serde
::{
Deserialize
,
Serialize
};
use
serde
::{
Deserialize
,
Serialize
};
use
validator
::
Validate
;
use
validator
::
Validate
;
use
crate
::
engines
::
ValidateRequest
;
use
super
::{
use
super
::{
common
::{
self
,
SamplingOptionsProvider
,
StopConditionsProvider
},
common
::{
self
,
SamplingOptionsProvider
,
StopConditionsProvider
},
nvext
::{
NvExt
,
NvExtProvider
},
nvext
::{
NvExt
,
NvExtProvider
},
ContentProvider
,
OpenAISamplingOptionsProvider
,
OpenAIStopConditionsProvider
,
validate
,
ContentProvider
,
OpenAISamplingOptionsProvider
,
OpenAIStopConditionsProvider
,
};
};
mod
aggregator
;
mod
aggregator
;
...
@@ -275,3 +277,30 @@ impl TryFrom<common::StreamingCompletionResponse> for async_openai::types::Choic
...
@@ -275,3 +277,30 @@ impl TryFrom<common::StreamingCompletionResponse> for async_openai::types::Choic
Ok
(
choice
)
Ok
(
choice
)
}
}
}
}
/// Implements `ValidateRequest` for `NvCreateCompletionRequest`,
/// allowing us to validate the data.
impl
ValidateRequest
for
NvCreateCompletionRequest
{
fn
validate
(
&
self
)
->
Result
<
(),
anyhow
::
Error
>
{
validate
::
validate_model
(
&
self
.inner.model
)
?
;
validate
::
validate_prompt
(
&
self
.inner.prompt
)
?
;
validate
::
validate_suffix
(
self
.inner.suffix
.as_deref
())
?
;
validate
::
validate_max_tokens
(
self
.inner.max_tokens
)
?
;
validate
::
validate_temperature
(
self
.inner.temperature
)
?
;
validate
::
validate_top_p
(
self
.inner.top_p
)
?
;
validate
::
validate_n
(
self
.inner.n
)
?
;
// none for stream
// none for stream_options
validate
::
validate_logprobs
(
self
.inner.logprobs
)
?
;
// none for echo
validate
::
validate_stop
(
&
self
.inner.stop
)
?
;
validate
::
validate_presence_penalty
(
self
.inner.presence_penalty
)
?
;
validate
::
validate_frequency_penalty
(
self
.inner.frequency_penalty
)
?
;
validate
::
validate_best_of
(
self
.inner.best_of
,
self
.inner.n
)
?
;
validate
::
validate_logit_bias
(
&
self
.inner.logit_bias
)
?
;
validate
::
validate_user
(
self
.inner.user
.as_deref
())
?
;
// none for seed
Ok
(())
}
}
lib/llm/src/protocols/openai/validate.rs
0 → 100644
View file @
ee86bad3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use
std
::
fmt
::
Display
;
//
// Hyperparameter Contraints
//
/// Minimum allowed value for OpenAI's `temperature` sampling option
pub
const
MIN_TEMPERATURE
:
f32
=
0.0
;
/// Maximum allowed value for OpenAI's `temperature` sampling option
pub
const
MAX_TEMPERATURE
:
f32
=
2.0
;
/// Allowed range of values for OpenAI's `temperature`` sampling option
pub
const
TEMPERATURE_RANGE
:
(
f32
,
f32
)
=
(
MIN_TEMPERATURE
,
MAX_TEMPERATURE
);
/// Minimum allowed value for OpenAI's `top_p` sampling option
pub
const
MIN_TOP_P
:
f32
=
0.0
;
/// Maximum allowed value for OpenAI's `top_p` sampling option
pub
const
MAX_TOP_P
:
f32
=
1.0
;
/// Allowed range of values for OpenAI's `top_p` sampling option
pub
const
TOP_P_RANGE
:
(
f32
,
f32
)
=
(
MIN_TOP_P
,
MAX_TOP_P
);
/// Minimum allowed value for OpenAI's `frequency_penalty` sampling option
pub
const
MIN_FREQUENCY_PENALTY
:
f32
=
-
2.0
;
/// Maximum allowed value for OpenAI's `frequency_penalty` sampling option
pub
const
MAX_FREQUENCY_PENALTY
:
f32
=
2.0
;
/// Allowed range of values for OpenAI's `frequency_penalty` sampling option
pub
const
FREQUENCY_PENALTY_RANGE
:
(
f32
,
f32
)
=
(
MIN_FREQUENCY_PENALTY
,
MAX_FREQUENCY_PENALTY
);
/// Minimum allowed value for OpenAI's `presence_penalty` sampling option
pub
const
MIN_PRESENCE_PENALTY
:
f32
=
-
2.0
;
/// Maximum allowed value for OpenAI's `presence_penalty` sampling option
pub
const
MAX_PRESENCE_PENALTY
:
f32
=
2.0
;
/// Allowed range of values for OpenAI's `presence_penalty` sampling option
pub
const
PRESENCE_PENALTY_RANGE
:
(
f32
,
f32
)
=
(
MIN_PRESENCE_PENALTY
,
MAX_PRESENCE_PENALTY
);
/// Maximum allowed value for `top_logprobs`
pub
const
MIN_TOP_LOGPROBS
:
u8
=
0
;
/// Maximum allowed value for `top_logprobs`
pub
const
MAX_TOP_LOGPROBS
:
u8
=
20
;
/// Minimum allowed value for `logprobs` in completion requests
pub
const
MIN_LOGPROBS
:
u8
=
0
;
/// Maximum allowed value for `logprobs` in completion requests
pub
const
MAX_LOGPROBS
:
u8
=
5
;
/// Minimum allowed value for `n` (number of choices)
pub
const
MIN_N
:
u8
=
1
;
/// Maximum allowed value for `n` (number of choices)
pub
const
MAX_N
:
u8
=
128
;
/// Minimum allowed value for OpenAI's `logit_bias` values
pub
const
MIN_LOGIT_BIAS
:
f32
=
-
100.0
;
/// Maximum allowed value for OpenAI's `logit_bias` values
pub
const
MAX_LOGIT_BIAS
:
f32
=
100.0
;
/// Minimum allowed value for `best_of`
pub
const
MIN_BEST_OF
:
u8
=
0
;
/// Maximum allowed value for `best_of`
pub
const
MAX_BEST_OF
:
u8
=
20
;
/// Maximum allowed number of stop sequences
pub
const
MAX_STOP_SEQUENCES
:
usize
=
4
;
/// Maximum allowed number of tools
pub
const
MAX_TOOLS
:
usize
=
128
;
/// Maximum allowed number of metadata key-value pairs
pub
const
MAX_METADATA_PAIRS
:
usize
=
16
;
/// Maximum allowed length for metadata keys
pub
const
MAX_METADATA_KEY_LENGTH
:
usize
=
64
;
/// Maximum allowed length for metadata values
pub
const
MAX_METADATA_VALUE_LENGTH
:
usize
=
512
;
/// Maximum allowed length for function names
pub
const
MAX_FUNCTION_NAME_LENGTH
:
usize
=
64
;
/// Maximum allowed value for Prompt IntegerArray elements
pub
const
MAX_PROMPT_TOKEN_ID
:
u32
=
50256
;
//
// Shared Fields
//
/// Validates the temperature parameter
pub
fn
validate_temperature
(
temperature
:
Option
<
f32
>
)
->
Result
<
(),
anyhow
::
Error
>
{
if
let
Some
(
temp
)
=
temperature
{
if
!
(
MIN_TEMPERATURE
..=
MAX_TEMPERATURE
)
.contains
(
&
temp
)
{
anyhow
::
bail!
(
"Temperature must be between {} and {}, got {}"
,
MIN_TEMPERATURE
,
MAX_TEMPERATURE
,
temp
);
}
}
Ok
(())
}
/// Validates the top_p parameter
pub
fn
validate_top_p
(
top_p
:
Option
<
f32
>
)
->
Result
<
(),
anyhow
::
Error
>
{
if
let
Some
(
p
)
=
top_p
{
if
!
(
MIN_TOP_P
..=
MAX_TOP_P
)
.contains
(
&
p
)
{
anyhow
::
bail!
(
"Top_p must be between {} and {}, got {}"
,
MIN_TOP_P
,
MAX_TOP_P
,
p
);
}
}
Ok
(())
}
/// Validates mutual exclusion of temperature and top_p
pub
fn
validate_temperature_top_p_exclusion
(
temperature
:
Option
<
f32
>
,
top_p
:
Option
<
f32
>
,
)
->
Result
<
(),
anyhow
::
Error
>
{
match
(
temperature
,
top_p
)
{
(
Some
(
t
),
Some
(
p
))
if
t
!=
1.0
&&
p
!=
1.0
=>
{
anyhow
::
bail!
(
"Only one of temperature or top_p should be set (not both)"
);
}
_
=>
Ok
(()),
}
}
/// Validates frequency penalty parameter
pub
fn
validate_frequency_penalty
(
frequency_penalty
:
Option
<
f32
>
)
->
Result
<
(),
anyhow
::
Error
>
{
if
let
Some
(
penalty
)
=
frequency_penalty
{
if
!
(
MIN_FREQUENCY_PENALTY
..=
MAX_FREQUENCY_PENALTY
)
.contains
(
&
penalty
)
{
anyhow
::
bail!
(
"Frequency penalty must be between {} and {}, got {}"
,
MIN_FREQUENCY_PENALTY
,
MAX_FREQUENCY_PENALTY
,
penalty
);
}
}
Ok
(())
}
/// Validates presence penalty parameter
pub
fn
validate_presence_penalty
(
presence_penalty
:
Option
<
f32
>
)
->
Result
<
(),
anyhow
::
Error
>
{
if
let
Some
(
penalty
)
=
presence_penalty
{
if
!
(
MIN_PRESENCE_PENALTY
..=
MAX_PRESENCE_PENALTY
)
.contains
(
&
penalty
)
{
anyhow
::
bail!
(
"Presence penalty must be between {} and {}, got {}"
,
MIN_PRESENCE_PENALTY
,
MAX_PRESENCE_PENALTY
,
penalty
);
}
}
Ok
(())
}
/// Validates logit bias map
pub
fn
validate_logit_bias
(
logit_bias
:
&
Option
<
std
::
collections
::
HashMap
<
String
,
serde_json
::
Value
>>
,
)
->
Result
<
(),
anyhow
::
Error
>
{
let
logit_bias
=
match
logit_bias
{
Some
(
val
)
=>
val
,
None
=>
return
Ok
(()),
};
for
(
token
,
bias_value
)
in
logit_bias
{
let
bias
=
bias_value
.as_f64
()
.ok_or_else
(||
{
anyhow
::
anyhow!
(
"Logit bias value for token '{}' must be a number, got {:?}"
,
token
,
bias_value
)
})
?
as
f32
;
if
!
(
MIN_LOGIT_BIAS
..=
MAX_LOGIT_BIAS
)
.contains
(
&
bias
)
{
anyhow
::
bail!
(
"Logit bias for token '{}' must be between {} and {}, got {}"
,
token
,
MIN_LOGIT_BIAS
,
MAX_LOGIT_BIAS
,
bias
);
}
}
Ok
(())
}
/// Validates n parameter (number of choices)
pub
fn
validate_n
(
n
:
Option
<
u8
>
)
->
Result
<
(),
anyhow
::
Error
>
{
if
let
Some
(
value
)
=
n
{
if
!
(
MIN_N
..=
MAX_N
)
.contains
(
&
value
)
{
anyhow
::
bail!
(
"n must be between {} and {}, got {}"
,
MIN_N
,
MAX_N
,
value
);
}
}
Ok
(())
}
/// Validates model parameter
pub
fn
validate_model
(
model
:
&
str
)
->
Result
<
(),
anyhow
::
Error
>
{
if
model
.trim
()
.is_empty
()
{
anyhow
::
bail!
(
"Model cannot be empty"
);
}
Ok
(())
}
/// Validates user parameter
pub
fn
validate_user
(
user
:
Option
<&
str
>
)
->
Result
<
(),
anyhow
::
Error
>
{
if
let
Some
(
user_id
)
=
user
{
if
user_id
.trim
()
.is_empty
()
{
anyhow
::
bail!
(
"User ID cannot be empty"
);
}
}
Ok
(())
}
/// Validates stop sequences
pub
fn
validate_stop
(
stop
:
&
Option
<
async_openai
::
types
::
Stop
>
)
->
Result
<
(),
anyhow
::
Error
>
{
if
let
Some
(
stop_value
)
=
stop
{
match
stop_value
{
async_openai
::
types
::
Stop
::
String
(
s
)
=>
{
if
s
.is_empty
()
{
anyhow
::
bail!
(
"Stop sequence cannot be empty"
);
}
}
async_openai
::
types
::
Stop
::
StringArray
(
sequences
)
=>
{
if
sequences
.is_empty
()
{
anyhow
::
bail!
(
"Stop sequences array cannot be empty"
);
}
if
sequences
.len
()
>
MAX_STOP_SEQUENCES
{
anyhow
::
bail!
(
"Maximum of {} stop sequences allowed, got {}"
,
MAX_STOP_SEQUENCES
,
sequences
.len
()
);
}
for
(
i
,
sequence
)
in
sequences
.iter
()
.enumerate
()
{
if
sequence
.is_empty
()
{
anyhow
::
bail!
(
"Stop sequence at index {} cannot be empty"
,
i
);
}
}
}
}
}
Ok
(())
}
//
// Chat Completion Specific
//
/// Validates messages array
pub
fn
validate_messages
(
messages
:
&
[
async_openai
::
types
::
ChatCompletionRequestMessage
],
)
->
Result
<
(),
anyhow
::
Error
>
{
if
messages
.is_empty
()
{
anyhow
::
bail!
(
"Messages array cannot be empty"
);
}
Ok
(())
}
/// Validates top_logprobs parameter
pub
fn
validate_top_logprobs
(
top_logprobs
:
Option
<
u8
>
)
->
Result
<
(),
anyhow
::
Error
>
{
if
let
Some
(
value
)
=
top_logprobs
{
if
!
(
0
..=
20
)
.contains
(
&
value
)
{
anyhow
::
bail!
(
"Top_logprobs must be between 0 and {}, got {}"
,
MAX_TOP_LOGPROBS
,
value
);
}
}
Ok
(())
}
/// Validates tools array
pub
fn
validate_tools
(
tools
:
&
Option
<&
[
async_openai
::
types
::
ChatCompletionTool
]
>
,
)
->
Result
<
(),
anyhow
::
Error
>
{
let
tools
=
match
tools
{
Some
(
val
)
=>
val
,
None
=>
return
Ok
(()),
};
if
tools
.len
()
>
MAX_TOOLS
{
anyhow
::
bail!
(
"Maximum of {} tools are supported, got {}"
,
MAX_TOOLS
,
tools
.len
()
);
}
for
(
i
,
tool
)
in
tools
.iter
()
.enumerate
()
{
if
tool
.function.name
.len
()
>
MAX_FUNCTION_NAME_LENGTH
{
anyhow
::
bail!
(
"Function name at index {} exceeds {} character limit, got {} characters"
,
i
,
MAX_FUNCTION_NAME_LENGTH
,
tool
.function.name
.len
()
);
}
if
tool
.function.name
.trim
()
.is_empty
()
{
anyhow
::
bail!
(
"Function name at index {} cannot be empty"
,
i
);
}
}
Ok
(())
}
/// Validates metadata
pub
fn
validate_metadata
(
metadata
:
&
Option
<
serde_json
::
Value
>
)
->
Result
<
(),
anyhow
::
Error
>
{
let
metadata
=
match
metadata
{
Some
(
val
)
=>
val
,
None
=>
return
Ok
(()),
};
if
let
Some
(
obj
)
=
metadata
.as_object
()
{
if
obj
.len
()
>
MAX_METADATA_PAIRS
{
anyhow
::
bail!
(
"Metadata cannot have more than {} key-value pairs, got {}"
,
MAX_METADATA_PAIRS
,
obj
.len
()
);
}
for
(
key
,
value
)
in
obj
{
if
key
.len
()
>
MAX_METADATA_KEY_LENGTH
{
anyhow
::
bail!
(
"Metadata key '{}' exceeds {} character limit"
,
key
,
MAX_METADATA_KEY_LENGTH
);
}
if
let
Some
(
value_str
)
=
value
.as_str
()
{
if
value_str
.len
()
>
MAX_METADATA_VALUE_LENGTH
{
anyhow
::
bail!
(
"Metadata value for key '{}' exceeds {} character limit"
,
key
,
MAX_METADATA_VALUE_LENGTH
);
}
}
}
}
Ok
(())
}
/// Validates reasoning effort parameter
pub
fn
validate_reasoning_effort
(
_
reasoning_effort
:
&
Option
<
async_openai
::
types
::
ReasoningEffort
>
,
)
->
Result
<
(),
anyhow
::
Error
>
{
// TODO ADD HERE
// ReasoningEffort is an enum, so if it exists, it's valid by definition
// This function is here for completeness and future validation needs
Ok
(())
}
/// Validates service tier parameter
pub
fn
validate_service_tier
(
_
service_tier
:
&
Option
<
async_openai
::
types
::
ServiceTier
>
,
)
->
Result
<
(),
anyhow
::
Error
>
{
// TODO ADD HERE
// ServiceTier is an enum, so if it exists, it's valid by definition
// This function is here for completeness and future validation needs
Ok
(())
}
//
// Completion Specific
//
/// Validates prompt
pub
fn
validate_prompt
(
prompt
:
&
async_openai
::
types
::
Prompt
)
->
Result
<
(),
anyhow
::
Error
>
{
match
prompt
{
async_openai
::
types
::
Prompt
::
String
(
s
)
=>
{
if
s
.is_empty
()
{
anyhow
::
bail!
(
"Prompt string cannot be empty"
);
}
}
async_openai
::
types
::
Prompt
::
StringArray
(
arr
)
=>
{
if
arr
.is_empty
()
{
anyhow
::
bail!
(
"Prompt string array cannot be empty"
);
}
for
(
i
,
s
)
in
arr
.iter
()
.enumerate
()
{
if
s
.is_empty
()
{
anyhow
::
bail!
(
"Prompt string at index {} cannot be empty"
,
i
);
}
}
}
async_openai
::
types
::
Prompt
::
IntegerArray
(
arr
)
=>
{
if
arr
.is_empty
()
{
anyhow
::
bail!
(
"Prompt integer array cannot be empty"
);
}
for
(
i
,
&
token_id
)
in
arr
.iter
()
.enumerate
()
{
if
token_id
>
MAX_PROMPT_TOKEN_ID
{
anyhow
::
bail!
(
"Token ID at index {} must be between 0 and {}, got {}"
,
i
,
MAX_PROMPT_TOKEN_ID
,
token_id
);
}
}
}
async_openai
::
types
::
Prompt
::
ArrayOfIntegerArray
(
arr
)
=>
{
if
arr
.is_empty
()
{
anyhow
::
bail!
(
"Prompt array of integer arrays cannot be empty"
);
}
for
(
i
,
inner_arr
)
in
arr
.iter
()
.enumerate
()
{
if
inner_arr
.is_empty
()
{
anyhow
::
bail!
(
"Prompt integer array at index {} cannot be empty"
,
i
);
}
for
(
j
,
&
token_id
)
in
inner_arr
.iter
()
.enumerate
()
{
if
token_id
>
MAX_PROMPT_TOKEN_ID
{
anyhow
::
bail!
(
"Token ID at index [{}][{}] must be between 0 and {}, got {}"
,
i
,
j
,
MAX_PROMPT_TOKEN_ID
,
token_id
);
}
}
}
}
}
Ok
(())
}
/// Validates logprobs parameter (for completion requests)
pub
fn
validate_logprobs
(
logprobs
:
Option
<
u8
>
)
->
Result
<
(),
anyhow
::
Error
>
{
if
let
Some
(
value
)
=
logprobs
{
if
!
(
MIN_LOGPROBS
..=
MAX_LOGPROBS
)
.contains
(
&
value
)
{
anyhow
::
bail!
(
"Logprobs must be between 0 and {}, got {}"
,
MAX_LOGPROBS
,
value
);
}
}
Ok
(())
}
/// Validates best_of parameter
pub
fn
validate_best_of
(
best_of
:
Option
<
u8
>
,
n
:
Option
<
u8
>
)
->
Result
<
(),
anyhow
::
Error
>
{
if
let
Some
(
best_of_value
)
=
best_of
{
if
!
(
MIN_BEST_OF
..=
MAX_BEST_OF
)
.contains
(
&
best_of_value
)
{
anyhow
::
bail!
(
"Best_of must be between 0 and {}, got {}"
,
MAX_BEST_OF
,
best_of_value
);
}
if
let
Some
(
n_value
)
=
n
{
if
best_of_value
<
n_value
{
anyhow
::
bail!
(
"Best_of must be greater than or equal to n, got best_of={} and n={}"
,
best_of_value
,
n_value
);
}
}
}
Ok
(())
}
/// Validates suffix parameter
pub
fn
validate_suffix
(
suffix
:
Option
<&
str
>
)
->
Result
<
(),
anyhow
::
Error
>
{
if
let
Some
(
suffix_str
)
=
suffix
{
// Suffix can be empty, but if it's very long it might cause issues
if
suffix_str
.len
()
>
10000
{
anyhow
::
bail!
(
"Suffix is too long, maximum 10000 characters"
);
}
}
Ok
(())
}
/// Validates max_tokens parameter
pub
fn
validate_max_tokens
(
max_tokens
:
Option
<
u32
>
)
->
Result
<
(),
anyhow
::
Error
>
{
if
let
Some
(
tokens
)
=
max_tokens
{
if
tokens
==
0
{
anyhow
::
bail!
(
"Max tokens must be greater than 0, got {}"
,
tokens
);
}
}
Ok
(())
}
/// Validates max_completion_tokens parameter
pub
fn
validate_max_completion_tokens
(
max_completion_tokens
:
Option
<
u32
>
,
)
->
Result
<
(),
anyhow
::
Error
>
{
if
let
Some
(
tokens
)
=
max_completion_tokens
{
if
tokens
==
0
{
anyhow
::
bail!
(
"Max completion tokens must be greater than 0, got {}"
,
tokens
);
}
}
Ok
(())
}
//
// Helpers
//
pub
fn
validate_range
<
T
>
(
value
:
Option
<
T
>
,
range
:
&
(
T
,
T
))
->
anyhow
::
Result
<
Option
<
T
>>
where
T
:
PartialOrd
+
Display
,
{
if
value
.is_none
()
{
return
Ok
(
None
);
}
let
value
=
value
.unwrap
();
if
value
<
range
.0
||
value
>
range
.1
{
anyhow
::
bail!
(
"Value {} is out of range [{}, {}]"
,
value
,
range
.0
,
range
.1
);
}
Ok
(
Some
(
value
))
}
lib/llm/tests/openai_completions.rs
View file @
ee86bad3
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
// limitations under the License.
// limitations under the License.
use
async_openai
::
types
::
CreateCompletionRequestArgs
;
use
async_openai
::
types
::
CreateCompletionRequestArgs
;
use
dynamo_llm
::
protocols
::
openai
::{
self
,
completions
::
NvCreateCompletionRequest
};
use
dynamo_llm
::
protocols
::
openai
::{
completions
::
NvCreateCompletionRequest
,
validate
};
use
serde
::{
Deserialize
,
Serialize
};
use
serde
::{
Deserialize
,
Serialize
};
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
#[derive(Serialize,
Deserialize,
Debug,
Clone)]
...
@@ -89,22 +89,22 @@ fn build_samples() -> Result<Vec<CompletionSample>, String> {
...
@@ -89,22 +89,22 @@ fn build_samples() -> Result<Vec<CompletionSample>, String> {
samples
.push
(
CompletionSample
::
new
(
samples
.push
(
CompletionSample
::
new
(
"should have prompt, model, and temperature fields"
,
"should have prompt, model, and temperature fields"
,
|
builder
|
builder
.temperature
(
openai
::
MIN_TEMPERATURE
),
|
builder
|
builder
.temperature
(
validate
::
MIN_TEMPERATURE
),
)
?
);
)
?
);
samples
.push
(
CompletionSample
::
new
(
samples
.push
(
CompletionSample
::
new
(
"should have prompt, model, and top_p fields"
,
"should have prompt, model, and top_p fields"
,
|
builder
|
builder
.top_p
(
openai
::
MIN_TOP_P
),
|
builder
|
builder
.top_p
(
validate
::
MIN_TOP_P
),
)
?
);
)
?
);
samples
.push
(
CompletionSample
::
new
(
samples
.push
(
CompletionSample
::
new
(
"should have prompt, model, and frequency_penalty fields"
,
"should have prompt, model, and frequency_penalty fields"
,
|
builder
|
builder
.frequency_penalty
(
openai
::
MIN_FREQUENCY_PENALTY
),
|
builder
|
builder
.frequency_penalty
(
validate
::
MIN_FREQUENCY_PENALTY
),
)
?
);
)
?
);
samples
.push
(
CompletionSample
::
new
(
samples
.push
(
CompletionSample
::
new
(
"should have prompt, model, and presence_penalty fields"
,
"should have prompt, model, and presence_penalty fields"
,
|
builder
|
builder
.presence_penalty
(
openai
::
MIN_PRESENCE_PENALTY
),
|
builder
|
builder
.presence_penalty
(
validate
::
MIN_PRESENCE_PENALTY
),
)
?
);
)
?
);
samples
.push
(
CompletionSample
::
new
(
samples
.push
(
CompletionSample
::
new
(
...
...
lib/runtime/src/pipeline/context.rs
View file @
ee86bad3
...
@@ -48,6 +48,15 @@ impl<T: Send + Sync + 'static> Context<T> {
...
@@ -48,6 +48,15 @@ impl<T: Send + Sync + 'static> Context<T> {
}
}
}
}
pub
fn
rejoin
<
U
:
Send
+
Sync
+
'static
>
(
current
:
T
,
context
:
Context
<
U
>
)
->
Self
{
Context
{
current
,
controller
:
context
.controller
,
registry
:
context
.registry
,
stages
:
context
.stages
,
}
}
pub
fn
with_controller
(
current
:
T
,
controller
:
Controller
)
->
Self
{
pub
fn
with_controller
(
current
:
T
,
controller
:
Controller
)
->
Self
{
Context
{
Context
{
current
,
current
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment