Unverified Commit 03ce92e5 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

router-spec: Reorder `ChatCompletionRequest` and fix validation logic (#10675)

parent 00eb5eb7
...@@ -48,50 +48,15 @@ fn default_generate_request() -> GenerateRequest { ...@@ -48,50 +48,15 @@ fn default_generate_request() -> GenerateRequest {
} }
/// Create a default ChatCompletionRequest for benchmarks with minimal fields set /// Create a default ChatCompletionRequest for benchmarks with minimal fields set
#[allow(deprecated)]
fn default_chat_completion_request() -> ChatCompletionRequest { fn default_chat_completion_request() -> ChatCompletionRequest {
ChatCompletionRequest { ChatCompletionRequest {
model: String::new(), // Required fields in OpenAI order
messages: vec![], messages: vec![],
max_tokens: None, model: String::new(),
max_completion_tokens: None,
temperature: None, // Use default for all other fields
top_p: None, ..Default::default()
n: None,
stream: false,
stream_options: None,
stop: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
logprobs: false,
top_logprobs: None,
user: None,
response_format: None,
seed: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
function_call: None,
functions: None,
// SGLang Extensions
top_k: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
stop_token_ids: None,
no_stop_trim: false,
ignore_eos: false,
continue_final_message: false,
skip_special_tokens: true,
// SGLang Extensions
lora_path: None,
session_params: None,
separate_reasoning: true,
stream_reasoning: true,
chat_template_kwargs: None,
return_hidden_states: false,
} }
} }
...@@ -161,6 +126,7 @@ fn create_sample_generate_request() -> GenerateRequest { ...@@ -161,6 +126,7 @@ fn create_sample_generate_request() -> GenerateRequest {
} }
} }
#[allow(deprecated)]
fn create_sample_chat_completion_request() -> ChatCompletionRequest { fn create_sample_chat_completion_request() -> ChatCompletionRequest {
ChatCompletionRequest { ChatCompletionRequest {
model: "gpt-3.5-turbo".to_string(), model: "gpt-3.5-turbo".to_string(),
...@@ -205,6 +171,7 @@ fn create_sample_completion_request() -> CompletionRequest { ...@@ -205,6 +171,7 @@ fn create_sample_completion_request() -> CompletionRequest {
} }
} }
#[allow(deprecated)]
fn create_large_chat_completion_request() -> ChatCompletionRequest { fn create_large_chat_completion_request() -> ChatCompletionRequest {
let mut messages = vec![ChatMessage::System { let mut messages = vec![ChatMessage::System {
role: "system".to_string(), role: "system".to_string(),
...@@ -240,7 +207,6 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest { ...@@ -240,7 +207,6 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest {
presence_penalty: Some(0.1), presence_penalty: Some(0.1),
frequency_penalty: Some(0.1), frequency_penalty: Some(0.1),
top_logprobs: Some(5), top_logprobs: Some(5),
user: Some("benchmark_user".to_string()),
seed: Some(42), seed: Some(42),
parallel_tool_calls: Some(true), parallel_tool_calls: Some(true),
..default_chat_completion_request() ..default_chat_completion_request()
......
...@@ -179,97 +179,125 @@ pub struct FunctionCallDelta { ...@@ -179,97 +179,125 @@ pub struct FunctionCallDelta {
// ============= Request ============= // ============= Request =============
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct ChatCompletionRequest { pub struct ChatCompletionRequest {
/// A list of messages comprising the conversation so far
pub messages: Vec<ChatMessage>,
/// ID of the model to use /// ID of the model to use
pub model: String, pub model: String,
/// A list of messages comprising the conversation so far /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
pub messages: Vec<ChatMessage>, #[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
/// What sampling temperature to use, between 0 and 2 /// Deprecated: Replaced by tool_choice
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>, #[deprecated(note = "Use tool_choice instead")]
pub function_call: Option<FunctionCall>,
/// An alternative to sampling with temperature /// Deprecated: Replaced by tools
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>, #[deprecated(note = "Use tools instead")]
pub functions: Option<Vec<Function>>,
/// How many chat completion choices to generate for each input message /// Modify the likelihood of specified tokens appearing in the completion
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>, pub logit_bias: Option<HashMap<String, f32>>,
/// If set, partial message deltas will be sent /// Whether to return log probabilities of the output tokens
#[serde(default)] #[serde(default)]
pub stream: bool, pub logprobs: bool,
/// Options for streaming response /// Deprecated: Replaced by max_completion_tokens
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>, #[deprecated(note = "Use max_completion_tokens instead")]
pub max_tokens: Option<u32>,
/// Up to 4 sequences where the API will stop generating further tokens /// An upper bound for the number of tokens that can be generated for a completion
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StringOrArray>, pub max_completion_tokens: Option<u32>,
/// The maximum number of tokens to generate /// Developer-defined tags and values used for filtering completions in the dashboard
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>, pub metadata: Option<HashMap<String, String>>,
/// An upper bound for the number of tokens that can be generated for a completion /// Output types that you would like the model to generate for this request
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub max_completion_tokens: Option<u32>, pub modalities: Option<Vec<String>>,
/// How many chat completion choices to generate for each input message
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
/// Whether to enable parallel function calling during tool use
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>, pub presence_penalty: Option<f32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far /// Cache key for prompts (beta feature)
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>, pub prompt_cache_key: Option<String>,
/// Modify the likelihood of specified tokens appearing in the completion /// Effort level for reasoning models (low, medium, high)
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, f32>>, pub reasoning_effort: Option<String>,
/// A unique identifier representing your end-user /// An object specifying the format that the model must output
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>, pub response_format: Option<ResponseFormat>,
/// If specified, our system will make a best effort to sample deterministically /// Safety identifier for content moderation
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub safety_identifier: Option<String>,
/// Deprecated: This feature is in Legacy mode
#[serde(skip_serializing_if = "Option::is_none")]
#[deprecated(note = "This feature is in Legacy mode")]
pub seed: Option<i64>, pub seed: Option<i64>,
/// Whether to return log probabilities of the output tokens /// The service tier to use for this request
#[serde(default)] #[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: bool, pub service_tier: Option<String>,
/// An integer between 0 and 20 specifying the number of most likely tokens to return /// Up to 4 sequences where the API will stop generating further tokens
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<u32>, pub stop: Option<StringOrArray>,
/// An object specifying the format that the model must output /// If set, partial message deltas will be sent
#[serde(default)]
pub stream: bool,
/// Options for streaming response
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>, pub stream_options: Option<StreamOptions>,
/// A list of tools the model may call /// What sampling temperature to use, between 0 and 2
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>, pub temperature: Option<f32>,
/// Controls which (if any) tool is called by the model /// Controls which (if any) tool is called by the model
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>, pub tool_choice: Option<ToolChoice>,
/// Whether to enable parallel function calling during tool use /// A list of tools the model may call
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>, pub tools: Option<Vec<Tool>>,
/// Deprecated: use tools instead /// An integer between 0 and 20 specifying the number of most likely tokens to return
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub functions: Option<Vec<Function>>, pub top_logprobs: Option<u32>,
/// An alternative to sampling with temperature
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
/// Deprecated: use tool_choice instead /// Verbosity level for debugging
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>, pub verbosity: Option<i32>,
// ============= SGLang Extensions ============= // ============= SGLang Extensions =============
/// Top-k sampling parameter (-1 to disable) /// Top-k sampling parameter (-1 to disable)
...@@ -316,7 +344,6 @@ pub struct ChatCompletionRequest { ...@@ -316,7 +344,6 @@ pub struct ChatCompletionRequest {
#[serde(default = "default_true")] #[serde(default = "default_true")]
pub skip_special_tokens: bool, pub skip_special_tokens: bool,
// ============= SGLang Extensions =============
/// Path to LoRA adapter(s) for model customization /// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub lora_path: Option<LoRAPath>, pub lora_path: Option<LoRAPath>,
......
...@@ -563,6 +563,7 @@ impl StopConditionsProvider for ChatCompletionRequest { ...@@ -563,6 +563,7 @@ impl StopConditionsProvider for ChatCompletionRequest {
} }
impl TokenLimitsProvider for ChatCompletionRequest { impl TokenLimitsProvider for ChatCompletionRequest {
#[allow(deprecated)]
fn get_max_tokens(&self) -> Option<u32> { fn get_max_tokens(&self) -> Option<u32> {
// Prefer max_completion_tokens over max_tokens if both are set // Prefer max_completion_tokens over max_tokens if both are set
self.max_completion_tokens.or(self.max_tokens) self.max_completion_tokens.or(self.max_tokens)
...@@ -656,19 +657,13 @@ impl ChatCompletionRequest { ...@@ -656,19 +657,13 @@ impl ChatCompletionRequest {
/// Validate chat API specific logprobs requirements /// Validate chat API specific logprobs requirements
pub fn validate_chat_logprobs(&self) -> Result<(), ValidationError> { pub fn validate_chat_logprobs(&self) -> Result<(), ValidationError> {
// In chat API, if logprobs=true, top_logprobs must be specified // OpenAI rule: If top_logprobs is specified, logprobs must be true
if self.logprobs && self.top_logprobs.is_none() { // But logprobs=true without top_logprobs is valid (returns basic logprobs)
return Err(ValidationError::MissingRequired {
parameter: "top_logprobs".to_string(),
});
}
// If top_logprobs is specified, logprobs should be true
if self.top_logprobs.is_some() && !self.logprobs { if self.top_logprobs.is_some() && !self.logprobs {
return Err(ValidationError::InvalidValue { return Err(ValidationError::InvalidValue {
parameter: "logprobs".to_string(), parameter: "top_logprobs".to_string(),
value: "false".to_string(), value: self.top_logprobs.unwrap().to_string(),
reason: "must be true when top_logprobs is specified".to_string(), reason: "top_logprobs is only allowed when logprobs is enabled".to_string(),
}); });
} }
...@@ -676,6 +671,7 @@ impl ChatCompletionRequest { ...@@ -676,6 +671,7 @@ impl ChatCompletionRequest {
} }
/// Validate cross-parameter relationships specific to chat completions /// Validate cross-parameter relationships specific to chat completions
#[allow(deprecated)]
pub fn validate_chat_cross_parameters(&self) -> Result<(), ValidationError> { pub fn validate_chat_cross_parameters(&self) -> Result<(), ValidationError> {
// Validate that both max_tokens and max_completion_tokens aren't set // Validate that both max_tokens and max_completion_tokens aren't set
utils::validate_conflicting_parameters( utils::validate_conflicting_parameters(
...@@ -871,53 +867,24 @@ mod tests { ...@@ -871,53 +867,24 @@ mod tests {
mod chat_tests { mod chat_tests {
use super::*; use super::*;
#[allow(deprecated)]
fn create_valid_chat_request() -> ChatCompletionRequest { fn create_valid_chat_request() -> ChatCompletionRequest {
ChatCompletionRequest { ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![ChatMessage::User { messages: vec![ChatMessage::User {
role: "user".to_string(), role: "user".to_string(),
content: UserMessageContent::Text("Hello".to_string()), content: UserMessageContent::Text("Hello".to_string()),
name: None, name: None,
}], }],
model: "gpt-4".to_string(),
// Set specific fields we want to test
temperature: Some(1.0), temperature: Some(1.0),
top_p: Some(0.9), top_p: Some(0.9),
n: Some(1), n: Some(1),
stream: false,
stream_options: None,
stop: None,
max_tokens: Some(100), max_tokens: Some(100),
max_completion_tokens: None,
presence_penalty: Some(0.0),
frequency_penalty: Some(0.0), frequency_penalty: Some(0.0),
logit_bias: None, presence_penalty: Some(0.0),
user: None, // Use default for all other fields
seed: None, ..Default::default()
logprobs: false,
top_logprobs: None,
response_format: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
functions: None,
function_call: None,
// SGLang extensions
top_k: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
stop_token_ids: None,
no_stop_trim: false,
ignore_eos: false,
continue_final_message: false,
skip_special_tokens: true,
lora_path: None,
session_params: None,
separate_reasoning: true,
stream_reasoning: true,
chat_template_kwargs: None,
return_hidden_states: false,
} }
} }
...@@ -938,19 +905,47 @@ mod tests { ...@@ -938,19 +905,47 @@ mod tests {
} }
#[test] #[test]
fn test_chat_conflicts() { #[allow(deprecated)]
fn test_chat_cross_parameter_conflicts() {
let mut request = create_valid_chat_request(); let mut request = create_valid_chat_request();
// Conflicting max_tokens // Test 1: max_tokens vs max_completion_tokens conflict
request.max_tokens = Some(100); request.max_tokens = Some(100);
request.max_completion_tokens = Some(200); request.max_completion_tokens = Some(200);
assert!(request.validate().is_err()); assert!(
request.validate().is_err(),
"Should reject both max_tokens and max_completion_tokens"
);
// Logprobs without top_logprobs // Reset for next test
request.max_tokens = None; request.max_tokens = None;
request.max_completion_tokens = None;
// Test 2: tools vs functions conflict (deprecated)
request.tools = Some(vec![]);
request.functions = Some(vec![]);
assert!(
request.validate().is_err(),
"Should reject both tools and functions"
);
// Test 3: logprobs=true without top_logprobs should be valid
let mut request = create_valid_chat_request();
request.logprobs = true; request.logprobs = true;
request.top_logprobs = None; request.top_logprobs = None;
assert!(request.validate().is_err()); assert!(
request.validate().is_ok(),
"logprobs=true without top_logprobs should be valid"
);
// Test 4: top_logprobs without logprobs=true should fail (OpenAI rule)
let mut request = create_valid_chat_request();
request.logprobs = false;
request.top_logprobs = Some(5);
assert!(
request.validate().is_err(),
"top_logprobs without logprobs=true should fail"
);
} }
#[test] #[test]
...@@ -1097,14 +1092,17 @@ mod tests { ...@@ -1097,14 +1092,17 @@ mod tests {
fn test_logprobs_validation() { fn test_logprobs_validation() {
let mut request = create_valid_chat_request(); let mut request = create_valid_chat_request();
// Valid logprobs configuration // Valid logprobs configuration with top_logprobs
request.logprobs = true; request.logprobs = true;
request.top_logprobs = Some(10); request.top_logprobs = Some(10);
assert!(request.validate().is_ok()); assert!(request.validate().is_ok());
// logprobs=true without top_logprobs should fail // logprobs=true without top_logprobs should be valid (OpenAI behavior)
request.top_logprobs = None; request.top_logprobs = None;
assert!(request.validate().is_err()); assert!(
request.validate().is_ok(),
"logprobs=true without top_logprobs should be valid"
);
// top_logprobs without logprobs=true should fail // top_logprobs without logprobs=true should fail
request.logprobs = false; request.logprobs = false;
...@@ -1137,6 +1135,7 @@ mod tests { ...@@ -1137,6 +1135,7 @@ mod tests {
} }
#[test] #[test]
#[allow(deprecated)]
fn test_min_max_tokens_validation() { fn test_min_max_tokens_validation() {
let mut request = create_valid_chat_request(); let mut request = create_valid_chat_request();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment