Unverified Commit d4f0d2bc authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: added more api error code validations (#3231)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 19f48dba
......@@ -711,12 +711,6 @@ pub fn validate_chat_completion_unsupported_fields(
) -> Result<(), ErrorResponse> {
let inner = &request.inner;
if inner.parallel_tool_calls == Some(true) {
return Err(ErrorMessage::not_implemented_error(
"`parallel_tool_calls: true` is not supported.",
));
}
if inner.function_call.is_some() {
return Err(ErrorMessage::not_implemented_error(
"`function_call` is deprecated. Please migrate to use `tool_choice` instead.",
......
......@@ -360,6 +360,10 @@ pub struct GuidedDecodingOptions {
/// If specified, the backend to use for guided decoding, can be backends like xgrammar or custom guided decoding backend
#[serde(skip_serializing_if = "Option::is_none")]
pub backend: Option<String>,
/// If specified, whitespace pattern to use for guided decoding. Can be a string or null.
#[serde(skip_serializing_if = "Option::is_none")]
pub whitespace_pattern: Option<String>,
}
impl GuidedDecodingOptions {
......@@ -370,6 +374,7 @@ impl GuidedDecodingOptions {
choice: Option<Vec<String>>,
grammar: Option<String>,
backend: Option<String>,
whitespace_pattern: Option<String>,
) -> Self {
Self {
json,
......@@ -377,6 +382,7 @@ impl GuidedDecodingOptions {
choice,
grammar,
backend,
whitespace_pattern,
}
}
......@@ -387,8 +393,9 @@ impl GuidedDecodingOptions {
choice: Option<Vec<String>>,
grammar: Option<String>,
backend: Option<String>,
whitespace_pattern: Option<String>,
) -> Result<Self> {
let instance = Self::new(json, regex, choice, grammar, backend);
let instance = Self::new(json, regex, choice, grammar, backend, whitespace_pattern);
instance.validate()?;
Ok(instance)
}
......@@ -400,12 +407,18 @@ impl GuidedDecodingOptions {
choice: Option<Vec<String>>,
grammar: Option<String>,
backend: Option<String>,
whitespace_pattern: Option<String>,
) -> Result<Option<Self>> {
let is_empty_choice = choice.as_ref().is_none_or(|v| v.is_empty());
if json.is_none() && regex.is_none() && is_empty_choice && grammar.is_none() {
if json.is_none()
&& regex.is_none()
&& is_empty_choice
&& grammar.is_none()
&& whitespace_pattern.is_none()
{
return Ok(None);
}
let instance = Self::validated(json, regex, choice, grammar, backend)?;
let instance = Self::validated(json, regex, choice, grammar, backend, whitespace_pattern)?;
Ok(Some(instance))
}
......@@ -416,6 +429,7 @@ impl GuidedDecodingOptions {
self.regex.is_some(),
self.choice.as_ref().is_some_and(|v| !v.is_empty()),
self.grammar.is_some(),
self.whitespace_pattern.is_some(),
]
.iter()
.filter(|&&v| v)
......@@ -674,7 +688,6 @@ mod tests {
}
#[test]
fn test_guided_decoding_options_new_and_exclusive() {
// Only JSON set
let json_val = serde_json::json!({"type": "object"});
......@@ -685,6 +698,7 @@ mod tests {
None,
None,
backend.clone(),
None,
);
assert!(opts.is_ok());
let opts = opts.unwrap();
......@@ -693,36 +707,58 @@ mod tests {
assert!(opts.choice.is_none());
assert!(opts.grammar.is_none());
assert_eq!(opts.backend, backend);
assert!(opts.whitespace_pattern.is_none());
// Only regex set
let regex = Some(r"\d+".to_string());
let opts = GuidedDecodingOptions::validated(None, regex.clone(), None, None, None);
let opts = GuidedDecodingOptions::validated(None, regex.clone(), None, None, None, None);
assert!(opts.is_ok());
let opts = opts.unwrap();
assert_eq!(opts.regex, regex);
assert!(opts.json.is_none());
assert!(opts.choice.is_none());
assert!(opts.grammar.is_none());
assert!(opts.whitespace_pattern.is_none());
// Only choice set
let choice = Some(vec!["A".to_string(), "B".to_string()]);
let opts = GuidedDecodingOptions::validated(None, None, choice.clone(), None, None);
let opts = GuidedDecodingOptions::validated(None, None, choice.clone(), None, None, None);
assert!(opts.is_ok());
let opts = opts.unwrap();
assert_eq!(opts.choice, choice);
assert!(opts.json.is_none());
assert!(opts.regex.is_none());
assert!(opts.grammar.is_none());
assert!(opts.whitespace_pattern.is_none());
// Only grammar set
let grammar = Some("root ::= 'yes' | 'no'".to_string());
let opts = GuidedDecodingOptions::validated(None, None, None, grammar.clone(), None);
let opts = GuidedDecodingOptions::validated(None, None, None, grammar.clone(), None, None);
assert!(opts.is_ok());
let opts = opts.unwrap();
assert_eq!(opts.grammar, grammar);
assert!(opts.json.is_none());
assert!(opts.regex.is_none());
assert!(opts.choice.is_none());
assert!(opts.whitespace_pattern.is_none());
// Only whitespace_pattern set
let whitespace_pattern = Some(r"\s+".to_string());
let opts = GuidedDecodingOptions::validated(
None,
None,
None,
None,
None,
whitespace_pattern.clone(),
);
assert!(opts.is_ok());
let opts = opts.unwrap();
assert_eq!(opts.whitespace_pattern, whitespace_pattern);
assert!(opts.json.is_none());
assert!(opts.regex.is_none());
assert!(opts.choice.is_none());
assert!(opts.grammar.is_none());
// Multiple fields set (should error)
let opts = GuidedDecodingOptions::validated(
......@@ -731,6 +767,7 @@ mod tests {
None,
None,
None,
None,
);
assert!(opts.is_err());
......@@ -740,6 +777,7 @@ mod tests {
Some(vec!["A".to_string()]),
None,
None,
None,
);
assert!(opts.is_err());
......@@ -749,24 +787,26 @@ mod tests {
Some(vec!["A".to_string()]),
Some("root ::= 'yes'".to_string()),
None,
None,
);
assert!(opts.is_err());
// All fields None (should be ok, but not useful)
let opts = GuidedDecodingOptions::validated(None, None, None, None, None);
let opts = GuidedDecodingOptions::validated(None, None, None, None, None, None);
assert!(opts.is_ok());
}
#[test]
fn test_guided_decoding_options_from_optional() {
// All None returns Ok(None)
let opts = GuidedDecodingOptions::from_optional(None, None, None, None, None);
let opts = GuidedDecodingOptions::from_optional(None, None, None, None, None, None);
assert!(opts.is_ok());
assert!(opts.unwrap().is_none());
// Only one set returns Ok(Some)
let regex = Some(r"\w+".to_string());
let opts = GuidedDecodingOptions::from_optional(None, regex.clone(), None, None, None);
let opts =
GuidedDecodingOptions::from_optional(None, regex.clone(), None, None, None, None);
assert!(opts.is_ok());
let val = opts.unwrap();
assert!(val.is_some());
......@@ -780,11 +820,12 @@ mod tests {
None,
None,
None,
None,
);
assert!(opts.is_err());
// Choice set but empty vector should not count as set
let opts = GuidedDecodingOptions::from_optional(None, None, Some(vec![]), None, None);
let opts = GuidedDecodingOptions::from_optional(None, None, Some(vec![]), None, None, None);
assert!(opts.is_ok());
let val = opts.unwrap();
assert!(val.is_none());
......@@ -796,6 +837,7 @@ mod tests {
Some(vec!["A".to_string()]),
None,
None,
None,
);
assert!(opts.is_ok());
let val = opts.unwrap();
......
......@@ -133,6 +133,7 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
let guided_regex = self.get_guided_regex();
let guided_grammar = self.get_guided_grammar();
let guided_choice = self.get_guided_choice();
let guided_whitespace_pattern = self.get_guided_whitespace_pattern();
let guided_decoding = match common::GuidedDecodingOptions::from_optional(
guided_json.cloned(),
......@@ -140,6 +141,7 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
guided_choice,
guided_grammar,
guided_decoding_backend,
guided_whitespace_pattern,
) {
Ok(options) => options,
Err(e) => {
......
......@@ -206,6 +206,16 @@ impl CommonExtProvider for NvCreateChatCompletionRequest {
)
}
fn get_guided_whitespace_pattern(&self) -> Option<String> {
choose_with_deprecation(
"guided_whitespace_pattern",
self.common.guided_whitespace_pattern.as_ref(),
self.nvext
.as_ref()
.and_then(|nv| nv.guided_whitespace_pattern.as_ref()),
)
}
fn get_top_k(&self) -> Option<i32> {
choose_with_deprecation(
"top_k",
......@@ -349,6 +359,8 @@ impl ValidateRequest for NvCreateChatCompletionRequest {
// none for functions
// Common Ext
validate::validate_repetition_penalty(self.get_repetition_penalty())?;
validate::validate_min_p(self.get_min_p())?;
validate::validate_top_k(self.get_top_k())?;
Ok(())
}
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::nvext::validate_top_k;
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use validator::Validate;
......@@ -25,20 +23,17 @@ pub struct CommonExt {
/// Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
#[validate(custom(function = "validate_top_k"))]
pub top_k: Option<i32>,
/// Relative probability floor
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
#[validate(range(min = 0.0, max = 1.0))]
pub min_p: Option<f32>,
/// How much to penalize tokens based on how frequently they occur in the text.
/// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
#[validate(range(exclusive_min = 0.0, max = 2.0))]
pub repetition_penalty: Option<f32>,
/// include_stop_str_in_output
......@@ -71,6 +66,12 @@ pub struct CommonExt {
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_decoding_backend: Option<String>,
/// If specified, the output will follow the whitespace pattern. Can be a string or null.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
#[allow(unused)] // Not used
pub guided_whitespace_pattern: Option<String>,
}
impl CommonExt {
......@@ -90,6 +91,8 @@ pub trait CommonExtProvider {
fn get_guided_grammar(&self) -> Option<String>;
fn get_guided_choice(&self) -> Option<Vec<String>>;
fn get_guided_decoding_backend(&self) -> Option<String>;
#[allow(unused)] // Not used
fn get_guided_whitespace_pattern(&self) -> Option<String>;
/// Other sampling Options
fn get_top_k(&self) -> Option<i32>;
......@@ -215,6 +218,7 @@ mod tests {
guided_grammar: None,
guided_choice: None,
guided_decoding_backend: None,
guided_whitespace_pattern: None,
};
assert!(common_ext.validate().is_ok());
}
......
......@@ -193,6 +193,16 @@ impl CommonExtProvider for NvCreateCompletionRequest {
)
}
fn get_guided_whitespace_pattern(&self) -> Option<String> {
choose_with_deprecation(
"guided_whitespace_pattern",
self.common.guided_whitespace_pattern.as_ref(),
self.nvext
.as_ref()
.and_then(|nv| nv.guided_whitespace_pattern.as_ref()),
)
}
fn get_top_k(&self) -> Option<i32> {
choose_with_deprecation(
"top_k",
......@@ -435,6 +445,8 @@ impl ValidateRequest for NvCreateCompletionRequest {
// Common Ext
validate::validate_repetition_penalty(self.get_repetition_penalty())?;
validate::validate_min_p(self.get_min_p())?;
validate::validate_top_k(self.get_top_k())?;
Ok(())
}
......
......@@ -20,20 +20,17 @@ pub struct NvExt {
pub ignore_eos: Option<bool>,
#[builder(default, setter(strip_option))] // NIM LLM might default to -1
#[validate(custom(function = "validate_top_k"))]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
/// Relative probability floor
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
#[validate(range(min = 0.0, max = 1.0))]
pub min_p: Option<f32>,
/// How much to penalize tokens based on how frequently they occur in the text.
/// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
#[builder(default, setter(strip_option))]
#[validate(range(exclusive_min = 0.0, max = 2.0))]
pub repetition_penalty: Option<f32>,
/// If true, sampling will be forced to be greedy.
......@@ -95,6 +92,11 @@ pub struct NvExt {
#[builder(default, setter(strip_option))]
pub guided_decoding_backend: Option<String>,
/// If specified, the output will follow the whitespace pattern. Can be a string or null.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_whitespace_pattern: Option<String>,
/// Maximum number of thinking tokens allowed
/// NOTE: Currently passed through to backends as a no-op for future implementation
#[serde(default, skip_serializing_if = "Option::is_none")]
......@@ -118,15 +120,6 @@ fn validate_nv_ext(_nv_ext: &NvExt) -> Result<(), ValidationError> {
Ok(())
}
pub fn validate_top_k(top_k: i32) -> Result<(), ValidationError> {
if top_k == -1 || (top_k >= 1) {
return Ok(());
}
let mut error = ValidationError::new("top_k");
error.message = Some("top_k must be -1 or greater than or equal to 1".into());
Err(error)
}
impl NvExtBuilder {
pub fn add_annotation(&mut self, annotation: impl Into<String>) -> &mut Self {
self.annotations
......@@ -140,7 +133,6 @@ impl NvExtBuilder {
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use validator::Validate;
use super::*;
......@@ -157,6 +149,7 @@ mod tests {
assert_eq!(nv_ext.guided_regex, None);
assert_eq!(nv_ext.guided_grammar, None);
assert_eq!(nv_ext.guided_choice, None);
assert_eq!(nv_ext.guided_whitespace_pattern, None);
assert_eq!(nv_ext.max_thinking_tokens, None);
}
......@@ -199,59 +192,4 @@ mod tests {
// Validate the built struct
assert!(nv_ext.validate().is_ok());
}
// Test invalid `top_k` validation using proptest
proptest! {
#[test]
fn test_invalid_top_k_value(top_k in any::<i32>().prop_filter("Invalid top_k", |&k| k < -1 || (k > 0 && k < 1))) {
let nv_ext = NvExt::builder()
.top_k(top_k)
.build()
.unwrap();
let validation_result = nv_ext.validate();
assert!(validation_result.is_err(), "top_k should fail validation if less than -1 or in the invalid range 0 < top_k < 1");
}
}
// Test valid `top_k` values
#[test]
fn test_valid_top_k_values() {
let nv_ext = NvExt::builder().top_k(-1).build().unwrap();
assert!(nv_ext.validate().is_ok());
let nv_ext = NvExt::builder().top_k(1).build().unwrap();
assert!(nv_ext.validate().is_ok());
let nv_ext = NvExt::builder().top_k(10).build().unwrap();
assert!(nv_ext.validate().is_ok());
}
// Test valid repetition_penalty values
proptest! {
#[test]
fn test_valid_repetition_penalty_values(repetition_penalty in 0.01f32..=2.0f32) {
let nv_ext = NvExt::builder()
.repetition_penalty(repetition_penalty)
.build()
.unwrap();
let validation_result = nv_ext.validate();
assert!(validation_result.is_ok(), "repetition_penalty should be valid within the range (0, 2]");
}
}
// Test invalid repetition_penalty values
proptest! {
#[test]
fn test_invalid_repetition_penalty_values(repetition_penalty in -10.0f32..0.0f32) {
let nv_ext = NvExt::builder()
.repetition_penalty(repetition_penalty)
.build()
.unwrap();
let validation_result = nv_ext.validate();
assert!(validation_result.is_err(), "repetition_penalty should fail validation when outside the range (0, 2]");
}
}
}
......@@ -131,6 +131,15 @@ pub fn validate_top_p(top_p: Option<f32>) -> Result<(), anyhow::Error> {
Ok(())
}
// Validate top_k
pub fn validate_top_k(top_k: Option<i32>) -> Result<(), anyhow::Error> {
match top_k {
None => Ok(()),
Some(k) if k == -1 || k >= 1 => Ok(()),
_ => anyhow::bail!("Top_k must be null, -1, or greater than or equal to 1"),
}
}
/// Validates mutual exclusion of temperature and top_p
pub fn validate_temperature_top_p_exclusion(
temperature: Option<f32>,
......@@ -175,8 +184,9 @@ pub fn validate_presence_penalty(presence_penalty: Option<f32>) -> Result<(), an
}
pub fn validate_repetition_penalty(repetition_penalty: Option<f32>) -> Result<(), anyhow::Error> {
// It should be greater than 0.0 and less than equal to 2.0
if let Some(penalty) = repetition_penalty
&& !(MIN_REPETITION_PENALTY..=MAX_REPETITION_PENALTY).contains(&penalty)
&& (penalty <= MIN_REPETITION_PENALTY || penalty > MAX_REPETITION_PENALTY)
{
anyhow::bail!(
"Repetition penalty must be between {} and {}, got {}",
......@@ -188,6 +198,21 @@ pub fn validate_repetition_penalty(repetition_penalty: Option<f32>) -> Result<()
Ok(())
}
/// Validates min_p parameter
pub fn validate_min_p(min_p: Option<f32>) -> Result<(), anyhow::Error> {
if let Some(p) = min_p
&& !(MIN_MIN_P..=MAX_MIN_P).contains(&p)
{
anyhow::bail!(
"Min_p must be between {} and {}, got {}",
MIN_MIN_P,
MAX_MIN_P,
p
);
}
Ok(())
}
/// Validates logit bias map
pub fn validate_logit_bias(
logit_bias: &Option<std::collections::HashMap<String, serde_json::Value>>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment