Unverified Commit 3de04dd9 authored by Greg Clark's avatar Greg Clark Committed by GitHub
Browse files

chore: fillout sampling params (seed, n, best_of, min_p) (#3055)


Signed-off-by: default avatarGreg Clark <grclark@nvidia.com>
parent e2c0e8d1
......@@ -274,14 +274,14 @@ pub const FREQUENCY_PENALTY_RANGE: (f32, f32) = (-1.0, 1.0);
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct SamplingOptions {
/// Number of output sequences to return for the given prompt
pub n: Option<i32>,
pub n: Option<u8>,
/// Number of output sequences that are generated from the prompt.
/// From these `best_of` sequences, the top `n` sequences are returned.
/// `best_of` must be greater than or equal to `n`. This is treated as
/// the beam width when `use_beam_search` is True. By default, `best_of`
/// is set to `n`.
pub best_of: Option<i32>,
pub best_of: Option<u8>,
/// Float that penalizes new tokens based on whether they
/// appear in the generated text so far. Values > 0 encourage the model
......
......@@ -20,7 +20,8 @@ pub mod responses;
pub mod validate;
use validate::{
FREQUENCY_PENALTY_RANGE, PRESENCE_PENALTY_RANGE, TEMPERATURE_RANGE, TOP_P_RANGE, validate_range,
BEST_OF_RANGE, FREQUENCY_PENALTY_RANGE, MIN_P_RANGE, N_RANGE, PRESENCE_PENALTY_RANGE,
TEMPERATURE_RANGE, TOP_P_RANGE, validate_range,
};
#[derive(Serialize, Deserialize, Debug)]
......@@ -40,6 +41,12 @@ trait OpenAISamplingOptionsProvider {
fn get_presence_penalty(&self) -> Option<f32>;
fn get_seed(&self) -> Option<i64>;
fn get_n(&self) -> Option<u8>;
fn get_best_of(&self) -> Option<u8>;
fn nvext(&self) -> Option<&nvext::NvExt>;
}
......@@ -104,6 +111,14 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
let top_k = CommonExtProvider::get_top_k(self);
let repetition_penalty = CommonExtProvider::get_repetition_penalty(self);
let include_stop_str_in_output = CommonExtProvider::get_include_stop_str_in_output(self);
let seed = self.get_seed();
let n = validate_range(self.get_n(), &N_RANGE)
.map_err(|e| anyhow::anyhow!("Error validating n: {}", e))?;
let best_of = validate_range(self.get_best_of(), &BEST_OF_RANGE)
.map_err(|e| anyhow::anyhow!("Error validating best_of: {}", e))?;
let min_p = validate_range(CommonExtProvider::get_min_p(self), &MIN_P_RANGE)
.map_err(|e| anyhow::anyhow!("Error validating min_p: {}", e))?;
if let Some(nvext) = self.nvext() {
let greedy = nvext.greed_sampling.unwrap_or(false);
......@@ -135,16 +150,16 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
};
Ok(common::SamplingOptions {
n: None,
best_of: None,
n,
best_of,
frequency_penalty,
presence_penalty,
repetition_penalty,
temperature,
top_p,
top_k,
min_p: None,
seed: None,
min_p,
seed,
use_beam_search: None,
length_penalty: None,
guided_decoding,
......
......@@ -131,6 +131,20 @@ impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest {
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
/// Retrieves the seed value for random number generation, if set.
fn get_seed(&self) -> Option<i64> {
self.inner.seed
}
/// Retrieves the number of completions to generate for each prompt, if set.
fn get_n(&self) -> Option<u8> {
self.inner.n
}
/// Retrieves the best_of parameter, if set.
fn get_best_of(&self) -> Option<u8> {
None // Not supported in chat completions
}
}
/// Implements `CommonExtProvider` for `NvCreateChatCompletionRequest`,
......@@ -199,6 +213,14 @@ impl CommonExtProvider for NvCreateChatCompletionRequest {
)
}
fn get_min_p(&self) -> Option<f32> {
choose_with_deprecation(
"min_p",
self.common.min_p.as_ref(),
self.nvext.as_ref().and_then(|nv| nv.min_p.as_ref()),
)
}
fn get_repetition_penalty(&self) -> Option<f32> {
choose_with_deprecation(
"repetition_penalty",
......
......@@ -28,6 +28,12 @@ pub struct CommonExt {
#[validate(custom(function = "validate_top_k"))]
pub top_k: Option<i32>,
/// Relative probability floor
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
#[validate(range(min = 0.0, max = 1.0))]
pub min_p: Option<f32>,
/// How much to penalize tokens based on how frequently they occur in the text.
/// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
#[serde(default, skip_serializing_if = "Option::is_none")]
......@@ -87,6 +93,7 @@ pub trait CommonExtProvider {
/// Other sampling Options
fn get_top_k(&self) -> Option<i32>;
fn get_min_p(&self) -> Option<f32>;
fn get_repetition_penalty(&self) -> Option<f32>;
fn get_include_stop_str_in_output(&self) -> Option<bool>;
}
......@@ -200,6 +207,7 @@ mod tests {
ignore_eos: None,
min_tokens: Some(0), // Should be valid (min = 0)
top_k: None,
min_p: None,
repetition_penalty: None,
include_stop_str_in_output: None,
guided_json: None,
......
......@@ -124,6 +124,18 @@ impl OpenAISamplingOptionsProvider for NvCreateCompletionRequest {
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
fn get_seed(&self) -> Option<i64> {
self.inner.seed
}
fn get_n(&self) -> Option<u8> {
self.inner.n
}
fn get_best_of(&self) -> Option<u8> {
self.inner.best_of
}
}
impl CommonExtProvider for NvCreateCompletionRequest {
......@@ -189,6 +201,14 @@ impl CommonExtProvider for NvCreateCompletionRequest {
)
}
fn get_min_p(&self) -> Option<f32> {
choose_with_deprecation(
"min_p",
self.common.min_p.as_ref(),
self.nvext.as_ref().and_then(|nv| nv.min_p.as_ref()),
)
}
fn get_repetition_penalty(&self) -> Option<f32> {
choose_with_deprecation(
"repetition_penalty",
......
......@@ -24,6 +24,12 @@ pub struct NvExt {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
/// Relative probability floor
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
#[validate(range(min = 0.0, max = 1.0))]
pub min_p: Option<f32>,
/// How much to penalize tokens based on how frequently they occur in the text.
/// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
#[builder(default, setter(strip_option))]
......
......@@ -100,6 +100,18 @@ impl OpenAISamplingOptionsProvider for NvCreateResponse {
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
fn get_seed(&self) -> Option<i64> {
None // TODO setting as None for now
}
fn get_n(&self) -> Option<u8> {
None // TODO setting as None for now
}
fn get_best_of(&self) -> Option<u8> {
None // TODO setting as None for now
}
}
/// Implements `OpenAIStopConditionsProvider` for `NvCreateResponse`,
......
......@@ -21,6 +21,13 @@ pub const MAX_TOP_P: f32 = 1.0;
/// Allowed range of values for OpenAI's `top_p` sampling option
pub const TOP_P_RANGE: (f32, f32) = (MIN_TOP_P, MAX_TOP_P);
/// Minimum allowed value for `min_p`
pub const MIN_MIN_P: f32 = 0.0;
/// Maximum allowed value for `min_p`
pub const MAX_MIN_P: f32 = 1.0;
/// Allowed range of values for `min_p`
pub const MIN_P_RANGE: (f32, f32) = (MIN_MIN_P, MAX_MIN_P);
/// Minimum allowed value for OpenAI's `frequency_penalty` sampling option
pub const MIN_FREQUENCY_PENALTY: f32 = -2.0;
/// Maximum allowed value for OpenAI's `frequency_penalty` sampling option
......@@ -35,6 +42,13 @@ pub const MAX_PRESENCE_PENALTY: f32 = 2.0;
/// Allowed range of values for OpenAI's `presence_penalty` sampling option
pub const PRESENCE_PENALTY_RANGE: (f32, f32) = (MIN_PRESENCE_PENALTY, MAX_PRESENCE_PENALTY);
/// Minimum allowed value for `length_penalty`
pub const MIN_LENGTH_PENALTY: f32 = -2.0;
/// Maximum allowed value for `length_penalty`
pub const MAX_LENGTH_PENALTY: f32 = 2.0;
/// Allowed range of values for `length_penalty`
pub const LENGTH_PENALTY_RANGE: (f32, f32) = (MIN_LENGTH_PENALTY, MAX_LENGTH_PENALTY);
/// Maximum allowed value for `top_logprobs`
pub const MIN_TOP_LOGPROBS: u8 = 0;
/// Maximum allowed value for `top_logprobs`
......@@ -49,6 +63,8 @@ pub const MAX_LOGPROBS: u8 = 5;
pub const MIN_N: u8 = 1;
/// Maximum allowed value for `n` (number of choices)
pub const MAX_N: u8 = 128;
/// Allowed range of values for `n` (number of choices)
pub const N_RANGE: (u8, u8) = (MIN_N, MAX_N);
/// Minimum allowed value for OpenAI's `logit_bias` values
pub const MIN_LOGIT_BIAS: f32 = -100.0;
......@@ -59,6 +75,8 @@ pub const MAX_LOGIT_BIAS: f32 = 100.0;
pub const MIN_BEST_OF: u8 = 0;
/// Maximum allowed value for `best_of`
pub const MAX_BEST_OF: u8 = 20;
/// Allowed range of values for `best_of`
pub const BEST_OF_RANGE: (u8, u8) = (MIN_BEST_OF, MAX_BEST_OF);
/// Maximum allowed number of stop sequences
pub const MAX_STOP_SEQUENCES: usize = 4;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment