Unverified Commit 3de04dd9 authored by Greg Clark's avatar Greg Clark Committed by GitHub
Browse files

chore: fillout sampling params (seed, n, best_of, min_p) (#3055)


Signed-off-by: default avatarGreg Clark <grclark@nvidia.com>
parent e2c0e8d1
...@@ -274,14 +274,14 @@ pub const FREQUENCY_PENALTY_RANGE: (f32, f32) = (-1.0, 1.0); ...@@ -274,14 +274,14 @@ pub const FREQUENCY_PENALTY_RANGE: (f32, f32) = (-1.0, 1.0);
#[derive(Serialize, Deserialize, Debug, Clone, Default)] #[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct SamplingOptions { pub struct SamplingOptions {
/// Number of output sequences to return for the given prompt /// Number of output sequences to return for the given prompt
pub n: Option<i32>, pub n: Option<u8>,
/// Number of output sequences that are generated from the prompt. /// Number of output sequences that are generated from the prompt.
/// From these `best_of` sequences, the top `n` sequences are returned. /// From these `best_of` sequences, the top `n` sequences are returned.
/// `best_of` must be greater than or equal to `n`. This is treated as /// `best_of` must be greater than or equal to `n`. This is treated as
/// the beam width when `use_beam_search` is True. By default, `best_of` /// the beam width when `use_beam_search` is True. By default, `best_of`
/// is set to `n`. /// is set to `n`.
pub best_of: Option<i32>, pub best_of: Option<u8>,
/// Float that penalizes new tokens based on whether they /// Float that penalizes new tokens based on whether they
/// appear in the generated text so far. Values > 0 encourage the model /// appear in the generated text so far. Values > 0 encourage the model
......
...@@ -20,7 +20,8 @@ pub mod responses; ...@@ -20,7 +20,8 @@ pub mod responses;
pub mod validate; pub mod validate;
use validate::{ use validate::{
FREQUENCY_PENALTY_RANGE, PRESENCE_PENALTY_RANGE, TEMPERATURE_RANGE, TOP_P_RANGE, validate_range, BEST_OF_RANGE, FREQUENCY_PENALTY_RANGE, MIN_P_RANGE, N_RANGE, PRESENCE_PENALTY_RANGE,
TEMPERATURE_RANGE, TOP_P_RANGE, validate_range,
}; };
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
...@@ -40,6 +41,12 @@ trait OpenAISamplingOptionsProvider { ...@@ -40,6 +41,12 @@ trait OpenAISamplingOptionsProvider {
fn get_presence_penalty(&self) -> Option<f32>; fn get_presence_penalty(&self) -> Option<f32>;
fn get_seed(&self) -> Option<i64>;
fn get_n(&self) -> Option<u8>;
fn get_best_of(&self) -> Option<u8>;
fn nvext(&self) -> Option<&nvext::NvExt>; fn nvext(&self) -> Option<&nvext::NvExt>;
} }
...@@ -104,6 +111,14 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid ...@@ -104,6 +111,14 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
let top_k = CommonExtProvider::get_top_k(self); let top_k = CommonExtProvider::get_top_k(self);
let repetition_penalty = CommonExtProvider::get_repetition_penalty(self); let repetition_penalty = CommonExtProvider::get_repetition_penalty(self);
let include_stop_str_in_output = CommonExtProvider::get_include_stop_str_in_output(self); let include_stop_str_in_output = CommonExtProvider::get_include_stop_str_in_output(self);
let seed = self.get_seed();
let n = validate_range(self.get_n(), &N_RANGE)
.map_err(|e| anyhow::anyhow!("Error validating n: {}", e))?;
let best_of = validate_range(self.get_best_of(), &BEST_OF_RANGE)
.map_err(|e| anyhow::anyhow!("Error validating best_of: {}", e))?;
let min_p = validate_range(CommonExtProvider::get_min_p(self), &MIN_P_RANGE)
.map_err(|e| anyhow::anyhow!("Error validating min_p: {}", e))?;
if let Some(nvext) = self.nvext() { if let Some(nvext) = self.nvext() {
let greedy = nvext.greed_sampling.unwrap_or(false); let greedy = nvext.greed_sampling.unwrap_or(false);
...@@ -135,16 +150,16 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid ...@@ -135,16 +150,16 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
}; };
Ok(common::SamplingOptions { Ok(common::SamplingOptions {
n: None, n,
best_of: None, best_of,
frequency_penalty, frequency_penalty,
presence_penalty, presence_penalty,
repetition_penalty, repetition_penalty,
temperature, temperature,
top_p, top_p,
top_k, top_k,
min_p: None, min_p,
seed: None, seed,
use_beam_search: None, use_beam_search: None,
length_penalty: None, length_penalty: None,
guided_decoding, guided_decoding,
......
...@@ -131,6 +131,20 @@ impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest { ...@@ -131,6 +131,20 @@ impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest {
fn nvext(&self) -> Option<&NvExt> { fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref() self.nvext.as_ref()
} }
/// Retrieves the seed value for random number generation, if set.
fn get_seed(&self) -> Option<i64> {
self.inner.seed
}
/// Retrieves the number of completions to generate for each prompt, if set.
fn get_n(&self) -> Option<u8> {
self.inner.n
}
/// Retrieves the best_of parameter, if set.
fn get_best_of(&self) -> Option<u8> {
None // Not supported in chat completions
}
} }
/// Implements `CommonExtProvider` for `NvCreateChatCompletionRequest`, /// Implements `CommonExtProvider` for `NvCreateChatCompletionRequest`,
...@@ -199,6 +213,14 @@ impl CommonExtProvider for NvCreateChatCompletionRequest { ...@@ -199,6 +213,14 @@ impl CommonExtProvider for NvCreateChatCompletionRequest {
) )
} }
fn get_min_p(&self) -> Option<f32> {
choose_with_deprecation(
"min_p",
self.common.min_p.as_ref(),
self.nvext.as_ref().and_then(|nv| nv.min_p.as_ref()),
)
}
fn get_repetition_penalty(&self) -> Option<f32> { fn get_repetition_penalty(&self) -> Option<f32> {
choose_with_deprecation( choose_with_deprecation(
"repetition_penalty", "repetition_penalty",
......
...@@ -28,6 +28,12 @@ pub struct CommonExt { ...@@ -28,6 +28,12 @@ pub struct CommonExt {
#[validate(custom(function = "validate_top_k"))] #[validate(custom(function = "validate_top_k"))]
pub top_k: Option<i32>, pub top_k: Option<i32>,
/// Relative probability floor
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
#[validate(range(min = 0.0, max = 1.0))]
pub min_p: Option<f32>,
/// How much to penalize tokens based on how frequently they occur in the text. /// How much to penalize tokens based on how frequently they occur in the text.
/// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage. /// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
...@@ -87,6 +93,7 @@ pub trait CommonExtProvider { ...@@ -87,6 +93,7 @@ pub trait CommonExtProvider {
/// Other sampling Options /// Other sampling Options
fn get_top_k(&self) -> Option<i32>; fn get_top_k(&self) -> Option<i32>;
fn get_min_p(&self) -> Option<f32>;
fn get_repetition_penalty(&self) -> Option<f32>; fn get_repetition_penalty(&self) -> Option<f32>;
fn get_include_stop_str_in_output(&self) -> Option<bool>; fn get_include_stop_str_in_output(&self) -> Option<bool>;
} }
...@@ -200,6 +207,7 @@ mod tests { ...@@ -200,6 +207,7 @@ mod tests {
ignore_eos: None, ignore_eos: None,
min_tokens: Some(0), // Should be valid (min = 0) min_tokens: Some(0), // Should be valid (min = 0)
top_k: None, top_k: None,
min_p: None,
repetition_penalty: None, repetition_penalty: None,
include_stop_str_in_output: None, include_stop_str_in_output: None,
guided_json: None, guided_json: None,
......
...@@ -124,6 +124,18 @@ impl OpenAISamplingOptionsProvider for NvCreateCompletionRequest { ...@@ -124,6 +124,18 @@ impl OpenAISamplingOptionsProvider for NvCreateCompletionRequest {
fn nvext(&self) -> Option<&NvExt> { fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref() self.nvext.as_ref()
} }
fn get_seed(&self) -> Option<i64> {
self.inner.seed
}
fn get_n(&self) -> Option<u8> {
self.inner.n
}
fn get_best_of(&self) -> Option<u8> {
self.inner.best_of
}
} }
impl CommonExtProvider for NvCreateCompletionRequest { impl CommonExtProvider for NvCreateCompletionRequest {
...@@ -189,6 +201,14 @@ impl CommonExtProvider for NvCreateCompletionRequest { ...@@ -189,6 +201,14 @@ impl CommonExtProvider for NvCreateCompletionRequest {
) )
} }
fn get_min_p(&self) -> Option<f32> {
choose_with_deprecation(
"min_p",
self.common.min_p.as_ref(),
self.nvext.as_ref().and_then(|nv| nv.min_p.as_ref()),
)
}
fn get_repetition_penalty(&self) -> Option<f32> { fn get_repetition_penalty(&self) -> Option<f32> {
choose_with_deprecation( choose_with_deprecation(
"repetition_penalty", "repetition_penalty",
......
...@@ -24,6 +24,12 @@ pub struct NvExt { ...@@ -24,6 +24,12 @@ pub struct NvExt {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>, pub top_k: Option<i32>,
/// Relative probability floor
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
#[validate(range(min = 0.0, max = 1.0))]
pub min_p: Option<f32>,
/// How much to penalize tokens based on how frequently they occur in the text. /// How much to penalize tokens based on how frequently they occur in the text.
/// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage. /// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
......
...@@ -100,6 +100,18 @@ impl OpenAISamplingOptionsProvider for NvCreateResponse { ...@@ -100,6 +100,18 @@ impl OpenAISamplingOptionsProvider for NvCreateResponse {
fn nvext(&self) -> Option<&NvExt> { fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref() self.nvext.as_ref()
} }
fn get_seed(&self) -> Option<i64> {
None // TODO setting as None for now
}
fn get_n(&self) -> Option<u8> {
None // TODO setting as None for now
}
fn get_best_of(&self) -> Option<u8> {
None // TODO setting as None for now
}
} }
/// Implements `OpenAIStopConditionsProvider` for `NvCreateResponse`, /// Implements `OpenAIStopConditionsProvider` for `NvCreateResponse`,
......
...@@ -21,6 +21,13 @@ pub const MAX_TOP_P: f32 = 1.0; ...@@ -21,6 +21,13 @@ pub const MAX_TOP_P: f32 = 1.0;
/// Allowed range of values for OpenAI's `top_p` sampling option /// Allowed range of values for OpenAI's `top_p` sampling option
pub const TOP_P_RANGE: (f32, f32) = (MIN_TOP_P, MAX_TOP_P); pub const TOP_P_RANGE: (f32, f32) = (MIN_TOP_P, MAX_TOP_P);
/// Minimum allowed value for `min_p`
pub const MIN_MIN_P: f32 = 0.0;
/// Maximum allowed value for `min_p`
pub const MAX_MIN_P: f32 = 1.0;
/// Allowed range of values for `min_p`
pub const MIN_P_RANGE: (f32, f32) = (MIN_MIN_P, MAX_MIN_P);
/// Minimum allowed value for OpenAI's `frequency_penalty` sampling option /// Minimum allowed value for OpenAI's `frequency_penalty` sampling option
pub const MIN_FREQUENCY_PENALTY: f32 = -2.0; pub const MIN_FREQUENCY_PENALTY: f32 = -2.0;
/// Maximum allowed value for OpenAI's `frequency_penalty` sampling option /// Maximum allowed value for OpenAI's `frequency_penalty` sampling option
...@@ -35,6 +42,13 @@ pub const MAX_PRESENCE_PENALTY: f32 = 2.0; ...@@ -35,6 +42,13 @@ pub const MAX_PRESENCE_PENALTY: f32 = 2.0;
/// Allowed range of values for OpenAI's `presence_penalty` sampling option /// Allowed range of values for OpenAI's `presence_penalty` sampling option
pub const PRESENCE_PENALTY_RANGE: (f32, f32) = (MIN_PRESENCE_PENALTY, MAX_PRESENCE_PENALTY); pub const PRESENCE_PENALTY_RANGE: (f32, f32) = (MIN_PRESENCE_PENALTY, MAX_PRESENCE_PENALTY);
/// Minimum allowed value for `length_penalty`
pub const MIN_LENGTH_PENALTY: f32 = -2.0;
/// Maximum allowed value for `length_penalty`
pub const MAX_LENGTH_PENALTY: f32 = 2.0;
/// Allowed range of values for `length_penalty`
pub const LENGTH_PENALTY_RANGE: (f32, f32) = (MIN_LENGTH_PENALTY, MAX_LENGTH_PENALTY);
/// Maximum allowed value for `top_logprobs` /// Maximum allowed value for `top_logprobs`
pub const MIN_TOP_LOGPROBS: u8 = 0; pub const MIN_TOP_LOGPROBS: u8 = 0;
/// Maximum allowed value for `top_logprobs` /// Maximum allowed value for `top_logprobs`
...@@ -49,6 +63,8 @@ pub const MAX_LOGPROBS: u8 = 5; ...@@ -49,6 +63,8 @@ pub const MAX_LOGPROBS: u8 = 5;
pub const MIN_N: u8 = 1; pub const MIN_N: u8 = 1;
/// Maximum allowed value for `n` (number of choices) /// Maximum allowed value for `n` (number of choices)
pub const MAX_N: u8 = 128; pub const MAX_N: u8 = 128;
/// Allowed range of values for `n` (number of choices)
pub const N_RANGE: (u8, u8) = (MIN_N, MAX_N);
/// Minimum allowed value for OpenAI's `logit_bias` values /// Minimum allowed value for OpenAI's `logit_bias` values
pub const MIN_LOGIT_BIAS: f32 = -100.0; pub const MIN_LOGIT_BIAS: f32 = -100.0;
...@@ -59,6 +75,8 @@ pub const MAX_LOGIT_BIAS: f32 = 100.0; ...@@ -59,6 +75,8 @@ pub const MAX_LOGIT_BIAS: f32 = 100.0;
pub const MIN_BEST_OF: u8 = 0; pub const MIN_BEST_OF: u8 = 0;
/// Maximum allowed value for `best_of` /// Maximum allowed value for `best_of`
pub const MAX_BEST_OF: u8 = 20; pub const MAX_BEST_OF: u8 = 20;
/// Allowed range of values for `best_of`
pub const BEST_OF_RANGE: (u8, u8) = (MIN_BEST_OF, MAX_BEST_OF);
/// Maximum allowed number of stop sequences /// Maximum allowed number of stop sequences
pub const MAX_STOP_SEQUENCES: usize = 4; pub const MAX_STOP_SEQUENCES: usize = 4;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment