Unverified Commit 18bb779e authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

feat: Add frontend support for `min_tokens` and `ignore_eos` (outside of...


feat: Add frontend support for `min_tokens` and `ignore_eos` (outside of `nvext`) and Structured Output / Guided Decoding (#2380)
Signed-off-by: default avatarKrishnanPrash <140860868+KrishnanPrash@users.noreply.github.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
Co-authored-by: default avatarAyush Agarwal <ayushag@nvidia.com>
parent 38cf0f8d
...@@ -222,7 +222,11 @@ async fn evaluate( ...@@ -222,7 +222,11 @@ async fn evaluate(
) )
.temperature(template.as_ref().map_or(0.7, |t| t.temperature)) .temperature(template.as_ref().map_or(0.7, |t| t.temperature))
.build()?; .build()?;
let req = NvCreateChatCompletionRequest { inner, nvext: None }; let req = NvCreateChatCompletionRequest {
inner,
common: Default::default(),
nvext: None,
};
let mut stream = engine.generate(Context::new(req)).await?; let mut stream = engine.generate(Context::new(req)).await?;
let mut output = String::new(); let mut output = String::new();
while let Some(item) = stream.next().await { while let Some(item) = stream.next().await {
......
...@@ -118,6 +118,7 @@ async fn main_loop( ...@@ -118,6 +118,7 @@ async fn main_loop(
let req = NvCreateChatCompletionRequest { let req = NvCreateChatCompletionRequest {
inner, inner,
common: Default::default(),
nvext: Some(nvext), nvext: Some(nvext),
}; };
......
...@@ -1237,6 +1237,7 @@ mod tests { ...@@ -1237,6 +1237,7 @@ mod tests {
messages: vec![], messages: vec![],
..Default::default() ..Default::default()
}, },
common: Default::default(),
nvext: None, nvext: None,
}; };
let result = validate_chat_completion_required_fields(&request); let result = validate_chat_completion_required_fields(&request);
...@@ -1263,6 +1264,7 @@ mod tests { ...@@ -1263,6 +1264,7 @@ mod tests {
)], )],
..Default::default() ..Default::default()
}, },
common: Default::default(),
nvext: None, nvext: None,
}; };
let result = validate_chat_completion_required_fields(&request); let result = validate_chat_completion_required_fields(&request);
......
...@@ -20,8 +20,10 @@ use super::{ ...@@ -20,8 +20,10 @@ use super::{
common::{self, SamplingOptionsProvider, StopConditionsProvider}, common::{self, SamplingOptionsProvider, StopConditionsProvider},
ContentProvider, ContentProvider,
}; };
use crate::protocols::openai::common_ext::CommonExtProvider;
pub mod chat_completions; pub mod chat_completions;
pub mod common_ext;
pub mod completions; pub mod completions;
pub mod embeddings; pub mod embeddings;
pub mod models; pub mod models;
...@@ -61,9 +63,23 @@ trait OpenAIStopConditionsProvider { ...@@ -61,9 +63,23 @@ trait OpenAIStopConditionsProvider {
fn get_stop(&self) -> Option<Vec<String>>; fn get_stop(&self) -> Option<Vec<String>>;
fn nvext(&self) -> Option<&nvext::NvExt>; fn nvext(&self) -> Option<&nvext::NvExt>;
/// Get ignore_eos from CommonExt if the type supports it.
/// Default returns None for types without CommonExt support.
fn get_common_ignore_eos(&self) -> Option<bool> {
None
}
/// Get the effective ignore_eos value, considering both CommonExt and NvExt.
/// CommonExt (root-level) takes precedence over NvExt.
fn get_ignore_eos(&self) -> Option<bool> {
// Check common first (takes precedence), then fall back to nvext
self.get_common_ignore_eos()
.or_else(|| self.nvext().and_then(|nv| nv.ignore_eos))
}
} }
impl<T: OpenAISamplingOptionsProvider> SamplingOptionsProvider for T { impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvider for T {
fn extract_sampling_options(&self) -> Result<common::SamplingOptions> { fn extract_sampling_options(&self) -> Result<common::SamplingOptions> {
// let result = self.validate(); // let result = self.validate();
// if let Err(e) = result { // if let Err(e) = result {
...@@ -88,29 +104,26 @@ impl<T: OpenAISamplingOptionsProvider> SamplingOptionsProvider for T { ...@@ -88,29 +104,26 @@ impl<T: OpenAISamplingOptionsProvider> SamplingOptionsProvider for T {
} }
} }
let mut guided_decoding = None; let guided_decoding_backend = self.get_guided_decoding_backend();
if let Some(nvext) = self.nvext() { let guided_json = self.get_guided_json();
let guided_decoding_backend = nvext.guided_decoding_backend.clone(); let guided_regex = self.get_guided_regex();
let guided_json = nvext.guided_json.clone(); let guided_grammar = self.get_guided_grammar();
let guided_regex = nvext.guided_regex.clone(); let guided_choice = self.get_guided_choice();
let guided_grammar = nvext.guided_grammar.clone();
let guided_choice = nvext.guided_choice.clone(); let guided_decoding = match common::GuidedDecodingOptions::from_optional(
guided_json.cloned(),
match common::GuidedDecodingOptions::from_optional(
guided_json,
guided_regex, guided_regex,
guided_choice, guided_choice,
guided_grammar, guided_grammar,
guided_decoding_backend, guided_decoding_backend,
) { ) {
Ok(options) => guided_decoding = options, Ok(options) => options,
Err(e) => { Err(e) => {
// Handle the validation error (log, return error, etc.) // Handle the validation error (log, return error, etc.)
tracing::error!("Invalid guided decoding options: {}", e); tracing::error!("Invalid guided decoding options: {:?}", e);
return Err(e); return Err(e);
} }
} };
}
Ok(common::SamplingOptions { Ok(common::SamplingOptions {
n: None, n: None,
...@@ -142,11 +155,8 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T { ...@@ -142,11 +155,8 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
} }
} }
let mut ignore_eos = None; // Use the trait method to get ignore_eos, which handles precedence
let ignore_eos = self.get_ignore_eos();
if let Some(nvext) = self.nvext() {
ignore_eos = nvext.ignore_eos;
}
Ok(common::StopConditions { Ok(common::StopConditions {
max_tokens, max_tokens,
......
...@@ -20,8 +20,10 @@ use validator::Validate; ...@@ -20,8 +20,10 @@ use validator::Validate;
use crate::engines::ValidateRequest; use crate::engines::ValidateRequest;
use super::{ use super::{
nvext::NvExt, nvext::NvExtProvider, validate, OpenAISamplingOptionsProvider, common_ext::{CommonExt, CommonExtProvider},
OpenAIStopConditionsProvider, nvext::NvExt,
nvext::NvExtProvider,
validate, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
}; };
mod aggregator; mod aggregator;
...@@ -31,17 +33,21 @@ pub use aggregator::DeltaAggregator; ...@@ -31,17 +33,21 @@ pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator; pub use delta::DeltaGenerator;
/// A request structure for creating a chat completion, extending OpenAI's /// A request structure for creating a chat completion, extending OpenAI's
/// `CreateChatCompletionRequest` with [`NvExt`] extensions. /// `CreateChatCompletionRequest` with [`NvExt`] extensions and common fields.
/// ///
/// # Fields /// # Fields
/// - `inner`: The base OpenAI chat completion request, embedded using `serde(flatten)`. /// - `inner`: The base OpenAI chat completion request, embedded using `serde(flatten)`.
/// - `nvext`: The optional NVIDIA extension field. See [`NvExt`] for /// - `common`: Common extension fields (ignore_eos, min_tokens) at root level, embedded using `serde(flatten)`.
/// more details. /// - `nvext`: The optional NVIDIA extension field. See [`NvExt`] for more details.
/// Note: If ignore_eos is specified in both common and nvext, the common (root-level) value takes precedence.
#[derive(Serialize, Deserialize, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateChatCompletionRequest { pub struct NvCreateChatCompletionRequest {
#[serde(flatten)] #[serde(flatten)]
pub inner: async_openai::types::CreateChatCompletionRequest, pub inner: async_openai::types::CreateChatCompletionRequest,
#[serde(flatten, default)]
pub common: CommonExt,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<NvExt>, pub nvext: Option<NvExt>,
} }
...@@ -139,6 +145,52 @@ impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest { ...@@ -139,6 +145,52 @@ impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest {
} }
} }
/// Implements `CommonExtProvider` for `NvCreateChatCompletionRequest`,
/// providing access to common extension fields.
impl CommonExtProvider for NvCreateChatCompletionRequest {
/// Returns a reference to the CommonExt struct.
fn common_ext(&self) -> Option<&CommonExt> {
Some(&self.common)
}
/// Guided Decoding Options
fn get_guided_json(&self) -> Option<&serde_json::Value> {
self.common
.guided_json
.as_ref()
.or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_json.as_ref()))
}
fn get_guided_regex(&self) -> Option<String> {
self.common
.guided_regex
.clone()
.or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_regex.clone()))
}
fn get_guided_grammar(&self) -> Option<String> {
self.common
.guided_grammar
.clone()
.or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_grammar.clone()))
}
fn get_guided_choice(&self) -> Option<Vec<String>> {
self.common
.guided_choice
.clone()
.or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_choice.clone()))
}
fn get_guided_decoding_backend(&self) -> Option<String> {
self.common.guided_decoding_backend.clone().or_else(|| {
self.nvext
.as_ref()
.and_then(|nv| nv.guided_decoding_backend.clone())
})
}
}
/// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`, /// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`,
/// providing access to stop conditions that control chat completion behavior. /// providing access to stop conditions that control chat completion behavior.
impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest { impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
...@@ -149,12 +201,10 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest { ...@@ -149,12 +201,10 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
} }
/// Retrieves the minimum number of tokens required in the response. /// Retrieves the minimum number of tokens required in the response.
/// /// Returns `min_tokens` Value
/// # Note /// `min_tokens` is not an OpenAI-supported parameter.
/// This method is currently a placeholder and always returns `None`
/// since `min_tokens` is not an OpenAI-supported parameter.
fn get_min_tokens(&self) -> Option<u32> { fn get_min_tokens(&self) -> Option<u32> {
None self.common.min_tokens
} }
/// Retrieves the stop conditions that terminate the chat completion response. /// Retrieves the stop conditions that terminate the chat completion response.
...@@ -175,6 +225,11 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest { ...@@ -175,6 +225,11 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
fn nvext(&self) -> Option<&NvExt> { fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref() self.nvext.as_ref()
} }
/// Get ignore_eos from CommonExt.
fn get_common_ignore_eos(&self) -> Option<bool> {
self.common.ignore_eos
}
} }
/// Implements `ValidateRequest` for `NvCreateChatCompletionRequest`, /// Implements `ValidateRequest` for `NvCreateChatCompletionRequest`,
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use validator::Validate;
/// Common extensions for OpenAI API requests that are not part of the standard OpenAI spec
/// but are commonly needed across different request types.
#[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone, Default)]
pub struct CommonExt {
/// If true, the model will ignore the end of string token and generate to max_tokens.
/// This field can also be specified in nvext, but the root-level value takes precedence.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub ignore_eos: Option<bool>,
/// The minimum number of tokens to generate.
/// This is a common parameter needed across different request types.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub min_tokens: Option<u32>,
/// Guided Decoding Options
/// If specified, the output will be a JSON object. Can be a string, an object, or null.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_json: Option<serde_json::Value>,
/// If specified, the output will follow the regex pattern. Can be a string or null.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_regex: Option<String>,
/// If specified, the output will follow the context-free grammar. Can be a string or null.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_grammar: Option<String>,
/// If specified, the output will be exactly one of the choices.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_choice: Option<Vec<String>>,
/// If specified, the backend to use for guided decoding, can be backends like xgrammar or custom guided decoding backend
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_decoding_backend: Option<String>,
}
impl CommonExt {
pub fn builder() -> CommonExtBuilder {
CommonExtBuilder::default()
}
}
/// Trait for types that provide CommonExt fields
pub trait CommonExtProvider {
/// Get a reference to the CommonExt struct if available
fn common_ext(&self) -> Option<&CommonExt>;
/// Guided Decoding Options
fn get_guided_json(&self) -> Option<&serde_json::Value>;
fn get_guided_regex(&self) -> Option<String>;
fn get_guided_grammar(&self) -> Option<String>;
fn get_guided_choice(&self) -> Option<Vec<String>>;
fn get_guided_decoding_backend(&self) -> Option<String>;
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json;
#[test]
fn test_common_ext_builder_default() {
let common_ext = CommonExt::builder().build().unwrap();
assert_eq!(common_ext.ignore_eos, None);
assert_eq!(common_ext.min_tokens, None);
assert_eq!(common_ext.guided_json, None);
assert_eq!(common_ext.guided_regex, None);
assert_eq!(common_ext.guided_grammar, None);
assert_eq!(common_ext.guided_choice, None);
assert_eq!(common_ext.guided_decoding_backend, None);
}
#[test]
fn test_common_ext_builder_with_values() {
let common_ext = CommonExt::builder()
.ignore_eos(true)
.min_tokens(10)
.guided_json(serde_json::json!({"key": "value"}))
.guided_regex("regex".to_string())
.guided_grammar("grammar".to_string())
.guided_choice(vec!["choice1".to_string(), "choice2".to_string()])
.guided_decoding_backend("backend".to_string())
.build()
.unwrap();
assert_eq!(common_ext.ignore_eos, Some(true));
assert_eq!(common_ext.min_tokens, Some(10));
assert_eq!(
common_ext.guided_json.as_ref(),
Some(&serde_json::json!({"key": "value"}))
);
assert_eq!(common_ext.guided_regex, Some("regex".to_string()));
assert_eq!(common_ext.guided_grammar, Some("grammar".to_string()));
assert_eq!(
common_ext.guided_choice,
Some(vec!["choice1".to_string(), "choice2".to_string()])
);
assert_eq!(
common_ext.guided_decoding_backend,
Some("backend".to_string())
);
}
#[test]
fn test_common_ext_fields() {
// Test that CommonExt fields can be set and retrieved correctly
let common_ext = CommonExt::builder()
.ignore_eos(false)
.min_tokens(5)
.build()
.unwrap();
assert_eq!(common_ext.ignore_eos, Some(false));
assert_eq!(common_ext.min_tokens, Some(5));
}
#[test]
fn test_validation_min_tokens() {
// Test that min_tokens with 0 is valid
let common_ext = CommonExt {
ignore_eos: None,
min_tokens: Some(0), // Should be valid (min = 0)
guided_json: None,
guided_regex: None,
guided_grammar: None,
guided_choice: None,
guided_decoding_backend: None,
};
assert!(common_ext.validate().is_ok());
}
#[test]
fn test_common_ext_neither_specified() {
// Test that neither ignore_eos nor min_tokens specified works
let common_ext = CommonExt::builder().build().unwrap();
assert_eq!(common_ext.ignore_eos, None);
assert_eq!(common_ext.min_tokens, None);
assert!(common_ext.validate().is_ok());
}
#[test]
fn test_common_ext_default() {
// Test that Default trait implementation works correctly
let common_ext = CommonExt::default();
assert_eq!(common_ext.ignore_eos, None);
assert_eq!(common_ext.min_tokens, None);
assert!(common_ext.validate().is_ok());
}
}
...@@ -22,6 +22,7 @@ use crate::engines::ValidateRequest; ...@@ -22,6 +22,7 @@ use crate::engines::ValidateRequest;
use super::{ use super::{
common::{self, SamplingOptionsProvider, StopConditionsProvider}, common::{self, SamplingOptionsProvider, StopConditionsProvider},
common_ext::{CommonExt, CommonExtProvider},
nvext::{NvExt, NvExtProvider}, nvext::{NvExt, NvExtProvider},
validate, ContentProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider, validate, ContentProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
}; };
...@@ -37,6 +38,9 @@ pub struct NvCreateCompletionRequest { ...@@ -37,6 +38,9 @@ pub struct NvCreateCompletionRequest {
#[serde(flatten)] #[serde(flatten)]
pub inner: async_openai::types::CreateCompletionRequest, pub inner: async_openai::types::CreateCompletionRequest,
#[serde(flatten)]
pub common: CommonExt,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<NvExt>, pub nvext: Option<NvExt>,
} }
...@@ -131,13 +135,56 @@ impl OpenAISamplingOptionsProvider for NvCreateCompletionRequest { ...@@ -131,13 +135,56 @@ impl OpenAISamplingOptionsProvider for NvCreateCompletionRequest {
} }
} }
impl CommonExtProvider for NvCreateCompletionRequest {
fn common_ext(&self) -> Option<&CommonExt> {
Some(&self.common)
}
/// Guided Decoding Options
fn get_guided_json(&self) -> Option<&serde_json::Value> {
self.common
.guided_json
.as_ref()
.or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_json.as_ref()))
}
fn get_guided_regex(&self) -> Option<String> {
self.common
.guided_regex
.clone()
.or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_regex.clone()))
}
fn get_guided_grammar(&self) -> Option<String> {
self.common
.guided_grammar
.clone()
.or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_grammar.clone()))
}
fn get_guided_choice(&self) -> Option<Vec<String>> {
self.common
.guided_choice
.clone()
.or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_choice.clone()))
}
fn get_guided_decoding_backend(&self) -> Option<String> {
self.common.guided_decoding_backend.clone().or_else(|| {
self.nvext
.as_ref()
.and_then(|nv| nv.guided_decoding_backend.clone())
})
}
}
impl OpenAIStopConditionsProvider for NvCreateCompletionRequest { impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
fn get_max_tokens(&self) -> Option<u32> { fn get_max_tokens(&self) -> Option<u32> {
self.inner.max_tokens self.inner.max_tokens
} }
fn get_min_tokens(&self) -> Option<u32> { fn get_min_tokens(&self) -> Option<u32> {
None self.common.min_tokens
} }
fn get_stop(&self) -> Option<Vec<String>> { fn get_stop(&self) -> Option<Vec<String>> {
...@@ -147,6 +194,10 @@ impl OpenAIStopConditionsProvider for NvCreateCompletionRequest { ...@@ -147,6 +194,10 @@ impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
fn nvext(&self) -> Option<&NvExt> { fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref() self.nvext.as_ref()
} }
fn get_common_ignore_eos(&self) -> Option<bool> {
self.common.ignore_eos
}
} }
#[derive(Builder)] #[derive(Builder)]
......
...@@ -185,6 +185,7 @@ impl TryFrom<NvCreateResponse> for NvCreateChatCompletionRequest { ...@@ -185,6 +185,7 @@ impl TryFrom<NvCreateResponse> for NvCreateChatCompletionRequest {
stream: Some(true), // Set this to Some(True) by default to aggregate stream stream: Some(true), // Set this to Some(True) by default to aggregate stream
..Default::default() ..Default::default()
}, },
common: Default::default(),
nvext: resp.nvext, nvext: resp.nvext,
}) })
} }
......
...@@ -765,6 +765,7 @@ async fn test_nv_custom_client( ...@@ -765,6 +765,7 @@ async fn test_nv_custom_client(
let request = NvCreateChatCompletionRequest { let request = NvCreateChatCompletionRequest {
inner: inner_request, inner: inner_request,
common: Default::default(),
nvext: None, nvext: None,
}; };
...@@ -802,6 +803,7 @@ async fn test_nv_custom_client( ...@@ -802,6 +803,7 @@ async fn test_nv_custom_client(
let request = NvCreateChatCompletionRequest { let request = NvCreateChatCompletionRequest {
inner: inner_request, inner: inner_request,
common: Default::default(),
nvext: None, nvext: None,
}; };
...@@ -840,6 +842,7 @@ async fn test_nv_custom_client( ...@@ -840,6 +842,7 @@ async fn test_nv_custom_client(
let request = NvCreateChatCompletionRequest { let request = NvCreateChatCompletionRequest {
inner: inner_request, inner: inner_request,
common: Default::default(),
nvext: None, nvext: None,
}; };
......
...@@ -36,7 +36,11 @@ impl CompletionSample { ...@@ -36,7 +36,11 @@ impl CompletionSample {
let inner = builder.build().unwrap(); let inner = builder.build().unwrap();
let request = NvCreateCompletionRequest { inner, nvext: None }; let request = NvCreateCompletionRequest {
inner,
common: Default::default(),
nvext: None,
};
Ok(Self { Ok(Self {
request, request,
......
...@@ -266,7 +266,11 @@ impl Request { ...@@ -266,7 +266,11 @@ impl Request {
} }
let inner = inner.build().unwrap(); let inner = inner.build().unwrap();
NvCreateChatCompletionRequest { inner, nvext: None } NvCreateChatCompletionRequest {
inner,
common: Default::default(),
nvext: None,
}
} }
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_llm::protocols::{
common::StopConditionsProvider,
openai::{
chat_completions::NvCreateChatCompletionRequest,
common_ext::{CommonExt, CommonExtProvider},
completions::NvCreateCompletionRequest,
nvext::NvExt,
},
};
#[test]
fn test_chat_completions_ignore_eos_from_common() {
// Test that ignore_eos can be specified at root level
let json_str = r#"{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"ignore_eos": true,
"min_tokens": 100
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(request.common.ignore_eos, Some(true));
assert_eq!(request.common.min_tokens, Some(100));
}
#[test]
fn test_chat_completions_guided_decoding_from_common() {
// Test that guided_json can be specified at root level
let json_str = r#"{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"guided_json": {"key": "value"}
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(
request.common.guided_json,
Some(serde_json::json!({"key": "value"}))
);
assert_eq!(
request.get_guided_json(),
Some(&serde_json::json!({"key": "value"}))
);
// Test guided_regex can be specified at root level
let json_str = r#"{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"guided_regex": "*"
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(request.common.guided_regex, Some("*".to_string()));
assert_eq!(request.get_guided_regex(), Some("*".to_string()));
// Test guided_grammar can be specified at root level
let json_str = r#"{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"guided_grammar": "::=[1-9]"
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(request.common.guided_grammar, Some("::=[1-9]".to_string()));
assert_eq!(request.get_guided_grammar(), Some("::=[1-9]".to_string()));
// Test guided_choice can be specified at root level
let json_str = r#"{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"guided_choice": ["choice1", "choice2"]
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(
request.common.guided_choice,
Some(vec!["choice1".to_string(), "choice2".to_string()])
);
assert_eq!(
request.get_guided_choice(),
Some(vec!["choice1".to_string(), "choice2".to_string()])
);
// Test guided_decoding_backend can be specified at root level
let json_str = r#"{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"guided_decoding_backend": "backend"
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(
request.common.guided_decoding_backend,
Some("backend".to_string())
);
assert_eq!(
request.get_guided_decoding_backend(),
Some("backend".to_string())
);
}
#[test]
fn test_chat_completions_common_overrides_nvext() {
// Test that root-level ignore_eos overrides nvext ignore_eos
let json_str = r#"{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"ignore_eos": false,
"guided_regex": ".*",
"min_tokens": 50,
"nvext": {
"ignore_eos": true,
"guided_regex": "./*"
}
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(request.common.ignore_eos, Some(false));
assert_eq!(request.common.guided_regex, Some(".*".to_string()));
assert_eq!(
request.nvext.as_ref().and_then(|nv| nv.ignore_eos),
Some(true)
);
assert_eq!(request.get_guided_regex(), Some(".*".to_string())); // common value takes precedence
// Verify precedence through stop conditions extraction
let stop_conditions = request.extract_stop_conditions().unwrap();
assert_eq!(stop_conditions.ignore_eos, Some(false)); // common value takes precedence
assert_eq!(stop_conditions.min_tokens, Some(50));
}
#[test]
fn test_chat_completions_backward_compatibility() {
// Test backward compatibility - ignore_eos and guided_json only in nvext
let json_str = r#"{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"nvext": {
"ignore_eos": true,
"guided_json": {"key": "value"}
}
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(request.common.ignore_eos, None);
assert_eq!(request.common.guided_json, None);
assert_eq!(
request.nvext.as_ref().and_then(|nv| nv.ignore_eos),
Some(true)
);
assert_eq!(
request
.nvext
.as_ref()
.and_then(|nv| nv.guided_json.as_ref()),
Some(&serde_json::json!({"key": "value"}))
);
assert_eq!(
request.get_guided_json(),
Some(&serde_json::json!({"key": "value"}))
);
// Verify through stop conditions extraction
let stop_conditions = request.extract_stop_conditions().unwrap();
assert_eq!(stop_conditions.ignore_eos, Some(true));
assert_eq!(stop_conditions.min_tokens, None);
}
#[test]
fn test_completions_ignore_eos_from_common() {
// Test that ignore_eos can be specified at root level for completions
let json_str = r#"{
"model": "test-model",
"prompt": "Hello world",
"ignore_eos": true,
"min_tokens": 200
}"#;
let request: NvCreateCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(request.common.ignore_eos, Some(true));
assert_eq!(request.common.min_tokens, Some(200));
// Verify through stop conditions extraction
let stop_conditions = request.extract_stop_conditions().unwrap();
assert_eq!(stop_conditions.ignore_eos, Some(true));
assert_eq!(stop_conditions.min_tokens, Some(200));
}
#[test]
fn test_completions_common_overrides_nvext() {
// Test that root-level ignore_eos overrides nvext ignore_eos for completions
let json_str = r#"{
"model": "test-model",
"prompt": "Hello world",
"ignore_eos": false,
"min_tokens": 75,
"nvext": {
"ignore_eos": true
}
}"#;
let request: NvCreateCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(request.common.ignore_eos, Some(false));
assert_eq!(
request.nvext.as_ref().and_then(|nv| nv.ignore_eos),
Some(true)
);
// Verify precedence through stop conditions extraction
let stop_conditions = request.extract_stop_conditions().unwrap();
assert_eq!(stop_conditions.ignore_eos, Some(false)); // common value takes precedence
assert_eq!(stop_conditions.min_tokens, Some(75));
}
#[test]
fn test_serialization_preserves_structure() {
// Test that serialization preserves the flattened structure
let request = NvCreateChatCompletionRequest {
inner: async_openai::types::CreateChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"Hello".to_string(),
),
..Default::default()
},
)],
..Default::default()
},
common: CommonExt {
ignore_eos: Some(true),
min_tokens: Some(100),
..Default::default()
},
nvext: Some(NvExt {
ignore_eos: Some(false),
..Default::default()
}),
};
let json = serde_json::to_value(&request).unwrap();
// Check that fields are at the expected levels
assert_eq!(json["model"], "test-model");
assert_eq!(json["ignore_eos"], true); // From common (flattened)
assert_eq!(json["min_tokens"], 100); // From common (flattened)
assert_eq!(json["nvext"]["ignore_eos"], false); // From nvext
// Verify precedence through stop conditions extraction
let stop_conditions = request.extract_stop_conditions().unwrap();
assert_eq!(stop_conditions.ignore_eos, Some(true)); // common overrides nvext
assert_eq!(stop_conditions.min_tokens, Some(100));
}
#[test]
fn test_min_tokens_only_at_root_level() {
// Test that min_tokens is only available at root level, not in nvext
let json_str = r#"{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"min_tokens": 150
}"#;
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
assert_eq!(request.common.min_tokens, Some(150));
// Verify through stop conditions extraction
let stop_conditions = request.extract_stop_conditions().unwrap();
assert_eq!(stop_conditions.min_tokens, Some(150));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment