// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use dynamo_runtime::protocols::annotated::AnnotationsProvider; use serde::{Deserialize, Serialize}; use utoipa::ToSchema; use validator::Validate; use crate::engines::ValidateRequest; use crate::preprocessor::media::MediaDecoder; use super::{ OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider, common_ext::{CommonExt, CommonExtProvider}, nvext::NvExt, nvext::NvExtProvider, tools, validate, }; pub mod aggregator; mod delta; pub mod jail; pub use aggregator::DeltaAggregator; pub use delta::DeltaGenerator; /// A request structure for creating a chat completion, extending OpenAI's /// `CreateChatCompletionRequest` with [`NvExt`] extensions and common fields. /// /// # Fields /// - `inner`: The base OpenAI chat completion request, embedded using `serde(flatten)`. /// - `common`: Common extension fields (ignore_eos, min_tokens) at root level, embedded using `serde(flatten)`. /// - `nvext`: The optional NVIDIA extension field. See [`NvExt`] for more details. /// Note: If ignore_eos is specified in both common and nvext, the common (root-level) value takes precedence. #[derive(ToSchema, Serialize, Deserialize, Validate, Debug, Clone)] pub struct NvCreateChatCompletionRequest { #[serde(flatten)] #[schema(value_type = Object)] pub inner: dynamo_protocols::types::CreateChatCompletionRequest, #[serde(flatten, default)] pub common: CommonExt, #[serde(skip_serializing_if = "Option::is_none")] pub nvext: Option, /// Extra args to pass to the chat template rendering context /// Also accepts "chat_template_kwargs" as an alias for compatibility #[serde( default, skip_serializing_if = "Option::is_none", alias = "chat_template_kwargs" )] pub chat_template_args: Option>, /// Runtime media decoding parameters. /// When provided, these override the MDC defaults /// Example: `{"video": {"num_frames": 16}}` #[serde(default, skip_serializing_if = "Option::is_none")] pub media_io_kwargs: Option, /// Catch-all for unsupported fields - checked during validation #[serde(flatten, default, skip_serializing)] pub unsupported_fields: std::collections::HashMap, } /// A response structure for unary chat completion responses, embedding OpenAI's /// `CreateChatCompletionResponse` with optional NVIDIA extension metadata. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct NvCreateChatCompletionResponse { #[serde(flatten)] pub inner: dynamo_protocols::types::CreateChatCompletionResponse, #[serde(skip_serializing_if = "Option::is_none")] pub nvext: Option, } /// A response structure for streamed chat completions, embedding OpenAI's /// `CreateChatCompletionStreamResponse` with optional NVIDIA extension metadata. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct NvCreateChatCompletionStreamResponse { #[serde(flatten)] pub inner: dynamo_protocols::types::CreateChatCompletionStreamResponse, #[serde(skip_serializing_if = "Option::is_none")] pub nvext: Option, } /// Implements `NvExtProvider` for `NvCreateChatCompletionRequest`, /// providing access to NVIDIA-specific extensions. impl NvExtProvider for NvCreateChatCompletionRequest { /// Returns a reference to the optional `NvExt` extension, if available. fn nvext(&self) -> Option<&NvExt> { self.nvext.as_ref() } /// Returns `None`, as raw prompt extraction is not implemented. fn raw_prompt(&self) -> Option { None } } /// Implements `AnnotationsProvider` for `NvCreateChatCompletionRequest`, /// enabling retrieval and management of request annotations. impl AnnotationsProvider for NvCreateChatCompletionRequest { /// Retrieves the list of annotations from `NvExt`, if present. fn annotations(&self) -> Option> { self.nvext .as_ref() .and_then(|nvext| nvext.annotations.clone()) } /// Checks whether a specific annotation exists in the request. /// /// # Arguments /// * `annotation` - A string slice representing the annotation to check. /// /// # Returns /// `true` if the annotation exists, `false` otherwise. fn has_annotation(&self, annotation: &str) -> bool { self.nvext .as_ref() .and_then(|nvext| nvext.annotations.as_ref()) .map(|annotations| annotations.contains(&annotation.to_string())) .unwrap_or(false) } } /// Implements `OpenAISamplingOptionsProvider` for `NvCreateChatCompletionRequest`, /// exposing OpenAI's sampling parameters for chat completion. impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest { /// Retrieves the temperature parameter for sampling, if set. fn get_temperature(&self) -> Option { self.inner.temperature } /// Retrieves the top-p (nucleus sampling) parameter, if set. fn get_top_p(&self) -> Option { self.inner.top_p } /// Retrieves the frequency penalty parameter, if set. fn get_frequency_penalty(&self) -> Option { self.inner.frequency_penalty } /// Retrieves the presence penalty parameter, if set. fn get_presence_penalty(&self) -> Option { self.inner.presence_penalty } /// Returns a reference to the optional `NvExt` extension, if available. fn nvext(&self) -> Option<&NvExt> { self.nvext.as_ref() } /// Retrieves the seed value for random number generation, if set. fn get_seed(&self) -> Option { self.inner.seed } /// Retrieves the number of completions to generate for each prompt, if set. fn get_n(&self) -> Option { self.inner.n } /// Retrieves the best_of parameter, if set. fn get_best_of(&self) -> Option { None // Not supported in chat completions } } /// Implements `CommonExtProvider` for `NvCreateChatCompletionRequest`, /// providing access to common extension fields. impl CommonExtProvider for NvCreateChatCompletionRequest { /// Returns a reference to the CommonExt struct. fn common_ext(&self) -> Option<&CommonExt> { Some(&self.common) } /// Guided Decoding Options fn get_guided_json(&self) -> Option { if let Some(value) = self.common.guided_json.clone() { return Some(value); } // 1) Tool-call guided decoding (highest precedence after explicit guided_json) if let (Some(tool_choice), Some(tools)) = (self.inner.tool_choice.as_ref(), self.inner.tools.as_deref()) { match tools::get_json_schema_from_tools(Some(tool_choice), Some(tools)) { Ok(Some(schema)) => return Some(schema), Ok(None) => {} Err(err) => { tracing::warn!( error = %err, "failed to derive guided_json from tool_choice" ); } } } // 2) OpenAI `response_format` (applies to assistant content, not tool calls) if let Some(response_format) = self.inner.response_format.as_ref() { use dynamo_protocols::types::ResponseFormat; match response_format { ResponseFormat::Text => {} ResponseFormat::JsonObject => { // Minimal JSON Schema for "any JSON object" return Some(serde_json::json!({ "type": "object" })); } ResponseFormat::JsonSchema { json_schema } => { // validate_response_format ensures schema is present when type=json_schema if let Some(schema) = json_schema.schema.clone() { return Some(schema); } } } } None } fn get_guided_regex(&self) -> Option { self.common.guided_regex.clone() } fn get_guided_grammar(&self) -> Option { self.common.guided_grammar.clone() } fn get_guided_choice(&self) -> Option> { self.common.guided_choice.clone() } fn get_guided_decoding_backend(&self) -> Option { self.common.guided_decoding_backend.clone() } fn get_guided_whitespace_pattern(&self) -> Option { self.common.guided_whitespace_pattern.clone() } fn get_top_k(&self) -> Option { self.common.top_k } fn get_min_p(&self) -> Option { self.common.min_p } fn get_repetition_penalty(&self) -> Option { self.common.repetition_penalty } fn get_include_stop_str_in_output(&self) -> Option { self.common.include_stop_str_in_output } fn get_skip_special_tokens(&self) -> Option { self.common.skip_special_tokens } } /// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`, /// providing access to stop conditions that control chat completion behavior. impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest { /// Retrieves the maximum number of tokens allowed in the response. #[allow(deprecated)] fn get_max_tokens(&self) -> Option { self.inner.max_completion_tokens.or(self.inner.max_tokens) } /// Retrieves the minimum number of tokens required in the response. /// Returns `min_tokens` Value /// `min_tokens` is not an OpenAI-supported parameter. fn get_min_tokens(&self) -> Option { self.common.min_tokens } /// Retrieves the stop conditions that terminate the chat completion response. /// /// Converts OpenAI's `Stop` enum to a `Vec`, normalizing the representation. /// /// # Returns /// * `Some(Vec)` if stop conditions are set. /// * `None` if no stop conditions are defined. fn get_stop(&self) -> Option> { self.inner.stop.as_ref().map(|stop| match stop { dynamo_protocols::types::Stop::String(s) => vec![s.clone()], dynamo_protocols::types::Stop::StringArray(arr) => arr.clone(), }) } /// Returns a reference to the optional `NvExt` extension, if available. fn nvext(&self) -> Option<&NvExt> { self.nvext.as_ref() } /// Get ignore_eos from CommonExt. fn get_common_ignore_eos(&self) -> Option { self.common.ignore_eos } /// Get the effective ignore_eos value from CommonExt. fn get_ignore_eos(&self) -> Option { self.common.ignore_eos } } impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest { fn get_logprobs(&self) -> Option { match self.inner.logprobs { Some(true) => match self.inner.top_logprobs { Some(top_logprobs) => Some(top_logprobs as u32), None => Some(1_u32), }, Some(false) => None, None => None, } } fn get_prompt_logprobs(&self) -> Option { None } fn get_skip_special_tokens(&self) -> Option { CommonExtProvider::get_skip_special_tokens(self) } fn get_formatted_prompt(&self) -> Option { None } } /// Implements `ValidateRequest` for `NvCreateChatCompletionRequest`, /// allowing us to validate the data. impl ValidateRequest for NvCreateChatCompletionRequest { fn validate(&self) -> Result<(), anyhow::Error> { validate::validate_no_unsupported_fields(&self.unsupported_fields)?; validate::validate_messages(&self.inner.messages)?; validate::validate_model(&self.inner.model)?; // none for store validate::validate_reasoning_effort(&self.inner.reasoning_effort)?; // none for metadata validate::validate_frequency_penalty(self.inner.frequency_penalty)?; validate::validate_logit_bias(&self.inner.logit_bias)?; // none for logprobs validate::validate_top_logprobs(self.inner.top_logprobs)?; // validate::validate_max_tokens(self.inner.max_tokens)?; // warning depricated field validate::validate_max_completion_tokens(self.inner.max_completion_tokens)?; validate::validate_n(self.inner.n)?; // none for modalities // none for prediction // none for audio validate::validate_presence_penalty(self.inner.presence_penalty)?; validate::validate_response_format(&self.inner.response_format)?; // none for seed validate::validate_service_tier(&self.inner.service_tier)?; validate::validate_stop(&self.inner.stop)?; // none for stream // none for stream_options validate::validate_temperature(self.inner.temperature)?; validate::validate_top_p(self.inner.top_p)?; validate::validate_tools(&self.inner.tools.as_deref())?; // none for tool_choice // none for parallel_tool_calls validate::validate_user(self.inner.user.as_deref())?; // none for function call // none for functions // Common Ext validate::validate_repetition_penalty(self.get_repetition_penalty())?; validate::validate_min_p(self.get_min_p())?; validate::validate_top_k(self.get_top_k())?; // Cross-field validation validate::validate_n_with_temperature(self.inner.n, self.inner.temperature)?; Ok(()) } } #[cfg(test)] mod tests { use super::*; use crate::protocols::common::OutputOptionsProvider; use serde_json::json; #[test] fn test_skip_special_tokens_none() { let json_str = json!({ "model": "test-model", "messages": [ {"role": "user", "content": "Hello"} ] }); let request: NvCreateChatCompletionRequest = serde_json::from_value(json_str).expect("Failed to deserialize request"); assert_eq!(request.common.skip_special_tokens, None); let output_options = request .extract_output_options() .expect("Failed to extract output options"); assert_eq!(output_options.skip_special_tokens, None); } #[test] fn test_skip_special_tokens_propagates() { for skip_value in [true, false] { let json_str = json!({ "model": "test-model", "messages": [ {"role": "user", "content": "Hello"} ], "skip_special_tokens": skip_value }); let request: NvCreateChatCompletionRequest = serde_json::from_value(json_str).expect("Failed to deserialize request"); let output_options = request .extract_output_options() .expect("Failed to extract output options"); assert_eq!(output_options.skip_special_tokens, Some(skip_value)); } } }