Unverified Commit 5ef545e6 authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] Move all protocols to spec.rs file (#9519)

parent c4500233
// Supporting types for Responses API
use crate::protocols::openai::common::ChatLogProbs;
use serde::{Deserialize, Serialize};
// ============= Tool Definitions =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponseTool {
#[serde(rename = "type")]
pub r#type: ResponseToolType,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ResponseToolType {
WebSearchPreview,
CodeInterpreter,
}
// ============= Reasoning Configuration =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponseReasoningParam {
#[serde(default = "default_reasoning_effort")]
pub effort: Option<ReasoningEffort>,
}
fn default_reasoning_effort() -> Option<ReasoningEffort> {
Some(ReasoningEffort::Medium)
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ReasoningEffort {
Low,
Medium,
High,
}
// ============= Input/Output Items =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ResponseInputOutputItem {
#[serde(rename = "message")]
Message {
id: String,
role: String,
content: Vec<ResponseContentPart>,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
#[serde(rename = "reasoning")]
Reasoning {
id: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
summary: Vec<String>,
content: Vec<ResponseReasoningContent>,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
#[serde(rename = "function_tool_call")]
FunctionToolCall {
id: String,
name: String,
arguments: String,
#[serde(skip_serializing_if = "Option::is_none")]
output: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ResponseContentPart {
#[serde(rename = "output_text")]
OutputText {
text: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
annotations: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
logprobs: Option<ChatLogProbs>,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ResponseReasoningContent {
#[serde(rename = "reasoning_text")]
ReasoningText { text: String },
}
// ============= Output Items for Response =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ResponseOutputItem {
#[serde(rename = "message")]
Message {
id: String,
role: String,
content: Vec<ResponseContentPart>,
status: String,
},
#[serde(rename = "reasoning")]
Reasoning {
id: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
summary: Vec<String>,
content: Vec<ResponseReasoningContent>,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
#[serde(rename = "function_tool_call")]
FunctionToolCall {
id: String,
name: String,
arguments: String,
#[serde(skip_serializing_if = "Option::is_none")]
output: Option<String>,
status: String,
},
}
// ============= Service Tier =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ServiceTier {
Auto,
Default,
Flex,
Scale,
Priority,
}
impl Default for ServiceTier {
fn default() -> Self {
Self::Auto
}
}
// ============= Tool Choice =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolChoice {
Auto,
Required,
None,
}
impl Default for ToolChoice {
fn default() -> Self {
Self::Auto
}
}
// ============= Truncation =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum Truncation {
Auto,
Disabled,
}
impl Default for Truncation {
fn default() -> Self {
Self::Disabled
}
}
// ============= Response Status =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ResponseStatus {
Queued,
InProgress,
Completed,
Failed,
Cancelled,
}
// ============= Include Fields =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum IncludeField {
#[serde(rename = "code_interpreter_call.outputs")]
CodeInterpreterCallOutputs,
#[serde(rename = "computer_call_output.output.image_url")]
ComputerCallOutputImageUrl,
#[serde(rename = "file_search_call.results")]
FileSearchCallResults,
#[serde(rename = "message.input_image.image_url")]
MessageInputImageUrl,
#[serde(rename = "message.output_text.logprobs")]
MessageOutputTextLogprobs,
#[serde(rename = "reasoning.encrypted_content")]
ReasoningEncryptedContent,
}
// ============= Usage Info =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct UsageInfo {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct PromptTokenUsageInfo {
pub cached_tokens: u32,
}
// ============= Response Usage Format =============
/// OpenAI Responses API usage format (different from standard UsageInfo)
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponseUsage {
pub input_tokens: u32,
pub output_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_tokens_details: Option<InputTokensDetails>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_tokens_details: Option<OutputTokensDetails>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct InputTokensDetails {
pub cached_tokens: u32,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct OutputTokensDetails {
pub reasoning_tokens: u32,
}
impl UsageInfo {
/// Convert to OpenAI Responses API format
pub fn to_response_usage(&self) -> ResponseUsage {
ResponseUsage {
input_tokens: self.prompt_tokens,
output_tokens: self.completion_tokens,
total_tokens: self.total_tokens,
input_tokens_details: self.prompt_tokens_details.as_ref().map(|details| {
InputTokensDetails {
cached_tokens: details.cached_tokens,
}
}),
output_tokens_details: self.reasoning_tokens.map(|tokens| OutputTokensDetails {
reasoning_tokens: tokens,
}),
}
}
}
impl From<UsageInfo> for ResponseUsage {
fn from(usage: UsageInfo) -> Self {
usage.to_response_usage()
}
}
impl ResponseUsage {
/// Convert back to standard UsageInfo format
pub fn to_usage_info(&self) -> UsageInfo {
UsageInfo {
prompt_tokens: self.input_tokens,
completion_tokens: self.output_tokens,
total_tokens: self.total_tokens,
reasoning_tokens: self
.output_tokens_details
.as_ref()
.map(|details| details.reasoning_tokens),
prompt_tokens_details: self.input_tokens_details.as_ref().map(|details| {
PromptTokenUsageInfo {
cached_tokens: details.cached_tokens,
}
}),
}
}
}
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
// # Protocol Specifications
//
// This module contains all protocol definitions for OpenAI and SGLang APIs.
//
// ## Table of Contents
//
// 1. **OPENAI SPEC - Chat Completions API**
// - Message Types
// - Response Format Types
// - Tool/Function Types
// - Streaming Delta Types
// - Request/Response structures
//
// 2. **OPENAI SPEC - Completions API**
// - Request/Response structures
// - Streaming support
//
// 3. **OPENAI SPEC - Responses API**
// - Tool Definitions
// - Reasoning Configuration
// - Input/Output Items
// - Service Tier & Tool Choice
// - Request/Response structures
//
// 4. **OPENAI SPEC - Common**
// - Shared Request Components
// - Tool Choice Types
// - Usage Tracking
// - Logprobs Types
// - Error Response Types
//
// 5. **SGLANG SPEC - GENERATE API**
// - Generate Parameters
// - Sampling Parameters
// - Request/Response structures
//
// 6. **COMMON**
// - GenerationRequest trait
// - StringOrArray & LoRAPath types
// - Helper functions
// ==================================================================
// = OPENAI SPEC - Chat Completions API =
// ==================================================================
// ============= Message Types =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ChatMessage {
System {
role: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
User {
role: String, // "user"
content: UserMessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
Assistant {
role: String, // "assistant"
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
function_call: Option<FunctionCallResponse>,
/// Reasoning content for O1-style models (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_content: Option<String>,
},
Tool {
role: String, // "tool"
content: String,
tool_call_id: String,
},
Function {
role: String, // "function"
content: String,
name: String,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum UserMessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
pub enum ContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>, // "auto", "low", or "high"
}
// ============= Response Format Types =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
pub enum ResponseFormat {
#[serde(rename = "text")]
Text,
#[serde(rename = "json_object")]
JsonObject,
#[serde(rename = "json_schema")]
JsonSchema { json_schema: JsonSchemaFormat },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct JsonSchemaFormat {
pub name: String,
pub schema: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
// ============= Streaming Delta Types =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatMessageDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallDelta>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCallDelta>,
/// Reasoning content delta for O1-style models (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolCallDelta {
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "type")]
pub tool_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<FunctionCallDelta>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FunctionCallDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
// ============= Request =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionRequest {
/// ID of the model to use
pub model: String,
/// A list of messages comprising the conversation so far
pub messages: Vec<ChatMessage>,
/// What sampling temperature to use, between 0 and 2
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
/// How many chat completion choices to generate for each input message
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
/// If set, partial message deltas will be sent
#[serde(default)]
pub stream: bool,
/// Options for streaming response
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
/// Up to 4 sequences where the API will stop generating further tokens
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StringOrArray>,
/// The maximum number of tokens to generate
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
/// An upper bound for the number of tokens that can be generated for a completion
#[serde(skip_serializing_if = "Option::is_none")]
pub max_completion_tokens: Option<u32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
/// Modify the likelihood of specified tokens appearing in the completion
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, f32>>,
/// A unique identifier representing your end-user
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
/// If specified, our system will make a best effort to sample deterministically
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
/// Whether to return log probabilities of the output tokens
#[serde(default)]
pub logprobs: bool,
/// An integer between 0 and 20 specifying the number of most likely tokens to return
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<u32>,
/// An object specifying the format that the model must output
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
/// A list of tools the model may call
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
/// Controls which (if any) tool is called by the model
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
/// Whether to enable parallel function calling during tool use
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
/// Deprecated: use tools instead
#[serde(skip_serializing_if = "Option::is_none")]
pub functions: Option<Vec<Function>>,
/// Deprecated: use tool_choice instead
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>,
// ============= SGLang Extensions =============
/// Top-k sampling parameter (-1 to disable)
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
/// Min-p nucleus sampling parameter
#[serde(skip_serializing_if = "Option::is_none")]
pub min_p: Option<f32>,
/// Minimum number of tokens to generate
#[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
/// Repetition penalty for reducing repetitive text
#[serde(skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f32>,
/// Regex constraint for output generation
#[serde(skip_serializing_if = "Option::is_none")]
pub regex: Option<String>,
/// EBNF grammar constraint for structured output
#[serde(skip_serializing_if = "Option::is_none")]
pub ebnf: Option<String>,
/// Specific token IDs to use as stop conditions
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_token_ids: Option<Vec<i32>>,
/// Skip trimming stop tokens from output
#[serde(default)]
pub no_stop_trim: bool,
/// Ignore end-of-sequence tokens during generation
#[serde(default)]
pub ignore_eos: bool,
/// Continue generating from final assistant message
#[serde(default)]
pub continue_final_message: bool,
/// Skip special tokens during detokenization
#[serde(default = "default_true")]
pub skip_special_tokens: bool,
// ============= SGLang Extensions =============
/// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if = "Option::is_none")]
pub lora_path: Option<LoRAPath>,
/// Session parameters for continual prompting
#[serde(skip_serializing_if = "Option::is_none")]
pub session_params: Option<HashMap<String, serde_json::Value>>,
/// Separate reasoning content from final answer (O1-style models)
#[serde(default = "default_true")]
pub separate_reasoning: bool,
/// Stream reasoning tokens during generation
#[serde(default = "default_true")]
pub stream_reasoning: bool,
/// Return model hidden states
#[serde(default)]
pub return_hidden_states: bool,
}
impl GenerationRequest for ChatCompletionRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_model(&self) -> Option<&str> {
Some(&self.model)
}
fn extract_text_for_routing(&self) -> String {
// Extract text from messages for routing decisions
self.messages
.iter()
.filter_map(|msg| match msg {
ChatMessage::System { content, .. } => Some(content.clone()),
ChatMessage::User { content, .. } => match content {
UserMessageContent::Text(text) => Some(text.clone()),
UserMessageContent::Parts(parts) => {
let texts: Vec<String> = parts
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.clone()),
_ => None,
})
.collect();
Some(texts.join(" "))
}
},
ChatMessage::Assistant {
content,
reasoning_content,
..
} => {
// Combine content and reasoning content for routing decisions
let main_content = content.clone().unwrap_or_default();
let reasoning = reasoning_content.clone().unwrap_or_default();
if main_content.is_empty() && reasoning.is_empty() {
None
} else {
Some(format!("{} {}", main_content, reasoning).trim().to_string())
}
}
ChatMessage::Tool { content, .. } => Some(content.clone()),
ChatMessage::Function { content, .. } => Some(content.clone()),
})
.collect::<Vec<String>>()
.join(" ")
}
}
// ============= Regular Response =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String, // "chat.completion"
pub created: u64,
pub model: String,
pub choices: Vec<ChatChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatChoice {
pub index: u32,
pub message: ChatMessage,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"
/// Information about which stop condition was matched
#[serde(skip_serializing_if = "Option::is_none")]
pub matched_stop: Option<serde_json::Value>, // Can be string or integer
/// Hidden states from the model (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
pub hidden_states: Option<Vec<f32>>,
}
// ============= Streaming Response =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionStreamResponse {
pub id: String,
pub object: String, // "chat.completion.chunk"
pub created: u64,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
pub choices: Vec<ChatStreamChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatStreamChoice {
pub index: u32,
pub delta: ChatMessageDelta,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>,
}
// ==================================================================
// = OPENAI SPEC - Completions API =
// ==================================================================
// Completions API request types (v1/completions) - DEPRECATED but still supported
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionRequest {
/// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang)
pub model: String,
/// The prompt(s) to generate completions for
pub prompt: StringOrArray,
/// The suffix that comes after a completion of inserted text
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
/// The maximum number of tokens to generate
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
/// What sampling temperature to use, between 0 and 2
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature (nucleus sampling)
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
/// How many completions to generate for each prompt
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
/// Whether to stream back partial progress
#[serde(default)]
pub stream: bool,
/// Options for streaming response
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
/// Include the log probabilities on the logprobs most likely tokens
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<u32>,
/// Echo back the prompt in addition to the completion
#[serde(default)]
pub echo: bool,
/// Up to 4 sequences where the API will stop generating further tokens
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StringOrArray>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
/// Generates best_of completions server-side and returns the "best"
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of: Option<u32>,
/// Modify the likelihood of specified tokens appearing in the completion
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, f32>>,
/// A unique identifier representing your end-user
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
/// If specified, our system will make a best effort to sample deterministically
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
// ============= SGLang Extensions =============
/// Top-k sampling parameter (-1 to disable)
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
/// Min-p nucleus sampling parameter
#[serde(skip_serializing_if = "Option::is_none")]
pub min_p: Option<f32>,
/// Minimum number of tokens to generate
#[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
/// Repetition penalty for reducing repetitive text
#[serde(skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f32>,
/// Regex constraint for output generation
#[serde(skip_serializing_if = "Option::is_none")]
pub regex: Option<String>,
/// EBNF grammar constraint for structured output
#[serde(skip_serializing_if = "Option::is_none")]
pub ebnf: Option<String>,
/// JSON schema constraint for structured output
#[serde(skip_serializing_if = "Option::is_none")]
pub json_schema: Option<String>,
/// Specific token IDs to use as stop conditions
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_token_ids: Option<Vec<i32>>,
/// Skip trimming stop tokens from output
#[serde(default)]
pub no_stop_trim: bool,
/// Ignore end-of-sequence tokens during generation
#[serde(default)]
pub ignore_eos: bool,
/// Skip special tokens during detokenization
#[serde(default = "default_true")]
pub skip_special_tokens: bool,
// ============= SGLang Extensions =============
/// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if = "Option::is_none")]
pub lora_path: Option<LoRAPath>,
/// Session parameters for continual prompting
#[serde(skip_serializing_if = "Option::is_none")]
pub session_params: Option<HashMap<String, serde_json::Value>>,
/// Return model hidden states
#[serde(default)]
pub return_hidden_states: bool,
/// Additional fields including bootstrap info for PD routing
#[serde(flatten)]
pub other: serde_json::Map<String, serde_json::Value>,
}
impl GenerationRequest for CompletionRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_model(&self) -> Option<&str> {
Some(&self.model)
}
fn extract_text_for_routing(&self) -> String {
match &self.prompt {
StringOrArray::String(s) => s.clone(),
StringOrArray::Array(v) => v.join(" "),
}
}
}
// ============= Regular Response =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionResponse {
pub id: String,
pub object: String, // "text_completion"
pub created: u64,
pub model: String,
pub choices: Vec<CompletionChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionChoice {
pub text: String,
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<LogProbs>,
pub finish_reason: Option<String>, // "stop", "length", "content_filter", etc.
/// Information about which stop condition was matched
#[serde(skip_serializing_if = "Option::is_none")]
pub matched_stop: Option<serde_json::Value>, // Can be string or integer
/// Hidden states from the model (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
pub hidden_states: Option<Vec<f32>>,
}
// ============= Streaming Response =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionStreamResponse {
pub id: String,
pub object: String, // "text_completion"
pub created: u64,
pub choices: Vec<CompletionStreamChoice>,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionStreamChoice {
pub text: String,
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<LogProbs>,
pub finish_reason: Option<String>,
}
// ==================================================================
// = OPENAI SPEC - Responses API =
// ==================================================================
// ============= Tool Definitions =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponseTool {
#[serde(rename = "type")]
pub r#type: ResponseToolType,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ResponseToolType {
WebSearchPreview,
CodeInterpreter,
}
// ============= Reasoning Configuration =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponseReasoningParam {
#[serde(default = "default_reasoning_effort")]
pub effort: Option<ReasoningEffort>,
}
fn default_reasoning_effort() -> Option<ReasoningEffort> {
Some(ReasoningEffort::Medium)
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ReasoningEffort {
Low,
Medium,
High,
}
// ============= Input/Output Items =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ResponseInputOutputItem {
#[serde(rename = "message")]
Message {
id: String,
role: String,
content: Vec<ResponseContentPart>,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
#[serde(rename = "reasoning")]
Reasoning {
id: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
summary: Vec<String>,
content: Vec<ResponseReasoningContent>,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
#[serde(rename = "function_tool_call")]
FunctionToolCall {
id: String,
name: String,
arguments: String,
#[serde(skip_serializing_if = "Option::is_none")]
output: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ResponseContentPart {
#[serde(rename = "output_text")]
OutputText {
text: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
annotations: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
logprobs: Option<ChatLogProbs>,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ResponseReasoningContent {
#[serde(rename = "reasoning_text")]
ReasoningText { text: String },
}
// ============= Output Items for Response =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum ResponseOutputItem {
#[serde(rename = "message")]
Message {
id: String,
role: String,
content: Vec<ResponseContentPart>,
status: String,
},
#[serde(rename = "reasoning")]
Reasoning {
id: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
summary: Vec<String>,
content: Vec<ResponseReasoningContent>,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
},
#[serde(rename = "function_tool_call")]
FunctionToolCall {
id: String,
name: String,
arguments: String,
#[serde(skip_serializing_if = "Option::is_none")]
output: Option<String>,
status: String,
},
}
// ============= Service Tier =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ServiceTier {
Auto,
Default,
Flex,
Scale,
Priority,
}
impl Default for ServiceTier {
fn default() -> Self {
Self::Auto
}
}
// ============= Truncation =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum Truncation {
Auto,
Disabled,
}
impl Default for Truncation {
fn default() -> Self {
Self::Disabled
}
}
// ============= Response Status =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ResponseStatus {
Queued,
InProgress,
Completed,
Failed,
Cancelled,
}
// ============= Include Fields =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum IncludeField {
#[serde(rename = "code_interpreter_call.outputs")]
CodeInterpreterCallOutputs,
#[serde(rename = "computer_call_output.output.image_url")]
ComputerCallOutputImageUrl,
#[serde(rename = "file_search_call.results")]
FileSearchCallResults,
#[serde(rename = "message.input_image.image_url")]
MessageInputImageUrl,
#[serde(rename = "message.output_text.logprobs")]
MessageOutputTextLogprobs,
#[serde(rename = "reasoning.encrypted_content")]
ReasoningEncryptedContent,
}
// ============= Usage Info =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct UsageInfo {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct PromptTokenUsageInfo {
pub cached_tokens: u32,
}
// ============= Response Usage Format =============
/// OpenAI Responses API usage format (different from standard UsageInfo)
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponseUsage {
pub input_tokens: u32,
pub output_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_tokens_details: Option<InputTokensDetails>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_tokens_details: Option<OutputTokensDetails>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct InputTokensDetails {
pub cached_tokens: u32,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct OutputTokensDetails {
pub reasoning_tokens: u32,
}
impl UsageInfo {
/// Convert to OpenAI Responses API format
pub fn to_response_usage(&self) -> ResponseUsage {
ResponseUsage {
input_tokens: self.prompt_tokens,
output_tokens: self.completion_tokens,
total_tokens: self.total_tokens,
input_tokens_details: self.prompt_tokens_details.as_ref().map(|details| {
InputTokensDetails {
cached_tokens: details.cached_tokens,
}
}),
output_tokens_details: self.reasoning_tokens.map(|tokens| OutputTokensDetails {
reasoning_tokens: tokens,
}),
}
}
}
impl From<UsageInfo> for ResponseUsage {
fn from(usage: UsageInfo) -> Self {
usage.to_response_usage()
}
}
impl ResponseUsage {
/// Convert back to standard UsageInfo format
pub fn to_usage_info(&self) -> UsageInfo {
UsageInfo {
prompt_tokens: self.input_tokens,
completion_tokens: self.output_tokens,
total_tokens: self.total_tokens,
reasoning_tokens: self
.output_tokens_details
.as_ref()
.map(|details| details.reasoning_tokens),
prompt_tokens_details: self.input_tokens_details.as_ref().map(|details| {
PromptTokenUsageInfo {
cached_tokens: details.cached_tokens,
}
}),
}
}
}
fn generate_request_id() -> String {
format!("resp_{}", uuid::Uuid::new_v4().simple())
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponsesRequest {
// ============= Core OpenAI API fields =============
/// Run the request in the background
#[serde(default)]
pub background: bool,
/// Fields to include in the response
#[serde(skip_serializing_if = "Option::is_none")]
pub include: Option<Vec<IncludeField>>,
/// Input content - can be string or structured items
pub input: ResponseInput,
/// System instructions for the model
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
/// Maximum number of output tokens
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<u32>,
/// Maximum number of tool calls
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tool_calls: Option<u32>,
/// Additional metadata
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, serde_json::Value>>,
/// Model to use (optional to match vLLM)
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
/// Whether to enable parallel tool calls
#[serde(default = "default_true")]
pub parallel_tool_calls: bool,
/// ID of previous response to continue from
#[serde(skip_serializing_if = "Option::is_none")]
pub previous_response_id: Option<String>,
/// Reasoning configuration
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning: Option<ResponseReasoningParam>,
/// Service tier
#[serde(default)]
pub service_tier: ServiceTier,
/// Whether to store the response
#[serde(default = "default_true")]
pub store: bool,
/// Whether to stream the response
#[serde(default)]
pub stream: bool,
/// Temperature for sampling
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
/// Tool choice behavior
#[serde(default)]
pub tool_choice: ToolChoice,
/// Available tools
#[serde(default)]
pub tools: Vec<ResponseTool>,
/// Number of top logprobs to return
#[serde(default)]
pub top_logprobs: u32,
/// Top-p sampling parameter
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
/// Truncation behavior
#[serde(default)]
pub truncation: Truncation,
/// User identifier
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
// ============= SGLang Extensions =============
/// Request ID
#[serde(default = "generate_request_id")]
pub request_id: String,
/// Request priority
#[serde(default)]
pub priority: i32,
/// Frequency penalty
#[serde(default)]
pub frequency_penalty: f32,
/// Presence penalty
#[serde(default)]
pub presence_penalty: f32,
/// Stop sequences
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StringOrArray>,
/// Top-k sampling parameter
#[serde(default = "default_top_k")]
pub top_k: i32,
/// Min-p sampling parameter
#[serde(default)]
pub min_p: f32,
/// Repetition penalty
#[serde(default = "default_repetition_penalty")]
pub repetition_penalty: f32,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ResponseInput {
Text(String),
Items(Vec<ResponseInputOutputItem>),
}
fn default_top_k() -> i32 {
-1
}
fn default_repetition_penalty() -> f32 {
1.0
}
impl ResponsesRequest {
/// Default sampling parameters
const DEFAULT_TEMPERATURE: f32 = 0.7;
const DEFAULT_TOP_P: f32 = 1.0;
/// Convert to sampling parameters for generation
pub fn to_sampling_params(
&self,
default_max_tokens: u32,
default_params: Option<HashMap<String, serde_json::Value>>,
) -> HashMap<String, serde_json::Value> {
let mut params = HashMap::new();
// Use max_output_tokens if available
let max_tokens = if let Some(max_output) = self.max_output_tokens {
std::cmp::min(max_output, default_max_tokens)
} else {
default_max_tokens
};
// Avoid exceeding context length by minus 1 token
let max_tokens = max_tokens.saturating_sub(1);
// Temperature
let temperature = self.temperature.unwrap_or_else(|| {
default_params
.as_ref()
.and_then(|p| p.get("temperature"))
.and_then(|v| v.as_f64())
.map(|v| v as f32)
.unwrap_or(Self::DEFAULT_TEMPERATURE)
});
// Top-p
let top_p = self.top_p.unwrap_or_else(|| {
default_params
.as_ref()
.and_then(|p| p.get("top_p"))
.and_then(|v| v.as_f64())
.map(|v| v as f32)
.unwrap_or(Self::DEFAULT_TOP_P)
});
params.insert(
"max_new_tokens".to_string(),
serde_json::Value::Number(serde_json::Number::from(max_tokens)),
);
params.insert(
"temperature".to_string(),
serde_json::Value::Number(serde_json::Number::from_f64(temperature as f64).unwrap()),
);
params.insert(
"top_p".to_string(),
serde_json::Value::Number(serde_json::Number::from_f64(top_p as f64).unwrap()),
);
params.insert(
"frequency_penalty".to_string(),
serde_json::Value::Number(
serde_json::Number::from_f64(self.frequency_penalty as f64).unwrap(),
),
);
params.insert(
"presence_penalty".to_string(),
serde_json::Value::Number(
serde_json::Number::from_f64(self.presence_penalty as f64).unwrap(),
),
);
params.insert(
"top_k".to_string(),
serde_json::Value::Number(serde_json::Number::from(self.top_k)),
);
params.insert(
"min_p".to_string(),
serde_json::Value::Number(serde_json::Number::from_f64(self.min_p as f64).unwrap()),
);
params.insert(
"repetition_penalty".to_string(),
serde_json::Value::Number(
serde_json::Number::from_f64(self.repetition_penalty as f64).unwrap(),
),
);
if let Some(ref stop) = self.stop {
match serde_json::to_value(stop) {
Ok(value) => params.insert("stop".to_string(), value),
Err(_) => params.insert("stop".to_string(), serde_json::Value::Null),
};
}
// Apply any additional default parameters
if let Some(default_params) = default_params {
for (key, value) in default_params {
params.entry(key).or_insert(value);
}
}
params
}
}
impl GenerationRequest for ResponsesRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_model(&self) -> Option<&str> {
self.model.as_deref()
}
fn extract_text_for_routing(&self) -> String {
match &self.input {
ResponseInput::Text(text) => text.clone(),
ResponseInput::Items(items) => items
.iter()
.filter_map(|item| match item {
ResponseInputOutputItem::Message { content, .. } => {
let texts: Vec<String> = content
.iter()
.map(|part| match part {
ResponseContentPart::OutputText { text, .. } => text.clone(),
})
.collect();
if texts.is_empty() {
None
} else {
Some(texts.join(" "))
}
}
ResponseInputOutputItem::Reasoning { content, .. } => {
let texts: Vec<String> = content
.iter()
.map(|part| match part {
ResponseReasoningContent::ReasoningText { text } => text.clone(),
})
.collect();
if texts.is_empty() {
None
} else {
Some(texts.join(" "))
}
}
ResponseInputOutputItem::FunctionToolCall { arguments, .. } => {
Some(arguments.clone())
}
})
.collect::<Vec<String>>()
.join(" "),
}
}
}
fn generate_response_id() -> String {
format!("resp_{}", uuid::Uuid::new_v4().simple())
}
fn current_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs() as i64
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponsesResponse {
/// Response ID
#[serde(default = "generate_response_id")]
pub id: String,
/// Object type
#[serde(default = "default_object_type")]
pub object: String,
/// Creation timestamp
#[serde(default = "current_timestamp")]
pub created_at: i64,
/// Model name
pub model: String,
/// Output items
#[serde(default)]
pub output: Vec<ResponseOutputItem>,
/// Response status
pub status: ResponseStatus,
/// Usage information
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<UsageInfo>,
/// Whether parallel tool calls are enabled
#[serde(default = "default_true")]
pub parallel_tool_calls: bool,
/// Tool choice setting
#[serde(default = "default_tool_choice")]
pub tool_choice: String,
/// Available tools
#[serde(default)]
pub tools: Vec<ResponseTool>,
}
fn default_object_type() -> String {
"response".to_string()
}
fn default_tool_choice() -> String {
"auto".to_string()
}
impl ResponsesResponse {
/// Create a response from a request
#[allow(clippy::too_many_arguments)]
pub fn from_request(
request: &ResponsesRequest,
_sampling_params: &HashMap<String, serde_json::Value>,
model_name: String,
created_time: i64,
output: Vec<ResponseOutputItem>,
status: ResponseStatus,
usage: Option<UsageInfo>,
) -> Self {
Self {
id: request.request_id.clone(),
object: "response".to_string(),
created_at: created_time,
model: model_name,
output,
status,
usage,
parallel_tool_calls: request.parallel_tool_calls,
tool_choice: match &request.tool_choice {
ToolChoice::Value(ToolChoiceValue::Auto) => "auto".to_string(),
ToolChoice::Value(ToolChoiceValue::Required) => "required".to_string(),
ToolChoice::Value(ToolChoiceValue::None) => "none".to_string(),
ToolChoice::Function { .. } => "function".to_string(),
},
tools: request.tools.clone(),
}
}
/// Create a new response with default values
pub fn new(request_id: String, model: String, status: ResponseStatus) -> Self {
Self {
id: request_id,
object: "response".to_string(),
created_at: current_timestamp(),
model,
output: Vec::new(),
status,
usage: None,
parallel_tool_calls: true,
tool_choice: "auto".to_string(),
tools: Vec::new(),
}
}
/// Add an output item to the response
pub fn add_output(&mut self, item: ResponseOutputItem) {
self.output.push(item);
}
/// Set the usage information
pub fn set_usage(&mut self, usage: UsageInfo) {
self.usage = Some(usage);
}
/// Update the status
pub fn set_status(&mut self, status: ResponseStatus) {
self.status = status;
}
/// Check if the response is complete
pub fn is_complete(&self) -> bool {
matches!(self.status, ResponseStatus::Completed)
}
/// Check if the response is in progress
pub fn is_in_progress(&self) -> bool {
matches!(self.status, ResponseStatus::InProgress)
}
/// Check if the response failed
pub fn is_failed(&self) -> bool {
matches!(self.status, ResponseStatus::Failed)
}
/// Check if the response was cancelled
pub fn is_cancelled(&self) -> bool {
matches!(self.status, ResponseStatus::Cancelled)
}
/// Check if the response is queued
pub fn is_queued(&self) -> bool {
matches!(self.status, ResponseStatus::Queued)
}
/// Convert usage to OpenAI Responses API format
pub fn usage_in_response_format(&self) -> Option<ResponseUsage> {
self.usage.as_ref().map(|usage| usage.to_response_usage())
}
/// Get the response as a JSON value with usage in response format
pub fn to_response_format(&self) -> serde_json::Value {
let mut response = serde_json::to_value(self).unwrap_or(serde_json::Value::Null);
// Convert usage to response format if present
if let Some(usage) = &self.usage {
if let Ok(usage_value) = serde_json::to_value(usage.to_response_usage()) {
response["usage"] = usage_value;
}
}
response
}
}
// ============= Helper Functions =============
impl ResponseOutputItem {
/// Create a new message output item
pub fn new_message(
id: String,
role: String,
content: Vec<ResponseContentPart>,
status: String,
) -> Self {
Self::Message {
id,
role,
content,
status,
}
}
/// Create a new reasoning output item
pub fn new_reasoning(
id: String,
summary: Vec<String>,
content: Vec<ResponseReasoningContent>,
status: Option<String>,
) -> Self {
Self::Reasoning {
id,
summary,
content,
status,
}
}
/// Create a new function tool call output item
pub fn new_function_tool_call(
id: String,
name: String,
arguments: String,
output: Option<String>,
status: String,
) -> Self {
Self::FunctionToolCall {
id,
name,
arguments,
output,
status,
}
}
}
impl ResponseContentPart {
/// Create a new text content part
pub fn new_text(
text: String,
annotations: Vec<String>,
logprobs: Option<ChatLogProbs>,
) -> Self {
Self::OutputText {
text,
annotations,
logprobs,
}
}
}
impl ResponseReasoningContent {
/// Create a new reasoning text content
pub fn new_reasoning_text(text: String) -> Self {
Self::ReasoningText { text }
}
}
impl UsageInfo {
/// Create a new usage info with token counts
pub fn new(prompt_tokens: u32, completion_tokens: u32, reasoning_tokens: Option<u32>) -> Self {
Self {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
reasoning_tokens,
prompt_tokens_details: None,
}
}
/// Create usage info with cached token details
pub fn new_with_cached(
prompt_tokens: u32,
completion_tokens: u32,
reasoning_tokens: Option<u32>,
cached_tokens: u32,
) -> Self {
Self {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
reasoning_tokens,
prompt_tokens_details: Some(PromptTokenUsageInfo { cached_tokens }),
}
}
}
// ==================================================================
// = OPENAI SPEC - Common =
// ==================================================================
// ============= Shared Request Components =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct StreamOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub include_usage: Option<bool>,
}
// ============= Tool Choice Types =============
/// Tool choice value for simple string options
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolChoiceValue {
Auto,
Required,
None,
}
/// Tool choice for both Chat Completion and Responses APIs
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ToolChoice {
Value(ToolChoiceValue),
Function {
#[serde(rename = "type")]
tool_type: String, // "function"
function: FunctionChoice,
},
}
impl Default for ToolChoice {
fn default() -> Self {
Self::Value(ToolChoiceValue::Auto)
}
}
/// Function choice specification for ToolChoice::Function
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FunctionChoice {
pub name: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String, // "function"
pub function: Function,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Function {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: Value, // JSON Schema
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: String, // "function"
pub function: FunctionCallResponse,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum FunctionCall {
None,
Auto,
Function { name: String },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FunctionCallResponse {
pub name: String,
pub arguments: String, // JSON string
}
// ============= Usage Tracking =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub completion_tokens_details: Option<CompletionTokensDetails>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionTokensDetails {
pub reasoning_tokens: Option<u32>,
}
// ============= Logprobs Types =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LogProbs {
pub tokens: Vec<String>,
pub token_logprobs: Vec<Option<f32>>,
pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
pub text_offset: Vec<u32>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatLogProbs {
pub content: Option<Vec<ChatLogProbsContent>>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatLogProbsContent {
pub token: String,
pub logprob: f32,
pub bytes: Option<Vec<u8>>,
pub top_logprobs: Vec<TopLogProb>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TopLogProb {
pub token: String,
pub logprob: f32,
pub bytes: Option<Vec<u8>>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ErrorResponse {
pub error: ErrorDetail,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ErrorDetail {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub param: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub code: Option<String>,
}
// ==================================================================
// = SGLANG SPEC - GENERATE API =
// ==================================================================
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum InputIds {
Single(Vec<i32>),
Batch(Vec<Vec<i32>>),
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct GenerateParameters {
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub decoder_input_details: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub do_sample: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_new_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub return_full_text: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub truncate: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub typical_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub watermark: Option<bool>,
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct SamplingParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_new_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StringOrArray>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ignore_eos: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub skip_special_tokens: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub json_schema: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub regex: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ebnf: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_token_ids: Option<Vec<i32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub no_stop_trim: Option<bool>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct GenerateRequest {
/// The prompt to generate from (OpenAI style)
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<StringOrArray>,
/// Text input - SGLang native format
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
/// Input IDs for tokenized input
#[serde(skip_serializing_if = "Option::is_none")]
pub input_ids: Option<InputIds>,
/// Generation parameters
#[serde(default, skip_serializing_if = "Option::is_none")]
pub parameters: Option<GenerateParameters>,
/// Sampling parameters (sglang style)
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling_params: Option<SamplingParams>,
/// Whether to stream the response
#[serde(default)]
pub stream: bool,
/// Whether to return logprobs
#[serde(default)]
pub return_logprob: bool,
// ============= SGLang Extensions =============
/// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if = "Option::is_none")]
pub lora_path: Option<LoRAPath>,
/// Session parameters for continual prompting
#[serde(skip_serializing_if = "Option::is_none")]
pub session_params: Option<HashMap<String, serde_json::Value>>,
/// Return model hidden states
#[serde(default)]
pub return_hidden_states: bool,
/// Request ID for tracking
#[serde(skip_serializing_if = "Option::is_none")]
pub rid: Option<String>,
}
impl GenerationRequest for GenerateRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_model(&self) -> Option<&str> {
// Generate requests typically don't have a model field
None
}
fn extract_text_for_routing(&self) -> String {
// Check fields in priority order: text, prompt, inputs
if let Some(ref text) = self.text {
return text.clone();
}
if let Some(ref prompt) = self.prompt {
return match prompt {
StringOrArray::String(s) => s.clone(),
StringOrArray::Array(v) => v.join(" "),
};
}
if let Some(ref input_ids) = self.input_ids {
return match input_ids {
InputIds::Single(ids) => ids
.iter()
.map(|&id| id.to_string())
.collect::<Vec<String>>()
.join(" "),
InputIds::Batch(batches) => batches
.iter()
.flat_map(|batch| batch.iter().map(|&id| id.to_string()))
.collect::<Vec<String>>()
.join(" "),
};
}
// No text input found
String::new()
}
}
// ==================================================================
// = COMMON =
// ==================================================================
/// Helper function for serde default value
pub fn default_true() -> bool {
true
}
/// Common trait for all generation requests across different APIs
pub trait GenerationRequest: Send + Sync {
/// Check if the request is for streaming
fn is_stream(&self) -> bool;
/// Get the model name if specified
fn get_model(&self) -> Option<&str>;
/// Extract text content for routing decisions
fn extract_text_for_routing(&self) -> String;
}
/// Helper type for string or array of strings
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum StringOrArray {
String(String),
Array(Vec<String>),
}
impl StringOrArray {
/// Get the number of items in the StringOrArray
pub fn len(&self) -> usize {
match self {
StringOrArray::String(_) => 1,
StringOrArray::Array(arr) => arr.len(),
}
}
/// Check if the StringOrArray is empty
pub fn is_empty(&self) -> bool {
match self {
StringOrArray::String(s) => s.is_empty(),
StringOrArray::Array(arr) => arr.is_empty(),
}
}
/// Convert to a vector of strings
pub fn to_vec(&self) -> Vec<String> {
match self {
StringOrArray::String(s) => vec![s.clone()],
StringOrArray::Array(arr) => arr.clone(),
}
}
}
/// LoRA adapter path - can be single path or batch of paths (SGLang extension)
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum LoRAPath {
Single(Option<String>),
Batch(Vec<Option<String>>),
}
......@@ -4,6 +4,11 @@ use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
// Import types from spec module
use crate::protocols::spec::{
ChatCompletionRequest, ChatMessage, ResponseFormat, StringOrArray, UserMessageContent,
};
/// Validation constants for OpenAI API parameters
pub mod constants {
/// Temperature range: 0.0 to 2.0 (OpenAI spec)
......@@ -257,7 +262,7 @@ pub mod utils {
) -> Result<(), ValidationError> {
if let Some(stop) = request.get_stop_sequences() {
match stop {
crate::protocols::common::StringOrArray::String(s) => {
StringOrArray::String(s) => {
if s.is_empty() {
return Err(ValidationError::InvalidValue {
parameter: "stop".to_string(),
......@@ -266,7 +271,7 @@ pub mod utils {
});
}
}
crate::protocols::common::StringOrArray::Array(arr) => {
StringOrArray::Array(arr) => {
validate_max_items(arr, constants::MAX_STOP_SEQUENCES, "stop")?;
for (i, s) in arr.iter().enumerate() {
if s.is_empty() {
......@@ -469,7 +474,7 @@ pub trait SamplingOptionsProvider {
/// Trait for validating stop conditions
pub trait StopConditionsProvider {
/// Get stop sequences
fn get_stop_sequences(&self) -> Option<&crate::protocols::common::StringOrArray>;
fn get_stop_sequences(&self) -> Option<&StringOrArray>;
}
/// Trait for validating token limits
......@@ -532,25 +537,237 @@ pub trait ValidatableRequest:
}
}
// ==================================================================
// = OPENAI CHAT COMPLETION VALIDATION =
// ==================================================================
impl SamplingOptionsProvider for ChatCompletionRequest {
fn get_temperature(&self) -> Option<f32> {
self.temperature
}
fn get_top_p(&self) -> Option<f32> {
self.top_p
}
fn get_frequency_penalty(&self) -> Option<f32> {
self.frequency_penalty
}
fn get_presence_penalty(&self) -> Option<f32> {
self.presence_penalty
}
}
impl StopConditionsProvider for ChatCompletionRequest {
fn get_stop_sequences(&self) -> Option<&StringOrArray> {
self.stop.as_ref()
}
}
impl TokenLimitsProvider for ChatCompletionRequest {
fn get_max_tokens(&self) -> Option<u32> {
// Prefer max_completion_tokens over max_tokens if both are set
self.max_completion_tokens.or(self.max_tokens)
}
fn get_min_tokens(&self) -> Option<u32> {
self.min_tokens
}
}
impl LogProbsProvider for ChatCompletionRequest {
fn get_logprobs(&self) -> Option<u32> {
// For chat API, logprobs is a boolean, return 1 if true for validation purposes
if self.logprobs {
Some(1)
} else {
None
}
}
fn get_top_logprobs(&self) -> Option<u32> {
self.top_logprobs
}
}
impl SGLangExtensionsProvider for ChatCompletionRequest {
fn get_top_k(&self) -> Option<i32> {
self.top_k
}
fn get_min_p(&self) -> Option<f32> {
self.min_p
}
fn get_repetition_penalty(&self) -> Option<f32> {
self.repetition_penalty
}
}
impl CompletionCountProvider for ChatCompletionRequest {
fn get_n(&self) -> Option<u32> {
self.n
}
}
impl ChatCompletionRequest {
/// Validate message-specific requirements
pub fn validate_messages(&self) -> Result<(), ValidationError> {
// Ensure messages array is not empty
utils::validate_non_empty_array(&self.messages, "messages")?;
// Validate message content is not empty
for (i, msg) in self.messages.iter().enumerate() {
if let ChatMessage::User { content, .. } = msg {
match content {
UserMessageContent::Text(text) if text.is_empty() => {
return Err(ValidationError::InvalidValue {
parameter: format!("messages[{}].content", i),
value: "empty".to_string(),
reason: "message content cannot be empty".to_string(),
});
}
UserMessageContent::Parts(parts) if parts.is_empty() => {
return Err(ValidationError::InvalidValue {
parameter: format!("messages[{}].content", i),
value: "empty array".to_string(),
reason: "message content parts cannot be empty".to_string(),
});
}
_ => {}
}
}
}
Ok(())
}
/// Validate response format if specified
pub fn validate_response_format(&self) -> Result<(), ValidationError> {
if let Some(ResponseFormat::JsonSchema { json_schema }) = &self.response_format {
if json_schema.name.is_empty() {
return Err(ValidationError::InvalidValue {
parameter: "response_format.json_schema.name".to_string(),
value: "empty".to_string(),
reason: "JSON schema name cannot be empty".to_string(),
});
}
}
Ok(())
}
/// Validate chat API specific logprobs requirements
pub fn validate_chat_logprobs(&self) -> Result<(), ValidationError> {
// In chat API, if logprobs=true, top_logprobs must be specified
if self.logprobs && self.top_logprobs.is_none() {
return Err(ValidationError::MissingRequired {
parameter: "top_logprobs".to_string(),
});
}
// If top_logprobs is specified, logprobs should be true
if self.top_logprobs.is_some() && !self.logprobs {
return Err(ValidationError::InvalidValue {
parameter: "logprobs".to_string(),
value: "false".to_string(),
reason: "must be true when top_logprobs is specified".to_string(),
});
}
Ok(())
}
/// Validate cross-parameter relationships specific to chat completions
pub fn validate_chat_cross_parameters(&self) -> Result<(), ValidationError> {
// Validate that both max_tokens and max_completion_tokens aren't set
utils::validate_conflicting_parameters(
"max_tokens",
self.max_tokens.is_some(),
"max_completion_tokens",
self.max_completion_tokens.is_some(),
"cannot specify both max_tokens and max_completion_tokens",
)?;
// Validate that tools and functions aren't both specified (deprecated)
utils::validate_conflicting_parameters(
"tools",
self.tools.is_some(),
"functions",
self.functions.is_some(),
"functions is deprecated, use tools instead",
)?;
// Validate structured output constraints don't conflict with JSON response format
let has_json_format = matches!(
self.response_format,
Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. })
);
utils::validate_conflicting_parameters(
"response_format",
has_json_format,
"regex",
self.regex.is_some(),
"cannot use regex constraint with JSON response format",
)?;
utils::validate_conflicting_parameters(
"response_format",
has_json_format,
"ebnf",
self.ebnf.is_some(),
"cannot use EBNF constraint with JSON response format",
)?;
// Only one structured output constraint should be active
let structured_constraints = [
("regex", self.regex.is_some()),
("ebnf", self.ebnf.is_some()),
(
"json_schema",
matches!(
self.response_format,
Some(ResponseFormat::JsonSchema { .. })
),
),
];
utils::validate_mutually_exclusive_options(
&structured_constraints,
"Only one structured output constraint (regex, ebnf, or json_schema) can be active at a time",
)?;
Ok(())
}
}
impl ValidatableRequest for ChatCompletionRequest {
fn validate(&self) -> Result<(), ValidationError> {
// Call the common validation function from the validation module
utils::validate_common_request_params(self)?;
// Then validate chat-specific parameters
self.validate_messages()?;
self.validate_response_format()?;
self.validate_chat_logprobs()?;
self.validate_chat_cross_parameters()?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::constants::*;
use super::utils::*;
use super::*;
use crate::protocols::common::StringOrArray;
use crate::protocols::spec::StringOrArray;
// Mock request type for testing validation traits
#[derive(Debug, Default)]
struct MockRequest {
temperature: Option<f32>,
top_p: Option<f32>,
frequency_penalty: Option<f32>,
presence_penalty: Option<f32>,
stop: Option<StringOrArray>,
max_tokens: Option<u32>,
min_tokens: Option<u32>,
logprobs: Option<u32>,
top_logprobs: Option<u32>,
}
impl SamplingOptionsProvider for MockRequest {
......@@ -558,13 +775,13 @@ mod tests {
self.temperature
}
fn get_top_p(&self) -> Option<f32> {
self.top_p
None
}
fn get_frequency_penalty(&self) -> Option<f32> {
self.frequency_penalty
None
}
fn get_presence_penalty(&self) -> Option<f32> {
self.presence_penalty
None
}
}
......@@ -585,173 +802,362 @@ mod tests {
impl LogProbsProvider for MockRequest {
fn get_logprobs(&self) -> Option<u32> {
self.logprobs
None
}
fn get_top_logprobs(&self) -> Option<u32> {
self.top_logprobs
None
}
}
impl SGLangExtensionsProvider for MockRequest {
// Default implementations return None, so no custom logic needed
}
impl CompletionCountProvider for MockRequest {
// Default implementation returns None, so no custom logic needed
}
impl SGLangExtensionsProvider for MockRequest {}
impl CompletionCountProvider for MockRequest {}
impl ValidatableRequest for MockRequest {}
#[test]
fn test_validate_range_valid() {
let result = validate_range(1.5f32, &TEMPERATURE_RANGE, "temperature");
assert!(result.is_ok());
assert_eq!(result.unwrap(), 1.5f32);
}
#[test]
fn test_validate_range_too_low() {
let result = validate_range(-0.1f32, &TEMPERATURE_RANGE, "temperature");
assert!(result.is_err());
match result.unwrap_err() {
ValidationError::OutOfRange { parameter, .. } => {
assert_eq!(parameter, "temperature");
}
_ => panic!("Expected OutOfRange error"),
}
}
#[test]
fn test_validate_positive_valid() {
let result = validate_positive(5i32, "max_tokens");
assert!(result.is_ok());
assert_eq!(result.unwrap(), 5i32);
}
#[test]
fn test_validate_max_items_valid() {
let items = vec!["stop1", "stop2"];
let result = validate_max_items(&items, MAX_STOP_SEQUENCES, "stop");
assert!(result.is_ok());
fn test_range_validation() {
// Valid range
assert!(validate_range(1.5f32, &TEMPERATURE_RANGE, "temperature").is_ok());
// Invalid range
assert!(validate_range(-0.1f32, &TEMPERATURE_RANGE, "temperature").is_err());
assert!(validate_range(3.0f32, &TEMPERATURE_RANGE, "temperature").is_err());
}
#[test]
fn test_validate_top_k() {
fn test_sglang_top_k_validation() {
assert!(validate_top_k(-1).is_ok()); // Disabled
assert!(validate_top_k(50).is_ok()); // Positive
assert!(validate_top_k(50).is_ok()); // Valid positive
assert!(validate_top_k(0).is_err()); // Invalid
assert!(validate_top_k(-5).is_err()); // Invalid
}
#[test]
fn test_valid_request() {
fn test_stop_sequences_limits() {
let request = MockRequest {
temperature: Some(1.0),
top_p: Some(0.9),
frequency_penalty: Some(0.5),
presence_penalty: Some(-0.5),
stop: Some(StringOrArray::Array(vec![
"stop1".to_string(),
"stop2".to_string(),
"stop3".to_string(),
"stop4".to_string(),
"stop5".to_string(), // Too many
])),
max_tokens: Some(100),
min_tokens: Some(10),
logprobs: Some(3),
top_logprobs: Some(15),
..Default::default()
};
assert!(request.validate().is_ok());
assert!(request.validate().is_err());
}
#[test]
fn test_invalid_temperature() {
fn test_token_limits_conflict() {
let request = MockRequest {
temperature: Some(3.0), // Invalid: too high
min_tokens: Some(100),
max_tokens: Some(50), // min > max
..Default::default()
};
let result = request.validate();
assert!(result.is_err());
assert!(request.validate().is_err());
}
#[test]
fn test_too_many_stop_sequences() {
fn test_valid_request() {
let request = MockRequest {
stop: Some(StringOrArray::Array(vec![
temperature: Some(1.0),
stop: Some(StringOrArray::Array(vec!["stop".to_string()])),
max_tokens: Some(100),
min_tokens: Some(10),
};
assert!(request.validate().is_ok());
}
// Chat completion specific tests
#[cfg(test)]
mod chat_tests {
use super::*;
fn create_valid_chat_request() -> ChatCompletionRequest {
ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Hello".to_string()),
name: None,
}],
temperature: Some(1.0),
top_p: Some(0.9),
n: Some(1),
stream: false,
stream_options: None,
stop: None,
max_tokens: Some(100),
max_completion_tokens: None,
presence_penalty: Some(0.0),
frequency_penalty: Some(0.0),
logit_bias: None,
user: None,
seed: None,
logprobs: false,
top_logprobs: None,
response_format: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
functions: None,
function_call: None,
// SGLang extensions
top_k: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
stop_token_ids: None,
no_stop_trim: false,
ignore_eos: false,
continue_final_message: false,
skip_special_tokens: true,
lora_path: None,
session_params: None,
separate_reasoning: true,
stream_reasoning: true,
return_hidden_states: false,
}
}
#[test]
fn test_chat_validation_basics() {
// Valid request
assert!(create_valid_chat_request().validate().is_ok());
// Empty messages
let mut request = create_valid_chat_request();
request.messages = vec![];
assert!(request.validate().is_err());
// Invalid temperature
let mut request = create_valid_chat_request();
request.temperature = Some(3.0);
assert!(request.validate().is_err());
}
#[test]
fn test_chat_conflicts() {
let mut request = create_valid_chat_request();
// Conflicting max_tokens
request.max_tokens = Some(100);
request.max_completion_tokens = Some(200);
assert!(request.validate().is_err());
// Logprobs without top_logprobs
request.max_tokens = None;
request.logprobs = true;
request.top_logprobs = None;
assert!(request.validate().is_err());
}
#[test]
fn test_sglang_extensions() {
let mut request = create_valid_chat_request();
// Valid SGLang parameters
request.top_k = Some(-1);
request.min_p = Some(0.1);
request.repetition_penalty = Some(1.2);
assert!(request.validate().is_ok());
// Invalid parameters
request.top_k = Some(0); // Invalid
assert!(request.validate().is_err());
}
#[test]
fn test_parameter_ranges() {
let mut request = create_valid_chat_request();
// Test temperature range (0.0 to 2.0)
request.temperature = Some(1.5);
assert!(request.validate().is_ok());
request.temperature = Some(-0.1);
assert!(request.validate().is_err());
request.temperature = Some(3.0);
assert!(request.validate().is_err());
// Test top_p range (0.0 to 1.0)
request.temperature = Some(1.0); // Reset
request.top_p = Some(0.9);
assert!(request.validate().is_ok());
request.top_p = Some(-0.1);
assert!(request.validate().is_err());
request.top_p = Some(1.5);
assert!(request.validate().is_err());
// Test frequency_penalty range (-2.0 to 2.0)
request.top_p = Some(0.9); // Reset
request.frequency_penalty = Some(1.5);
assert!(request.validate().is_ok());
request.frequency_penalty = Some(-2.5);
assert!(request.validate().is_err());
request.frequency_penalty = Some(3.0);
assert!(request.validate().is_err());
// Test presence_penalty range (-2.0 to 2.0)
request.frequency_penalty = Some(0.0); // Reset
request.presence_penalty = Some(-1.5);
assert!(request.validate().is_ok());
request.presence_penalty = Some(-3.0);
assert!(request.validate().is_err());
request.presence_penalty = Some(2.5);
assert!(request.validate().is_err());
// Test repetition_penalty range (0.0 to 2.0)
request.presence_penalty = Some(0.0); // Reset
request.repetition_penalty = Some(1.2);
assert!(request.validate().is_ok());
request.repetition_penalty = Some(-0.1);
assert!(request.validate().is_err());
request.repetition_penalty = Some(2.1);
assert!(request.validate().is_err());
// Test min_p range (0.0 to 1.0)
request.repetition_penalty = Some(1.0); // Reset
request.min_p = Some(0.5);
assert!(request.validate().is_ok());
request.min_p = Some(-0.1);
assert!(request.validate().is_err());
request.min_p = Some(1.5);
assert!(request.validate().is_err());
}
#[test]
fn test_structured_output_conflicts() {
let mut request = create_valid_chat_request();
// JSON response format with regex should conflict
request.response_format = Some(ResponseFormat::JsonObject);
request.regex = Some(".*".to_string());
assert!(request.validate().is_err());
// JSON response format with EBNF should conflict
request.regex = None;
request.ebnf = Some("grammar".to_string());
assert!(request.validate().is_err());
// Multiple structured constraints should conflict
request.response_format = None;
request.regex = Some(".*".to_string());
request.ebnf = Some("grammar".to_string());
assert!(request.validate().is_err());
// Only one constraint should work
request.ebnf = None;
request.regex = Some(".*".to_string());
assert!(request.validate().is_ok());
request.regex = None;
request.ebnf = Some("grammar".to_string());
assert!(request.validate().is_ok());
request.ebnf = None;
request.response_format = Some(ResponseFormat::JsonObject);
assert!(request.validate().is_ok());
}
#[test]
fn test_stop_sequences_validation() {
let mut request = create_valid_chat_request();
// Valid stop sequences
request.stop = Some(StringOrArray::Array(vec![
"stop1".to_string(),
"stop2".to_string(),
]));
assert!(request.validate().is_ok());
// Too many stop sequences (max 4)
request.stop = Some(StringOrArray::Array(vec![
"stop1".to_string(),
"stop2".to_string(),
"stop3".to_string(),
"stop4".to_string(),
"stop5".to_string(), // Too many
])),
..Default::default()
};
"stop5".to_string(),
]));
assert!(request.validate().is_err());
let result = request.validate();
assert!(result.is_err());
match result.unwrap_err() {
ValidationError::TooManyItems {
parameter,
count,
max,
} => {
assert_eq!(parameter, "stop");
assert_eq!(count, 5);
assert_eq!(max, MAX_STOP_SEQUENCES);
}
_ => panic!("Expected TooManyItems error"),
// Empty stop sequence should fail
request.stop = Some(StringOrArray::String("".to_string()));
assert!(request.validate().is_err());
// Empty string in array should fail
request.stop = Some(StringOrArray::Array(vec![
"stop1".to_string(),
"".to_string(),
]));
assert!(request.validate().is_err());
}
}
#[test]
fn test_conflicting_token_limits() {
let request = MockRequest {
min_tokens: Some(100),
max_tokens: Some(50), // Invalid: min > max
..Default::default()
};
#[test]
fn test_logprobs_validation() {
let mut request = create_valid_chat_request();
let result = request.validate();
assert!(result.is_err());
match result.unwrap_err() {
ValidationError::ConflictingParameters {
parameter1,
parameter2,
..
} => {
assert_eq!(parameter1, "min_tokens");
assert_eq!(parameter2, "max_tokens");
}
_ => panic!("Expected ConflictingParameters error"),
}
}
// Valid logprobs configuration
request.logprobs = true;
request.top_logprobs = Some(10);
assert!(request.validate().is_ok());
#[test]
fn test_boundary_values() {
let request = MockRequest {
temperature: Some(0.0), // Boundary: minimum
top_p: Some(1.0), // Boundary: maximum
frequency_penalty: Some(-2.0), // Boundary: minimum
presence_penalty: Some(2.0), // Boundary: maximum
logprobs: Some(0), // Boundary: minimum
top_logprobs: Some(20), // Boundary: maximum
..Default::default()
};
// logprobs=true without top_logprobs should fail
request.top_logprobs = None;
assert!(request.validate().is_err());
assert!(request.validate().is_ok());
}
// top_logprobs without logprobs=true should fail
request.logprobs = false;
request.top_logprobs = Some(10);
assert!(request.validate().is_err());
#[test]
fn test_validation_error_display() {
let error = ValidationError::OutOfRange {
parameter: "temperature".to_string(),
value: "3.0".to_string(),
min: "0.0".to_string(),
max: "2.0".to_string(),
};
// top_logprobs out of range (0-20)
request.logprobs = true;
request.top_logprobs = Some(25);
assert!(request.validate().is_err());
}
let message = format!("{}", error);
assert!(message.contains("temperature"));
assert!(message.contains("3.0"));
#[test]
fn test_n_parameter_validation() {
let mut request = create_valid_chat_request();
// Valid n values (1-10)
request.n = Some(1);
assert!(request.validate().is_ok());
request.n = Some(5);
assert!(request.validate().is_ok());
request.n = Some(10);
assert!(request.validate().is_ok());
// Invalid n values
request.n = Some(0);
assert!(request.validate().is_err());
request.n = Some(15);
assert!(request.validate().is_err());
}
#[test]
fn test_min_max_tokens_validation() {
let mut request = create_valid_chat_request();
// Valid token limits
request.min_tokens = Some(10);
request.max_tokens = Some(100);
assert!(request.validate().is_ok());
// min_tokens > max_tokens should fail
request.min_tokens = Some(150);
request.max_tokens = Some(100);
assert!(request.validate().is_err());
// Should work with max_completion_tokens instead
request.max_tokens = None;
request.max_completion_tokens = Some(200);
request.min_tokens = Some(50);
assert!(request.validate().is_ok());
// min_tokens > max_completion_tokens should fail
request.min_tokens = Some(250);
assert!(request.validate().is_err());
}
}
}
......@@ -9,10 +9,7 @@ use axum::{
};
use std::fmt::Debug;
use crate::protocols::{
generate::GenerateRequest,
openai::{chat::ChatCompletionRequest, completions::CompletionRequest},
};
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
pub mod factory;
pub mod header_utils;
......
......@@ -12,13 +12,9 @@ use crate::core::{
};
use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy;
use crate::protocols::{
common::StringOrArray,
generate::GenerateRequest,
openai::{
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
completions::CompletionRequest,
},
use crate::protocols::spec::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, StringOrArray,
UserMessageContent,
};
use crate::routers::{RouterTrait, WorkerManagement};
use async_trait::async_trait;
......
......@@ -9,10 +9,8 @@ use crate::core::{
};
use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy;
use crate::protocols::{
common::GenerationRequest,
generate::GenerateRequest,
openai::{chat::ChatCompletionRequest, completions::CompletionRequest},
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest,
};
use crate::routers::{RouterTrait, WorkerManagement};
use axum::{
......
use crate::config::RouterConfig;
use crate::logging::{self, LoggingConfig};
use crate::metrics::{self, PrometheusConfig};
use crate::protocols::{
generate::GenerateRequest,
openai::{chat::ChatCompletionRequest, completions::CompletionRequest},
};
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::routers::{RouterFactory, RouterTrait};
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
use axum::{
......
......@@ -5,13 +5,9 @@
use serde_json::{from_str, to_string, to_value};
use sglang_router_rs::core::{BasicWorker, WorkerType};
use sglang_router_rs::protocols::{
common::StringOrArray,
generate::{GenerateParameters, GenerateRequest, SamplingParams},
openai::{
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
completions::CompletionRequest,
},
use sglang_router_rs::protocols::spec::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
SamplingParams, StringOrArray, UserMessageContent,
};
/// Create a default GenerateRequest for benchmarks with minimal fields set
......
// Integration test for Responses API
use sglang_router_rs::protocols::common::GenerationRequest;
use sglang_router_rs::protocols::openai::responses::request::ResponseInput;
use sglang_router_rs::protocols::openai::responses::*;
use sglang_router_rs::protocols::spec::{
GenerationRequest, ReasoningEffort, ResponseInput, ResponseReasoningParam, ResponseStatus,
ResponseTool, ResponseToolType, ResponsesRequest, ResponsesResponse, ServiceTier, ToolChoice,
ToolChoiceValue, Truncation, UsageInfo,
};
#[test]
fn test_responses_request_creation() {
......@@ -24,7 +26,7 @@ fn test_responses_request_creation() {
store: true,
stream: false,
temperature: Some(0.7),
tool_choice: ToolChoice::Auto,
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![ResponseTool {
r#type: ResponseToolType::WebSearchPreview,
}],
......@@ -67,7 +69,7 @@ fn test_sampling_params_conversion() {
store: true, // Use default true
stream: false,
temperature: Some(0.8),
tool_choice: ToolChoice::Auto,
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![],
top_logprobs: 0, // Use default 0
top_p: Some(0.95),
......@@ -177,7 +179,7 @@ fn test_json_serialization() {
store: false,
stream: true,
temperature: Some(0.9),
tool_choice: ToolChoice::Required,
tool_choice: ToolChoice::Value(ToolChoiceValue::Required),
tools: vec![ResponseTool {
r#type: ResponseToolType::CodeInterpreter,
}],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment