use std::collections::HashMap; use serde::{Deserialize, Serialize}; use serde_json::Value; use validator::Validate; use super::{ common::{default_true, GenerationRequest, InputIds}, sampling_params::SamplingParams, }; use crate::protocols::validated::Normalizable; // ============================================================================ // SGLang Generate API (native format) // ============================================================================ #[derive(Clone, Debug, Serialize, Deserialize, Validate)] #[validate(schema(function = "validate_generate_request"))] pub struct GenerateRequest { /// Text input - SGLang native format #[serde(skip_serializing_if = "Option::is_none")] pub text: Option, /// Input IDs for tokenized input #[serde(skip_serializing_if = "Option::is_none")] pub input_ids: Option, /// Input embeddings for direct embedding input /// Can be a 2D array (single request) or 3D array (batch of requests) /// Placeholder for future use #[serde(skip_serializing_if = "Option::is_none")] pub input_embeds: Option, /// Image input data /// Can be an image instance, file name, URL, or base64 encoded string /// Supports single images, lists of images, or nested lists for batch processing /// Placeholder for future use #[serde(skip_serializing_if = "Option::is_none")] pub image_data: Option, /// Video input data /// Can be a file name, URL, or base64 encoded string /// Supports single videos, lists of videos, or nested lists for batch processing /// Placeholder for future use #[serde(skip_serializing_if = "Option::is_none")] pub video_data: Option, /// Audio input data /// Can be a file name, URL, or base64 encoded string /// Supports single audio files, lists of audio, or nested lists for batch processing /// Placeholder for future use #[serde(skip_serializing_if = "Option::is_none")] pub audio_data: Option, /// Sampling parameters (sglang style) #[serde(skip_serializing_if = "Option::is_none")] pub sampling_params: Option, /// Whether to return logprobs #[serde(skip_serializing_if = "Option::is_none")] pub return_logprob: Option, /// If return logprobs, the start location in the prompt for returning logprobs. #[serde(skip_serializing_if = "Option::is_none")] pub logprob_start_len: Option, /// If return logprobs, the number of top logprobs to return at each position. #[serde(skip_serializing_if = "Option::is_none")] pub top_logprobs_num: Option, /// If return logprobs, the token ids to return logprob for. #[serde(skip_serializing_if = "Option::is_none")] pub token_ids_logprob: Option>, /// Whether to detokenize tokens in text in the returned logprobs. #[serde(default)] pub return_text_in_logprobs: bool, /// Whether to stream the response #[serde(default)] pub stream: bool, /// Whether to log metrics for this request (e.g. health_generate calls do not log metrics) #[serde(default = "default_true")] pub log_metrics: bool, /// Return model hidden states #[serde(default)] pub return_hidden_states: bool, /// The modalities of the image data [image, multi-images, video] #[serde(skip_serializing_if = "Option::is_none")] pub modalities: Option>, /// Session parameters for continual prompting #[serde(skip_serializing_if = "Option::is_none")] pub session_params: Option>, /// Path to LoRA adapter(s) for model customization #[serde(skip_serializing_if = "Option::is_none")] pub lora_path: Option, /// LoRA adapter ID (if pre-loaded) #[serde(skip_serializing_if = "Option::is_none")] pub lora_id: Option, /// Custom logit processor for advanced sampling control. Must be a serialized instance /// of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py /// Use the processor's `to_str()` method to generate the serialized string. #[serde(skip_serializing_if = "Option::is_none")] pub custom_logit_processor: Option, /// For disaggregated inference #[serde(skip_serializing_if = "Option::is_none")] pub bootstrap_host: Option, /// For disaggregated inference #[serde(skip_serializing_if = "Option::is_none")] pub bootstrap_port: Option, /// For disaggregated inference #[serde(skip_serializing_if = "Option::is_none")] pub bootstrap_room: Option, /// For disaggregated inference #[serde(skip_serializing_if = "Option::is_none")] pub bootstrap_pair_key: Option, /// Data parallel rank routing #[serde(skip_serializing_if = "Option::is_none")] pub data_parallel_rank: Option, /// Background response #[serde(default)] pub background: bool, /// Conversation ID for tracking #[serde(skip_serializing_if = "Option::is_none")] pub conversation_id: Option, /// Priority for the request #[serde(skip_serializing_if = "Option::is_none")] pub priority: Option, /// Extra key for classifying the request (e.g. cache_salt) #[serde(skip_serializing_if = "Option::is_none")] pub extra_key: Option, /// Whether to disallow logging for this request (e.g. due to ZDR) #[serde(default)] pub no_logs: bool, /// Custom metric labels #[serde(skip_serializing_if = "Option::is_none")] pub custom_labels: Option>, /// Whether to return bytes for image generation #[serde(default)] pub return_bytes: bool, /// Whether to return entropy #[serde(default)] pub return_entropy: bool, /// Request ID for tracking (inherited from BaseReq in Python) #[serde(skip_serializing_if = "Option::is_none")] pub rid: Option, } impl Normalizable for GenerateRequest { // Use default no-op implementation - no normalization needed for GenerateRequest } /// Validation function for GenerateRequest - ensure exactly one input type is provided fn validate_generate_request(req: &GenerateRequest) -> Result<(), validator::ValidationError> { // Exactly one of text or input_ids must be provided // Note: input_embeds not yet supported in Rust implementation let has_text = req.text.is_some(); let has_input_ids = req.input_ids.is_some(); let count = [has_text, has_input_ids].iter().filter(|&&x| x).count(); if count == 0 { return Err(validator::ValidationError::new( "Either text or input_ids should be provided.", )); } if count > 1 { return Err(validator::ValidationError::new( "Either text or input_ids should be provided.", )); } Ok(()) } impl GenerationRequest for GenerateRequest { fn is_stream(&self) -> bool { self.stream } fn get_model(&self) -> Option<&str> { // Generate requests typically don't have a model field None } fn extract_text_for_routing(&self) -> String { // Check fields in priority order: text, input_ids if let Some(ref text) = self.text { return text.clone(); } if let Some(ref input_ids) = self.input_ids { return match input_ids { InputIds::Single(ids) => ids .iter() .map(|&id| id.to_string()) .collect::>() .join(" "), InputIds::Batch(batches) => batches .iter() .flat_map(|batch| batch.iter().map(|&id| id.to_string())) .collect::>() .join(" "), }; } // No text input found String::new() } } // ============================================================================ // SGLang Generate Response Types // ============================================================================ /// SGLang generate response (single completion or array for n>1) /// /// Format for n=1: /// ```json /// { /// "text": "...", /// "output_ids": [...], /// "meta_info": { ... } /// } /// ``` /// /// Format for n>1: /// ```json /// [ /// {"text": "...", "output_ids": [...], "meta_info": {...}}, /// {"text": "...", "output_ids": [...], "meta_info": {...}} /// ] /// ``` #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GenerateResponse { pub text: String, pub output_ids: Vec, pub meta_info: GenerateMetaInfo, } /// Metadata for a single generate completion #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GenerateMetaInfo { pub id: String, pub finish_reason: GenerateFinishReason, pub prompt_tokens: u32, pub weight_version: String, #[serde(skip_serializing_if = "Option::is_none")] pub input_token_logprobs: Option>>>, #[serde(skip_serializing_if = "Option::is_none")] pub output_token_logprobs: Option>>>, pub completion_tokens: u32, pub cached_tokens: u32, pub e2e_latency: f64, #[serde(skip_serializing_if = "Option::is_none")] pub matched_stop: Option, } /// Finish reason for generate endpoint #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "lowercase")] pub enum GenerateFinishReason { Length { length: u32, }, Stop, #[serde(untagged)] Other(Value), }