Unverified Commit 63f5bbc0 authored by ryan-lempka's avatar ryan-lempka Committed by GitHub
Browse files

chore: deprecate nvext.top_k and nvext.repetition_penalty and make available top level (#2767)


Signed-off-by: default avatarRyan Lempka <rlempka@nvidia.com>
parent 79a9d69d
...@@ -95,6 +95,8 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid ...@@ -95,6 +95,8 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
.map_err(|e| anyhow::anyhow!("Error validating frequency_penalty: {}", e))?; .map_err(|e| anyhow::anyhow!("Error validating frequency_penalty: {}", e))?;
let presence_penalty = validate_range(self.get_presence_penalty(), &PRESENCE_PENALTY_RANGE) let presence_penalty = validate_range(self.get_presence_penalty(), &PRESENCE_PENALTY_RANGE)
.map_err(|e| anyhow::anyhow!("Error validating presence_penalty: {}", e))?; .map_err(|e| anyhow::anyhow!("Error validating presence_penalty: {}", e))?;
let top_k = CommonExtProvider::get_top_k(self);
let repetition_penalty = CommonExtProvider::get_repetition_penalty(self);
if let Some(nvext) = self.nvext() { if let Some(nvext) = self.nvext() {
let greedy = nvext.greed_sampling.unwrap_or(false); let greedy = nvext.greed_sampling.unwrap_or(false);
...@@ -130,10 +132,10 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid ...@@ -130,10 +132,10 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
best_of: None, best_of: None,
frequency_penalty, frequency_penalty,
presence_penalty, presence_penalty,
repetition_penalty: None, repetition_penalty,
temperature, temperature,
top_p, top_p,
top_k: None, top_k,
min_p: None, min_p: None,
seed: None, seed: None,
use_beam_search: None, use_beam_search: None,
......
...@@ -198,6 +198,24 @@ impl CommonExtProvider for NvCreateChatCompletionRequest { ...@@ -198,6 +198,24 @@ impl CommonExtProvider for NvCreateChatCompletionRequest {
.and_then(|nv| nv.guided_decoding_backend.as_ref()), .and_then(|nv| nv.guided_decoding_backend.as_ref()),
) )
} }
fn get_top_k(&self) -> Option<i32> {
choose_with_deprecation(
"top_k",
self.common.top_k.as_ref(),
self.nvext.as_ref().and_then(|nv| nv.top_k.as_ref()),
)
}
fn get_repetition_penalty(&self) -> Option<f32> {
choose_with_deprecation(
"repetition_penalty",
self.common.repetition_penalty.as_ref(),
self.nvext
.as_ref()
.and_then(|nv| nv.repetition_penalty.as_ref()),
)
}
} }
/// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`, /// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`,
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::nvext::validate_top_k;
use derive_builder::Builder; use derive_builder::Builder;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use validator::Validate; use validator::Validate;
...@@ -21,6 +22,19 @@ pub struct CommonExt { ...@@ -21,6 +22,19 @@ pub struct CommonExt {
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
pub min_tokens: Option<u32>, pub min_tokens: Option<u32>,
/// Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
#[validate(custom(function = "validate_top_k"))]
pub top_k: Option<i32>,
/// How much to penalize tokens based on how frequently they occur in the text.
/// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
#[validate(range(exclusive_min = 0.0, max = 2.0))]
pub repetition_penalty: Option<f32>,
/// Guided Decoding Options /// Guided Decoding Options
/// If specified, the output will be a JSON object. Can be a string, an object, or null. /// If specified, the output will be a JSON object. Can be a string, an object, or null.
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
...@@ -65,6 +79,10 @@ pub trait CommonExtProvider { ...@@ -65,6 +79,10 @@ pub trait CommonExtProvider {
fn get_guided_grammar(&self) -> Option<String>; fn get_guided_grammar(&self) -> Option<String>;
fn get_guided_choice(&self) -> Option<Vec<String>>; fn get_guided_choice(&self) -> Option<Vec<String>>;
fn get_guided_decoding_backend(&self) -> Option<String>; fn get_guided_decoding_backend(&self) -> Option<String>;
/// Other sampling Options
fn get_top_k(&self) -> Option<i32>;
fn get_repetition_penalty(&self) -> Option<f32>;
} }
/// Helper function to emit deprecation warnings for nvext parameters /// Helper function to emit deprecation warnings for nvext parameters
...@@ -107,6 +125,8 @@ mod tests { ...@@ -107,6 +125,8 @@ mod tests {
let common_ext = CommonExt::builder().build().unwrap(); let common_ext = CommonExt::builder().build().unwrap();
assert_eq!(common_ext.ignore_eos, None); assert_eq!(common_ext.ignore_eos, None);
assert_eq!(common_ext.min_tokens, None); assert_eq!(common_ext.min_tokens, None);
assert_eq!(common_ext.top_k, None);
assert_eq!(common_ext.repetition_penalty, None);
assert_eq!(common_ext.guided_json, None); assert_eq!(common_ext.guided_json, None);
assert_eq!(common_ext.guided_regex, None); assert_eq!(common_ext.guided_regex, None);
assert_eq!(common_ext.guided_grammar, None); assert_eq!(common_ext.guided_grammar, None);
...@@ -119,6 +139,8 @@ mod tests { ...@@ -119,6 +139,8 @@ mod tests {
let common_ext = CommonExt::builder() let common_ext = CommonExt::builder()
.ignore_eos(true) .ignore_eos(true)
.min_tokens(10) .min_tokens(10)
.top_k(50)
.repetition_penalty(1.2)
.guided_json(serde_json::json!({"key": "value"})) .guided_json(serde_json::json!({"key": "value"}))
.guided_regex("regex".to_string()) .guided_regex("regex".to_string())
.guided_grammar("grammar".to_string()) .guided_grammar("grammar".to_string())
...@@ -129,6 +151,8 @@ mod tests { ...@@ -129,6 +151,8 @@ mod tests {
assert_eq!(common_ext.ignore_eos, Some(true)); assert_eq!(common_ext.ignore_eos, Some(true));
assert_eq!(common_ext.min_tokens, Some(10)); assert_eq!(common_ext.min_tokens, Some(10));
assert_eq!(common_ext.top_k, Some(50));
assert_eq!(common_ext.repetition_penalty, Some(1.2));
assert_eq!( assert_eq!(
common_ext.guided_json.as_ref(), common_ext.guided_json.as_ref(),
Some(&serde_json::json!({"key": "value"})) Some(&serde_json::json!({"key": "value"}))
...@@ -164,6 +188,8 @@ mod tests { ...@@ -164,6 +188,8 @@ mod tests {
let common_ext = CommonExt { let common_ext = CommonExt {
ignore_eos: None, ignore_eos: None,
min_tokens: Some(0), // Should be valid (min = 0) min_tokens: Some(0), // Should be valid (min = 0)
top_k: None,
repetition_penalty: None,
guided_json: None, guided_json: None,
guided_regex: None, guided_regex: None,
guided_grammar: None, guided_grammar: None,
...@@ -180,6 +206,8 @@ mod tests { ...@@ -180,6 +206,8 @@ mod tests {
assert_eq!(common_ext.ignore_eos, None); assert_eq!(common_ext.ignore_eos, None);
assert_eq!(common_ext.min_tokens, None); assert_eq!(common_ext.min_tokens, None);
assert_eq!(common_ext.top_k, None);
assert_eq!(common_ext.repetition_penalty, None);
assert!(common_ext.validate().is_ok()); assert!(common_ext.validate().is_ok());
} }
...@@ -190,6 +218,8 @@ mod tests { ...@@ -190,6 +218,8 @@ mod tests {
assert_eq!(common_ext.ignore_eos, None); assert_eq!(common_ext.ignore_eos, None);
assert_eq!(common_ext.min_tokens, None); assert_eq!(common_ext.min_tokens, None);
assert_eq!(common_ext.top_k, None);
assert_eq!(common_ext.repetition_penalty, None);
assert!(common_ext.validate().is_ok()); assert!(common_ext.validate().is_ok());
} }
......
...@@ -192,6 +192,24 @@ impl CommonExtProvider for NvCreateCompletionRequest { ...@@ -192,6 +192,24 @@ impl CommonExtProvider for NvCreateCompletionRequest {
.and_then(|nv| nv.guided_decoding_backend.as_ref()), .and_then(|nv| nv.guided_decoding_backend.as_ref()),
) )
} }
fn get_top_k(&self) -> Option<i32> {
choose_with_deprecation(
"top_k",
self.common.top_k.as_ref(),
self.nvext.as_ref().and_then(|nv| nv.top_k.as_ref()),
)
}
fn get_repetition_penalty(&self) -> Option<f32> {
choose_with_deprecation(
"repetition_penalty",
self.common.repetition_penalty.as_ref(),
self.nvext
.as_ref()
.and_then(|nv| nv.repetition_penalty.as_ref()),
)
}
} }
impl OpenAIStopConditionsProvider for NvCreateCompletionRequest { impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
......
...@@ -34,13 +34,13 @@ pub struct NvExt { ...@@ -34,13 +34,13 @@ pub struct NvExt {
#[builder(default, setter(strip_option))] // NIM LLM might default to -1 #[builder(default, setter(strip_option))] // NIM LLM might default to -1
#[validate(custom(function = "validate_top_k"))] #[validate(custom(function = "validate_top_k"))]
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub top_k: Option<i64>, pub top_k: Option<i32>,
/// How much to penalize tokens based on how frequently they occur in the text. /// How much to penalize tokens based on how frequently they occur in the text.
/// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage. /// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
#[validate(range(exclusive_min = 0.0, max = 2.0))] #[validate(range(exclusive_min = 0.0, max = 2.0))]
pub repetition_penalty: Option<f64>, pub repetition_penalty: Option<f32>,
/// If true, sampling will be forced to be greedy. /// If true, sampling will be forced to be greedy.
/// The backend is responsible for selecting the correct backend-specific options to /// The backend is responsible for selecting the correct backend-specific options to
...@@ -118,7 +118,7 @@ fn validate_nv_ext(_nv_ext: &NvExt) -> Result<(), ValidationError> { ...@@ -118,7 +118,7 @@ fn validate_nv_ext(_nv_ext: &NvExt) -> Result<(), ValidationError> {
Ok(()) Ok(())
} }
fn validate_top_k(top_k: i64) -> Result<(), ValidationError> { pub fn validate_top_k(top_k: i32) -> Result<(), ValidationError> {
if top_k == -1 || (top_k >= 1) { if top_k == -1 || (top_k >= 1) {
return Ok(()); return Ok(());
} }
...@@ -200,7 +200,7 @@ mod tests { ...@@ -200,7 +200,7 @@ mod tests {
// Test invalid `top_k` validation using proptest // Test invalid `top_k` validation using proptest
proptest! { proptest! {
#[test] #[test]
fn test_invalid_top_k_value(top_k in any::<i64>().prop_filter("Invalid top_k", |&k| k < -1 || (k > 0 && k < 1))) { fn test_invalid_top_k_value(top_k in any::<i32>().prop_filter("Invalid top_k", |&k| k < -1 || (k > 0 && k < 1))) {
let nv_ext = NvExt::builder() let nv_ext = NvExt::builder()
.top_k(top_k) .top_k(top_k)
.build() .build()
...@@ -227,7 +227,7 @@ mod tests { ...@@ -227,7 +227,7 @@ mod tests {
// Test valid repetition_penalty values // Test valid repetition_penalty values
proptest! { proptest! {
#[test] #[test]
fn test_valid_repetition_penalty_values(repetition_penalty in 0.01f64..=2.0f64) { fn test_valid_repetition_penalty_values(repetition_penalty in 0.01f32..=2.0f32) {
let nv_ext = NvExt::builder() let nv_ext = NvExt::builder()
.repetition_penalty(repetition_penalty) .repetition_penalty(repetition_penalty)
.build() .build()
...@@ -241,7 +241,7 @@ mod tests { ...@@ -241,7 +241,7 @@ mod tests {
// Test invalid repetition_penalty values // Test invalid repetition_penalty values
proptest! { proptest! {
#[test] #[test]
fn test_invalid_repetition_penalty_values(repetition_penalty in -10.0f64..0.0f64) { fn test_invalid_repetition_penalty_values(repetition_penalty in -10.0f32..0.0f32) {
let nv_ext = NvExt::builder() let nv_ext = NvExt::builder()
.repetition_penalty(repetition_penalty) .repetition_penalty(repetition_penalty)
.build() .build()
......
...@@ -280,3 +280,26 @@ fn test_min_tokens_only_at_root_level() { ...@@ -280,3 +280,26 @@ fn test_min_tokens_only_at_root_level() {
let stop_conditions = request.extract_stop_conditions().unwrap(); let stop_conditions = request.extract_stop_conditions().unwrap();
assert_eq!(stop_conditions.min_tokens, Some(150)); assert_eq!(stop_conditions.min_tokens, Some(150));
} }
#[test]
fn test_sampling_parameters_extraction() {
use dynamo_llm::protocols::common::SamplingOptionsProvider;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
use dynamo_llm::protocols::openai::common_ext::CommonExt;
// Test that top_k and repetition_penalty are extracted in sampling options when passed a top level
let request = NvCreateChatCompletionRequest {
inner: Default::default(),
common: CommonExt::builder()
.top_k(42)
.repetition_penalty(1.3)
.build()
.unwrap(),
nvext: None,
};
let sampling_options = request.extract_sampling_options().unwrap();
assert_eq!(sampling_options.top_k, Some(42));
assert_eq!(sampling_options.repetition_penalty, Some(1.3));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment