// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use derive_builder::Builder; use serde::{Deserialize, Serialize}; use utoipa::ToSchema; use validator::{Validate, ValidationError}; pub use crate::protocols::common::timing::TimingInfo; pub trait NvExtProvider { fn nvext(&self) -> Option<&NvExt>; fn raw_prompt(&self) -> Option; } /// Worker ID information for disaggregated serving #[derive(ToSchema, Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct WorkerIdInfo { /// The prefill worker ID that processed this request #[serde(skip_serializing_if = "Option::is_none")] pub prefill_worker_id: Option, /// The decode worker ID that processed this request #[serde(skip_serializing_if = "Option::is_none")] pub decode_worker_id: Option, } /// NVIDIA LLM response extensions #[derive(ToSchema, Serialize, Deserialize, Debug, Clone)] pub struct NvExtResponse { /// Worker ID information (prefill and decode worker IDs) #[serde(skip_serializing_if = "Option::is_none")] pub worker_id: Option, /// Per-request timing information /// Populated when client requests `extra_fields: ["timing"]` #[serde(skip_serializing_if = "Option::is_none")] pub timing: Option, /// Token IDs for GAIE Stage 1 query-only mode /// Contains the tokenized prompt for reuse in Stage 2 #[serde(skip_serializing_if = "Option::is_none")] pub token_ids: Option>, } /// NVIDIA LLM extensions to the OpenAI API #[derive(ToSchema, Serialize, Deserialize, Builder, Validate, Debug, Clone)] #[validate(schema(function = "validate_nv_ext"))] pub struct NvExt { /// If true, sampling will be forced to be greedy. /// The backend is responsible for selecting the correct backend-specific options to /// implement this. #[serde(default, skip_serializing_if = "Option::is_none")] #[builder(default, setter(strip_option))] pub greed_sampling: Option, /// If true, the preproessor will try to bypass the prompt template and pass the prompt directly to /// to the tokenizer. #[serde(default, skip_serializing_if = "Option::is_none")] #[builder(default, setter(strip_option))] pub use_raw_prompt: Option, /// Annotations /// User requests triggers which result in the request issue back out-of-band information in the SSE /// stream using the `event:` field. #[serde(default, skip_serializing_if = "Option::is_none")] #[builder(default, setter(strip_option))] pub annotations: Option>, /// Targeted backend instance ID for the request /// If set, the request will be routed to backend instance with the given ID. /// If not set, the request will be routed to the best matching instance. #[builder(default, setter(strip_option))] #[serde(default, skip_serializing_if = "Option::is_none")] pub backend_instance_id: Option, /// Pre-tokenized data to use instead of tokenizing the prompt /// If provided along with backend_instance_id, these tokens will be used directly /// and tokenization will be skipped. #[builder(default, setter(strip_option))] #[serde(default, skip_serializing_if = "Option::is_none")] pub token_data: Option>, /// Maximum number of thinking tokens allowed /// NOTE: Currently passed through to backends as a no-op for future implementation #[serde(default, skip_serializing_if = "Option::is_none")] #[builder(default, setter(strip_option))] pub max_thinking_tokens: Option, /// Extra fields to be included in the response's nvext /// This is a list of field names that should be populated in the response /// Supported fields: "worker_id", "timing", which has a 1:1 mapping with the NvExtResponse names #[serde(default, skip_serializing_if = "Option::is_none")] #[builder(default, setter(strip_option))] pub extra_fields: Option>, /// Targeted prefill worker ID for disaggregated serving (GAIE Stage 2) /// When set, the request will be routed to this specific prefill worker. #[builder(default, setter(strip_option))] #[serde(default, skip_serializing_if = "Option::is_none")] pub prefill_worker_id: Option, /// Targeted decode worker ID for disaggregated serving (GAIE Stage 2) /// When set, the request will be routed to this specific decode worker. #[builder(default, setter(strip_option))] #[serde(default, skip_serializing_if = "Option::is_none")] pub decode_worker_id: Option, /// Controls whether the router should manage local bookkeeping (add_request, /// mark_prefill_completed, free) for this request. /// /// - `None` or `true`: Router handles bookkeeping locally (default behavior) /// - `false`: External caller (e.g., GAIE sidecar) handles bookkeeping via C FFI /// /// Set to `false` for GAIE Stage 2 when the EPP/sidecar manages request lifecycle. #[builder(default, setter(strip_option))] #[serde(default, skip_serializing_if = "Option::is_none")] pub enable_local_updates: Option, } impl Default for NvExt { fn default() -> Self { NvExt::builder().build().unwrap() } } impl NvExt { pub fn builder() -> NvExtBuilder { NvExtBuilder::default() } } fn validate_nv_ext(_nv_ext: &NvExt) -> Result<(), ValidationError> { Ok(()) } impl NvExtBuilder { pub fn add_annotation(&mut self, annotation: impl Into) -> &mut Self { self.annotations .get_or_insert_with(|| Some(vec![])) .as_mut() .expect("stop should always be Some(Vec)") .push(annotation.into()); self } } #[cfg(test)] mod tests { use validator::Validate; use super::*; // Test default builder configuration #[test] fn test_nv_ext_builder_default() { let nv_ext = NvExt::builder().build().unwrap(); assert_eq!(nv_ext.greed_sampling, None); assert_eq!(nv_ext.use_raw_prompt, None); assert_eq!(nv_ext.annotations, None); assert_eq!(nv_ext.backend_instance_id, None); assert_eq!(nv_ext.token_data, None); assert_eq!(nv_ext.max_thinking_tokens, None); assert_eq!(nv_ext.extra_fields, None); assert_eq!(nv_ext.prefill_worker_id, None); assert_eq!(nv_ext.decode_worker_id, None); assert_eq!(nv_ext.enable_local_updates, None); } // Test valid builder configurations #[test] fn test_nv_ext_builder_custom() { let nv_ext = NvExt::builder() .greed_sampling(true) .use_raw_prompt(true) .backend_instance_id(42) .token_data(vec![1, 2, 3, 4]) .max_thinking_tokens(1024) .extra_fields(vec!["worker_id".to_string()]) .build() .unwrap(); assert_eq!(nv_ext.greed_sampling, Some(true)); assert_eq!(nv_ext.use_raw_prompt, Some(true)); assert_eq!(nv_ext.backend_instance_id, Some(42)); assert_eq!(nv_ext.token_data, Some(vec![1, 2, 3, 4])); assert_eq!(nv_ext.max_thinking_tokens, Some(1024)); assert_eq!(nv_ext.extra_fields, Some(vec!["worker_id".to_string()])); // Validate the built struct assert!(nv_ext.validate().is_ok()); } // Test GAIE Stage 2 disaggregated worker IDs #[test] fn test_nv_ext_disagg_worker_ids() { let nv_ext = NvExt::builder() .prefill_worker_id(100) .decode_worker_id(200) .build() .unwrap(); assert_eq!(nv_ext.prefill_worker_id, Some(100)); assert_eq!(nv_ext.decode_worker_id, Some(200)); assert!(nv_ext.validate().is_ok()); } }