Unverified Commit b954a249 authored by ryan-lempka's avatar ryan-lempka Committed by GitHub
Browse files

chore: remove deprecated nvext parameters for 6.0 (#3551)

parent 227846f2
...@@ -8,7 +8,7 @@ use super::{ ...@@ -8,7 +8,7 @@ use super::{
ContentProvider, ContentProvider,
common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider}, common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
}; };
use crate::protocols::openai::common_ext::{CommonExtProvider, choose_with_deprecation}; use crate::protocols::openai::common_ext::CommonExtProvider;
pub mod chat_completions; pub mod chat_completions;
pub mod common_ext; pub mod common_ext;
...@@ -65,14 +65,9 @@ trait OpenAIStopConditionsProvider { ...@@ -65,14 +65,9 @@ trait OpenAIStopConditionsProvider {
None None
} }
/// Get the effective ignore_eos value, considering both CommonExt and NvExt. /// Get the effective ignore_eos value from CommonExt.
/// CommonExt (root-level) takes precedence over NvExt.
fn get_ignore_eos(&self) -> Option<bool> { fn get_ignore_eos(&self) -> Option<bool> {
choose_with_deprecation( self.get_common_ignore_eos()
"ignore_eos",
self.get_common_ignore_eos().as_ref(),
self.nvext().and_then(|nv| nv.ignore_eos.as_ref()),
)
} }
/// Get max_thinking_tokens from nvext /// Get max_thinking_tokens from nvext
......
...@@ -9,9 +9,7 @@ use crate::engines::ValidateRequest; ...@@ -9,9 +9,7 @@ use crate::engines::ValidateRequest;
use super::{ use super::{
OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
common_ext::{ common_ext::{CommonExt, CommonExtProvider},
CommonExt, CommonExtProvider, choose_with_deprecation, emit_nvext_deprecation_warning,
},
nvext::NvExt, nvext::NvExt,
nvext::NvExtProvider, nvext::NvExtProvider,
validate, validate,
...@@ -158,88 +156,39 @@ impl CommonExtProvider for NvCreateChatCompletionRequest { ...@@ -158,88 +156,39 @@ impl CommonExtProvider for NvCreateChatCompletionRequest {
/// Guided Decoding Options /// Guided Decoding Options
fn get_guided_json(&self) -> Option<&serde_json::Value> { fn get_guided_json(&self) -> Option<&serde_json::Value> {
// Note: This one needs special handling since it returns a reference self.common.guided_json.as_ref()
if let Some(nvext) = &self.nvext
&& nvext.guided_json.is_some()
{
emit_nvext_deprecation_warning("guided_json", true, self.common.guided_json.is_some());
}
self.common
.guided_json
.as_ref()
.or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_json.as_ref()))
} }
fn get_guided_regex(&self) -> Option<String> { fn get_guided_regex(&self) -> Option<String> {
choose_with_deprecation( self.common.guided_regex.clone()
"guided_regex",
self.common.guided_regex.as_ref(),
self.nvext.as_ref().and_then(|nv| nv.guided_regex.as_ref()),
)
} }
fn get_guided_grammar(&self) -> Option<String> { fn get_guided_grammar(&self) -> Option<String> {
choose_with_deprecation( self.common.guided_grammar.clone()
"guided_grammar",
self.common.guided_grammar.as_ref(),
self.nvext
.as_ref()
.and_then(|nv| nv.guided_grammar.as_ref()),
)
} }
fn get_guided_choice(&self) -> Option<Vec<String>> { fn get_guided_choice(&self) -> Option<Vec<String>> {
choose_with_deprecation( self.common.guided_choice.clone()
"guided_choice",
self.common.guided_choice.as_ref(),
self.nvext.as_ref().and_then(|nv| nv.guided_choice.as_ref()),
)
} }
fn get_guided_decoding_backend(&self) -> Option<String> { fn get_guided_decoding_backend(&self) -> Option<String> {
choose_with_deprecation( self.common.guided_decoding_backend.clone()
"guided_decoding_backend",
self.common.guided_decoding_backend.as_ref(),
self.nvext
.as_ref()
.and_then(|nv| nv.guided_decoding_backend.as_ref()),
)
} }
fn get_guided_whitespace_pattern(&self) -> Option<String> { fn get_guided_whitespace_pattern(&self) -> Option<String> {
choose_with_deprecation( self.common.guided_whitespace_pattern.clone()
"guided_whitespace_pattern",
self.common.guided_whitespace_pattern.as_ref(),
self.nvext
.as_ref()
.and_then(|nv| nv.guided_whitespace_pattern.as_ref()),
)
} }
fn get_top_k(&self) -> Option<i32> { fn get_top_k(&self) -> Option<i32> {
choose_with_deprecation( self.common.top_k
"top_k",
self.common.top_k.as_ref(),
self.nvext.as_ref().and_then(|nv| nv.top_k.as_ref()),
)
} }
fn get_min_p(&self) -> Option<f32> { fn get_min_p(&self) -> Option<f32> {
choose_with_deprecation( self.common.min_p
"min_p",
self.common.min_p.as_ref(),
self.nvext.as_ref().and_then(|nv| nv.min_p.as_ref()),
)
} }
fn get_repetition_penalty(&self) -> Option<f32> { fn get_repetition_penalty(&self) -> Option<f32> {
choose_with_deprecation( self.common.repetition_penalty
"repetition_penalty",
self.common.repetition_penalty.as_ref(),
self.nvext
.as_ref()
.and_then(|nv| nv.repetition_penalty.as_ref()),
)
} }
fn get_include_stop_str_in_output(&self) -> Option<bool> { fn get_include_stop_str_in_output(&self) -> Option<bool> {
...@@ -287,14 +236,9 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest { ...@@ -287,14 +236,9 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
self.common.ignore_eos self.common.ignore_eos
} }
/// Get the effective ignore_eos value, considering both CommonExt and NvExt. /// Get the effective ignore_eos value from CommonExt.
/// CommonExt (root-level) takes precedence over NvExt.
fn get_ignore_eos(&self) -> Option<bool> { fn get_ignore_eos(&self) -> Option<bool> {
choose_with_deprecation( self.common.ignore_eos
"ignore_eos",
self.get_common_ignore_eos().as_ref(),
NvExtProvider::nvext(self).and_then(|nv| nv.ignore_eos.as_ref()),
)
} }
} }
......
...@@ -101,35 +101,6 @@ pub trait CommonExtProvider { ...@@ -101,35 +101,6 @@ pub trait CommonExtProvider {
fn get_include_stop_str_in_output(&self) -> Option<bool>; fn get_include_stop_str_in_output(&self) -> Option<bool>;
} }
/// Helper function to emit deprecation warnings for nvext parameters
pub fn emit_nvext_deprecation_warning(
field_name: &str,
nvext_has_value: bool,
common_has_value: bool,
) {
if nvext_has_value && !common_has_value {
tracing::warn!(
"DEPRECATION WARNING: 'nvext.{field_name}' is deprecated and will be removed in a future release. Use '{field_name}' at the top level or in 'extra_body' instead."
);
} else if nvext_has_value && common_has_value {
tracing::warn!(
"DEPRECATION WARNING: 'nvext.{field_name}' is deprecated and will be removed in a future release. Top-level '{field_name}' takes precedence. Use '{field_name}' at the top level or in 'extra_body' instead."
);
}
}
/// Helper function to choose between common and nvext values with deprecation warnings
pub fn choose_with_deprecation<T: Clone>(
field: &'static str,
common: Option<&T>,
nv: Option<&T>,
) -> Option<T> {
if nv.is_some() {
emit_nvext_deprecation_warning(field, true, common.is_some());
}
common.cloned().or_else(|| nv.cloned())
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
...@@ -248,23 +219,4 @@ mod tests { ...@@ -248,23 +219,4 @@ mod tests {
assert_eq!(common_ext.include_stop_str_in_output, None); assert_eq!(common_ext.include_stop_str_in_output, None);
assert!(common_ext.validate().is_ok()); assert!(common_ext.validate().is_ok());
} }
#[test]
fn test_choose_with_deprecation() {
// Common takes precedence
let result = choose_with_deprecation(
"test_field",
Some(&"common_value".to_string()),
Some(&"nvext_value".to_string()),
);
assert_eq!(result, Some("common_value".to_string()));
// Fallback to nvext
let result = choose_with_deprecation("test_field", None, Some(&"nvext_value".to_string()));
assert_eq!(result, Some("nvext_value".to_string()));
// Both None
let result: Option<String> = choose_with_deprecation("test_field", None, None);
assert_eq!(result, None);
}
} }
...@@ -12,9 +12,7 @@ use super::{ ...@@ -12,9 +12,7 @@ use super::{
ContentProvider, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider, ContentProvider, OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider,
OpenAIStopConditionsProvider, OpenAIStopConditionsProvider,
common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider}, common::{self, OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
common_ext::{ common_ext::{CommonExt, CommonExtProvider},
CommonExt, CommonExtProvider, choose_with_deprecation, emit_nvext_deprecation_warning,
},
nvext::{NvExt, NvExtProvider}, nvext::{NvExt, NvExtProvider},
validate, validate,
}; };
...@@ -149,88 +147,39 @@ impl CommonExtProvider for NvCreateCompletionRequest { ...@@ -149,88 +147,39 @@ impl CommonExtProvider for NvCreateCompletionRequest {
/// Guided Decoding Options /// Guided Decoding Options
fn get_guided_json(&self) -> Option<&serde_json::Value> { fn get_guided_json(&self) -> Option<&serde_json::Value> {
// Note: This one needs special handling since it returns a reference self.common.guided_json.as_ref()
if let Some(nvext) = &self.nvext
&& nvext.guided_json.is_some()
{
emit_nvext_deprecation_warning("guided_json", true, self.common.guided_json.is_some());
}
self.common
.guided_json
.as_ref()
.or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_json.as_ref()))
} }
fn get_guided_regex(&self) -> Option<String> { fn get_guided_regex(&self) -> Option<String> {
choose_with_deprecation( self.common.guided_regex.clone()
"guided_regex",
self.common.guided_regex.as_ref(),
self.nvext.as_ref().and_then(|nv| nv.guided_regex.as_ref()),
)
} }
fn get_guided_grammar(&self) -> Option<String> { fn get_guided_grammar(&self) -> Option<String> {
choose_with_deprecation( self.common.guided_grammar.clone()
"guided_grammar",
self.common.guided_grammar.as_ref(),
self.nvext
.as_ref()
.and_then(|nv| nv.guided_grammar.as_ref()),
)
} }
fn get_guided_choice(&self) -> Option<Vec<String>> { fn get_guided_choice(&self) -> Option<Vec<String>> {
choose_with_deprecation( self.common.guided_choice.clone()
"guided_choice",
self.common.guided_choice.as_ref(),
self.nvext.as_ref().and_then(|nv| nv.guided_choice.as_ref()),
)
} }
fn get_guided_decoding_backend(&self) -> Option<String> { fn get_guided_decoding_backend(&self) -> Option<String> {
choose_with_deprecation( self.common.guided_decoding_backend.clone()
"guided_decoding_backend",
self.common.guided_decoding_backend.as_ref(),
self.nvext
.as_ref()
.and_then(|nv| nv.guided_decoding_backend.as_ref()),
)
} }
fn get_guided_whitespace_pattern(&self) -> Option<String> { fn get_guided_whitespace_pattern(&self) -> Option<String> {
choose_with_deprecation( self.common.guided_whitespace_pattern.clone()
"guided_whitespace_pattern",
self.common.guided_whitespace_pattern.as_ref(),
self.nvext
.as_ref()
.and_then(|nv| nv.guided_whitespace_pattern.as_ref()),
)
} }
fn get_top_k(&self) -> Option<i32> { fn get_top_k(&self) -> Option<i32> {
choose_with_deprecation( self.common.top_k
"top_k",
self.common.top_k.as_ref(),
self.nvext.as_ref().and_then(|nv| nv.top_k.as_ref()),
)
} }
fn get_min_p(&self) -> Option<f32> { fn get_min_p(&self) -> Option<f32> {
choose_with_deprecation( self.common.min_p
"min_p",
self.common.min_p.as_ref(),
self.nvext.as_ref().and_then(|nv| nv.min_p.as_ref()),
)
} }
fn get_repetition_penalty(&self) -> Option<f32> { fn get_repetition_penalty(&self) -> Option<f32> {
choose_with_deprecation( self.common.repetition_penalty
"repetition_penalty",
self.common.repetition_penalty.as_ref(),
self.nvext
.as_ref()
.and_then(|nv| nv.repetition_penalty.as_ref()),
)
} }
fn get_include_stop_str_in_output(&self) -> Option<bool> { fn get_include_stop_str_in_output(&self) -> Option<bool> {
...@@ -259,14 +208,9 @@ impl OpenAIStopConditionsProvider for NvCreateCompletionRequest { ...@@ -259,14 +208,9 @@ impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
self.common.ignore_eos self.common.ignore_eos
} }
/// Get the effective ignore_eos value, considering both CommonExt and NvExt. /// Get the effective ignore_eos value from CommonExt.
/// CommonExt (root-level) takes precedence over NvExt.
fn get_ignore_eos(&self) -> Option<bool> { fn get_ignore_eos(&self) -> Option<bool> {
choose_with_deprecation( self.common.ignore_eos
"ignore_eos",
self.get_common_ignore_eos().as_ref(),
NvExtProvider::nvext(self).and_then(|nv| nv.ignore_eos.as_ref()),
)
} }
} }
......
...@@ -14,25 +14,6 @@ pub trait NvExtProvider { ...@@ -14,25 +14,6 @@ pub trait NvExtProvider {
#[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone)]
#[validate(schema(function = "validate_nv_ext"))] #[validate(schema(function = "validate_nv_ext"))]
pub struct NvExt { pub struct NvExt {
/// If true, the model will ignore the end of string token and generate to max_tokens.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub ignore_eos: Option<bool>,
#[builder(default, setter(strip_option))] // NIM LLM might default to -1
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
/// Relative probability floor
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub min_p: Option<f32>,
/// How much to penalize tokens based on how frequently they occur in the text.
/// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
#[builder(default, setter(strip_option))]
pub repetition_penalty: Option<f32>,
/// If true, sampling will be forced to be greedy. /// If true, sampling will be forced to be greedy.
/// The backend is responsible for selecting the correct backend-specific options to /// The backend is responsible for selecting the correct backend-specific options to
/// implement this. /// implement this.
...@@ -66,36 +47,6 @@ pub struct NvExt { ...@@ -66,36 +47,6 @@ pub struct NvExt {
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub token_data: Option<Vec<u32>>, pub token_data: Option<Vec<u32>>,
/// Guided Decoding Options
/// If specified, the output will be a JSON object. Can be a string, an object, or null.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_json: Option<serde_json::Value>,
/// If specified, the output will follow the regex pattern. Can be a string or null.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_regex: Option<String>,
/// If specified, the output will follow the context-free grammar. Can be a string or null.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_grammar: Option<String>,
/// If specified, the output will be exactly one of the choices.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_choice: Option<Vec<String>>,
/// If specified, the backend to use for guided decoding, can be backends like xgrammar or custom guided decoding backend
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_decoding_backend: Option<String>,
/// If specified, the output will follow the whitespace pattern. Can be a string or null.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_whitespace_pattern: Option<String>,
/// Maximum number of thinking tokens allowed /// Maximum number of thinking tokens allowed
/// NOTE: Currently passed through to backends as a no-op for future implementation /// NOTE: Currently passed through to backends as a no-op for future implementation
...@@ -141,15 +92,11 @@ mod tests { ...@@ -141,15 +92,11 @@ mod tests {
#[test] #[test]
fn test_nv_ext_builder_default() { fn test_nv_ext_builder_default() {
let nv_ext = NvExt::builder().build().unwrap(); let nv_ext = NvExt::builder().build().unwrap();
assert_eq!(nv_ext.ignore_eos, None);
assert_eq!(nv_ext.top_k, None);
assert_eq!(nv_ext.repetition_penalty, None);
assert_eq!(nv_ext.greed_sampling, None); assert_eq!(nv_ext.greed_sampling, None);
assert_eq!(nv_ext.guided_json, None); assert_eq!(nv_ext.use_raw_prompt, None);
assert_eq!(nv_ext.guided_regex, None); assert_eq!(nv_ext.annotations, None);
assert_eq!(nv_ext.guided_grammar, None); assert_eq!(nv_ext.backend_instance_id, None);
assert_eq!(nv_ext.guided_choice, None); assert_eq!(nv_ext.token_data, None);
assert_eq!(nv_ext.guided_whitespace_pattern, None);
assert_eq!(nv_ext.max_thinking_tokens, None); assert_eq!(nv_ext.max_thinking_tokens, None);
} }
...@@ -157,37 +104,18 @@ mod tests { ...@@ -157,37 +104,18 @@ mod tests {
#[test] #[test]
fn test_nv_ext_builder_custom() { fn test_nv_ext_builder_custom() {
let nv_ext = NvExt::builder() let nv_ext = NvExt::builder()
.ignore_eos(true)
.top_k(10)
.repetition_penalty(1.5)
.greed_sampling(true) .greed_sampling(true)
.guided_json(serde_json::json!({"type": "object"})) .use_raw_prompt(true)
.guided_regex("^[0-9]+$".to_string()) .backend_instance_id(42)
.guided_grammar("S -> 'a' S 'b' | 'c'".to_string()) .token_data(vec![1, 2, 3, 4])
.guided_choice(vec!["choice1".to_string(), "choice2".to_string()])
.guided_decoding_backend("xgrammar".to_string())
.max_thinking_tokens(1024) .max_thinking_tokens(1024)
.build() .build()
.unwrap(); .unwrap();
assert_eq!(nv_ext.ignore_eos, Some(true));
assert_eq!(nv_ext.top_k, Some(10));
assert_eq!(nv_ext.repetition_penalty, Some(1.5));
assert_eq!(nv_ext.greed_sampling, Some(true)); assert_eq!(nv_ext.greed_sampling, Some(true));
assert_eq!( assert_eq!(nv_ext.use_raw_prompt, Some(true));
nv_ext.guided_json, assert_eq!(nv_ext.backend_instance_id, Some(42));
Some(serde_json::json!({"type": "object"})) assert_eq!(nv_ext.token_data, Some(vec![1, 2, 3, 4]));
);
assert_eq!(nv_ext.guided_regex, Some("^[0-9]+$".to_string()));
assert_eq!(
nv_ext.guided_grammar,
Some("S -> 'a' S 'b' | 'c'".to_string())
);
assert_eq!(
nv_ext.guided_choice,
Some(vec!["choice1".to_string(), "choice2".to_string()])
);
assert_eq!(nv_ext.guided_decoding_backend, Some("xgrammar".to_string()));
assert_eq!(nv_ext.max_thinking_tokens, Some(1024)); assert_eq!(nv_ext.max_thinking_tokens, Some(1024));
// Validate the built struct // Validate the built struct
assert!(nv_ext.validate().is_ok()); assert!(nv_ext.validate().is_ok());
......
...@@ -156,32 +156,24 @@ fn test_chat_completions_guided_decoding_from_common() { ...@@ -156,32 +156,24 @@ fn test_chat_completions_guided_decoding_from_common() {
} }
#[test] #[test]
fn test_chat_completions_common_overrides_nvext() { fn test_chat_completions_common_values() {
// Test that root-level ignore_eos overrides nvext ignore_eos // Test that ignore_eos and guided_regex are read from common (root level)
let json_str = r#"{ let json_str = r#"{
"model": "test-model", "model": "test-model",
"messages": [{"role": "user", "content": "Hello"}], "messages": [{"role": "user", "content": "Hello"}],
"ignore_eos": false, "ignore_eos": false,
"guided_regex": ".*", "guided_regex": ".*",
"min_tokens": 50, "min_tokens": 50
"nvext": {
"ignore_eos": true,
"guided_regex": "./*"
}
}"#; }"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(request.common.ignore_eos, Some(false)); assert_eq!(request.common.ignore_eos, Some(false));
assert_eq!(request.common.guided_regex, Some(".*".to_string())); assert_eq!(request.common.guided_regex, Some(".*".to_string()));
assert_eq!( assert_eq!(request.get_guided_regex(), Some(".*".to_string()));
request.nvext.as_ref().and_then(|nv| nv.ignore_eos), // Verify extraction through stop conditions
Some(true)
);
assert_eq!(request.get_guided_regex(), Some(".*".to_string())); // common value takes precedence
// Verify precedence through stop conditions extraction
let stop_conditions = request.extract_stop_conditions().unwrap(); let stop_conditions = request.extract_stop_conditions().unwrap();
assert_eq!(stop_conditions.ignore_eos, Some(false)); // common value takes precedence assert_eq!(stop_conditions.ignore_eos, Some(false));
assert_eq!(stop_conditions.min_tokens, Some(50)); assert_eq!(stop_conditions.min_tokens, Some(50));
} }
...@@ -220,39 +212,21 @@ fn test_max_thinking_tokens_extraction() { ...@@ -220,39 +212,21 @@ fn test_max_thinking_tokens_extraction() {
} }
#[test] #[test]
fn test_chat_completions_backward_compatibility() { fn test_chat_completions_no_common_values() {
// Test backward compatibility - ignore_eos and guided_json only in nvext // Test that when no common values are set, we get None
let json_str = r#"{ let json_str = r#"{
"model": "test-model", "model": "test-model",
"messages": [{"role": "user", "content": "Hello"}], "messages": [{"role": "user", "content": "Hello"}]
"nvext": {
"ignore_eos": true,
"guided_json": {"key": "value"}
}
}"#; }"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(request.common.ignore_eos, None); assert_eq!(request.common.ignore_eos, None);
assert_eq!(request.common.guided_json, None); assert_eq!(request.common.guided_json, None);
assert_eq!( assert_eq!(request.get_guided_json(), None);
request.nvext.as_ref().and_then(|nv| nv.ignore_eos),
Some(true)
);
assert_eq!(
request
.nvext
.as_ref()
.and_then(|nv| nv.guided_json.as_ref()),
Some(&serde_json::json!({"key": "value"}))
);
assert_eq!(
request.get_guided_json(),
Some(&serde_json::json!({"key": "value"}))
);
// Verify through stop conditions extraction // Verify through stop conditions extraction
let stop_conditions = request.extract_stop_conditions().unwrap(); let stop_conditions = request.extract_stop_conditions().unwrap();
assert_eq!(stop_conditions.ignore_eos, Some(true)); assert_eq!(stop_conditions.ignore_eos, None);
assert_eq!(stop_conditions.min_tokens, None); assert_eq!(stop_conditions.min_tokens, None);
} }
...@@ -278,28 +252,21 @@ fn test_completions_ignore_eos_from_common() { ...@@ -278,28 +252,21 @@ fn test_completions_ignore_eos_from_common() {
} }
#[test] #[test]
fn test_completions_common_overrides_nvext() { fn test_completions_common_values() {
// Test that root-level ignore_eos overrides nvext ignore_eos for completions // Test that root-level ignore_eos is read from common for completions
let json_str = r#"{ let json_str = r#"{
"model": "test-model", "model": "test-model",
"prompt": "Hello world", "prompt": "Hello world",
"ignore_eos": false, "ignore_eos": false,
"min_tokens": 75, "min_tokens": 75
"nvext": {
"ignore_eos": true
}
}"#; }"#;
let request: NvCreateCompletionRequest = serde_json::from_str(json_str).unwrap(); let request: NvCreateCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(request.common.ignore_eos, Some(false)); assert_eq!(request.common.ignore_eos, Some(false));
assert_eq!( // Verify extraction through stop conditions
request.nvext.as_ref().and_then(|nv| nv.ignore_eos),
Some(true)
);
// Verify precedence through stop conditions extraction
let stop_conditions = request.extract_stop_conditions().unwrap(); let stop_conditions = request.extract_stop_conditions().unwrap();
assert_eq!(stop_conditions.ignore_eos, Some(false)); // common value takes precedence assert_eq!(stop_conditions.ignore_eos, Some(false));
assert_eq!(stop_conditions.min_tokens, Some(75)); assert_eq!(stop_conditions.min_tokens, Some(75));
} }
...@@ -325,7 +292,7 @@ fn test_serialization_preserves_structure() { ...@@ -325,7 +292,7 @@ fn test_serialization_preserves_structure() {
..Default::default() ..Default::default()
}, },
nvext: Some(NvExt { nvext: Some(NvExt {
ignore_eos: Some(false), greed_sampling: Some(false),
..Default::default() ..Default::default()
}), }),
chat_template_args: None, chat_template_args: None,
...@@ -337,11 +304,11 @@ fn test_serialization_preserves_structure() { ...@@ -337,11 +304,11 @@ fn test_serialization_preserves_structure() {
assert_eq!(json["model"], "test-model"); assert_eq!(json["model"], "test-model");
assert_eq!(json["ignore_eos"], true); // From common (flattened) assert_eq!(json["ignore_eos"], true); // From common (flattened)
assert_eq!(json["min_tokens"], 100); // From common (flattened) assert_eq!(json["min_tokens"], 100); // From common (flattened)
assert_eq!(json["nvext"]["ignore_eos"], false); // From nvext assert_eq!(json["nvext"]["greed_sampling"], false); // From nvext
// Verify precedence through stop conditions extraction // Verify extraction through stop conditions
let stop_conditions = request.extract_stop_conditions().unwrap(); let stop_conditions = request.extract_stop_conditions().unwrap();
assert_eq!(stop_conditions.ignore_eos, Some(true)); // common overrides nvext assert_eq!(stop_conditions.ignore_eos, Some(true));
assert_eq!(stop_conditions.min_tokens, Some(100)); assert_eq!(stop_conditions.min_tokens, Some(100));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment