Unverified Commit 27ef1459 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][protocols] Add Axum validate extractor and use it for...

[router][protocols] Add Axum validate extractor and use it for `/v1/chat/completions` endpoint (#11588)
parent e4358a45
......@@ -56,6 +56,7 @@ parking_lot = "0.12.4"
thiserror = "2.0.12"
regex = "1.10"
url = "2.5.4"
validator = { version = "0.18", features = ["derive"] }
tokio-stream = { version = "0.1", features = ["sync"] }
anyhow = "1.0"
tokenizers = { version = "0.22.0" }
......
......@@ -4,8 +4,8 @@ use std::time::Instant;
use sglang_router_rs::core::{BasicWorker, BasicWorkerBuilder, Worker, WorkerType};
use sglang_router_rs::protocols::spec::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
SamplingParams, StringOrArray, UserMessageContent,
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, SamplingParams,
StringOrArray, UserMessageContent,
};
use sglang_router_rs::routers::http::pd_types::{generate_room_id, RequestWithBootstrap};
......@@ -31,7 +31,6 @@ fn default_generate_request() -> GenerateRequest {
prompt: None,
input_ids: None,
stream: false,
parameters: None,
sampling_params: None,
return_logprob: false,
// SGLang Extensions
......@@ -101,14 +100,6 @@ fn default_completion_request() -> CompletionRequest {
fn create_sample_generate_request() -> GenerateRequest {
GenerateRequest {
text: Some("Write a story about artificial intelligence".to_string()),
parameters: Some(GenerateParameters {
max_new_tokens: Some(100),
temperature: Some(0.8),
top_p: Some(0.9),
top_k: Some(50),
repetition_penalty: Some(1.0),
..Default::default()
}),
sampling_params: Some(SamplingParams {
temperature: Some(0.8),
top_p: Some(0.9),
......@@ -128,12 +119,10 @@ fn create_sample_chat_completion_request() -> ChatCompletionRequest {
model: "gpt-3.5-turbo".to_string(),
messages: vec![
ChatMessage::System {
role: "system".to_string(),
content: "You are a helpful assistant".to_string(),
name: None,
},
ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text(
"Explain quantum computing in simple terms".to_string(),
),
......@@ -170,7 +159,6 @@ fn create_sample_completion_request() -> CompletionRequest {
#[allow(deprecated)]
fn create_large_chat_completion_request() -> ChatCompletionRequest {
let mut messages = vec![ChatMessage::System {
role: "system".to_string(),
content: "You are a helpful assistant with extensive knowledge.".to_string(),
name: None,
}];
......@@ -178,12 +166,10 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest {
// Add many user/assistant pairs to simulate a long conversation
for i in 0..50 {
messages.push(ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text(format!("Question {}: What do you think about topic number {} which involves complex reasoning about multiple interconnected systems and their relationships?", i, i)),
name: None,
});
messages.push(ChatMessage::Assistant {
role: "assistant".to_string(),
content: Some(format!("Answer {}: This is a detailed response about topic {} that covers multiple aspects and provides comprehensive analysis of the interconnected systems you mentioned.", i, i)),
name: None,
tool_calls: None,
......
......@@ -123,6 +123,7 @@ fn create_test_tools() -> Vec<Tool> {
"limit": {"type": "number"}
}
}),
strict: None,
},
},
Tool {
......@@ -137,6 +138,7 @@ fn create_test_tools() -> Vec<Tool> {
"code": {"type": "string"}
}
}),
strict: None,
},
},
]
......
......@@ -301,13 +301,7 @@ impl SglangSchedulerClient {
) -> Result<proto::SamplingParams, String> {
let stop_sequences = self.extract_stop_strings(request);
// Handle max tokens: prefer max_completion_tokens (new) over max_tokens (deprecated)
// If neither is specified, use None to let the backend decide the default
#[allow(deprecated)]
let max_new_tokens = request
.max_completion_tokens
.or(request.max_tokens)
.map(|v| v as i32);
let max_new_tokens = request.max_completion_tokens.map(|v| v as i32);
// Handle skip_special_tokens: set to false if tools are present and tool_choice is not "none"
let skip_special_tokens = if request.tools.is_some() {
......@@ -322,7 +316,6 @@ impl SglangSchedulerClient {
request.skip_special_tokens
};
#[allow(deprecated)]
Ok(proto::SamplingParams {
temperature: request.temperature.unwrap_or(1.0),
top_p: request.top_p.unwrap_or(1.0),
......@@ -485,10 +478,10 @@ impl SglangSchedulerClient {
})?);
}
// Handle min_tokens with conversion
if let Some(min_tokens) = p.min_tokens {
sampling.min_new_tokens = i32::try_from(min_tokens)
.map_err(|_| "min_tokens must fit into a 32-bit signed integer".to_string())?;
// Handle min_new_tokens with conversion
if let Some(min_new_tokens) = p.min_new_tokens {
sampling.min_new_tokens = i32::try_from(min_new_tokens)
.map_err(|_| "min_new_tokens must fit into a 32-bit signed integer".to_string())?;
}
// Handle n with conversion
......
......@@ -2,5 +2,5 @@
// This module provides a structured approach to handling different API protocols
pub mod spec;
pub mod validation;
pub mod validated;
pub mod worker_spec;
use serde::{Deserialize, Serialize};
use serde_json::{to_value, Map, Number, Value};
use std::collections::HashMap;
use validator::Validate;
use crate::protocols::validated::Normalizable;
// Default model value when not specified
fn default_model() -> String {
......@@ -55,22 +58,22 @@ fn default_model() -> String {
// - Helper functions
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
#[serde(tag = "role")]
pub enum ChatMessage {
#[serde(rename = "system")]
System {
role: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
#[serde(rename = "user")]
User {
role: String, // "user"
content: UserMessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
#[serde(rename = "assistant")]
Assistant {
role: String, // "assistant"
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
......@@ -81,16 +84,13 @@ pub enum ChatMessage {
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_content: Option<String>,
},
#[serde(rename = "tool")]
Tool {
role: String, // "tool"
content: String,
tool_call_id: String,
},
Function {
role: String, // "function"
content: String,
name: String,
},
#[serde(rename = "function")]
Function { content: String, name: String },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
......@@ -168,9 +168,11 @@ pub struct FunctionCallDelta {
pub arguments: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
#[derive(Debug, Clone, Deserialize, Serialize, Default, Validate)]
#[validate(schema(function = "validate_chat_cross_parameters"))]
pub struct ChatCompletionRequest {
/// A list of messages comprising the conversation so far
#[validate(custom(function = "validate_messages"))]
pub messages: Vec<ChatMessage>,
/// ID of the model to use
......@@ -179,6 +181,7 @@ pub struct ChatCompletionRequest {
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = -2.0, max = 2.0))]
pub frequency_penalty: Option<f32>,
/// Deprecated: Replaced by tool_choice
......@@ -202,10 +205,12 @@ pub struct ChatCompletionRequest {
/// Deprecated: Replaced by max_completion_tokens
#[serde(skip_serializing_if = "Option::is_none")]
#[deprecated(note = "Use max_completion_tokens instead")]
#[validate(range(min = 1))]
pub max_tokens: Option<u32>,
/// An upper bound for the number of tokens that can be generated for a completion
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 1))]
pub max_completion_tokens: Option<u32>,
/// Developer-defined tags and values used for filtering completions in the dashboard
......@@ -218,6 +223,7 @@ pub struct ChatCompletionRequest {
/// How many chat completion choices to generate for each input message
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 1, max = 10))]
pub n: Option<u32>,
/// Whether to enable parallel function calling during tool use
......@@ -226,6 +232,7 @@ pub struct ChatCompletionRequest {
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = -2.0, max = 2.0))]
pub presence_penalty: Option<f32>,
/// Cache key for prompts (beta feature)
......@@ -255,6 +262,7 @@ pub struct ChatCompletionRequest {
/// Up to 4 sequences where the API will stop generating further tokens
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(custom(function = "validate_stop"))]
pub stop: Option<StringOrArray>,
/// If set, partial message deltas will be sent
......@@ -267,6 +275,7 @@ pub struct ChatCompletionRequest {
/// What sampling temperature to use, between 0 and 2
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 0.0, max = 2.0))]
pub temperature: Option<f32>,
/// Controls which (if any) tool is called by the model
......@@ -279,30 +288,42 @@ pub struct ChatCompletionRequest {
/// An integer between 0 and 20 specifying the number of most likely tokens to return
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 0, max = 20))]
pub top_logprobs: Option<u32>,
/// An alternative to sampling with temperature
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(custom(function = "validate_top_p_value"))]
pub top_p: Option<f32>,
/// Verbosity level for debugging
#[serde(skip_serializing_if = "Option::is_none")]
pub verbosity: Option<i32>,
// =============================================================================
// Engine-Specific Sampling Parameters
// =============================================================================
// These parameters are extensions beyond the OpenAI API specification and
// control model generation behavior in engine-specific ways.
// =============================================================================
/// Top-k sampling parameter (-1 to disable)
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(custom(function = "validate_top_k_value"))]
pub top_k: Option<i32>,
/// Min-p nucleus sampling parameter
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 0.0, max = 1.0))]
pub min_p: Option<f32>,
/// Minimum number of tokens to generate
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 1))]
pub min_tokens: Option<u32>,
/// Repetition penalty for reducing repetitive text
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 0.0, max = 2.0))]
pub repetition_penalty: Option<f32>,
/// Regex constraint for output generation
......@@ -362,6 +383,290 @@ pub struct ChatCompletionRequest {
pub sampling_seed: Option<u64>,
}
// Validation functions for ChatCompletionRequest
// These are automatically called by the validator derive macro
/// Validates stop sequences (max 4, non-empty strings)
fn validate_stop(stop: &StringOrArray) -> Result<(), validator::ValidationError> {
match stop {
StringOrArray::String(s) => {
if s.is_empty() {
return Err(validator::ValidationError::new(
"stop sequences cannot be empty",
));
}
}
StringOrArray::Array(arr) => {
if arr.len() > 4 {
return Err(validator::ValidationError::new(
"maximum 4 stop sequences allowed",
));
}
for s in arr {
if s.is_empty() {
return Err(validator::ValidationError::new(
"stop sequences cannot be empty",
));
}
}
}
}
Ok(())
}
/// Validates messages array is not empty and has valid content
fn validate_messages(messages: &[ChatMessage]) -> Result<(), validator::ValidationError> {
if messages.is_empty() {
return Err(validator::ValidationError::new("messages cannot be empty"));
}
for msg in messages.iter() {
if let ChatMessage::User { content, .. } = msg {
match content {
UserMessageContent::Text(text) if text.is_empty() => {
return Err(validator::ValidationError::new(
"message content cannot be empty",
));
}
UserMessageContent::Parts(parts) if parts.is_empty() => {
return Err(validator::ValidationError::new(
"message content parts cannot be empty",
));
}
_ => {}
}
}
}
Ok(())
}
/// Validates top_p: 0.0 < top_p <= 1.0 (exclusive lower bound - can't use range validator)
fn validate_top_p_value(top_p: f32) -> Result<(), validator::ValidationError> {
if !(top_p > 0.0 && top_p <= 1.0) {
return Err(validator::ValidationError::new(
"top_p must be in (0, 1] - greater than 0.0 and at most 1.0",
));
}
Ok(())
}
/// Validates top_k: -1 (disabled) or >= 1 (special -1 case - can't use range validator)
fn validate_top_k_value(top_k: i32) -> Result<(), validator::ValidationError> {
if top_k != -1 && top_k < 1 {
return Err(validator::ValidationError::new(
"top_k must be -1 (disabled) or at least 1",
));
}
Ok(())
}
/// Schema-level validation for cross-field dependencies
fn validate_chat_cross_parameters(
req: &ChatCompletionRequest,
) -> Result<(), validator::ValidationError> {
// 1. Validate logprobs dependency
if req.top_logprobs.is_some() && !req.logprobs {
let mut e = validator::ValidationError::new("top_logprobs_requires_logprobs");
e.message = Some("top_logprobs is only allowed when logprobs is enabled".into());
return Err(e);
}
// 2. Validate stream_options dependency
if req.stream_options.is_some() && !req.stream {
let mut e = validator::ValidationError::new("stream_options_requires_stream");
e.message =
Some("The 'stream_options' parameter is only allowed when 'stream' is enabled".into());
return Err(e);
}
// 3. Validate token limits - min <= max
if let (Some(min), Some(max)) = (req.min_tokens, req.max_completion_tokens) {
if min > max {
let mut e = validator::ValidationError::new("min_tokens_exceeds_max");
e.message = Some("min_tokens cannot exceed max_tokens/max_completion_tokens".into());
return Err(e);
}
}
// 4. Validate structured output conflicts
let has_json_format = matches!(
req.response_format,
Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. })
);
if has_json_format && req.regex.is_some() {
let mut e = validator::ValidationError::new("regex_conflicts_with_json");
e.message = Some("cannot use regex constraint with JSON response format".into());
return Err(e);
}
if has_json_format && req.ebnf.is_some() {
let mut e = validator::ValidationError::new("ebnf_conflicts_with_json");
e.message = Some("cannot use EBNF constraint with JSON response format".into());
return Err(e);
}
// 5. Validate mutually exclusive structured output constraints
let constraint_count = [
req.regex.is_some(),
req.ebnf.is_some(),
matches!(req.response_format, Some(ResponseFormat::JsonSchema { .. })),
]
.iter()
.filter(|&&x| x)
.count();
if constraint_count > 1 {
let mut e = validator::ValidationError::new("multiple_constraints");
e.message = Some("only one structured output constraint (regex, ebnf, or json_schema) can be active at a time".into());
return Err(e);
}
// 6. Validate response format JSON schema name
if let Some(ResponseFormat::JsonSchema { json_schema }) = &req.response_format {
if json_schema.name.is_empty() {
let mut e = validator::ValidationError::new("json_schema_name_empty");
e.message = Some("JSON schema name cannot be empty".into());
return Err(e);
}
}
// 7. Validate tool_choice requires tools (except for "none")
if let Some(ref tool_choice) = req.tool_choice {
let has_tools = req.tools.as_ref().is_some_and(|t| !t.is_empty());
// Check if tool_choice is anything other than "none"
let is_some_choice = !matches!(tool_choice, ToolChoice::Value(ToolChoiceValue::None));
if is_some_choice && !has_tools {
let mut e = validator::ValidationError::new("tool_choice_requires_tools");
e.message = Some("Invalid value for 'tool_choice': 'tool_choice' is only allowed when 'tools' are specified.".into());
return Err(e);
}
// Additional validation when tools are present
if has_tools {
let tools = req.tools.as_ref().unwrap();
match tool_choice {
ToolChoice::Function { function, .. } => {
// Validate that the specified function name exists in tools
let function_exists = tools.iter().any(|tool| {
tool.tool_type == "function" && tool.function.name == function.name
});
if !function_exists {
let mut e =
validator::ValidationError::new("tool_choice_function_not_found");
e.message = Some(
format!(
"Invalid value for 'tool_choice': function '{}' not found in 'tools'.",
function.name
)
.into(),
);
return Err(e);
}
}
ToolChoice::AllowedTools {
mode,
tools: allowed_tools,
..
} => {
// Validate mode is "auto" or "required"
if mode != "auto" && mode != "required" {
let mut e = validator::ValidationError::new("tool_choice_invalid_mode");
e.message = Some(format!(
"Invalid value for 'tool_choice.mode': must be 'auto' or 'required', got '{}'.",
mode
).into());
return Err(e);
}
// Validate that all referenced tool names exist in tools
for tool_ref in allowed_tools {
let tool_exists = tools.iter().any(|tool| {
tool.tool_type == tool_ref.tool_type
&& tool.function.name == tool_ref.name
});
if !tool_exists {
let mut e =
validator::ValidationError::new("tool_choice_tool_not_found");
e.message = Some(format!(
"Invalid value for 'tool_choice.tools': tool '{}' not found in 'tools'.",
tool_ref.name
).into());
return Err(e);
}
}
}
_ => {}
}
}
}
Ok(())
}
impl Normalizable for ChatCompletionRequest {
/// Normalize the request by applying migrations and defaults:
/// 1. Migrate deprecated fields to their replacements
/// 2. Clear deprecated fields and log warnings
/// 3. Apply OpenAI defaults for tool_choice
fn normalize(&mut self) {
// Migrate deprecated max_tokens → max_completion_tokens
#[allow(deprecated)]
if self.max_completion_tokens.is_none() && self.max_tokens.is_some() {
tracing::warn!("max_tokens is deprecated, use max_completion_tokens instead");
self.max_completion_tokens = self.max_tokens;
self.max_tokens = None; // Clear deprecated field
}
// Migrate deprecated functions → tools
#[allow(deprecated)]
if self.tools.is_none() && self.functions.is_some() {
tracing::warn!("functions is deprecated, use tools instead");
self.tools = self.functions.as_ref().map(|functions| {
functions
.iter()
.map(|func| Tool {
tool_type: "function".to_string(),
function: func.clone(),
})
.collect()
});
self.functions = None; // Clear deprecated field
}
// Migrate deprecated function_call → tool_choice
#[allow(deprecated)]
if self.tool_choice.is_none() && self.function_call.is_some() {
tracing::warn!("function_call is deprecated, use tool_choice instead");
self.tool_choice = self.function_call.as_ref().map(|fc| match fc {
FunctionCall::None => ToolChoice::Value(ToolChoiceValue::None),
FunctionCall::Auto => ToolChoice::Value(ToolChoiceValue::Auto),
FunctionCall::Function { name } => ToolChoice::Function {
tool_type: "function".to_string(),
function: FunctionChoice { name: name.clone() },
},
});
self.function_call = None; // Clear deprecated field
}
// Apply tool_choice defaults
if self.tool_choice.is_none() {
let has_tools = self.tools.as_ref().is_some_and(|t| !t.is_empty());
self.tool_choice = if has_tools {
Some(ToolChoice::Value(ToolChoiceValue::Auto))
} else {
Some(ToolChoice::Value(ToolChoiceValue::None))
};
}
}
}
impl GenerationRequest for ChatCompletionRequest {
fn is_stream(&self) -> bool {
self.stream
......@@ -553,6 +858,7 @@ pub struct CompletionRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
// -------- Engine Specific Sampling Parameters --------
/// Top-k sampling parameter (-1 to disable)
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
......@@ -1816,6 +2122,9 @@ pub struct Function {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: Value, // JSON Schema
/// Whether to enable strict schema adherence (OpenAI structured outputs)
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
......@@ -1911,55 +2220,33 @@ pub enum InputIds {
Batch(Vec<Vec<i32>>),
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct GenerateParameters {
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub decoder_input_details: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub do_sample: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_new_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub return_full_text: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub truncate: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub typical_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub watermark: Option<bool>,
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
#[derive(Debug, Clone, Deserialize, Serialize, Default, Validate)]
#[validate(schema(function = "validate_sampling_params"))]
pub struct SamplingParams {
/// Temperature for sampling (must be >= 0.0, no upper limit)
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 0.0))]
pub temperature: Option<f32>,
/// Maximum number of new tokens to generate (must be >= 0)
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 0))]
pub max_new_tokens: Option<u32>,
/// Top-p nucleus sampling (0.0 < top_p <= 1.0)
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(custom(function = "validate_top_p_value"))]
pub top_p: Option<f32>,
/// Top-k sampling (-1 to disable, or >= 1)
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(custom(function = "validate_top_k_value"))]
pub top_k: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = -2.0, max = 2.0))]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = -2.0, max = 2.0))]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 0.0, max = 2.0))]
pub repetition_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StringOrArray>,
......@@ -1974,9 +2261,11 @@ pub struct SamplingParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub ebnf: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 0.0, max = 1.0))]
pub min_p: Option<f32>,
/// Minimum number of new tokens (validated in schema function for cross-field check with max_new_tokens)
#[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
pub min_new_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_token_ids: Option<Vec<u32>>,
#[serde(skip_serializing_if = "Option::is_none")]
......@@ -1987,7 +2276,38 @@ pub struct SamplingParams {
pub sampling_seed: Option<u64>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
/// Validation function for SamplingParams - cross-field validation only
fn validate_sampling_params(params: &SamplingParams) -> Result<(), validator::ValidationError> {
// 1. Cross-field validation: min_new_tokens <= max_new_tokens
if let (Some(min), Some(max)) = (params.min_new_tokens, params.max_new_tokens) {
if min > max {
return Err(validator::ValidationError::new(
"min_new_tokens cannot exceed max_new_tokens",
));
}
}
// 2. Validate mutually exclusive structured output constraints
let constraint_count = [
params.regex.is_some(),
params.ebnf.is_some(),
params.json_schema.is_some(),
]
.iter()
.filter(|&&x| x)
.count();
if constraint_count > 1 {
return Err(validator::ValidationError::new(
"only one of regex, ebnf, or json_schema can be set",
));
}
Ok(())
}
#[derive(Clone, Debug, Serialize, Deserialize, Validate)]
#[validate(schema(function = "validate_generate_request"))]
pub struct GenerateRequest {
/// The prompt to generate from (OpenAI style)
#[serde(skip_serializing_if = "Option::is_none")]
......@@ -2001,10 +2321,6 @@ pub struct GenerateRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub input_ids: Option<InputIds>,
/// Generation parameters
#[serde(default, skip_serializing_if = "Option::is_none")]
pub parameters: Option<GenerateParameters>,
/// Sampling parameters (sglang style)
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling_params: Option<SamplingParams>,
......@@ -2034,6 +2350,34 @@ pub struct GenerateRequest {
pub rid: Option<String>,
}
impl Normalizable for GenerateRequest {
// Use default no-op implementation - no normalization needed for GenerateRequest
}
/// Validation function for GenerateRequest - ensure exactly one input type is provided
fn validate_generate_request(req: &GenerateRequest) -> Result<(), validator::ValidationError> {
// Exactly one of text or input_ids must be provided
// Note: input_embeds not yet supported in Rust implementation
let has_text = req.text.is_some() || req.prompt.is_some();
let has_input_ids = req.input_ids.is_some();
let count = [has_text, has_input_ids].iter().filter(|&&x| x).count();
if count == 0 {
return Err(validator::ValidationError::new(
"Either text or input_ids should be provided.",
));
}
if count > 1 {
return Err(validator::ValidationError::new(
"Either text or input_ids should be provided.",
));
}
Ok(())
}
impl GenerationRequest for GenerateRequest {
fn is_stream(&self) -> bool {
self.stream
......@@ -2168,7 +2512,7 @@ pub struct RerankRequest {
pub user: Option<String>,
}
fn default_model_name() -> String {
pub fn default_model_name() -> String {
DEFAULT_MODEL_NAME.to_string()
}
......@@ -2441,710 +2785,3 @@ pub enum LoRAPath {
Single(Option<String>),
Batch(Vec<Option<String>>),
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::{from_str, json, to_string};
#[test]
fn test_rerank_request_serialization() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
model: "test-model".to_string(),
top_k: Some(5),
return_documents: true,
rid: Some(StringOrArray::String("req-123".to_string())),
user: Some("user-456".to_string()),
};
let serialized = to_string(&request).unwrap();
let deserialized: RerankRequest = from_str(&serialized).unwrap();
assert_eq!(deserialized.query, request.query);
assert_eq!(deserialized.documents, request.documents);
assert_eq!(deserialized.model, request.model);
assert_eq!(deserialized.top_k, request.top_k);
assert_eq!(deserialized.return_documents, request.return_documents);
assert_eq!(deserialized.rid, request.rid);
assert_eq!(deserialized.user, request.user);
}
#[test]
fn test_rerank_request_deserialization_with_defaults() {
let json = r#"{
"query": "test query",
"documents": ["doc1", "doc2"]
}"#;
let request: RerankRequest = from_str(json).unwrap();
assert_eq!(request.query, "test query");
assert_eq!(request.documents, vec!["doc1", "doc2"]);
assert_eq!(request.model, default_model_name());
assert_eq!(request.top_k, None);
assert!(request.return_documents);
assert_eq!(request.rid, None);
assert_eq!(request.user, None);
}
#[test]
fn test_rerank_request_validation_success() {
let request = RerankRequest {
query: "valid query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
model: "test-model".to_string(),
top_k: Some(2),
return_documents: true,
rid: None,
user: None,
};
assert!(request.validate().is_ok());
}
#[test]
fn test_rerank_request_validation_empty_query() {
let request = RerankRequest {
query: "".to_string(),
documents: vec!["doc1".to_string()],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: None,
user: None,
};
let result = request.validate();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Query cannot be empty");
}
#[test]
fn test_rerank_request_validation_whitespace_query() {
let request = RerankRequest {
query: " ".to_string(),
documents: vec!["doc1".to_string()],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: None,
user: None,
};
let result = request.validate();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Query cannot be empty");
}
#[test]
fn test_rerank_request_validation_empty_documents() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec![],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: None,
user: None,
};
let result = request.validate();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Documents list cannot be empty");
}
#[test]
fn test_rerank_request_validation_top_k_zero() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
model: "test-model".to_string(),
top_k: Some(0),
return_documents: true,
rid: None,
user: None,
};
let result = request.validate();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "top_k must be greater than 0");
}
#[test]
fn test_rerank_request_validation_top_k_greater_than_docs() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
model: "test-model".to_string(),
top_k: Some(5),
return_documents: true,
rid: None,
user: None,
};
// This should pass but log a warning
assert!(request.validate().is_ok());
}
#[test]
fn test_rerank_request_effective_top_k() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()],
model: "test-model".to_string(),
top_k: Some(2),
return_documents: true,
rid: None,
user: None,
};
assert_eq!(request.effective_top_k(), 2);
}
#[test]
fn test_rerank_request_effective_top_k_none() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: None,
user: None,
};
assert_eq!(request.effective_top_k(), 3);
}
#[test]
fn test_rerank_response_creation() {
let results = vec![
RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
},
RerankResult {
score: 0.6,
document: Some("doc2".to_string()),
index: 1,
meta_info: None,
},
];
let response = RerankResponse::new(
results.clone(),
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
assert_eq!(response.results.len(), 2);
assert_eq!(response.model, "test-model");
assert_eq!(
response.id,
Some(StringOrArray::String("req-123".to_string()))
);
assert_eq!(response.object, "rerank");
assert!(response.created > 0);
}
#[test]
fn test_rerank_response_serialization() {
let results = vec![RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
}];
let response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
let serialized = to_string(&response).unwrap();
let deserialized: RerankResponse = from_str(&serialized).unwrap();
assert_eq!(deserialized.results.len(), response.results.len());
assert_eq!(deserialized.model, response.model);
assert_eq!(deserialized.id, response.id);
assert_eq!(deserialized.object, response.object);
}
#[test]
fn test_rerank_response_sort_by_score() {
let results = vec![
RerankResult {
score: 0.6,
document: Some("doc2".to_string()),
index: 1,
meta_info: None,
},
RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
},
RerankResult {
score: 0.4,
document: Some("doc3".to_string()),
index: 2,
meta_info: None,
},
];
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.sort_by_score();
assert_eq!(response.results[0].score, 0.8);
assert_eq!(response.results[0].index, 0);
assert_eq!(response.results[1].score, 0.6);
assert_eq!(response.results[1].index, 1);
assert_eq!(response.results[2].score, 0.4);
assert_eq!(response.results[2].index, 2);
}
#[test]
fn test_rerank_response_apply_top_k() {
let results = vec![
RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
},
RerankResult {
score: 0.6,
document: Some("doc2".to_string()),
index: 1,
meta_info: None,
},
RerankResult {
score: 0.4,
document: Some("doc3".to_string()),
index: 2,
meta_info: None,
},
];
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.apply_top_k(2);
assert_eq!(response.results.len(), 2);
assert_eq!(response.results[0].score, 0.8);
assert_eq!(response.results[1].score, 0.6);
}
#[test]
fn test_rerank_response_apply_top_k_larger_than_results() {
let results = vec![RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
}];
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.apply_top_k(5);
assert_eq!(response.results.len(), 1);
}
#[test]
fn test_rerank_response_drop_documents() {
let results = vec![RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
}];
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.drop_documents();
assert_eq!(response.results[0].document, None);
}
#[test]
fn test_rerank_result_serialization() {
let result = RerankResult {
score: 0.85,
document: Some("test document".to_string()),
index: 42,
meta_info: Some(HashMap::from([
("confidence".to_string(), Value::String("high".to_string())),
(
"processing_time".to_string(),
Value::Number(Number::from(150)),
),
])),
};
let serialized = to_string(&result).unwrap();
let deserialized: RerankResult = from_str(&serialized).unwrap();
assert_eq!(deserialized.score, result.score);
assert_eq!(deserialized.document, result.document);
assert_eq!(deserialized.index, result.index);
assert_eq!(deserialized.meta_info, result.meta_info);
}
#[test]
fn test_rerank_result_serialization_without_document() {
let result = RerankResult {
score: 0.85,
document: None,
index: 42,
meta_info: None,
};
let serialized = to_string(&result).unwrap();
let deserialized: RerankResult = from_str(&serialized).unwrap();
assert_eq!(deserialized.score, result.score);
assert_eq!(deserialized.document, result.document);
assert_eq!(deserialized.index, result.index);
assert_eq!(deserialized.meta_info, result.meta_info);
}
#[test]
fn test_v1_rerank_req_input_serialization() {
let v1_input = V1RerankReqInput {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
};
let serialized = to_string(&v1_input).unwrap();
let deserialized: V1RerankReqInput = from_str(&serialized).unwrap();
assert_eq!(deserialized.query, v1_input.query);
assert_eq!(deserialized.documents, v1_input.documents);
}
#[test]
fn test_v1_to_rerank_request_conversion() {
let v1_input = V1RerankReqInput {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
};
let request: RerankRequest = v1_input.into();
assert_eq!(request.query, "test query");
assert_eq!(request.documents, vec!["doc1", "doc2"]);
assert_eq!(request.model, default_model_name());
assert_eq!(request.top_k, None);
assert!(request.return_documents);
assert_eq!(request.rid, None);
assert_eq!(request.user, None);
}
#[test]
fn test_rerank_request_generation_request_trait() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string()],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: None,
user: None,
};
assert_eq!(request.get_model(), Some("test-model"));
assert!(!request.is_stream());
assert_eq!(request.extract_text_for_routing(), "test query");
}
#[test]
fn test_rerank_request_very_long_query() {
let long_query = "a".repeat(100000);
let request = RerankRequest {
query: long_query,
documents: vec!["doc1".to_string()],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: None,
user: None,
};
assert!(request.validate().is_ok());
}
#[test]
fn test_rerank_request_many_documents() {
let documents: Vec<String> = (0..1000).map(|i| format!("doc{}", i)).collect();
let request = RerankRequest {
query: "test query".to_string(),
documents,
model: "test-model".to_string(),
top_k: Some(100),
return_documents: true,
rid: None,
user: None,
};
assert!(request.validate().is_ok());
assert_eq!(request.effective_top_k(), 100);
}
#[test]
fn test_rerank_request_special_characters() {
let request = RerankRequest {
query: "query with émojis 🚀 and unicode: 测试".to_string(),
documents: vec![
"doc with émojis 🎉".to_string(),
"doc with unicode: 测试".to_string(),
],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: Some(StringOrArray::String("req-🚀-123".to_string())),
user: Some("user-🎉-456".to_string()),
};
assert!(request.validate().is_ok());
}
#[test]
fn test_rerank_request_rid_array() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string()],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: Some(StringOrArray::Array(vec![
"req1".to_string(),
"req2".to_string(),
])),
user: None,
};
assert!(request.validate().is_ok());
}
#[test]
fn test_rerank_response_with_usage_info() {
let results = vec![RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
}];
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.usage = Some(UsageInfo {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
reasoning_tokens: None,
prompt_tokens_details: None,
});
let serialized = to_string(&response).unwrap();
let deserialized: RerankResponse = from_str(&serialized).unwrap();
assert!(deserialized.usage.is_some());
let usage = deserialized.usage.unwrap();
assert_eq!(usage.prompt_tokens, 100);
assert_eq!(usage.completion_tokens, 50);
assert_eq!(usage.total_tokens, 150);
}
#[test]
fn test_full_rerank_workflow() {
// Create request
let request = RerankRequest {
query: "machine learning".to_string(),
documents: vec![
"Introduction to machine learning algorithms".to_string(),
"Deep learning for computer vision".to_string(),
"Natural language processing basics".to_string(),
"Statistics and probability theory".to_string(),
],
model: "rerank-model".to_string(),
top_k: Some(2),
return_documents: true,
rid: Some(StringOrArray::String("req-123".to_string())),
user: Some("user-456".to_string()),
};
// Validate request
assert!(request.validate().is_ok());
// Simulate reranking results (in real scenario, this would come from the model)
let results = vec![
RerankResult {
score: 0.95,
document: Some("Introduction to machine learning algorithms".to_string()),
index: 0,
meta_info: None,
},
RerankResult {
score: 0.87,
document: Some("Deep learning for computer vision".to_string()),
index: 1,
meta_info: None,
},
RerankResult {
score: 0.72,
document: Some("Natural language processing basics".to_string()),
index: 2,
meta_info: None,
},
RerankResult {
score: 0.45,
document: Some("Statistics and probability theory".to_string()),
index: 3,
meta_info: None,
},
];
// Create response
let mut response = RerankResponse::new(results, request.model.clone(), request.rid.clone());
// Sort by score
response.sort_by_score();
// Apply top_k
response.apply_top_k(request.effective_top_k());
assert_eq!(response.results.len(), 2);
assert_eq!(response.results[0].score, 0.95);
assert_eq!(response.results[0].index, 0);
assert_eq!(response.results[1].score, 0.87);
assert_eq!(response.results[1].index, 1);
assert_eq!(response.model, "rerank-model");
// Serialize and deserialize
let serialized = to_string(&response).unwrap();
let deserialized: RerankResponse = from_str(&serialized).unwrap();
assert_eq!(deserialized.results.len(), 2);
assert_eq!(deserialized.model, response.model);
}
#[test]
fn test_embedding_request_serialization_string_input() {
let req = EmbeddingRequest {
model: "test-emb".to_string(),
input: Value::String("hello".to_string()),
encoding_format: Some("float".to_string()),
user: Some("user-1".to_string()),
dimensions: Some(128),
rid: Some("rid-123".to_string()),
};
let serialized = to_string(&req).unwrap();
let deserialized: EmbeddingRequest = from_str(&serialized).unwrap();
assert_eq!(deserialized.model, req.model);
assert_eq!(deserialized.input, req.input);
assert_eq!(deserialized.encoding_format, req.encoding_format);
assert_eq!(deserialized.user, req.user);
assert_eq!(deserialized.dimensions, req.dimensions);
assert_eq!(deserialized.rid, req.rid);
}
#[test]
fn test_embedding_request_serialization_array_input() {
let req = EmbeddingRequest {
model: "test-emb".to_string(),
input: json!(["a", "b", "c"]),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
let serialized = to_string(&req).unwrap();
let de: EmbeddingRequest = from_str(&serialized).unwrap();
assert_eq!(de.model, req.model);
assert_eq!(de.input, req.input);
}
#[test]
fn test_embedding_generation_request_trait_string() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: Value::String("hello".to_string()),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
assert!(!req.is_stream());
assert_eq!(req.get_model(), Some("emb-model"));
assert_eq!(req.extract_text_for_routing(), "hello");
}
#[test]
fn test_embedding_generation_request_trait_array() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: json!(["hello", "world"]),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
assert_eq!(req.extract_text_for_routing(), "hello world");
}
#[test]
fn test_embedding_generation_request_trait_non_text() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: json!({"tokens": [1, 2, 3]}),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
assert_eq!(req.extract_text_for_routing(), "");
}
#[test]
fn test_embedding_generation_request_trait_mixed_array_ignores_nested() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: json!(["a", ["b", "c"], 123, {"k": "v"}]),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
// Only top-level string elements are extracted
assert_eq!(req.extract_text_for_routing(), "a");
}
}
// Validated JSON extractor for automatic request validation
//
// This module provides a ValidatedJson extractor that automatically validates
// requests using the validator crate's Validate trait.
use axum::{
extract::{rejection::JsonRejection, FromRequest, Request},
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde::de::DeserializeOwned;
use serde_json::json;
use validator::Validate;
/// Trait for request types that need post-deserialization normalization
pub trait Normalizable {
/// Normalize the request by applying defaults and transformations
fn normalize(&mut self) {
// Default: no-op
}
}
/// A JSON extractor that automatically validates and normalizes the request body
///
/// This extractor deserializes the request body and automatically calls `.validate()`
/// on types that implement the `Validate` trait. If validation fails, it returns
/// a 400 Bad Request with detailed error information.
///
/// # Example
///
/// ```rust,ignore
/// async fn create_chat(
/// ValidatedJson(request): ValidatedJson<ChatCompletionRequest>,
/// ) -> Response {
/// // request is guaranteed to be valid here
/// process_request(request).await
/// }
/// ```
pub struct ValidatedJson<T>(pub T);
impl<S, T> FromRequest<S> for ValidatedJson<T>
where
T: DeserializeOwned + Validate + Normalizable + Send,
S: Send + Sync,
{
type Rejection = Response;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
// First, extract and deserialize the JSON
let Json(mut data) =
Json::<T>::from_request(req, state)
.await
.map_err(|err: JsonRejection| {
let error_message = match err {
JsonRejection::JsonDataError(e) => {
format!("Invalid JSON data: {}", e)
}
JsonRejection::JsonSyntaxError(e) => {
format!("JSON syntax error: {}", e)
}
JsonRejection::MissingJsonContentType(_) => {
"Missing Content-Type: application/json header".to_string()
}
_ => format!("Failed to parse JSON: {}", err),
};
(
StatusCode::BAD_REQUEST,
Json(json!({
"error": {
"message": error_message,
"type": "invalid_request_error",
"code": "json_parse_error"
}
})),
)
.into_response()
})?;
// Normalize the request (apply defaults based on other fields)
data.normalize();
// Then, automatically validate the data
data.validate().map_err(|validation_errors| {
// Extract the first error message from the validation errors
let error_message = validation_errors
.field_errors()
.values()
.flat_map(|errors| errors.iter())
.find_map(|e| e.message.as_ref())
.map(|m| m.to_string())
.unwrap_or_else(|| "Validation failed".to_string());
(
StatusCode::BAD_REQUEST,
Json(json!({
"error": {
"message": error_message,
"type": "invalid_request_error",
"code": 400
}
})),
)
.into_response()
})?;
Ok(ValidatedJson(data))
}
}
// Implement Deref to allow transparent access to the inner value
impl<T> std::ops::Deref for ValidatedJson<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> std::ops::DerefMut for ValidatedJson<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use validator::Validate;
#[derive(Debug, Deserialize, Serialize, Validate)]
struct TestRequest {
#[validate(range(min = 0.0, max = 1.0))]
value: f32,
#[validate(length(min = 1))]
name: String,
}
impl Normalizable for TestRequest {
// Use default no-op implementation
}
#[tokio::test]
async fn test_validated_json_valid() {
// This test is conceptual - actual testing would require Axum test harness
let request = TestRequest {
value: 0.5,
name: "test".to_string(),
};
assert!(request.validate().is_ok());
}
#[tokio::test]
async fn test_validated_json_invalid_range() {
let request = TestRequest {
value: 1.5, // Out of range
name: "test".to_string(),
};
assert!(request.validate().is_err());
}
#[tokio::test]
async fn test_validated_json_invalid_length() {
let request = TestRequest {
value: 0.5,
name: "".to_string(), // Empty name
};
assert!(request.validate().is_err());
}
}
// Core validation infrastructure for API parameter validation
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
// Import types from spec module
use crate::protocols::spec::{
ChatCompletionRequest, ChatMessage, ResponseFormat, StringOrArray, UserMessageContent,
};
/// Validation constants for OpenAI API parameters
pub mod constants {
/// Temperature range: 0.0 to 2.0 (OpenAI spec)
pub const TEMPERATURE_RANGE: (f32, f32) = (0.0, 2.0);
/// Top-p range: 0.0 to 1.0 (exclusive of 0.0)
pub const TOP_P_RANGE: (f32, f32) = (0.0, 1.0);
/// Presence penalty range: -2.0 to 2.0 (OpenAI spec)
pub const PRESENCE_PENALTY_RANGE: (f32, f32) = (-2.0, 2.0);
/// Frequency penalty range: -2.0 to 2.0 (OpenAI spec)
pub const FREQUENCY_PENALTY_RANGE: (f32, f32) = (-2.0, 2.0);
/// Logprobs range for completions API: 0 to 5
pub const LOGPROBS_RANGE: (u32, u32) = (0, 5);
/// Top logprobs range for chat completions: 0 to 20
pub const TOP_LOGPROBS_RANGE: (u32, u32) = (0, 20);
/// Maximum number of stop sequences allowed
pub const MAX_STOP_SEQUENCES: usize = 4;
/// SGLang-specific validation constants
pub mod sglang {
/// Min-p range: 0.0 to 1.0 (SGLang extension)
pub const MIN_P_RANGE: (f32, f32) = (0.0, 1.0);
/// Top-k minimum value: -1 to disable, otherwise positive
pub const TOP_K_MIN: i32 = -1;
/// Repetition penalty range: 0.0 to 2.0 (SGLang extension)
/// 1.0 = no penalty, >1.0 = discourage repetition, <1.0 = encourage repetition
pub const REPETITION_PENALTY_RANGE: (f32, f32) = (0.0, 2.0);
}
}
/// Core validation error types
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ValidationError {
/// Parameter value out of valid range
OutOfRange {
parameter: String,
value: String,
min: String,
max: String,
},
/// Invalid parameter value format or type
InvalidValue {
parameter: String,
value: String,
reason: String,
},
/// Cross-parameter validation failure
ConflictingParameters {
parameter1: String,
parameter2: String,
reason: String,
},
/// Required parameter missing
MissingRequired { parameter: String },
/// Too many items in array parameter
TooManyItems {
parameter: String,
count: usize,
max: usize,
},
/// Custom validation error
Custom(String),
}
impl Display for ValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ValidationError::OutOfRange {
parameter,
value,
min,
max,
} => {
write!(
f,
"Parameter '{}' must be between {} and {}, got {}",
parameter, min, max, value
)
}
ValidationError::InvalidValue {
parameter,
value,
reason,
} => {
write!(
f,
"Invalid value for parameter '{}': {} ({})",
parameter, value, reason
)
}
ValidationError::ConflictingParameters {
parameter1,
parameter2,
reason,
} => {
write!(
f,
"Conflicting parameters '{}' and '{}': {}",
parameter1, parameter2, reason
)
}
ValidationError::MissingRequired { parameter } => {
write!(f, "Required parameter '{}' is missing", parameter)
}
ValidationError::TooManyItems {
parameter,
count,
max,
} => {
write!(
f,
"Parameter '{}' has too many items: {} (maximum: {})",
parameter, count, max
)
}
ValidationError::Custom(msg) => write!(f, "{}", msg),
}
}
}
impl std::error::Error for ValidationError {}
/// Core validation utility functions
pub mod utils {
use super::*;
/// Validate that a numeric value is within the specified range (inclusive)
pub fn validate_range<T>(
value: T,
range: &(T, T),
param_name: &str,
) -> Result<T, ValidationError>
where
T: PartialOrd + Display + Copy,
{
if value >= range.0 && value <= range.1 {
Ok(value)
} else {
Err(ValidationError::OutOfRange {
parameter: param_name.to_string(),
value: value.to_string(),
min: range.0.to_string(),
max: range.1.to_string(),
})
}
}
/// Validate that a positive number is actually positive
pub fn validate_positive<T>(value: T, param_name: &str) -> Result<T, ValidationError>
where
T: PartialOrd + Display + Copy + Default,
{
if value > T::default() {
Ok(value)
} else {
Err(ValidationError::InvalidValue {
parameter: param_name.to_string(),
value: value.to_string(),
reason: "must be positive".to_string(),
})
}
}
/// Validate that an array doesn't exceed maximum length
pub fn validate_max_items<T>(
items: &[T],
max_count: usize,
param_name: &str,
) -> Result<(), ValidationError> {
if items.len() <= max_count {
Ok(())
} else {
Err(ValidationError::TooManyItems {
parameter: param_name.to_string(),
count: items.len(),
max: max_count,
})
}
}
/// Validate that a required parameter is present
pub fn validate_required<'a, T>(
value: &'a Option<T>,
param_name: &str,
) -> Result<&'a T, ValidationError> {
value
.as_ref()
.ok_or_else(|| ValidationError::MissingRequired {
parameter: param_name.to_string(),
})
}
/// Validate top_k parameter (SGLang extension)
pub fn validate_top_k(top_k: i32) -> Result<i32, ValidationError> {
if top_k == constants::sglang::TOP_K_MIN || top_k > 0 {
Ok(top_k)
} else {
Err(ValidationError::InvalidValue {
parameter: "top_k".to_string(),
value: top_k.to_string(),
reason: "must be -1 (disabled) or positive".to_string(),
})
}
}
/// Generic validation function for sampling options
pub fn validate_sampling_options<T: SamplingOptionsProvider + ?Sized>(
request: &T,
) -> Result<(), ValidationError> {
// Validate temperature (0.0 to 2.0)
if let Some(temp) = request.get_temperature() {
validate_range(temp, &constants::TEMPERATURE_RANGE, "temperature")?;
}
// Validate top_p (0.0 to 1.0)
if let Some(top_p) = request.get_top_p() {
validate_range(top_p, &constants::TOP_P_RANGE, "top_p")?;
}
// Validate frequency_penalty (-2.0 to 2.0)
if let Some(freq_penalty) = request.get_frequency_penalty() {
validate_range(
freq_penalty,
&constants::FREQUENCY_PENALTY_RANGE,
"frequency_penalty",
)?;
}
// Validate presence_penalty (-2.0 to 2.0)
if let Some(pres_penalty) = request.get_presence_penalty() {
validate_range(
pres_penalty,
&constants::PRESENCE_PENALTY_RANGE,
"presence_penalty",
)?;
}
Ok(())
}
/// Generic validation function for stop conditions
pub fn validate_stop_conditions<T: StopConditionsProvider + ?Sized>(
request: &T,
) -> Result<(), ValidationError> {
if let Some(stop) = request.get_stop_sequences() {
match stop {
StringOrArray::String(s) => {
if s.is_empty() {
return Err(ValidationError::InvalidValue {
parameter: "stop".to_string(),
value: "empty string".to_string(),
reason: "stop sequences cannot be empty".to_string(),
});
}
}
StringOrArray::Array(arr) => {
validate_max_items(arr, constants::MAX_STOP_SEQUENCES, "stop")?;
for (i, s) in arr.iter().enumerate() {
if s.is_empty() {
return Err(ValidationError::InvalidValue {
parameter: format!("stop[{}]", i),
value: "empty string".to_string(),
reason: "stop sequences cannot be empty".to_string(),
});
}
}
}
}
}
Ok(())
}
/// Generic validation function for token limits
pub fn validate_token_limits<T: TokenLimitsProvider + ?Sized>(
request: &T,
) -> Result<(), ValidationError> {
// Validate max_tokens if provided
if let Some(max_tokens) = request.get_max_tokens() {
validate_positive(max_tokens, "max_tokens")?;
}
// Validate min_tokens if provided (SGLang extension)
if let Some(min_tokens) = request.get_min_tokens() {
validate_positive(min_tokens, "min_tokens")?;
}
Ok(())
}
/// Generic validation function for logprobs
pub fn validate_logprobs<T: LogProbsProvider + ?Sized>(
request: &T,
) -> Result<(), ValidationError> {
// Validate logprobs (completions API - 0 to 5)
if let Some(logprobs) = request.get_logprobs() {
validate_range(logprobs, &constants::LOGPROBS_RANGE, "logprobs")?;
}
// Validate top_logprobs (chat API - 0 to 20)
if let Some(top_logprobs) = request.get_top_logprobs() {
validate_range(top_logprobs, &constants::TOP_LOGPROBS_RANGE, "top_logprobs")?;
}
Ok(())
}
/// Generic cross-parameter validation
pub fn validate_cross_parameters<T: TokenLimitsProvider + ?Sized>(
request: &T,
) -> Result<(), ValidationError> {
// Check min_tokens <= max_tokens if both are specified
if let (Some(min_tokens), Some(max_tokens)) =
(request.get_min_tokens(), request.get_max_tokens())
{
if min_tokens > max_tokens {
return Err(ValidationError::ConflictingParameters {
parameter1: "min_tokens".to_string(),
parameter2: "max_tokens".to_string(),
reason: format!(
"min_tokens ({}) cannot be greater than max_tokens ({})",
min_tokens, max_tokens
),
});
}
}
Ok(())
}
/// Validate conflicting structured output constraints
pub fn validate_conflicting_parameters(
param1_name: &str,
param1_value: bool,
param2_name: &str,
param2_value: bool,
reason: &str,
) -> Result<(), ValidationError> {
if param1_value && param2_value {
return Err(ValidationError::ConflictingParameters {
parameter1: param1_name.to_string(),
parameter2: param2_name.to_string(),
reason: reason.to_string(),
});
}
Ok(())
}
/// Validate that only one option from a set is active
pub fn validate_mutually_exclusive_options(
options: &[(&str, bool)],
error_msg: &str,
) -> Result<(), ValidationError> {
let active_count = options.iter().filter(|(_, is_active)| *is_active).count();
if active_count > 1 {
return Err(ValidationError::Custom(error_msg.to_string()));
}
Ok(())
}
/// Generic validation for SGLang extensions
pub fn validate_sglang_extensions<T: SGLangExtensionsProvider + ?Sized>(
request: &T,
) -> Result<(), ValidationError> {
// Validate top_k (-1 to disable, or positive)
if let Some(top_k) = request.get_top_k() {
validate_top_k(top_k)?;
}
// Validate min_p (0.0 to 1.0)
if let Some(min_p) = request.get_min_p() {
validate_range(min_p, &constants::sglang::MIN_P_RANGE, "min_p")?;
}
// Validate repetition_penalty (0.0 to 2.0)
if let Some(rep_penalty) = request.get_repetition_penalty() {
validate_range(
rep_penalty,
&constants::sglang::REPETITION_PENALTY_RANGE,
"repetition_penalty",
)?;
}
Ok(())
}
/// Generic validation for n parameter (number of completions)
pub fn validate_completion_count<T: CompletionCountProvider + ?Sized>(
request: &T,
) -> Result<(), ValidationError> {
const N_RANGE: (u32, u32) = (1, 10);
if let Some(n) = request.get_n() {
validate_range(n, &N_RANGE, "n")?;
}
Ok(())
}
/// Validate that an array is not empty
pub fn validate_non_empty_array<T>(
items: &[T],
param_name: &str,
) -> Result<(), ValidationError> {
if items.is_empty() {
return Err(ValidationError::MissingRequired {
parameter: param_name.to_string(),
});
}
Ok(())
}
/// Validate common request parameters that are shared across all API types
pub fn validate_common_request_params<T>(request: &T) -> Result<(), ValidationError>
where
T: SamplingOptionsProvider
+ StopConditionsProvider
+ TokenLimitsProvider
+ LogProbsProvider
+ SGLangExtensionsProvider
+ CompletionCountProvider
+ ?Sized,
{
// Validate all standard parameters
validate_sampling_options(request)?;
validate_stop_conditions(request)?;
validate_token_limits(request)?;
validate_logprobs(request)?;
// Validate SGLang extensions and completion count
validate_sglang_extensions(request)?;
validate_completion_count(request)?;
// Perform cross-parameter validation
validate_cross_parameters(request)?;
Ok(())
}
}
/// Core validation traits for different parameter categories
pub trait SamplingOptionsProvider {
/// Get temperature parameter
fn get_temperature(&self) -> Option<f32>;
/// Get top_p parameter
fn get_top_p(&self) -> Option<f32>;
/// Get frequency penalty parameter
fn get_frequency_penalty(&self) -> Option<f32>;
/// Get presence penalty parameter
fn get_presence_penalty(&self) -> Option<f32>;
}
/// Trait for validating stop conditions
pub trait StopConditionsProvider {
/// Get stop sequences
fn get_stop_sequences(&self) -> Option<&StringOrArray>;
}
/// Trait for validating token limits
pub trait TokenLimitsProvider {
/// Get maximum tokens parameter
fn get_max_tokens(&self) -> Option<u32>;
/// Get minimum tokens parameter (SGLang extension)
fn get_min_tokens(&self) -> Option<u32>;
}
/// Trait for validating logprobs parameters
pub trait LogProbsProvider {
/// Get logprobs parameter (completions API)
fn get_logprobs(&self) -> Option<u32>;
/// Get top_logprobs parameter (chat API)
fn get_top_logprobs(&self) -> Option<u32>;
}
/// Trait for SGLang-specific extensions
pub trait SGLangExtensionsProvider {
/// Get top_k parameter
fn get_top_k(&self) -> Option<i32> {
None
}
/// Get min_p parameter
fn get_min_p(&self) -> Option<f32> {
None
}
/// Get repetition_penalty parameter
fn get_repetition_penalty(&self) -> Option<f32> {
None
}
}
/// Trait for n parameter (number of completions)
pub trait CompletionCountProvider {
/// Get n parameter
fn get_n(&self) -> Option<u32> {
None
}
}
/// Comprehensive validation trait that combines all validation aspects
pub trait ValidatableRequest:
SamplingOptionsProvider
+ StopConditionsProvider
+ TokenLimitsProvider
+ LogProbsProvider
+ SGLangExtensionsProvider
+ CompletionCountProvider
{
/// Perform comprehensive validation of the entire request
fn validate(&self) -> Result<(), ValidationError> {
// Use the common validation function
utils::validate_common_request_params(self)
}
}
impl SamplingOptionsProvider for ChatCompletionRequest {
fn get_temperature(&self) -> Option<f32> {
self.temperature
}
fn get_top_p(&self) -> Option<f32> {
self.top_p
}
fn get_frequency_penalty(&self) -> Option<f32> {
self.frequency_penalty
}
fn get_presence_penalty(&self) -> Option<f32> {
self.presence_penalty
}
}
impl StopConditionsProvider for ChatCompletionRequest {
fn get_stop_sequences(&self) -> Option<&StringOrArray> {
self.stop.as_ref()
}
}
impl TokenLimitsProvider for ChatCompletionRequest {
#[allow(deprecated)]
fn get_max_tokens(&self) -> Option<u32> {
// Prefer max_completion_tokens over max_tokens if both are set
self.max_completion_tokens.or(self.max_tokens)
}
fn get_min_tokens(&self) -> Option<u32> {
self.min_tokens
}
}
impl LogProbsProvider for ChatCompletionRequest {
fn get_logprobs(&self) -> Option<u32> {
// For chat API, logprobs is a boolean, return 1 if true for validation purposes
if self.logprobs {
Some(1)
} else {
None
}
}
fn get_top_logprobs(&self) -> Option<u32> {
self.top_logprobs
}
}
impl SGLangExtensionsProvider for ChatCompletionRequest {
fn get_top_k(&self) -> Option<i32> {
self.top_k
}
fn get_min_p(&self) -> Option<f32> {
self.min_p
}
fn get_repetition_penalty(&self) -> Option<f32> {
self.repetition_penalty
}
}
impl CompletionCountProvider for ChatCompletionRequest {
fn get_n(&self) -> Option<u32> {
self.n
}
}
impl ChatCompletionRequest {
/// Validate message-specific requirements
pub fn validate_messages(&self) -> Result<(), ValidationError> {
// Ensure messages array is not empty
utils::validate_non_empty_array(&self.messages, "messages")?;
// Validate message content is not empty
for (i, msg) in self.messages.iter().enumerate() {
if let ChatMessage::User { content, .. } = msg {
match content {
UserMessageContent::Text(text) if text.is_empty() => {
return Err(ValidationError::InvalidValue {
parameter: format!("messages[{}].content", i),
value: "empty".to_string(),
reason: "message content cannot be empty".to_string(),
});
}
UserMessageContent::Parts(parts) if parts.is_empty() => {
return Err(ValidationError::InvalidValue {
parameter: format!("messages[{}].content", i),
value: "empty array".to_string(),
reason: "message content parts cannot be empty".to_string(),
});
}
_ => {}
}
}
}
Ok(())
}
/// Validate response format if specified
pub fn validate_response_format(&self) -> Result<(), ValidationError> {
if let Some(ResponseFormat::JsonSchema { json_schema }) = &self.response_format {
if json_schema.name.is_empty() {
return Err(ValidationError::InvalidValue {
parameter: "response_format.json_schema.name".to_string(),
value: "empty".to_string(),
reason: "JSON schema name cannot be empty".to_string(),
});
}
}
Ok(())
}
/// Validate chat API specific logprobs requirements
pub fn validate_chat_logprobs(&self) -> Result<(), ValidationError> {
// OpenAI rule: If top_logprobs is specified, logprobs must be true
// But logprobs=true without top_logprobs is valid (returns basic logprobs)
if self.top_logprobs.is_some() && !self.logprobs {
return Err(ValidationError::InvalidValue {
parameter: "top_logprobs".to_string(),
value: self.top_logprobs.unwrap().to_string(),
reason: "top_logprobs is only allowed when logprobs is enabled".to_string(),
});
}
Ok(())
}
/// Validate cross-parameter relationships specific to chat completions
#[allow(deprecated)]
pub fn validate_chat_cross_parameters(&self) -> Result<(), ValidationError> {
// Validate that both max_tokens and max_completion_tokens aren't set
utils::validate_conflicting_parameters(
"max_tokens",
self.max_tokens.is_some(),
"max_completion_tokens",
self.max_completion_tokens.is_some(),
"cannot specify both max_tokens and max_completion_tokens",
)?;
// Validate that tools and functions aren't both specified (deprecated)
utils::validate_conflicting_parameters(
"tools",
self.tools.is_some(),
"functions",
self.functions.is_some(),
"functions is deprecated, use tools instead",
)?;
// Validate structured output constraints don't conflict with JSON response format
let has_json_format = matches!(
self.response_format,
Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. })
);
utils::validate_conflicting_parameters(
"response_format",
has_json_format,
"regex",
self.regex.is_some(),
"cannot use regex constraint with JSON response format",
)?;
utils::validate_conflicting_parameters(
"response_format",
has_json_format,
"ebnf",
self.ebnf.is_some(),
"cannot use EBNF constraint with JSON response format",
)?;
// Only one structured output constraint should be active
let structured_constraints = [
("regex", self.regex.is_some()),
("ebnf", self.ebnf.is_some()),
(
"json_schema",
matches!(
self.response_format,
Some(ResponseFormat::JsonSchema { .. })
),
),
];
utils::validate_mutually_exclusive_options(
&structured_constraints,
"Only one structured output constraint (regex, ebnf, or json_schema) can be active at a time",
)?;
Ok(())
}
}
impl ValidatableRequest for ChatCompletionRequest {
fn validate(&self) -> Result<(), ValidationError> {
// Call the common validation function from the validation module
utils::validate_common_request_params(self)?;
// Then validate chat-specific parameters
self.validate_messages()?;
self.validate_response_format()?;
self.validate_chat_logprobs()?;
self.validate_chat_cross_parameters()?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::constants::*;
use super::utils::*;
use super::*;
use crate::protocols::spec::StringOrArray;
// Mock request type for testing validation traits
#[derive(Debug, Default)]
struct MockRequest {
temperature: Option<f32>,
stop: Option<StringOrArray>,
max_tokens: Option<u32>,
min_tokens: Option<u32>,
}
impl SamplingOptionsProvider for MockRequest {
fn get_temperature(&self) -> Option<f32> {
self.temperature
}
fn get_top_p(&self) -> Option<f32> {
None
}
fn get_frequency_penalty(&self) -> Option<f32> {
None
}
fn get_presence_penalty(&self) -> Option<f32> {
None
}
}
impl StopConditionsProvider for MockRequest {
fn get_stop_sequences(&self) -> Option<&StringOrArray> {
self.stop.as_ref()
}
}
impl TokenLimitsProvider for MockRequest {
fn get_max_tokens(&self) -> Option<u32> {
self.max_tokens
}
fn get_min_tokens(&self) -> Option<u32> {
self.min_tokens
}
}
impl LogProbsProvider for MockRequest {
fn get_logprobs(&self) -> Option<u32> {
None
}
fn get_top_logprobs(&self) -> Option<u32> {
None
}
}
impl SGLangExtensionsProvider for MockRequest {}
impl CompletionCountProvider for MockRequest {}
impl ValidatableRequest for MockRequest {}
#[test]
fn test_range_validation() {
// Valid range
assert!(validate_range(1.5f32, &TEMPERATURE_RANGE, "temperature").is_ok());
// Invalid range
assert!(validate_range(-0.1f32, &TEMPERATURE_RANGE, "temperature").is_err());
assert!(validate_range(3.0f32, &TEMPERATURE_RANGE, "temperature").is_err());
}
#[test]
fn test_sglang_top_k_validation() {
assert!(validate_top_k(-1).is_ok()); // Disabled
assert!(validate_top_k(50).is_ok()); // Valid positive
assert!(validate_top_k(0).is_err()); // Invalid
assert!(validate_top_k(-5).is_err()); // Invalid
}
#[test]
fn test_stop_sequences_limits() {
let request = MockRequest {
stop: Some(StringOrArray::Array(vec![
"stop1".to_string(),
"stop2".to_string(),
"stop3".to_string(),
"stop4".to_string(),
"stop5".to_string(), // Too many
])),
..Default::default()
};
assert!(request.validate().is_err());
}
#[test]
fn test_token_limits_conflict() {
let request = MockRequest {
min_tokens: Some(100),
max_tokens: Some(50), // min > max
..Default::default()
};
assert!(request.validate().is_err());
}
#[test]
fn test_valid_request() {
let request = MockRequest {
temperature: Some(1.0),
stop: Some(StringOrArray::Array(vec!["stop".to_string()])),
max_tokens: Some(100),
min_tokens: Some(10),
};
assert!(request.validate().is_ok());
}
// Chat completion specific tests
#[cfg(test)]
mod chat_tests {
use super::*;
#[allow(deprecated)]
fn create_valid_chat_request() -> ChatCompletionRequest {
ChatCompletionRequest {
messages: vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Hello".to_string()),
name: None,
}],
model: "gpt-4".to_string(),
// Set specific fields we want to test
temperature: Some(1.0),
top_p: Some(0.9),
n: Some(1),
max_tokens: Some(100),
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
// Use default for all other fields
..Default::default()
}
}
#[test]
fn test_chat_validation_basics() {
// Valid request
assert!(create_valid_chat_request().validate().is_ok());
// Empty messages
let mut request = create_valid_chat_request();
request.messages = vec![];
assert!(request.validate().is_err());
// Invalid temperature
let mut request = create_valid_chat_request();
request.temperature = Some(3.0);
assert!(request.validate().is_err());
}
#[test]
#[allow(deprecated)]
fn test_chat_cross_parameter_conflicts() {
let mut request = create_valid_chat_request();
request.max_tokens = Some(100);
request.max_completion_tokens = Some(200);
assert!(
request.validate().is_err(),
"Should reject both max_tokens and max_completion_tokens"
);
// Reset for next test
request.max_tokens = None;
request.max_completion_tokens = None;
request.tools = Some(vec![]);
request.functions = Some(vec![]);
assert!(
request.validate().is_err(),
"Should reject both tools and functions"
);
let mut request = create_valid_chat_request();
request.logprobs = true;
request.top_logprobs = None;
assert!(
request.validate().is_ok(),
"logprobs=true without top_logprobs should be valid"
);
let mut request = create_valid_chat_request();
request.logprobs = false;
request.top_logprobs = Some(5);
assert!(
request.validate().is_err(),
"top_logprobs without logprobs=true should fail"
);
}
#[test]
fn test_sglang_extensions() {
let mut request = create_valid_chat_request();
// Valid SGLang parameters
request.top_k = Some(-1);
request.min_p = Some(0.1);
request.repetition_penalty = Some(1.2);
assert!(request.validate().is_ok());
// Invalid parameters
request.top_k = Some(0); // Invalid
assert!(request.validate().is_err());
}
#[test]
fn test_parameter_ranges() {
let mut request = create_valid_chat_request();
request.temperature = Some(1.5);
assert!(request.validate().is_ok());
request.temperature = Some(-0.1);
assert!(request.validate().is_err());
request.temperature = Some(3.0);
assert!(request.validate().is_err());
request.temperature = Some(1.0); // Reset
request.top_p = Some(0.9);
assert!(request.validate().is_ok());
request.top_p = Some(-0.1);
assert!(request.validate().is_err());
request.top_p = Some(1.5);
assert!(request.validate().is_err());
request.top_p = Some(0.9); // Reset
request.frequency_penalty = Some(1.5);
assert!(request.validate().is_ok());
request.frequency_penalty = Some(-2.5);
assert!(request.validate().is_err());
request.frequency_penalty = Some(3.0);
assert!(request.validate().is_err());
request.frequency_penalty = Some(0.0); // Reset
request.presence_penalty = Some(-1.5);
assert!(request.validate().is_ok());
request.presence_penalty = Some(-3.0);
assert!(request.validate().is_err());
request.presence_penalty = Some(2.5);
assert!(request.validate().is_err());
request.presence_penalty = Some(0.0); // Reset
request.repetition_penalty = Some(1.2);
assert!(request.validate().is_ok());
request.repetition_penalty = Some(-0.1);
assert!(request.validate().is_err());
request.repetition_penalty = Some(2.1);
assert!(request.validate().is_err());
request.repetition_penalty = Some(1.0); // Reset
request.min_p = Some(0.5);
assert!(request.validate().is_ok());
request.min_p = Some(-0.1);
assert!(request.validate().is_err());
request.min_p = Some(1.5);
assert!(request.validate().is_err());
}
#[test]
fn test_structured_output_conflicts() {
let mut request = create_valid_chat_request();
// JSON response format with regex should conflict
request.response_format = Some(ResponseFormat::JsonObject);
request.regex = Some(".*".to_string());
assert!(request.validate().is_err());
// JSON response format with EBNF should conflict
request.regex = None;
request.ebnf = Some("grammar".to_string());
assert!(request.validate().is_err());
// Multiple structured constraints should conflict
request.response_format = None;
request.regex = Some(".*".to_string());
request.ebnf = Some("grammar".to_string());
assert!(request.validate().is_err());
// Only one constraint should work
request.ebnf = None;
request.regex = Some(".*".to_string());
assert!(request.validate().is_ok());
request.regex = None;
request.ebnf = Some("grammar".to_string());
assert!(request.validate().is_ok());
request.ebnf = None;
request.response_format = Some(ResponseFormat::JsonObject);
assert!(request.validate().is_ok());
}
#[test]
fn test_stop_sequences_validation() {
let mut request = create_valid_chat_request();
// Valid stop sequences
request.stop = Some(StringOrArray::Array(vec![
"stop1".to_string(),
"stop2".to_string(),
]));
assert!(request.validate().is_ok());
// Too many stop sequences (max 4)
request.stop = Some(StringOrArray::Array(vec![
"stop1".to_string(),
"stop2".to_string(),
"stop3".to_string(),
"stop4".to_string(),
"stop5".to_string(),
]));
assert!(request.validate().is_err());
// Empty stop sequence should fail
request.stop = Some(StringOrArray::String("".to_string()));
assert!(request.validate().is_err());
// Empty string in array should fail
request.stop = Some(StringOrArray::Array(vec![
"stop1".to_string(),
"".to_string(),
]));
assert!(request.validate().is_err());
}
#[test]
fn test_logprobs_validation() {
let mut request = create_valid_chat_request();
// Valid logprobs configuration with top_logprobs
request.logprobs = true;
request.top_logprobs = Some(10);
assert!(request.validate().is_ok());
// logprobs=true without top_logprobs should be valid (OpenAI behavior)
request.top_logprobs = None;
assert!(
request.validate().is_ok(),
"logprobs=true without top_logprobs should be valid"
);
// top_logprobs without logprobs=true should fail
request.logprobs = false;
request.top_logprobs = Some(10);
assert!(request.validate().is_err());
// top_logprobs out of range (0-20)
request.logprobs = true;
request.top_logprobs = Some(25);
assert!(request.validate().is_err());
}
#[test]
fn test_n_parameter_validation() {
let mut request = create_valid_chat_request();
// Valid n values (1-10)
request.n = Some(1);
assert!(request.validate().is_ok());
request.n = Some(5);
assert!(request.validate().is_ok());
request.n = Some(10);
assert!(request.validate().is_ok());
// Invalid n values
request.n = Some(0);
assert!(request.validate().is_err());
request.n = Some(15);
assert!(request.validate().is_err());
}
#[test]
#[allow(deprecated)]
fn test_min_max_tokens_validation() {
let mut request = create_valid_chat_request();
// Valid token limits
request.min_tokens = Some(10);
request.max_tokens = Some(100);
assert!(request.validate().is_ok());
// min_tokens > max_tokens should fail
request.min_tokens = Some(150);
request.max_tokens = Some(100);
assert!(request.validate().is_err());
// Should work with max_completion_tokens instead
request.max_tokens = None;
request.max_completion_tokens = Some(200);
request.min_tokens = Some(50);
assert!(request.validate().is_ok());
// min_tokens > max_completion_tokens should fail
request.min_tokens = Some(250);
assert!(request.validate().is_err());
}
}
}
......@@ -959,7 +959,6 @@ mod tests {
#[test]
fn test_transform_messages_string_format() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![
ContentPart::Text {
text: "Hello".to_string(),
......@@ -993,7 +992,6 @@ mod tests {
#[test]
fn test_transform_messages_openai_format() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![
ContentPart::Text {
text: "Describe this image:".to_string(),
......@@ -1028,7 +1026,6 @@ mod tests {
#[test]
fn test_transform_messages_simple_string_content() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Simple text message".to_string()),
name: None,
}];
......@@ -1049,12 +1046,10 @@ mod tests {
fn test_transform_messages_multiple_messages() {
let messages = vec![
ChatMessage::System {
role: "system".to_string(),
content: "System prompt".to_string(),
name: None,
},
ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![
ContentPart::Text {
text: "User message".to_string(),
......@@ -1086,7 +1081,6 @@ mod tests {
#[test]
fn test_transform_messages_empty_text_parts() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![ContentPart::ImageUrl {
image_url: ImageUrl {
url: "https://example.com/image.jpg".to_string(),
......@@ -1109,12 +1103,10 @@ mod tests {
fn test_transform_messages_mixed_content_types() {
let messages = vec![
ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Plain text".to_string()),
name: None,
},
ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![
ContentPart::Text {
text: "With image".to_string(),
......
......@@ -16,6 +16,7 @@ use crate::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest,
RerankRequest, ResponsesGetParams, ResponsesRequest, V1RerankReqInput,
},
validated::ValidatedJson,
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
},
reasoning_parser::ParserFactory as ReasoningParserFactory,
......@@ -291,7 +292,7 @@ async fn generate(
async fn v1_chat_completions(
State(state): State<Arc<AppState>>,
headers: http::HeaderMap,
Json(body): Json<ChatCompletionRequest>,
ValidatedJson(body): ValidatedJson<ChatCompletionRequest>,
) -> Response {
state.router.route_chat(Some(&headers), &body, None).await
}
......
......@@ -1461,39 +1461,6 @@ mod error_tests {
ctx.shutdown().await;
}
#[tokio::test]
async fn test_missing_required_fields() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 18405,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
// Missing messages in chat completion
let payload = json!({
"model": "test-model"
// missing "messages"
});
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header(CONTENT_TYPE, "application/json")
.body(Body::from(serde_json::to_string(&payload).unwrap()))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
// Axum validates JSON schema - returns 422 for validation errors
assert_eq!(resp.status(), StatusCode::UNPROCESSABLE_ENTITY);
ctx.shutdown().await;
}
#[tokio::test]
async fn test_invalid_model() {
let ctx = TestContext::new(vec![MockWorkerConfig {
......
......@@ -172,14 +172,12 @@ assistant:
let processor = ChatTemplateProcessor::new(template.to_string());
let messages = vec![
let messages = [
spec::ChatMessage::System {
role: "system".to_string(),
content: "You are helpful".to_string(),
name: None,
},
spec::ChatMessage::User {
role: "user".to_string(),
content: spec::UserMessageContent::Text("Hello".to_string()),
name: None,
},
......@@ -216,7 +214,6 @@ fn test_chat_template_with_tokens_unit_test() {
let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [spec::ChatMessage::User {
role: "user".to_string(),
content: spec::UserMessageContent::Text("Test".to_string()),
name: None,
}];
......
......@@ -18,7 +18,6 @@ fn test_simple_chat_template() {
let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [spec::ChatMessage::User {
role: "user".to_string(),
content: spec::UserMessageContent::Text("Test".to_string()),
name: None,
}];
......@@ -53,7 +52,6 @@ fn test_chat_template_with_tokens() {
let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [spec::ChatMessage::User {
role: "user".to_string(),
content: spec::UserMessageContent::Text("Test".to_string()),
name: None,
}];
......@@ -113,14 +111,12 @@ fn test_llama_style_template() {
let processor = ChatTemplateProcessor::new(template.to_string());
let messages = vec![
let messages = [
spec::ChatMessage::System {
role: "system".to_string(),
content: "You are a helpful assistant".to_string(),
name: None,
},
spec::ChatMessage::User {
role: "user".to_string(),
content: spec::UserMessageContent::Text("What is 2+2?".to_string()),
name: None,
},
......@@ -172,19 +168,16 @@ fn test_chatml_template() {
let messages = vec![
spec::ChatMessage::User {
role: "user".to_string(),
content: spec::UserMessageContent::Text("Hello".to_string()),
name: None,
},
spec::ChatMessage::Assistant {
role: "assistant".to_string(),
content: Some("Hi there!".to_string()),
name: None,
tool_calls: None,
reasoning_content: None,
},
spec::ChatMessage::User {
role: "user".to_string(),
content: spec::UserMessageContent::Text("How are you?".to_string()),
name: None,
},
......@@ -227,7 +220,6 @@ assistant:
let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [spec::ChatMessage::User {
role: "user".to_string(),
content: spec::UserMessageContent::Text("Test".to_string()),
name: None,
}];
......@@ -315,7 +307,6 @@ fn test_template_with_multimodal_content() {
let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [spec::ChatMessage::User {
role: "user".to_string(),
content: spec::UserMessageContent::Parts(vec![
spec::ContentPart::Text {
text: "Look at this:".to_string(),
......
......@@ -57,14 +57,12 @@ mod tests {
)
.unwrap();
let messages = vec![
let messages = [
spec::ChatMessage::User {
role: "user".to_string(),
content: spec::UserMessageContent::Text("Hello".to_string()),
name: None,
},
spec::ChatMessage::Assistant {
role: "assistant".to_string(),
content: Some("Hi there".to_string()),
name: None,
tool_calls: None,
......@@ -143,7 +141,6 @@ mod tests {
.unwrap();
let messages = [spec::ChatMessage::User {
role: "user".to_string(),
content: spec::UserMessageContent::Text("Test".to_string()),
name: None,
}];
......@@ -201,14 +198,12 @@ mod tests {
"NEW: {% for msg in messages %}{{ msg.role }}: {{ msg.content }}; {% endfor %}";
tokenizer.set_chat_template(new_template.to_string());
let messages = vec![
let messages = [
spec::ChatMessage::User {
role: "user".to_string(),
content: spec::UserMessageContent::Text("Hello".to_string()),
name: None,
},
spec::ChatMessage::Assistant {
role: "assistant".to_string(),
content: Some("World".to_string()),
name: None,
tool_calls: None,
......
......@@ -119,6 +119,7 @@ pub fn create_test_tools() -> Vec<Tool> {
"query": {"type": "string"}
}
}),
strict: None,
},
},
Tool {
......@@ -135,6 +136,7 @@ pub fn create_test_tools() -> Vec<Tool> {
"units": {"type": "string"}
}
}),
strict: None,
},
},
Tool {
......@@ -149,6 +151,7 @@ pub fn create_test_tools() -> Vec<Tool> {
"y": {"type": "number"}
}
}),
strict: None,
},
},
Tool {
......@@ -164,6 +167,7 @@ pub fn create_test_tools() -> Vec<Tool> {
"target_lang": {"type": "string"}
}
}),
strict: None,
},
},
Tool {
......@@ -178,6 +182,7 @@ pub fn create_test_tools() -> Vec<Tool> {
"format": {"type": "string"}
}
}),
strict: None,
},
},
Tool {
......@@ -192,6 +197,7 @@ pub fn create_test_tools() -> Vec<Tool> {
"format": {"type": "string"}
}
}),
strict: None,
},
},
Tool {
......@@ -206,6 +212,7 @@ pub fn create_test_tools() -> Vec<Tool> {
"notifications": {"type": "boolean"}
}
}),
strict: None,
},
},
Tool {
......@@ -214,6 +221,7 @@ pub fn create_test_tools() -> Vec<Tool> {
name: "ping".to_string(),
description: Some("Ping service".to_string()),
parameters: json!({"type": "object", "properties": {}}),
strict: None,
},
},
Tool {
......@@ -222,6 +230,7 @@ pub fn create_test_tools() -> Vec<Tool> {
name: "test".to_string(),
description: Some("Test function".to_string()),
parameters: json!({"type": "object", "properties": {}}),
strict: None,
},
},
Tool {
......@@ -239,6 +248,7 @@ pub fn create_test_tools() -> Vec<Tool> {
"text": {"type": "string"}
}
}),
strict: None,
},
},
Tool {
......@@ -254,6 +264,7 @@ pub fn create_test_tools() -> Vec<Tool> {
"search_type": {"type": "string"}
}
}),
strict: None,
},
},
Tool {
......@@ -267,6 +278,7 @@ pub fn create_test_tools() -> Vec<Tool> {
"city": {"type": "string"}
}
}),
strict: None,
},
},
Tool {
......@@ -282,6 +294,7 @@ pub fn create_test_tools() -> Vec<Tool> {
"optional": {"type": "null"}
}
}),
strict: None,
},
},
Tool {
......@@ -297,6 +310,7 @@ pub fn create_test_tools() -> Vec<Tool> {
"none_val": {"type": "null"}
}
}),
strict: None,
},
},
Tool {
......@@ -311,6 +325,7 @@ pub fn create_test_tools() -> Vec<Tool> {
"email": {"type": "string"}
}
}),
strict: None,
},
},
Tool {
......@@ -325,6 +340,7 @@ pub fn create_test_tools() -> Vec<Tool> {
"y": {"type": "number"}
}
}),
strict: None,
},
},
Tool {
......@@ -338,6 +354,7 @@ pub fn create_test_tools() -> Vec<Tool> {
"x": {"type": "number"}
}
}),
strict: None,
},
},
Tool {
......@@ -346,6 +363,7 @@ pub fn create_test_tools() -> Vec<Tool> {
name: "func1".to_string(),
description: Some("Function 1".to_string()),
parameters: json!({"type": "object", "properties": {}}),
strict: None,
},
},
Tool {
......@@ -359,6 +377,7 @@ pub fn create_test_tools() -> Vec<Tool> {
"y": {"type": "number"}
}
}),
strict: None,
},
},
Tool {
......@@ -367,6 +386,7 @@ pub fn create_test_tools() -> Vec<Tool> {
name: "tool1".to_string(),
description: Some("Tool 1".to_string()),
parameters: json!({"type": "object", "properties": {}}),
strict: None,
},
},
Tool {
......@@ -380,6 +400,7 @@ pub fn create_test_tools() -> Vec<Tool> {
"y": {"type": "number"}
}
}),
strict: None,
},
},
]
......
use serde_json::json;
use sglang_router_rs::protocols::spec::{
ChatCompletionRequest, ChatMessage, Function, FunctionCall, FunctionChoice, StreamOptions,
Tool, ToolChoice, ToolChoiceValue, ToolReference, UserMessageContent,
};
use sglang_router_rs::protocols::validated::Normalizable;
use validator::Validate;
// Deprecated fields normalization tests
#[test]
fn test_max_tokens_normalizes_to_max_completion_tokens() {
#[allow(deprecated)]
let mut req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
max_tokens: Some(100),
max_completion_tokens: None,
..Default::default()
};
req.normalize();
assert_eq!(
req.max_completion_tokens,
Some(100),
"max_tokens should be copied to max_completion_tokens"
);
#[allow(deprecated)]
{
assert!(
req.max_tokens.is_none(),
"Deprecated field should be cleared"
);
}
assert!(
req.validate().is_ok(),
"Should be valid after normalization"
);
}
#[test]
fn test_max_completion_tokens_takes_precedence() {
#[allow(deprecated)]
let mut req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
max_tokens: Some(100),
max_completion_tokens: Some(200),
..Default::default()
};
req.normalize();
assert_eq!(
req.max_completion_tokens,
Some(200),
"max_completion_tokens should take precedence"
);
assert!(
req.validate().is_ok(),
"Should be valid after normalization"
);
}
#[test]
fn test_functions_normalizes_to_tools() {
#[allow(deprecated)]
let mut req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
functions: Some(vec![Function {
name: "test_func".to_string(),
description: Some("Test function".to_string()),
parameters: json!({}),
strict: None,
}]),
tools: None,
..Default::default()
};
req.normalize();
assert!(req.tools.is_some(), "functions should be migrated to tools");
assert_eq!(req.tools.as_ref().unwrap().len(), 1);
assert_eq!(req.tools.as_ref().unwrap()[0].function.name, "test_func");
#[allow(deprecated)]
{
assert!(
req.functions.is_none(),
"Deprecated field should be cleared"
);
}
assert!(
req.validate().is_ok(),
"Should be valid after normalization"
);
}
#[test]
fn test_function_call_normalizes_to_tool_choice() {
#[allow(deprecated)]
let mut req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
function_call: Some(FunctionCall::None),
tool_choice: None,
..Default::default()
};
req.normalize();
assert!(
req.tool_choice.is_some(),
"function_call should be migrated to tool_choice"
);
assert!(matches!(
req.tool_choice,
Some(ToolChoice::Value(ToolChoiceValue::None))
));
#[allow(deprecated)]
{
assert!(
req.function_call.is_none(),
"Deprecated field should be cleared"
);
}
assert!(
req.validate().is_ok(),
"Should be valid after normalization"
);
}
#[test]
fn test_function_call_function_variant_normalizes() {
#[allow(deprecated)]
let mut req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
function_call: Some(FunctionCall::Function {
name: "my_function".to_string(),
}),
tool_choice: None,
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: Function {
name: "my_function".to_string(),
description: None,
parameters: json!({}),
strict: None,
},
}]),
..Default::default()
};
req.normalize();
assert!(
req.tool_choice.is_some(),
"function_call should be migrated to tool_choice"
);
match &req.tool_choice {
Some(ToolChoice::Function { function, .. }) => {
assert_eq!(function.name, "my_function");
}
_ => panic!("Expected ToolChoice::Function variant"),
}
#[allow(deprecated)]
{
assert!(
req.function_call.is_none(),
"Deprecated field should be cleared"
);
}
assert!(
req.validate().is_ok(),
"Should be valid after normalization"
);
}
// Stream options validation tests
#[test]
fn test_stream_options_requires_stream_enabled() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
stream: false,
stream_options: Some(StreamOptions {
include_usage: Some(true),
}),
..Default::default()
};
let result = req.validate();
assert!(
result.is_err(),
"Should reject stream_options when stream is false"
);
let err = result.unwrap_err().to_string();
assert!(
err.contains("stream_options") && err.contains("stream") && err.contains("enabled"),
"Error should mention stream dependency: {}",
err
);
}
#[test]
fn test_stream_options_valid_when_stream_enabled() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
stream: true,
stream_options: Some(StreamOptions {
include_usage: Some(true),
}),
..Default::default()
};
let result = req.validate();
assert!(
result.is_ok(),
"Should accept stream_options when stream is true"
);
}
#[test]
fn test_no_stream_options_valid_when_stream_disabled() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
stream: false,
stream_options: None,
..Default::default()
};
let result = req.validate();
assert!(
result.is_ok(),
"Should accept no stream_options when stream is false"
);
}
// Tool choice validation tests
#[test]
fn test_tool_choice_function_not_found() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
}]),
tool_choice: Some(ToolChoice::Function {
function: FunctionChoice {
name: "nonexistent_function".to_string(),
},
tool_type: "function".to_string(),
}),
..Default::default()
};
let result = req.validate();
assert!(result.is_err(), "Should reject nonexistent function name");
let err = result.unwrap_err().to_string();
assert!(
err.contains("function 'nonexistent_function' not found"),
"Error should mention the missing function: {}",
err
);
}
#[test]
fn test_tool_choice_function_exists_valid() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
}]),
tool_choice: Some(ToolChoice::Function {
function: FunctionChoice {
name: "get_weather".to_string(),
},
tool_type: "function".to_string(),
}),
..Default::default()
};
let result = req.validate();
assert!(result.is_ok(), "Should accept existing function name");
}
#[test]
fn test_tool_choice_allowed_tools_invalid_mode() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
}]),
tool_choice: Some(ToolChoice::AllowedTools {
mode: "invalid_mode".to_string(),
tools: vec![ToolReference {
tool_type: "function".to_string(),
name: "get_weather".to_string(),
}],
tool_type: "function".to_string(),
}),
..Default::default()
};
let result = req.validate();
assert!(result.is_err(), "Should reject invalid mode");
let err = result.unwrap_err().to_string();
assert!(
err.contains("must be 'auto' or 'required'"),
"Error should mention valid modes: {}",
err
);
}
#[test]
fn test_tool_choice_allowed_tools_valid_mode_auto() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
}]),
tool_choice: Some(ToolChoice::AllowedTools {
mode: "auto".to_string(),
tools: vec![ToolReference {
tool_type: "function".to_string(),
name: "get_weather".to_string(),
}],
tool_type: "function".to_string(),
}),
..Default::default()
};
let result = req.validate();
assert!(result.is_ok(), "Should accept 'auto' mode");
}
#[test]
fn test_tool_choice_allowed_tools_valid_mode_required() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
}]),
tool_choice: Some(ToolChoice::AllowedTools {
mode: "required".to_string(),
tools: vec![ToolReference {
tool_type: "function".to_string(),
name: "get_weather".to_string(),
}],
tool_type: "function".to_string(),
}),
..Default::default()
};
let result = req.validate();
assert!(result.is_ok(), "Should accept 'required' mode");
}
#[test]
fn test_tool_choice_allowed_tools_tool_not_found() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
tools: Some(vec![Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
}]),
tool_choice: Some(ToolChoice::AllowedTools {
mode: "auto".to_string(),
tools: vec![ToolReference {
tool_type: "function".to_string(),
name: "nonexistent_tool".to_string(),
}],
tool_type: "function".to_string(),
}),
..Default::default()
};
let result = req.validate();
assert!(result.is_err(), "Should reject nonexistent tool name");
let err = result.unwrap_err().to_string();
assert!(
err.contains("tool 'nonexistent_tool' not found"),
"Error should mention the missing tool: {}",
err
);
}
#[test]
fn test_tool_choice_allowed_tools_multiple_tools_valid() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
tools: Some(vec![
Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_time".to_string(),
description: Some("Get time".to_string()),
parameters: json!({}),
strict: None,
},
},
]),
tool_choice: Some(ToolChoice::AllowedTools {
mode: "auto".to_string(),
tools: vec![
ToolReference {
tool_type: "function".to_string(),
name: "get_weather".to_string(),
},
ToolReference {
tool_type: "function".to_string(),
name: "get_time".to_string(),
},
],
tool_type: "function".to_string(),
}),
..Default::default()
};
let result = req.validate();
assert!(result.is_ok(), "Should accept all valid tool references");
}
#[test]
fn test_tool_choice_allowed_tools_one_invalid_among_valid() {
let req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
content: UserMessageContent::Text("hello".to_string()),
name: None,
}],
tools: Some(vec![
Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get weather".to_string()),
parameters: json!({}),
strict: None,
},
},
Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_time".to_string(),
description: Some("Get time".to_string()),
parameters: json!({}),
strict: None,
},
},
]),
tool_choice: Some(ToolChoice::AllowedTools {
mode: "auto".to_string(),
tools: vec![
ToolReference {
tool_type: "function".to_string(),
name: "get_weather".to_string(),
},
ToolReference {
tool_type: "function".to_string(),
name: "nonexistent_tool".to_string(),
},
],
tool_type: "function".to_string(),
}),
..Default::default()
};
let result = req.validate();
assert!(
result.is_err(),
"Should reject if any tool reference is invalid"
);
let err = result.unwrap_err().to_string();
assert!(
err.contains("tool 'nonexistent_tool' not found"),
"Error should mention the missing tool: {}",
err
);
}
use serde_json::json;
use sglang_router_rs::protocols::spec::{ChatMessage, UserMessageContent};
#[test]
fn test_chat_message_tagged_by_role_system() {
let json = json!({
"role": "system",
"content": "You are a helpful assistant"
});
let msg: ChatMessage = serde_json::from_value(json).unwrap();
match msg {
ChatMessage::System { content, .. } => {
assert_eq!(content, "You are a helpful assistant");
}
_ => panic!("Expected System variant"),
}
}
#[test]
fn test_chat_message_tagged_by_role_user() {
let json = json!({
"role": "user",
"content": "Hello"
});
let msg: ChatMessage = serde_json::from_value(json).unwrap();
match msg {
ChatMessage::User { content, .. } => match content {
UserMessageContent::Text(text) => assert_eq!(text, "Hello"),
_ => panic!("Expected text content"),
},
_ => panic!("Expected User variant"),
}
}
#[test]
fn test_chat_message_tagged_by_role_assistant() {
let json = json!({
"role": "assistant",
"content": "Hi there!"
});
let msg: ChatMessage = serde_json::from_value(json).unwrap();
match msg {
ChatMessage::Assistant { content, .. } => {
assert_eq!(content, Some("Hi there!".to_string()));
}
_ => panic!("Expected Assistant variant"),
}
}
#[test]
fn test_chat_message_tagged_by_role_tool() {
let json = json!({
"role": "tool",
"content": "Tool result",
"tool_call_id": "call_123"
});
let msg: ChatMessage = serde_json::from_value(json).unwrap();
match msg {
ChatMessage::Tool {
content,
tool_call_id,
} => {
assert_eq!(content, "Tool result");
assert_eq!(tool_call_id, "call_123");
}
_ => panic!("Expected Tool variant"),
}
}
#[test]
fn test_chat_message_wrong_role_rejected() {
let json = json!({
"role": "invalid_role",
"content": "test"
});
let result = serde_json::from_value::<ChatMessage>(json);
assert!(result.is_err(), "Should reject invalid role");
}
use serde_json::{from_str, json, to_string};
use sglang_router_rs::protocols::spec::{EmbeddingRequest, GenerationRequest};
#[test]
fn test_embedding_request_serialization_string_input() {
let req = EmbeddingRequest {
model: "test-emb".to_string(),
input: json!("hello"),
encoding_format: Some("float".to_string()),
user: Some("user-1".to_string()),
dimensions: Some(128),
rid: Some("rid-123".to_string()),
};
let serialized = to_string(&req).unwrap();
let deserialized: EmbeddingRequest = from_str(&serialized).unwrap();
assert_eq!(deserialized.model, req.model);
assert_eq!(deserialized.input, req.input);
assert_eq!(deserialized.encoding_format, req.encoding_format);
assert_eq!(deserialized.user, req.user);
assert_eq!(deserialized.dimensions, req.dimensions);
assert_eq!(deserialized.rid, req.rid);
}
#[test]
fn test_embedding_request_serialization_array_input() {
let req = EmbeddingRequest {
model: "test-emb".to_string(),
input: json!(["a", "b", "c"]),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
let serialized = to_string(&req).unwrap();
let de: EmbeddingRequest = from_str(&serialized).unwrap();
assert_eq!(de.model, req.model);
assert_eq!(de.input, req.input);
}
#[test]
fn test_embedding_generation_request_trait_string() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: json!("hello"),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
assert!(!req.is_stream());
assert_eq!(req.get_model(), Some("emb-model"));
assert_eq!(req.extract_text_for_routing(), "hello");
}
#[test]
fn test_embedding_generation_request_trait_array() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: json!(["hello", "world"]),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
assert_eq!(req.extract_text_for_routing(), "hello world");
}
#[test]
fn test_embedding_generation_request_trait_non_text() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: json!({"tokens": [1, 2, 3]}),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
assert_eq!(req.extract_text_for_routing(), "");
}
#[test]
fn test_embedding_generation_request_trait_mixed_array_ignores_nested() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: json!(["a", ["b", "c"], 123, {"k": "v"}]),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
// Only top-level string elements are extracted
assert_eq!(req.extract_text_for_routing(), "a");
}
// Protocol specification tests
// These tests were originally in src/protocols/spec.rs and have been moved here
// to reduce the size of that file and improve test organization.
mod chat_completion;
mod chat_message;
mod embedding;
mod rerank;
use serde_json::{from_str, to_string, Number, Value};
use sglang_router_rs::protocols::spec::{
default_model_name, GenerationRequest, RerankRequest, RerankResponse, RerankResult,
StringOrArray, UsageInfo, V1RerankReqInput,
};
use std::collections::HashMap;
#[test]
fn test_rerank_request_serialization() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
model: "test-model".to_string(),
top_k: Some(5),
return_documents: true,
rid: Some(StringOrArray::String("req-123".to_string())),
user: Some("user-456".to_string()),
};
let serialized = to_string(&request).unwrap();
let deserialized: RerankRequest = from_str(&serialized).unwrap();
assert_eq!(deserialized.query, request.query);
assert_eq!(deserialized.documents, request.documents);
assert_eq!(deserialized.model, request.model);
assert_eq!(deserialized.top_k, request.top_k);
assert_eq!(deserialized.return_documents, request.return_documents);
assert_eq!(deserialized.rid, request.rid);
assert_eq!(deserialized.user, request.user);
}
#[test]
fn test_rerank_request_deserialization_with_defaults() {
let json = r#"{
"query": "test query",
"documents": ["doc1", "doc2"]
}"#;
let request: RerankRequest = from_str(json).unwrap();
assert_eq!(request.query, "test query");
assert_eq!(request.documents, vec!["doc1", "doc2"]);
assert_eq!(request.model, default_model_name());
assert_eq!(request.top_k, None);
assert!(request.return_documents);
assert_eq!(request.rid, None);
assert_eq!(request.user, None);
}
#[test]
fn test_rerank_request_validation_success() {
let request = RerankRequest {
query: "valid query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
model: "test-model".to_string(),
top_k: Some(2),
return_documents: true,
rid: None,
user: None,
};
assert!(request.validate().is_ok());
}
#[test]
fn test_rerank_request_validation_empty_query() {
let request = RerankRequest {
query: "".to_string(),
documents: vec!["doc1".to_string()],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: None,
user: None,
};
let result = request.validate();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Query cannot be empty");
}
#[test]
fn test_rerank_request_validation_whitespace_query() {
let request = RerankRequest {
query: " ".to_string(),
documents: vec!["doc1".to_string()],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: None,
user: None,
};
let result = request.validate();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Query cannot be empty");
}
#[test]
fn test_rerank_request_validation_empty_documents() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec![],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: None,
user: None,
};
let result = request.validate();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "Documents list cannot be empty");
}
#[test]
fn test_rerank_request_validation_top_k_zero() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
model: "test-model".to_string(),
top_k: Some(0),
return_documents: true,
rid: None,
user: None,
};
let result = request.validate();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "top_k must be greater than 0");
}
#[test]
fn test_rerank_request_validation_top_k_greater_than_docs() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
model: "test-model".to_string(),
top_k: Some(5),
return_documents: true,
rid: None,
user: None,
};
// This should pass but log a warning
assert!(request.validate().is_ok());
}
#[test]
fn test_rerank_request_effective_top_k() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()],
model: "test-model".to_string(),
top_k: Some(2),
return_documents: true,
rid: None,
user: None,
};
assert_eq!(request.effective_top_k(), 2);
}
#[test]
fn test_rerank_request_effective_top_k_none() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: None,
user: None,
};
assert_eq!(request.effective_top_k(), 3);
}
#[test]
fn test_rerank_response_creation() {
let results = vec![
RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
},
RerankResult {
score: 0.6,
document: Some("doc2".to_string()),
index: 1,
meta_info: None,
},
];
let response = RerankResponse::new(
results.clone(),
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
assert_eq!(response.results.len(), 2);
assert_eq!(response.model, "test-model");
assert_eq!(
response.id,
Some(StringOrArray::String("req-123".to_string()))
);
assert_eq!(response.object, "rerank");
assert!(response.created > 0);
}
#[test]
fn test_rerank_response_serialization() {
let results = vec![RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
}];
let response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
let serialized = to_string(&response).unwrap();
let deserialized: RerankResponse = from_str(&serialized).unwrap();
assert_eq!(deserialized.results.len(), response.results.len());
assert_eq!(deserialized.model, response.model);
assert_eq!(deserialized.id, response.id);
assert_eq!(deserialized.object, response.object);
}
#[test]
fn test_rerank_response_sort_by_score() {
let results = vec![
RerankResult {
score: 0.6,
document: Some("doc2".to_string()),
index: 1,
meta_info: None,
},
RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
},
RerankResult {
score: 0.4,
document: Some("doc3".to_string()),
index: 2,
meta_info: None,
},
];
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.sort_by_score();
assert_eq!(response.results[0].score, 0.8);
assert_eq!(response.results[0].index, 0);
assert_eq!(response.results[1].score, 0.6);
assert_eq!(response.results[1].index, 1);
assert_eq!(response.results[2].score, 0.4);
assert_eq!(response.results[2].index, 2);
}
#[test]
fn test_rerank_response_apply_top_k() {
let results = vec![
RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
},
RerankResult {
score: 0.6,
document: Some("doc2".to_string()),
index: 1,
meta_info: None,
},
RerankResult {
score: 0.4,
document: Some("doc3".to_string()),
index: 2,
meta_info: None,
},
];
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.apply_top_k(2);
assert_eq!(response.results.len(), 2);
assert_eq!(response.results[0].score, 0.8);
assert_eq!(response.results[1].score, 0.6);
}
#[test]
fn test_rerank_response_apply_top_k_larger_than_results() {
let results = vec![RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
}];
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.apply_top_k(5);
assert_eq!(response.results.len(), 1);
}
#[test]
fn test_rerank_response_drop_documents() {
let results = vec![RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
}];
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.drop_documents();
assert_eq!(response.results[0].document, None);
}
#[test]
fn test_rerank_result_serialization() {
let result = RerankResult {
score: 0.85,
document: Some("test document".to_string()),
index: 42,
meta_info: Some(HashMap::from([
("confidence".to_string(), Value::String("high".to_string())),
(
"processing_time".to_string(),
Value::Number(Number::from(150)),
),
])),
};
let serialized = to_string(&result).unwrap();
let deserialized: RerankResult = from_str(&serialized).unwrap();
assert_eq!(deserialized.score, result.score);
assert_eq!(deserialized.document, result.document);
assert_eq!(deserialized.index, result.index);
assert_eq!(deserialized.meta_info, result.meta_info);
}
#[test]
fn test_rerank_result_serialization_without_document() {
let result = RerankResult {
score: 0.85,
document: None,
index: 42,
meta_info: None,
};
let serialized = to_string(&result).unwrap();
let deserialized: RerankResult = from_str(&serialized).unwrap();
assert_eq!(deserialized.score, result.score);
assert_eq!(deserialized.document, result.document);
assert_eq!(deserialized.index, result.index);
assert_eq!(deserialized.meta_info, result.meta_info);
}
#[test]
fn test_v1_rerank_req_input_serialization() {
let v1_input = V1RerankReqInput {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
};
let serialized = to_string(&v1_input).unwrap();
let deserialized: V1RerankReqInput = from_str(&serialized).unwrap();
assert_eq!(deserialized.query, v1_input.query);
assert_eq!(deserialized.documents, v1_input.documents);
}
#[test]
fn test_v1_to_rerank_request_conversion() {
let v1_input = V1RerankReqInput {
query: "test query".to_string(),
documents: vec!["doc1".to_string(), "doc2".to_string()],
};
let request: RerankRequest = v1_input.into();
assert_eq!(request.query, "test query");
assert_eq!(request.documents, vec!["doc1", "doc2"]);
assert_eq!(request.model, default_model_name());
assert_eq!(request.top_k, None);
assert!(request.return_documents);
assert_eq!(request.rid, None);
assert_eq!(request.user, None);
}
#[test]
fn test_rerank_request_generation_request_trait() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string()],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: None,
user: None,
};
assert_eq!(request.get_model(), Some("test-model"));
assert!(!request.is_stream());
assert_eq!(request.extract_text_for_routing(), "test query");
}
#[test]
fn test_rerank_request_very_long_query() {
let long_query = "a".repeat(100000);
let request = RerankRequest {
query: long_query,
documents: vec!["doc1".to_string()],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: None,
user: None,
};
assert!(request.validate().is_ok());
}
#[test]
fn test_rerank_request_many_documents() {
let documents: Vec<String> = (0..1000).map(|i| format!("doc{}", i)).collect();
let request = RerankRequest {
query: "test query".to_string(),
documents,
model: "test-model".to_string(),
top_k: Some(100),
return_documents: true,
rid: None,
user: None,
};
assert!(request.validate().is_ok());
assert_eq!(request.effective_top_k(), 100);
}
#[test]
fn test_rerank_request_special_characters() {
let request = RerankRequest {
query: "query with émojis 🚀 and unicode: 测试".to_string(),
documents: vec![
"doc with émojis 🎉".to_string(),
"doc with unicode: 测试".to_string(),
],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: Some(StringOrArray::String("req-🚀-123".to_string())),
user: Some("user-🎉-456".to_string()),
};
assert!(request.validate().is_ok());
}
#[test]
fn test_rerank_request_rid_array() {
let request = RerankRequest {
query: "test query".to_string(),
documents: vec!["doc1".to_string()],
model: "test-model".to_string(),
top_k: None,
return_documents: true,
rid: Some(StringOrArray::Array(vec![
"req1".to_string(),
"req2".to_string(),
])),
user: None,
};
assert!(request.validate().is_ok());
}
#[test]
fn test_rerank_response_with_usage_info() {
let results = vec![RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
}];
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.usage = Some(UsageInfo {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
reasoning_tokens: None,
prompt_tokens_details: None,
});
let serialized = to_string(&response).unwrap();
let deserialized: RerankResponse = from_str(&serialized).unwrap();
assert!(deserialized.usage.is_some());
let usage = deserialized.usage.unwrap();
assert_eq!(usage.prompt_tokens, 100);
assert_eq!(usage.completion_tokens, 50);
assert_eq!(usage.total_tokens, 150);
}
#[test]
fn test_full_rerank_workflow() {
// Create request
let request = RerankRequest {
query: "machine learning".to_string(),
documents: vec![
"Introduction to machine learning algorithms".to_string(),
"Deep learning for computer vision".to_string(),
"Natural language processing basics".to_string(),
"Statistics and probability theory".to_string(),
],
model: "rerank-model".to_string(),
top_k: Some(2),
return_documents: true,
rid: Some(StringOrArray::String("req-123".to_string())),
user: Some("user-456".to_string()),
};
// Validate request
assert!(request.validate().is_ok());
// Simulate reranking results (in real scenario, this would come from the model)
let results = vec![
RerankResult {
score: 0.95,
document: Some("Introduction to machine learning algorithms".to_string()),
index: 0,
meta_info: None,
},
RerankResult {
score: 0.87,
document: Some("Deep learning for computer vision".to_string()),
index: 1,
meta_info: None,
},
RerankResult {
score: 0.72,
document: Some("Natural language processing basics".to_string()),
index: 2,
meta_info: None,
},
RerankResult {
score: 0.45,
document: Some("Statistics and probability theory".to_string()),
index: 3,
meta_info: None,
},
];
// Create response
let mut response = RerankResponse::new(results, request.model.clone(), request.rid.clone());
// Sort by score
response.sort_by_score();
// Apply top_k
response.apply_top_k(request.effective_top_k());
assert_eq!(response.results.len(), 2);
assert_eq!(response.results[0].score, 0.95);
assert_eq!(response.results[0].index, 0);
assert_eq!(response.results[1].score, 0.87);
assert_eq!(response.results[1].index, 1);
assert_eq!(response.model, "rerank-model");
// Serialize and deserialize
let serialized = to_string(&response).unwrap();
let deserialized: RerankResponse = from_str(&serialized).unwrap();
assert_eq!(deserialized.results.len(), 2);
assert_eq!(deserialized.model, response.model);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment