Unverified Commit d4f0d2bc authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: added more api error code validations (#3231)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 19f48dba
...@@ -711,12 +711,6 @@ pub fn validate_chat_completion_unsupported_fields( ...@@ -711,12 +711,6 @@ pub fn validate_chat_completion_unsupported_fields(
) -> Result<(), ErrorResponse> { ) -> Result<(), ErrorResponse> {
let inner = &request.inner; let inner = &request.inner;
if inner.parallel_tool_calls == Some(true) {
return Err(ErrorMessage::not_implemented_error(
"`parallel_tool_calls: true` is not supported.",
));
}
if inner.function_call.is_some() { if inner.function_call.is_some() {
return Err(ErrorMessage::not_implemented_error( return Err(ErrorMessage::not_implemented_error(
"`function_call` is deprecated. Please migrate to use `tool_choice` instead.", "`function_call` is deprecated. Please migrate to use `tool_choice` instead.",
......
...@@ -360,6 +360,10 @@ pub struct GuidedDecodingOptions { ...@@ -360,6 +360,10 @@ pub struct GuidedDecodingOptions {
/// If specified, the backend to use for guided decoding, can be backends like xgrammar or custom guided decoding backend /// If specified, the backend to use for guided decoding, can be backends like xgrammar or custom guided decoding backend
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub backend: Option<String>, pub backend: Option<String>,
/// If specified, whitespace pattern to use for guided decoding. Can be a string or null.
#[serde(skip_serializing_if = "Option::is_none")]
pub whitespace_pattern: Option<String>,
} }
impl GuidedDecodingOptions { impl GuidedDecodingOptions {
...@@ -370,6 +374,7 @@ impl GuidedDecodingOptions { ...@@ -370,6 +374,7 @@ impl GuidedDecodingOptions {
choice: Option<Vec<String>>, choice: Option<Vec<String>>,
grammar: Option<String>, grammar: Option<String>,
backend: Option<String>, backend: Option<String>,
whitespace_pattern: Option<String>,
) -> Self { ) -> Self {
Self { Self {
json, json,
...@@ -377,6 +382,7 @@ impl GuidedDecodingOptions { ...@@ -377,6 +382,7 @@ impl GuidedDecodingOptions {
choice, choice,
grammar, grammar,
backend, backend,
whitespace_pattern,
} }
} }
...@@ -387,8 +393,9 @@ impl GuidedDecodingOptions { ...@@ -387,8 +393,9 @@ impl GuidedDecodingOptions {
choice: Option<Vec<String>>, choice: Option<Vec<String>>,
grammar: Option<String>, grammar: Option<String>,
backend: Option<String>, backend: Option<String>,
whitespace_pattern: Option<String>,
) -> Result<Self> { ) -> Result<Self> {
let instance = Self::new(json, regex, choice, grammar, backend); let instance = Self::new(json, regex, choice, grammar, backend, whitespace_pattern);
instance.validate()?; instance.validate()?;
Ok(instance) Ok(instance)
} }
...@@ -400,12 +407,18 @@ impl GuidedDecodingOptions { ...@@ -400,12 +407,18 @@ impl GuidedDecodingOptions {
choice: Option<Vec<String>>, choice: Option<Vec<String>>,
grammar: Option<String>, grammar: Option<String>,
backend: Option<String>, backend: Option<String>,
whitespace_pattern: Option<String>,
) -> Result<Option<Self>> { ) -> Result<Option<Self>> {
let is_empty_choice = choice.as_ref().is_none_or(|v| v.is_empty()); let is_empty_choice = choice.as_ref().is_none_or(|v| v.is_empty());
if json.is_none() && regex.is_none() && is_empty_choice && grammar.is_none() { if json.is_none()
&& regex.is_none()
&& is_empty_choice
&& grammar.is_none()
&& whitespace_pattern.is_none()
{
return Ok(None); return Ok(None);
} }
let instance = Self::validated(json, regex, choice, grammar, backend)?; let instance = Self::validated(json, regex, choice, grammar, backend, whitespace_pattern)?;
Ok(Some(instance)) Ok(Some(instance))
} }
...@@ -416,6 +429,7 @@ impl GuidedDecodingOptions { ...@@ -416,6 +429,7 @@ impl GuidedDecodingOptions {
self.regex.is_some(), self.regex.is_some(),
self.choice.as_ref().is_some_and(|v| !v.is_empty()), self.choice.as_ref().is_some_and(|v| !v.is_empty()),
self.grammar.is_some(), self.grammar.is_some(),
self.whitespace_pattern.is_some(),
] ]
.iter() .iter()
.filter(|&&v| v) .filter(|&&v| v)
...@@ -674,7 +688,6 @@ mod tests { ...@@ -674,7 +688,6 @@ mod tests {
} }
#[test] #[test]
fn test_guided_decoding_options_new_and_exclusive() { fn test_guided_decoding_options_new_and_exclusive() {
// Only JSON set // Only JSON set
let json_val = serde_json::json!({"type": "object"}); let json_val = serde_json::json!({"type": "object"});
...@@ -685,6 +698,7 @@ mod tests { ...@@ -685,6 +698,7 @@ mod tests {
None, None,
None, None,
backend.clone(), backend.clone(),
None,
); );
assert!(opts.is_ok()); assert!(opts.is_ok());
let opts = opts.unwrap(); let opts = opts.unwrap();
...@@ -693,36 +707,58 @@ mod tests { ...@@ -693,36 +707,58 @@ mod tests {
assert!(opts.choice.is_none()); assert!(opts.choice.is_none());
assert!(opts.grammar.is_none()); assert!(opts.grammar.is_none());
assert_eq!(opts.backend, backend); assert_eq!(opts.backend, backend);
assert!(opts.whitespace_pattern.is_none());
// Only regex set // Only regex set
let regex = Some(r"\d+".to_string()); let regex = Some(r"\d+".to_string());
let opts = GuidedDecodingOptions::validated(None, regex.clone(), None, None, None); let opts = GuidedDecodingOptions::validated(None, regex.clone(), None, None, None, None);
assert!(opts.is_ok()); assert!(opts.is_ok());
let opts = opts.unwrap(); let opts = opts.unwrap();
assert_eq!(opts.regex, regex); assert_eq!(opts.regex, regex);
assert!(opts.json.is_none()); assert!(opts.json.is_none());
assert!(opts.choice.is_none()); assert!(opts.choice.is_none());
assert!(opts.grammar.is_none()); assert!(opts.grammar.is_none());
assert!(opts.whitespace_pattern.is_none());
// Only choice set // Only choice set
let choice = Some(vec!["A".to_string(), "B".to_string()]); let choice = Some(vec!["A".to_string(), "B".to_string()]);
let opts = GuidedDecodingOptions::validated(None, None, choice.clone(), None, None); let opts = GuidedDecodingOptions::validated(None, None, choice.clone(), None, None, None);
assert!(opts.is_ok()); assert!(opts.is_ok());
let opts = opts.unwrap(); let opts = opts.unwrap();
assert_eq!(opts.choice, choice); assert_eq!(opts.choice, choice);
assert!(opts.json.is_none()); assert!(opts.json.is_none());
assert!(opts.regex.is_none()); assert!(opts.regex.is_none());
assert!(opts.grammar.is_none()); assert!(opts.grammar.is_none());
assert!(opts.whitespace_pattern.is_none());
// Only grammar set // Only grammar set
let grammar = Some("root ::= 'yes' | 'no'".to_string()); let grammar = Some("root ::= 'yes' | 'no'".to_string());
let opts = GuidedDecodingOptions::validated(None, None, None, grammar.clone(), None); let opts = GuidedDecodingOptions::validated(None, None, None, grammar.clone(), None, None);
assert!(opts.is_ok()); assert!(opts.is_ok());
let opts = opts.unwrap(); let opts = opts.unwrap();
assert_eq!(opts.grammar, grammar); assert_eq!(opts.grammar, grammar);
assert!(opts.json.is_none()); assert!(opts.json.is_none());
assert!(opts.regex.is_none()); assert!(opts.regex.is_none());
assert!(opts.choice.is_none()); assert!(opts.choice.is_none());
assert!(opts.whitespace_pattern.is_none());
// Only whitespace_pattern set
let whitespace_pattern = Some(r"\s+".to_string());
let opts = GuidedDecodingOptions::validated(
None,
None,
None,
None,
None,
whitespace_pattern.clone(),
);
assert!(opts.is_ok());
let opts = opts.unwrap();
assert_eq!(opts.whitespace_pattern, whitespace_pattern);
assert!(opts.json.is_none());
assert!(opts.regex.is_none());
assert!(opts.choice.is_none());
assert!(opts.grammar.is_none());
// Multiple fields set (should error) // Multiple fields set (should error)
let opts = GuidedDecodingOptions::validated( let opts = GuidedDecodingOptions::validated(
...@@ -731,6 +767,7 @@ mod tests { ...@@ -731,6 +767,7 @@ mod tests {
None, None,
None, None,
None, None,
None,
); );
assert!(opts.is_err()); assert!(opts.is_err());
...@@ -740,6 +777,7 @@ mod tests { ...@@ -740,6 +777,7 @@ mod tests {
Some(vec!["A".to_string()]), Some(vec!["A".to_string()]),
None, None,
None, None,
None,
); );
assert!(opts.is_err()); assert!(opts.is_err());
...@@ -749,24 +787,26 @@ mod tests { ...@@ -749,24 +787,26 @@ mod tests {
Some(vec!["A".to_string()]), Some(vec!["A".to_string()]),
Some("root ::= 'yes'".to_string()), Some("root ::= 'yes'".to_string()),
None, None,
None,
); );
assert!(opts.is_err()); assert!(opts.is_err());
// All fields None (should be ok, but not useful) // All fields None (should be ok, but not useful)
let opts = GuidedDecodingOptions::validated(None, None, None, None, None); let opts = GuidedDecodingOptions::validated(None, None, None, None, None, None);
assert!(opts.is_ok()); assert!(opts.is_ok());
} }
#[test] #[test]
fn test_guided_decoding_options_from_optional() { fn test_guided_decoding_options_from_optional() {
// All None returns Ok(None) // All None returns Ok(None)
let opts = GuidedDecodingOptions::from_optional(None, None, None, None, None); let opts = GuidedDecodingOptions::from_optional(None, None, None, None, None, None);
assert!(opts.is_ok()); assert!(opts.is_ok());
assert!(opts.unwrap().is_none()); assert!(opts.unwrap().is_none());
// Only one set returns Ok(Some) // Only one set returns Ok(Some)
let regex = Some(r"\w+".to_string()); let regex = Some(r"\w+".to_string());
let opts = GuidedDecodingOptions::from_optional(None, regex.clone(), None, None, None); let opts =
GuidedDecodingOptions::from_optional(None, regex.clone(), None, None, None, None);
assert!(opts.is_ok()); assert!(opts.is_ok());
let val = opts.unwrap(); let val = opts.unwrap();
assert!(val.is_some()); assert!(val.is_some());
...@@ -780,11 +820,12 @@ mod tests { ...@@ -780,11 +820,12 @@ mod tests {
None, None,
None, None,
None, None,
None,
); );
assert!(opts.is_err()); assert!(opts.is_err());
// Choice set but empty vector should not count as set // Choice set but empty vector should not count as set
let opts = GuidedDecodingOptions::from_optional(None, None, Some(vec![]), None, None); let opts = GuidedDecodingOptions::from_optional(None, None, Some(vec![]), None, None, None);
assert!(opts.is_ok()); assert!(opts.is_ok());
let val = opts.unwrap(); let val = opts.unwrap();
assert!(val.is_none()); assert!(val.is_none());
...@@ -796,6 +837,7 @@ mod tests { ...@@ -796,6 +837,7 @@ mod tests {
Some(vec!["A".to_string()]), Some(vec!["A".to_string()]),
None, None,
None, None,
None,
); );
assert!(opts.is_ok()); assert!(opts.is_ok());
let val = opts.unwrap(); let val = opts.unwrap();
......
...@@ -133,6 +133,7 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid ...@@ -133,6 +133,7 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
let guided_regex = self.get_guided_regex(); let guided_regex = self.get_guided_regex();
let guided_grammar = self.get_guided_grammar(); let guided_grammar = self.get_guided_grammar();
let guided_choice = self.get_guided_choice(); let guided_choice = self.get_guided_choice();
let guided_whitespace_pattern = self.get_guided_whitespace_pattern();
let guided_decoding = match common::GuidedDecodingOptions::from_optional( let guided_decoding = match common::GuidedDecodingOptions::from_optional(
guided_json.cloned(), guided_json.cloned(),
...@@ -140,6 +141,7 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid ...@@ -140,6 +141,7 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
guided_choice, guided_choice,
guided_grammar, guided_grammar,
guided_decoding_backend, guided_decoding_backend,
guided_whitespace_pattern,
) { ) {
Ok(options) => options, Ok(options) => options,
Err(e) => { Err(e) => {
......
...@@ -206,6 +206,16 @@ impl CommonExtProvider for NvCreateChatCompletionRequest { ...@@ -206,6 +206,16 @@ impl CommonExtProvider for NvCreateChatCompletionRequest {
) )
} }
fn get_guided_whitespace_pattern(&self) -> Option<String> {
choose_with_deprecation(
"guided_whitespace_pattern",
self.common.guided_whitespace_pattern.as_ref(),
self.nvext
.as_ref()
.and_then(|nv| nv.guided_whitespace_pattern.as_ref()),
)
}
fn get_top_k(&self) -> Option<i32> { fn get_top_k(&self) -> Option<i32> {
choose_with_deprecation( choose_with_deprecation(
"top_k", "top_k",
...@@ -349,6 +359,8 @@ impl ValidateRequest for NvCreateChatCompletionRequest { ...@@ -349,6 +359,8 @@ impl ValidateRequest for NvCreateChatCompletionRequest {
// none for functions // none for functions
// Common Ext // Common Ext
validate::validate_repetition_penalty(self.get_repetition_penalty())?; validate::validate_repetition_penalty(self.get_repetition_penalty())?;
validate::validate_min_p(self.get_min_p())?;
validate::validate_top_k(self.get_top_k())?;
Ok(()) Ok(())
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::nvext::validate_top_k;
use derive_builder::Builder; use derive_builder::Builder;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use validator::Validate; use validator::Validate;
...@@ -25,20 +23,17 @@ pub struct CommonExt { ...@@ -25,20 +23,17 @@ pub struct CommonExt {
/// Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens. /// Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
#[validate(custom(function = "validate_top_k"))]
pub top_k: Option<i32>, pub top_k: Option<i32>,
/// Relative probability floor /// Relative probability floor
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
#[validate(range(min = 0.0, max = 1.0))]
pub min_p: Option<f32>, pub min_p: Option<f32>,
/// How much to penalize tokens based on how frequently they occur in the text. /// How much to penalize tokens based on how frequently they occur in the text.
/// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage. /// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
#[validate(range(exclusive_min = 0.0, max = 2.0))]
pub repetition_penalty: Option<f32>, pub repetition_penalty: Option<f32>,
/// include_stop_str_in_output /// include_stop_str_in_output
...@@ -71,6 +66,12 @@ pub struct CommonExt { ...@@ -71,6 +66,12 @@ pub struct CommonExt {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
pub guided_decoding_backend: Option<String>, pub guided_decoding_backend: Option<String>,
/// If specified, the output will follow the whitespace pattern. Can be a string or null.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
#[allow(unused)] // Not used
pub guided_whitespace_pattern: Option<String>,
} }
impl CommonExt { impl CommonExt {
...@@ -90,6 +91,8 @@ pub trait CommonExtProvider { ...@@ -90,6 +91,8 @@ pub trait CommonExtProvider {
fn get_guided_grammar(&self) -> Option<String>; fn get_guided_grammar(&self) -> Option<String>;
fn get_guided_choice(&self) -> Option<Vec<String>>; fn get_guided_choice(&self) -> Option<Vec<String>>;
fn get_guided_decoding_backend(&self) -> Option<String>; fn get_guided_decoding_backend(&self) -> Option<String>;
#[allow(unused)] // Not used
fn get_guided_whitespace_pattern(&self) -> Option<String>;
/// Other sampling Options /// Other sampling Options
fn get_top_k(&self) -> Option<i32>; fn get_top_k(&self) -> Option<i32>;
...@@ -215,6 +218,7 @@ mod tests { ...@@ -215,6 +218,7 @@ mod tests {
guided_grammar: None, guided_grammar: None,
guided_choice: None, guided_choice: None,
guided_decoding_backend: None, guided_decoding_backend: None,
guided_whitespace_pattern: None,
}; };
assert!(common_ext.validate().is_ok()); assert!(common_ext.validate().is_ok());
} }
......
...@@ -193,6 +193,16 @@ impl CommonExtProvider for NvCreateCompletionRequest { ...@@ -193,6 +193,16 @@ impl CommonExtProvider for NvCreateCompletionRequest {
) )
} }
fn get_guided_whitespace_pattern(&self) -> Option<String> {
choose_with_deprecation(
"guided_whitespace_pattern",
self.common.guided_whitespace_pattern.as_ref(),
self.nvext
.as_ref()
.and_then(|nv| nv.guided_whitespace_pattern.as_ref()),
)
}
fn get_top_k(&self) -> Option<i32> { fn get_top_k(&self) -> Option<i32> {
choose_with_deprecation( choose_with_deprecation(
"top_k", "top_k",
...@@ -435,6 +445,8 @@ impl ValidateRequest for NvCreateCompletionRequest { ...@@ -435,6 +445,8 @@ impl ValidateRequest for NvCreateCompletionRequest {
// Common Ext // Common Ext
validate::validate_repetition_penalty(self.get_repetition_penalty())?; validate::validate_repetition_penalty(self.get_repetition_penalty())?;
validate::validate_min_p(self.get_min_p())?;
validate::validate_top_k(self.get_top_k())?;
Ok(()) Ok(())
} }
......
...@@ -20,20 +20,17 @@ pub struct NvExt { ...@@ -20,20 +20,17 @@ pub struct NvExt {
pub ignore_eos: Option<bool>, pub ignore_eos: Option<bool>,
#[builder(default, setter(strip_option))] // NIM LLM might default to -1 #[builder(default, setter(strip_option))] // NIM LLM might default to -1
#[validate(custom(function = "validate_top_k"))]
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>, pub top_k: Option<i32>,
/// Relative probability floor /// Relative probability floor
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
#[validate(range(min = 0.0, max = 1.0))]
pub min_p: Option<f32>, pub min_p: Option<f32>,
/// How much to penalize tokens based on how frequently they occur in the text. /// How much to penalize tokens based on how frequently they occur in the text.
/// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage. /// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
#[validate(range(exclusive_min = 0.0, max = 2.0))]
pub repetition_penalty: Option<f32>, pub repetition_penalty: Option<f32>,
/// If true, sampling will be forced to be greedy. /// If true, sampling will be forced to be greedy.
...@@ -95,6 +92,11 @@ pub struct NvExt { ...@@ -95,6 +92,11 @@ pub struct NvExt {
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
pub guided_decoding_backend: Option<String>, pub guided_decoding_backend: Option<String>,
/// If specified, the output will follow the whitespace pattern. Can be a string or null.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_whitespace_pattern: Option<String>,
/// Maximum number of thinking tokens allowed /// Maximum number of thinking tokens allowed
/// NOTE: Currently passed through to backends as a no-op for future implementation /// NOTE: Currently passed through to backends as a no-op for future implementation
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
...@@ -118,15 +120,6 @@ fn validate_nv_ext(_nv_ext: &NvExt) -> Result<(), ValidationError> { ...@@ -118,15 +120,6 @@ fn validate_nv_ext(_nv_ext: &NvExt) -> Result<(), ValidationError> {
Ok(()) Ok(())
} }
pub fn validate_top_k(top_k: i32) -> Result<(), ValidationError> {
if top_k == -1 || (top_k >= 1) {
return Ok(());
}
let mut error = ValidationError::new("top_k");
error.message = Some("top_k must be -1 or greater than or equal to 1".into());
Err(error)
}
impl NvExtBuilder { impl NvExtBuilder {
pub fn add_annotation(&mut self, annotation: impl Into<String>) -> &mut Self { pub fn add_annotation(&mut self, annotation: impl Into<String>) -> &mut Self {
self.annotations self.annotations
...@@ -140,7 +133,6 @@ impl NvExtBuilder { ...@@ -140,7 +133,6 @@ impl NvExtBuilder {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use proptest::prelude::*;
use validator::Validate; use validator::Validate;
use super::*; use super::*;
...@@ -157,6 +149,7 @@ mod tests { ...@@ -157,6 +149,7 @@ mod tests {
assert_eq!(nv_ext.guided_regex, None); assert_eq!(nv_ext.guided_regex, None);
assert_eq!(nv_ext.guided_grammar, None); assert_eq!(nv_ext.guided_grammar, None);
assert_eq!(nv_ext.guided_choice, None); assert_eq!(nv_ext.guided_choice, None);
assert_eq!(nv_ext.guided_whitespace_pattern, None);
assert_eq!(nv_ext.max_thinking_tokens, None); assert_eq!(nv_ext.max_thinking_tokens, None);
} }
...@@ -199,59 +192,4 @@ mod tests { ...@@ -199,59 +192,4 @@ mod tests {
// Validate the built struct // Validate the built struct
assert!(nv_ext.validate().is_ok()); assert!(nv_ext.validate().is_ok());
} }
// Test invalid `top_k` validation using proptest
proptest! {
#[test]
fn test_invalid_top_k_value(top_k in any::<i32>().prop_filter("Invalid top_k", |&k| k < -1 || (k > 0 && k < 1))) {
let nv_ext = NvExt::builder()
.top_k(top_k)
.build()
.unwrap();
let validation_result = nv_ext.validate();
assert!(validation_result.is_err(), "top_k should fail validation if less than -1 or in the invalid range 0 < top_k < 1");
}
}
// Test valid `top_k` values
#[test]
fn test_valid_top_k_values() {
let nv_ext = NvExt::builder().top_k(-1).build().unwrap();
assert!(nv_ext.validate().is_ok());
let nv_ext = NvExt::builder().top_k(1).build().unwrap();
assert!(nv_ext.validate().is_ok());
let nv_ext = NvExt::builder().top_k(10).build().unwrap();
assert!(nv_ext.validate().is_ok());
}
// Test valid repetition_penalty values
proptest! {
#[test]
fn test_valid_repetition_penalty_values(repetition_penalty in 0.01f32..=2.0f32) {
let nv_ext = NvExt::builder()
.repetition_penalty(repetition_penalty)
.build()
.unwrap();
let validation_result = nv_ext.validate();
assert!(validation_result.is_ok(), "repetition_penalty should be valid within the range (0, 2]");
}
}
// Test invalid repetition_penalty values
proptest! {
#[test]
fn test_invalid_repetition_penalty_values(repetition_penalty in -10.0f32..0.0f32) {
let nv_ext = NvExt::builder()
.repetition_penalty(repetition_penalty)
.build()
.unwrap();
let validation_result = nv_ext.validate();
assert!(validation_result.is_err(), "repetition_penalty should fail validation when outside the range (0, 2]");
}
}
} }
...@@ -131,6 +131,15 @@ pub fn validate_top_p(top_p: Option<f32>) -> Result<(), anyhow::Error> { ...@@ -131,6 +131,15 @@ pub fn validate_top_p(top_p: Option<f32>) -> Result<(), anyhow::Error> {
Ok(()) Ok(())
} }
// Validate top_k
pub fn validate_top_k(top_k: Option<i32>) -> Result<(), anyhow::Error> {
match top_k {
None => Ok(()),
Some(k) if k == -1 || k >= 1 => Ok(()),
_ => anyhow::bail!("Top_k must be null, -1, or greater than or equal to 1"),
}
}
/// Validates mutual exclusion of temperature and top_p /// Validates mutual exclusion of temperature and top_p
pub fn validate_temperature_top_p_exclusion( pub fn validate_temperature_top_p_exclusion(
temperature: Option<f32>, temperature: Option<f32>,
...@@ -175,8 +184,9 @@ pub fn validate_presence_penalty(presence_penalty: Option<f32>) -> Result<(), an ...@@ -175,8 +184,9 @@ pub fn validate_presence_penalty(presence_penalty: Option<f32>) -> Result<(), an
} }
pub fn validate_repetition_penalty(repetition_penalty: Option<f32>) -> Result<(), anyhow::Error> { pub fn validate_repetition_penalty(repetition_penalty: Option<f32>) -> Result<(), anyhow::Error> {
// It should be greater than 0.0 and less than equal to 2.0
if let Some(penalty) = repetition_penalty if let Some(penalty) = repetition_penalty
&& !(MIN_REPETITION_PENALTY..=MAX_REPETITION_PENALTY).contains(&penalty) && (penalty <= MIN_REPETITION_PENALTY || penalty > MAX_REPETITION_PENALTY)
{ {
anyhow::bail!( anyhow::bail!(
"Repetition penalty must be between {} and {}, got {}", "Repetition penalty must be between {} and {}, got {}",
...@@ -188,6 +198,21 @@ pub fn validate_repetition_penalty(repetition_penalty: Option<f32>) -> Result<() ...@@ -188,6 +198,21 @@ pub fn validate_repetition_penalty(repetition_penalty: Option<f32>) -> Result<()
Ok(()) Ok(())
} }
/// Validates min_p parameter
pub fn validate_min_p(min_p: Option<f32>) -> Result<(), anyhow::Error> {
if let Some(p) = min_p
&& !(MIN_MIN_P..=MAX_MIN_P).contains(&p)
{
anyhow::bail!(
"Min_p must be between {} and {}, got {}",
MIN_MIN_P,
MAX_MIN_P,
p
);
}
Ok(())
}
/// Validates logit bias map /// Validates logit bias map
pub fn validate_logit_bias( pub fn validate_logit_bias(
logit_bias: &Option<std::collections::HashMap<String, serde_json::Value>>, logit_bias: &Option<std::collections::HashMap<String, serde_json::Value>>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment