Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
3de04dd9
Unverified
Commit
3de04dd9
authored
Sep 17, 2025
by
Greg Clark
Committed by
GitHub
Sep 17, 2025
Browse files
chore: fillout sampling params (seed, n, best_of, min_p) (#3055)
Signed-off-by:
Greg Clark
<
grclark@nvidia.com
>
parent
e2c0e8d1
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
108 additions
and
7 deletions
+108
-7
lib/llm/src/protocols/common.rs
lib/llm/src/protocols/common.rs
+2
-2
lib/llm/src/protocols/openai.rs
lib/llm/src/protocols/openai.rs
+20
-5
lib/llm/src/protocols/openai/chat_completions.rs
lib/llm/src/protocols/openai/chat_completions.rs
+22
-0
lib/llm/src/protocols/openai/common_ext.rs
lib/llm/src/protocols/openai/common_ext.rs
+8
-0
lib/llm/src/protocols/openai/completions.rs
lib/llm/src/protocols/openai/completions.rs
+20
-0
lib/llm/src/protocols/openai/nvext.rs
lib/llm/src/protocols/openai/nvext.rs
+6
-0
lib/llm/src/protocols/openai/responses.rs
lib/llm/src/protocols/openai/responses.rs
+12
-0
lib/llm/src/protocols/openai/validate.rs
lib/llm/src/protocols/openai/validate.rs
+18
-0
No files found.
lib/llm/src/protocols/common.rs
View file @
3de04dd9
...
@@ -274,14 +274,14 @@ pub const FREQUENCY_PENALTY_RANGE: (f32, f32) = (-1.0, 1.0);
...
@@ -274,14 +274,14 @@ pub const FREQUENCY_PENALTY_RANGE: (f32, f32) = (-1.0, 1.0);
#[derive(Serialize,
Deserialize,
Debug,
Clone,
Default)]
#[derive(Serialize,
Deserialize,
Debug,
Clone,
Default)]
pub
struct
SamplingOptions
{
pub
struct
SamplingOptions
{
/// Number of output sequences to return for the given prompt
/// Number of output sequences to return for the given prompt
pub
n
:
Option
<
i32
>
,
pub
n
:
Option
<
u8
>
,
/// Number of output sequences that are generated from the prompt.
/// Number of output sequences that are generated from the prompt.
/// From these `best_of` sequences, the top `n` sequences are returned.
/// From these `best_of` sequences, the top `n` sequences are returned.
/// `best_of` must be greater than or equal to `n`. This is treated as
/// `best_of` must be greater than or equal to `n`. This is treated as
/// the beam width when `use_beam_search` is True. By default, `best_of`
/// the beam width when `use_beam_search` is True. By default, `best_of`
/// is set to `n`.
/// is set to `n`.
pub
best_of
:
Option
<
i32
>
,
pub
best_of
:
Option
<
u8
>
,
/// Float that penalizes new tokens based on whether they
/// Float that penalizes new tokens based on whether they
/// appear in the generated text so far. Values > 0 encourage the model
/// appear in the generated text so far. Values > 0 encourage the model
...
...
lib/llm/src/protocols/openai.rs
View file @
3de04dd9
...
@@ -20,7 +20,8 @@ pub mod responses;
...
@@ -20,7 +20,8 @@ pub mod responses;
pub
mod
validate
;
pub
mod
validate
;
use
validate
::{
use
validate
::{
FREQUENCY_PENALTY_RANGE
,
PRESENCE_PENALTY_RANGE
,
TEMPERATURE_RANGE
,
TOP_P_RANGE
,
validate_range
,
BEST_OF_RANGE
,
FREQUENCY_PENALTY_RANGE
,
MIN_P_RANGE
,
N_RANGE
,
PRESENCE_PENALTY_RANGE
,
TEMPERATURE_RANGE
,
TOP_P_RANGE
,
validate_range
,
};
};
#[derive(Serialize,
Deserialize,
Debug)]
#[derive(Serialize,
Deserialize,
Debug)]
...
@@ -40,6 +41,12 @@ trait OpenAISamplingOptionsProvider {
...
@@ -40,6 +41,12 @@ trait OpenAISamplingOptionsProvider {
fn
get_presence_penalty
(
&
self
)
->
Option
<
f32
>
;
fn
get_presence_penalty
(
&
self
)
->
Option
<
f32
>
;
fn
get_seed
(
&
self
)
->
Option
<
i64
>
;
fn
get_n
(
&
self
)
->
Option
<
u8
>
;
fn
get_best_of
(
&
self
)
->
Option
<
u8
>
;
fn
nvext
(
&
self
)
->
Option
<&
nvext
::
NvExt
>
;
fn
nvext
(
&
self
)
->
Option
<&
nvext
::
NvExt
>
;
}
}
...
@@ -104,6 +111,14 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
...
@@ -104,6 +111,14 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
let
top_k
=
CommonExtProvider
::
get_top_k
(
self
);
let
top_k
=
CommonExtProvider
::
get_top_k
(
self
);
let
repetition_penalty
=
CommonExtProvider
::
get_repetition_penalty
(
self
);
let
repetition_penalty
=
CommonExtProvider
::
get_repetition_penalty
(
self
);
let
include_stop_str_in_output
=
CommonExtProvider
::
get_include_stop_str_in_output
(
self
);
let
include_stop_str_in_output
=
CommonExtProvider
::
get_include_stop_str_in_output
(
self
);
let
seed
=
self
.get_seed
();
let
n
=
validate_range
(
self
.get_n
(),
&
N_RANGE
)
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Error validating n: {}"
,
e
))
?
;
let
best_of
=
validate_range
(
self
.get_best_of
(),
&
BEST_OF_RANGE
)
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Error validating best_of: {}"
,
e
))
?
;
let
min_p
=
validate_range
(
CommonExtProvider
::
get_min_p
(
self
),
&
MIN_P_RANGE
)
.map_err
(|
e
|
anyhow
::
anyhow!
(
"Error validating min_p: {}"
,
e
))
?
;
if
let
Some
(
nvext
)
=
self
.nvext
()
{
if
let
Some
(
nvext
)
=
self
.nvext
()
{
let
greedy
=
nvext
.greed_sampling
.unwrap_or
(
false
);
let
greedy
=
nvext
.greed_sampling
.unwrap_or
(
false
);
...
@@ -135,16 +150,16 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
...
@@ -135,16 +150,16 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
};
};
Ok
(
common
::
SamplingOptions
{
Ok
(
common
::
SamplingOptions
{
n
:
None
,
n
,
best_of
:
None
,
best_of
,
frequency_penalty
,
frequency_penalty
,
presence_penalty
,
presence_penalty
,
repetition_penalty
,
repetition_penalty
,
temperature
,
temperature
,
top_p
,
top_p
,
top_k
,
top_k
,
min_p
:
None
,
min_p
,
seed
:
None
,
seed
,
use_beam_search
:
None
,
use_beam_search
:
None
,
length_penalty
:
None
,
length_penalty
:
None
,
guided_decoding
,
guided_decoding
,
...
...
lib/llm/src/protocols/openai/chat_completions.rs
View file @
3de04dd9
...
@@ -131,6 +131,20 @@ impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest {
...
@@ -131,6 +131,20 @@ impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest {
fn
nvext
(
&
self
)
->
Option
<&
NvExt
>
{
fn
nvext
(
&
self
)
->
Option
<&
NvExt
>
{
self
.nvext
.as_ref
()
self
.nvext
.as_ref
()
}
}
/// Retrieves the seed value for random number generation, if set.
fn
get_seed
(
&
self
)
->
Option
<
i64
>
{
self
.inner.seed
}
/// Retrieves the number of completions to generate for each prompt, if set.
fn
get_n
(
&
self
)
->
Option
<
u8
>
{
self
.inner.n
}
/// Retrieves the best_of parameter, if set.
fn
get_best_of
(
&
self
)
->
Option
<
u8
>
{
None
// Not supported in chat completions
}
}
}
/// Implements `CommonExtProvider` for `NvCreateChatCompletionRequest`,
/// Implements `CommonExtProvider` for `NvCreateChatCompletionRequest`,
...
@@ -199,6 +213,14 @@ impl CommonExtProvider for NvCreateChatCompletionRequest {
...
@@ -199,6 +213,14 @@ impl CommonExtProvider for NvCreateChatCompletionRequest {
)
)
}
}
fn
get_min_p
(
&
self
)
->
Option
<
f32
>
{
choose_with_deprecation
(
"min_p"
,
self
.common.min_p
.as_ref
(),
self
.nvext
.as_ref
()
.and_then
(|
nv
|
nv
.min_p
.as_ref
()),
)
}
fn
get_repetition_penalty
(
&
self
)
->
Option
<
f32
>
{
fn
get_repetition_penalty
(
&
self
)
->
Option
<
f32
>
{
choose_with_deprecation
(
choose_with_deprecation
(
"repetition_penalty"
,
"repetition_penalty"
,
...
...
lib/llm/src/protocols/openai/common_ext.rs
View file @
3de04dd9
...
@@ -28,6 +28,12 @@ pub struct CommonExt {
...
@@ -28,6 +28,12 @@ pub struct CommonExt {
#[validate(custom(function
=
"validate_top_k"
))]
#[validate(custom(function
=
"validate_top_k"
))]
pub
top_k
:
Option
<
i32
>
,
pub
top_k
:
Option
<
i32
>
,
/// Relative probability floor
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(strip_option))]
#[validate(range(min
=
0.0
,
max
=
1.0
))]
pub
min_p
:
Option
<
f32
>
,
/// How much to penalize tokens based on how frequently they occur in the text.
/// How much to penalize tokens based on how frequently they occur in the text.
/// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
/// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
...
@@ -87,6 +93,7 @@ pub trait CommonExtProvider {
...
@@ -87,6 +93,7 @@ pub trait CommonExtProvider {
/// Other sampling Options
/// Other sampling Options
fn
get_top_k
(
&
self
)
->
Option
<
i32
>
;
fn
get_top_k
(
&
self
)
->
Option
<
i32
>
;
fn
get_min_p
(
&
self
)
->
Option
<
f32
>
;
fn
get_repetition_penalty
(
&
self
)
->
Option
<
f32
>
;
fn
get_repetition_penalty
(
&
self
)
->
Option
<
f32
>
;
fn
get_include_stop_str_in_output
(
&
self
)
->
Option
<
bool
>
;
fn
get_include_stop_str_in_output
(
&
self
)
->
Option
<
bool
>
;
}
}
...
@@ -200,6 +207,7 @@ mod tests {
...
@@ -200,6 +207,7 @@ mod tests {
ignore_eos
:
None
,
ignore_eos
:
None
,
min_tokens
:
Some
(
0
),
// Should be valid (min = 0)
min_tokens
:
Some
(
0
),
// Should be valid (min = 0)
top_k
:
None
,
top_k
:
None
,
min_p
:
None
,
repetition_penalty
:
None
,
repetition_penalty
:
None
,
include_stop_str_in_output
:
None
,
include_stop_str_in_output
:
None
,
guided_json
:
None
,
guided_json
:
None
,
...
...
lib/llm/src/protocols/openai/completions.rs
View file @
3de04dd9
...
@@ -124,6 +124,18 @@ impl OpenAISamplingOptionsProvider for NvCreateCompletionRequest {
...
@@ -124,6 +124,18 @@ impl OpenAISamplingOptionsProvider for NvCreateCompletionRequest {
fn
nvext
(
&
self
)
->
Option
<&
NvExt
>
{
fn
nvext
(
&
self
)
->
Option
<&
NvExt
>
{
self
.nvext
.as_ref
()
self
.nvext
.as_ref
()
}
}
fn
get_seed
(
&
self
)
->
Option
<
i64
>
{
self
.inner.seed
}
fn
get_n
(
&
self
)
->
Option
<
u8
>
{
self
.inner.n
}
fn
get_best_of
(
&
self
)
->
Option
<
u8
>
{
self
.inner.best_of
}
}
}
impl
CommonExtProvider
for
NvCreateCompletionRequest
{
impl
CommonExtProvider
for
NvCreateCompletionRequest
{
...
@@ -189,6 +201,14 @@ impl CommonExtProvider for NvCreateCompletionRequest {
...
@@ -189,6 +201,14 @@ impl CommonExtProvider for NvCreateCompletionRequest {
)
)
}
}
fn
get_min_p
(
&
self
)
->
Option
<
f32
>
{
choose_with_deprecation
(
"min_p"
,
self
.common.min_p
.as_ref
(),
self
.nvext
.as_ref
()
.and_then
(|
nv
|
nv
.min_p
.as_ref
()),
)
}
fn
get_repetition_penalty
(
&
self
)
->
Option
<
f32
>
{
fn
get_repetition_penalty
(
&
self
)
->
Option
<
f32
>
{
choose_with_deprecation
(
choose_with_deprecation
(
"repetition_penalty"
,
"repetition_penalty"
,
...
...
lib/llm/src/protocols/openai/nvext.rs
View file @
3de04dd9
...
@@ -24,6 +24,12 @@ pub struct NvExt {
...
@@ -24,6 +24,12 @@ pub struct NvExt {
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
top_k
:
Option
<
i32
>
,
pub
top_k
:
Option
<
i32
>
,
/// Relative probability floor
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
#[builder(default,
setter(strip_option))]
#[validate(range(min
=
0.0
,
max
=
1.0
))]
pub
min_p
:
Option
<
f32
>
,
/// How much to penalize tokens based on how frequently they occur in the text.
/// How much to penalize tokens based on how frequently they occur in the text.
/// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
/// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
#[builder(default,
setter(strip_option))]
#[builder(default,
setter(strip_option))]
...
...
lib/llm/src/protocols/openai/responses.rs
View file @
3de04dd9
...
@@ -100,6 +100,18 @@ impl OpenAISamplingOptionsProvider for NvCreateResponse {
...
@@ -100,6 +100,18 @@ impl OpenAISamplingOptionsProvider for NvCreateResponse {
fn
nvext
(
&
self
)
->
Option
<&
NvExt
>
{
fn
nvext
(
&
self
)
->
Option
<&
NvExt
>
{
self
.nvext
.as_ref
()
self
.nvext
.as_ref
()
}
}
fn
get_seed
(
&
self
)
->
Option
<
i64
>
{
None
// TODO setting as None for now
}
fn
get_n
(
&
self
)
->
Option
<
u8
>
{
None
// TODO setting as None for now
}
fn
get_best_of
(
&
self
)
->
Option
<
u8
>
{
None
// TODO setting as None for now
}
}
}
/// Implements `OpenAIStopConditionsProvider` for `NvCreateResponse`,
/// Implements `OpenAIStopConditionsProvider` for `NvCreateResponse`,
...
...
lib/llm/src/protocols/openai/validate.rs
View file @
3de04dd9
...
@@ -21,6 +21,13 @@ pub const MAX_TOP_P: f32 = 1.0;
...
@@ -21,6 +21,13 @@ pub const MAX_TOP_P: f32 = 1.0;
/// Allowed range of values for OpenAI's `top_p` sampling option
/// Allowed range of values for OpenAI's `top_p` sampling option
pub
const
TOP_P_RANGE
:
(
f32
,
f32
)
=
(
MIN_TOP_P
,
MAX_TOP_P
);
pub
const
TOP_P_RANGE
:
(
f32
,
f32
)
=
(
MIN_TOP_P
,
MAX_TOP_P
);
/// Minimum allowed value for `min_p`
pub
const
MIN_MIN_P
:
f32
=
0.0
;
/// Maximum allowed value for `min_p`
pub
const
MAX_MIN_P
:
f32
=
1.0
;
/// Allowed range of values for `min_p`
pub
const
MIN_P_RANGE
:
(
f32
,
f32
)
=
(
MIN_MIN_P
,
MAX_MIN_P
);
/// Minimum allowed value for OpenAI's `frequency_penalty` sampling option
/// Minimum allowed value for OpenAI's `frequency_penalty` sampling option
pub
const
MIN_FREQUENCY_PENALTY
:
f32
=
-
2.0
;
pub
const
MIN_FREQUENCY_PENALTY
:
f32
=
-
2.0
;
/// Maximum allowed value for OpenAI's `frequency_penalty` sampling option
/// Maximum allowed value for OpenAI's `frequency_penalty` sampling option
...
@@ -35,6 +42,13 @@ pub const MAX_PRESENCE_PENALTY: f32 = 2.0;
...
@@ -35,6 +42,13 @@ pub const MAX_PRESENCE_PENALTY: f32 = 2.0;
/// Allowed range of values for OpenAI's `presence_penalty` sampling option
/// Allowed range of values for OpenAI's `presence_penalty` sampling option
pub
const
PRESENCE_PENALTY_RANGE
:
(
f32
,
f32
)
=
(
MIN_PRESENCE_PENALTY
,
MAX_PRESENCE_PENALTY
);
pub
const
PRESENCE_PENALTY_RANGE
:
(
f32
,
f32
)
=
(
MIN_PRESENCE_PENALTY
,
MAX_PRESENCE_PENALTY
);
/// Minimum allowed value for `length_penalty`
pub
const
MIN_LENGTH_PENALTY
:
f32
=
-
2.0
;
/// Maximum allowed value for `length_penalty`
pub
const
MAX_LENGTH_PENALTY
:
f32
=
2.0
;
/// Allowed range of values for `length_penalty`
pub
const
LENGTH_PENALTY_RANGE
:
(
f32
,
f32
)
=
(
MIN_LENGTH_PENALTY
,
MAX_LENGTH_PENALTY
);
/// Maximum allowed value for `top_logprobs`
/// Maximum allowed value for `top_logprobs`
pub
const
MIN_TOP_LOGPROBS
:
u8
=
0
;
pub
const
MIN_TOP_LOGPROBS
:
u8
=
0
;
/// Maximum allowed value for `top_logprobs`
/// Maximum allowed value for `top_logprobs`
...
@@ -49,6 +63,8 @@ pub const MAX_LOGPROBS: u8 = 5;
...
@@ -49,6 +63,8 @@ pub const MAX_LOGPROBS: u8 = 5;
pub
const
MIN_N
:
u8
=
1
;
pub
const
MIN_N
:
u8
=
1
;
/// Maximum allowed value for `n` (number of choices)
/// Maximum allowed value for `n` (number of choices)
pub
const
MAX_N
:
u8
=
128
;
pub
const
MAX_N
:
u8
=
128
;
/// Allowed range of values for `n` (number of choices)
pub
const
N_RANGE
:
(
u8
,
u8
)
=
(
MIN_N
,
MAX_N
);
/// Minimum allowed value for OpenAI's `logit_bias` values
/// Minimum allowed value for OpenAI's `logit_bias` values
pub
const
MIN_LOGIT_BIAS
:
f32
=
-
100.0
;
pub
const
MIN_LOGIT_BIAS
:
f32
=
-
100.0
;
...
@@ -59,6 +75,8 @@ pub const MAX_LOGIT_BIAS: f32 = 100.0;
...
@@ -59,6 +75,8 @@ pub const MAX_LOGIT_BIAS: f32 = 100.0;
pub
const
MIN_BEST_OF
:
u8
=
0
;
pub
const
MIN_BEST_OF
:
u8
=
0
;
/// Maximum allowed value for `best_of`
/// Maximum allowed value for `best_of`
pub
const
MAX_BEST_OF
:
u8
=
20
;
pub
const
MAX_BEST_OF
:
u8
=
20
;
/// Allowed range of values for `best_of`
pub
const
BEST_OF_RANGE
:
(
u8
,
u8
)
=
(
MIN_BEST_OF
,
MAX_BEST_OF
);
/// Maximum allowed number of stop sequences
/// Maximum allowed number of stop sequences
pub
const
MAX_STOP_SEQUENCES
:
usize
=
4
;
pub
const
MAX_STOP_SEQUENCES
:
usize
=
4
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment