// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use dynamo_runtime::protocols::annotated::AnnotationsProvider; use serde::{Deserialize, Serialize}; use validator::Validate; use crate::engines::ValidateRequest; use super::{ OpenAIOutputOptionsProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider, common_ext::{ CommonExt, CommonExtProvider, choose_with_deprecation, emit_nvext_deprecation_warning, }, nvext::NvExt, nvext::NvExtProvider, validate, }; pub mod aggregator; mod delta; pub use aggregator::DeltaAggregator; pub use delta::DeltaGenerator; /// A request structure for creating a chat completion, extending OpenAI's /// `CreateChatCompletionRequest` with [`NvExt`] extensions and common fields. /// /// # Fields /// - `inner`: The base OpenAI chat completion request, embedded using `serde(flatten)`. /// - `common`: Common extension fields (ignore_eos, min_tokens) at root level, embedded using `serde(flatten)`. /// - `nvext`: The optional NVIDIA extension field. See [`NvExt`] for more details. /// Note: If ignore_eos is specified in both common and nvext, the common (root-level) value takes precedence. #[derive(Serialize, Deserialize, Validate, Debug, Clone)] pub struct NvCreateChatCompletionRequest { #[serde(flatten)] pub inner: dynamo_async_openai::types::CreateChatCompletionRequest, #[serde(flatten, default)] pub common: CommonExt, #[serde(skip_serializing_if = "Option::is_none")] pub nvext: Option, /// Extra args to pass to the chat template rendering context #[serde(default, skip_serializing_if = "Option::is_none")] pub chat_template_args: Option>, } /// A response structure for unary chat completion responses, embedding OpenAI's /// `CreateChatCompletionResponse`. /// /// # Fields /// - `inner`: The base OpenAI unary chat completion response, embedded /// using `serde(flatten)`. pub type NvCreateChatCompletionResponse = dynamo_async_openai::types::CreateChatCompletionResponse; /// A response structure for streamed chat completions, embedding OpenAI's /// `CreateChatCompletionStreamResponse`. /// /// # Fields /// - `inner`: The base OpenAI streaming chat completion response, embedded /// using `serde(flatten)`. pub type NvCreateChatCompletionStreamResponse = dynamo_async_openai::types::CreateChatCompletionStreamResponse; /// Implements `NvExtProvider` for `NvCreateChatCompletionRequest`, /// providing access to NVIDIA-specific extensions. impl NvExtProvider for NvCreateChatCompletionRequest { /// Returns a reference to the optional `NvExt` extension, if available. fn nvext(&self) -> Option<&NvExt> { self.nvext.as_ref() } /// Returns `None`, as raw prompt extraction is not implemented. fn raw_prompt(&self) -> Option { None } } /// Implements `AnnotationsProvider` for `NvCreateChatCompletionRequest`, /// enabling retrieval and management of request annotations. impl AnnotationsProvider for NvCreateChatCompletionRequest { /// Retrieves the list of annotations from `NvExt`, if present. fn annotations(&self) -> Option> { self.nvext .as_ref() .and_then(|nvext| nvext.annotations.clone()) } /// Checks whether a specific annotation exists in the request. /// /// # Arguments /// * `annotation` - A string slice representing the annotation to check. /// /// # Returns /// `true` if the annotation exists, `false` otherwise. fn has_annotation(&self, annotation: &str) -> bool { self.nvext .as_ref() .and_then(|nvext| nvext.annotations.as_ref()) .map(|annotations| annotations.contains(&annotation.to_string())) .unwrap_or(false) } } /// Implements `OpenAISamplingOptionsProvider` for `NvCreateChatCompletionRequest`, /// exposing OpenAI's sampling parameters for chat completion. impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest { /// Retrieves the temperature parameter for sampling, if set. fn get_temperature(&self) -> Option { self.inner.temperature } /// Retrieves the top-p (nucleus sampling) parameter, if set. fn get_top_p(&self) -> Option { self.inner.top_p } /// Retrieves the frequency penalty parameter, if set. fn get_frequency_penalty(&self) -> Option { self.inner.frequency_penalty } /// Retrieves the presence penalty parameter, if set. fn get_presence_penalty(&self) -> Option { self.inner.presence_penalty } /// Returns a reference to the optional `NvExt` extension, if available. fn nvext(&self) -> Option<&NvExt> { self.nvext.as_ref() } /// Retrieves the seed value for random number generation, if set. fn get_seed(&self) -> Option { self.inner.seed } /// Retrieves the number of completions to generate for each prompt, if set. fn get_n(&self) -> Option { self.inner.n } /// Retrieves the best_of parameter, if set. fn get_best_of(&self) -> Option { None // Not supported in chat completions } } /// Implements `CommonExtProvider` for `NvCreateChatCompletionRequest`, /// providing access to common extension fields. impl CommonExtProvider for NvCreateChatCompletionRequest { /// Returns a reference to the CommonExt struct. fn common_ext(&self) -> Option<&CommonExt> { Some(&self.common) } /// Guided Decoding Options fn get_guided_json(&self) -> Option<&serde_json::Value> { // Note: This one needs special handling since it returns a reference if let Some(nvext) = &self.nvext && nvext.guided_json.is_some() { emit_nvext_deprecation_warning("guided_json", true, self.common.guided_json.is_some()); } self.common .guided_json .as_ref() .or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_json.as_ref())) } fn get_guided_regex(&self) -> Option { choose_with_deprecation( "guided_regex", self.common.guided_regex.as_ref(), self.nvext.as_ref().and_then(|nv| nv.guided_regex.as_ref()), ) } fn get_guided_grammar(&self) -> Option { choose_with_deprecation( "guided_grammar", self.common.guided_grammar.as_ref(), self.nvext .as_ref() .and_then(|nv| nv.guided_grammar.as_ref()), ) } fn get_guided_choice(&self) -> Option> { choose_with_deprecation( "guided_choice", self.common.guided_choice.as_ref(), self.nvext.as_ref().and_then(|nv| nv.guided_choice.as_ref()), ) } fn get_guided_decoding_backend(&self) -> Option { choose_with_deprecation( "guided_decoding_backend", self.common.guided_decoding_backend.as_ref(), self.nvext .as_ref() .and_then(|nv| nv.guided_decoding_backend.as_ref()), ) } fn get_top_k(&self) -> Option { choose_with_deprecation( "top_k", self.common.top_k.as_ref(), self.nvext.as_ref().and_then(|nv| nv.top_k.as_ref()), ) } fn get_min_p(&self) -> Option { choose_with_deprecation( "min_p", self.common.min_p.as_ref(), self.nvext.as_ref().and_then(|nv| nv.min_p.as_ref()), ) } fn get_repetition_penalty(&self) -> Option { choose_with_deprecation( "repetition_penalty", self.common.repetition_penalty.as_ref(), self.nvext .as_ref() .and_then(|nv| nv.repetition_penalty.as_ref()), ) } fn get_include_stop_str_in_output(&self) -> Option { self.common.include_stop_str_in_output } } /// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`, /// providing access to stop conditions that control chat completion behavior. impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest { /// Retrieves the maximum number of tokens allowed in the response. #[allow(deprecated)] fn get_max_tokens(&self) -> Option { self.inner.max_completion_tokens.or(self.inner.max_tokens) } /// Retrieves the minimum number of tokens required in the response. /// Returns `min_tokens` Value /// `min_tokens` is not an OpenAI-supported parameter. fn get_min_tokens(&self) -> Option { self.common.min_tokens } /// Retrieves the stop conditions that terminate the chat completion response. /// /// Converts OpenAI's `Stop` enum to a `Vec`, normalizing the representation. /// /// # Returns /// * `Some(Vec)` if stop conditions are set. /// * `None` if no stop conditions are defined. fn get_stop(&self) -> Option> { self.inner.stop.as_ref().map(|stop| match stop { dynamo_async_openai::types::Stop::String(s) => vec![s.clone()], dynamo_async_openai::types::Stop::StringArray(arr) => arr.clone(), }) } /// Returns a reference to the optional `NvExt` extension, if available. fn nvext(&self) -> Option<&NvExt> { self.nvext.as_ref() } /// Get ignore_eos from CommonExt. fn get_common_ignore_eos(&self) -> Option { self.common.ignore_eos } /// Get the effective ignore_eos value, considering both CommonExt and NvExt. /// CommonExt (root-level) takes precedence over NvExt. fn get_ignore_eos(&self) -> Option { choose_with_deprecation( "ignore_eos", self.get_common_ignore_eos().as_ref(), NvExtProvider::nvext(self).and_then(|nv| nv.ignore_eos.as_ref()), ) } } impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest { fn get_logprobs(&self) -> Option { match self.inner.logprobs { Some(true) => match self.inner.top_logprobs { Some(top_logprobs) => Some(top_logprobs as u32), None => Some(1_u32), }, Some(false) => None, None => None, } } fn get_prompt_logprobs(&self) -> Option { None } fn get_skip_special_tokens(&self) -> Option { None } fn get_formatted_prompt(&self) -> Option { None } } /// Implements `ValidateRequest` for `NvCreateChatCompletionRequest`, /// allowing us to validate the data. impl ValidateRequest for NvCreateChatCompletionRequest { fn validate(&self) -> Result<(), anyhow::Error> { validate::validate_messages(&self.inner.messages)?; validate::validate_model(&self.inner.model)?; // none for store validate::validate_reasoning_effort(&self.inner.reasoning_effort)?; validate::validate_metadata(&self.inner.metadata)?; validate::validate_frequency_penalty(self.inner.frequency_penalty)?; validate::validate_logit_bias(&self.inner.logit_bias)?; // none for logprobs validate::validate_top_logprobs(self.inner.top_logprobs)?; // validate::validate_max_tokens(self.inner.max_tokens)?; // warning depricated field validate::validate_max_completion_tokens(self.inner.max_completion_tokens)?; validate::validate_n(self.inner.n)?; // none for modalities // none for prediction // none for audio validate::validate_presence_penalty(self.inner.presence_penalty)?; // none for response_format // none for seed validate::validate_service_tier(&self.inner.service_tier)?; validate::validate_stop(&self.inner.stop)?; // none for stream // none for stream_options validate::validate_temperature(self.inner.temperature)?; validate::validate_top_p(self.inner.top_p)?; validate::validate_tools(&self.inner.tools.as_deref())?; // none for tool_choice // none for parallel_tool_calls validate::validate_user(self.inner.user.as_deref())?; // none for function call // none for functions // Common Ext validate::validate_repetition_penalty(self.get_repetition_penalty())?; Ok(()) } }