"vscode:/vscode.git/clone" did not exist on "e5ddc62b3715f2f8cb1bcaa5d327dcc16ff2afa0"
Unverified Commit d2478cd4 authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] Fix response api related spec (#11621)

parent 30ea4c46
...@@ -69,6 +69,8 @@ fn generate_request_id(path: &str) -> String { ...@@ -69,6 +69,8 @@ fn generate_request_id(path: &str) -> String {
"cmpl-" "cmpl-"
} else if path.contains("/generate") { } else if path.contains("/generate") {
"gnt-" "gnt-"
} else if path.contains("/responses") {
"resp-"
} else { } else {
"req-" "req-"
}; };
......
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{to_value, Map, Number, Value}; use serde_json::{to_value, Map, Value};
use std::collections::HashMap; use std::collections::HashMap;
use validator::Validate; use validator::Validate;
...@@ -1325,10 +1325,6 @@ impl ResponsesUsage { ...@@ -1325,10 +1325,6 @@ impl ResponsesUsage {
} }
} }
fn generate_request_id() -> String {
format!("resp_{}", uuid::Uuid::new_v4().simple())
}
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponsesRequest { pub struct ResponsesRequest {
/// Run the request in the background /// Run the request in the background
...@@ -1419,8 +1415,8 @@ pub struct ResponsesRequest { ...@@ -1419,8 +1415,8 @@ pub struct ResponsesRequest {
pub user: Option<String>, pub user: Option<String>,
/// Request ID /// Request ID
#[serde(default = "generate_request_id")] #[serde(skip_serializing_if = "Option::is_none")]
pub request_id: String, pub request_id: Option<String>,
/// Request priority /// Request priority
#[serde(default)] #[serde(default)]
...@@ -1438,15 +1434,15 @@ pub struct ResponsesRequest { ...@@ -1438,15 +1434,15 @@ pub struct ResponsesRequest {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StringOrArray>, pub stop: Option<StringOrArray>,
/// Top-k sampling parameter /// Top-k sampling parameter (SGLang extension)
#[serde(default = "default_top_k")] #[serde(default = "default_top_k")]
pub top_k: i32, pub top_k: i32,
/// Min-p sampling parameter /// Min-p sampling parameter (SGLang extension)
#[serde(default)] #[serde(default)]
pub min_p: f32, pub min_p: f32,
/// Repetition penalty /// Repetition penalty (SGLang extension)
#[serde(default = "default_repetition_penalty")] #[serde(default = "default_repetition_penalty")]
pub repetition_penalty: f32, pub repetition_penalty: f32,
} }
...@@ -1491,7 +1487,7 @@ impl Default for ResponsesRequest { ...@@ -1491,7 +1487,7 @@ impl Default for ResponsesRequest {
top_p: None, top_p: None,
truncation: None, truncation: None,
user: None, user: None,
request_id: generate_request_id(), request_id: None,
priority: 0, priority: 0,
frequency_penalty: None, frequency_penalty: None,
presence_penalty: None, presence_penalty: None,
...@@ -1503,101 +1499,6 @@ impl Default for ResponsesRequest { ...@@ -1503,101 +1499,6 @@ impl Default for ResponsesRequest {
} }
} }
impl ResponsesRequest {
/// Default sampling parameters
const DEFAULT_TEMPERATURE: f32 = 0.7;
const DEFAULT_TOP_P: f32 = 1.0;
/// Convert to sampling parameters for generation
pub fn to_sampling_params(
&self,
default_max_tokens: u32,
default_params: Option<HashMap<String, Value>>,
) -> HashMap<String, Value> {
let mut params = HashMap::new();
// Use max_output_tokens if available
let max_tokens = if let Some(max_output) = self.max_output_tokens {
std::cmp::min(max_output, default_max_tokens)
} else {
default_max_tokens
};
// Avoid exceeding context length by minus 1 token
let max_tokens = max_tokens.saturating_sub(1);
// Temperature
let temperature = self.temperature.unwrap_or_else(|| {
default_params
.as_ref()
.and_then(|p| p.get("temperature"))
.and_then(|v| v.as_f64())
.map(|v| v as f32)
.unwrap_or(Self::DEFAULT_TEMPERATURE)
});
// Top-p
let top_p = self.top_p.unwrap_or_else(|| {
default_params
.as_ref()
.and_then(|p| p.get("top_p"))
.and_then(|v| v.as_f64())
.map(|v| v as f32)
.unwrap_or(Self::DEFAULT_TOP_P)
});
params.insert(
"max_new_tokens".to_string(),
Value::Number(Number::from(max_tokens)),
);
params.insert(
"temperature".to_string(),
Value::Number(Number::from_f64(temperature as f64).unwrap()),
);
params.insert(
"top_p".to_string(),
Value::Number(Number::from_f64(top_p as f64).unwrap()),
);
if let Some(fp) = self.frequency_penalty {
params.insert(
"frequency_penalty".to_string(),
Value::Number(Number::from_f64(fp as f64).unwrap()),
);
}
if let Some(pp) = self.presence_penalty {
params.insert(
"presence_penalty".to_string(),
Value::Number(Number::from_f64(pp as f64).unwrap()),
);
}
params.insert("top_k".to_string(), Value::Number(Number::from(self.top_k)));
params.insert(
"min_p".to_string(),
Value::Number(Number::from_f64(self.min_p as f64).unwrap()),
);
params.insert(
"repetition_penalty".to_string(),
Value::Number(Number::from_f64(self.repetition_penalty as f64).unwrap()),
);
if let Some(ref stop) = self.stop {
match to_value(stop) {
Ok(value) => params.insert("stop".to_string(), value),
Err(_) => params.insert("stop".to_string(), Value::Null),
};
}
// Apply any additional default parameters
if let Some(default_params) = default_params {
for (key, value) in default_params {
params.entry(key).or_insert(value);
}
}
params
}
}
impl GenerationRequest for ResponsesRequest { impl GenerationRequest for ResponsesRequest {
fn is_stream(&self) -> bool { fn is_stream(&self) -> bool {
self.stream.unwrap_or(false) self.stream.unwrap_or(false)
...@@ -1776,7 +1677,10 @@ impl ResponsesResponse { ...@@ -1776,7 +1677,10 @@ impl ResponsesResponse {
usage: Option<UsageInfo>, usage: Option<UsageInfo>,
) -> Self { ) -> Self {
Self { Self {
id: request.request_id.clone(), id: request
.request_id
.clone()
.expect("request_id should be set by middleware"),
object: "response".to_string(), object: "response".to_string(),
created_at: created_time, created_at: created_time,
status, status,
...@@ -2535,9 +2439,6 @@ pub enum GenerateFinishReason { ...@@ -2535,9 +2439,6 @@ pub enum GenerateFinishReason {
Other(Value), Other(Value),
} }
// Constants for rerank API
pub const DEFAULT_MODEL_NAME: &str = "default";
/// Rerank request for scoring documents against a query /// Rerank request for scoring documents against a query
/// Used for RAG systems and document relevance scoring /// Used for RAG systems and document relevance scoring
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
...@@ -2549,7 +2450,7 @@ pub struct RerankRequest { ...@@ -2549,7 +2450,7 @@ pub struct RerankRequest {
pub documents: Vec<String>, pub documents: Vec<String>,
/// Model to use for reranking /// Model to use for reranking
#[serde(default = "default_model_name")] #[serde(default = "default_model")]
pub model: String, pub model: String,
/// Maximum number of documents to return (optional) /// Maximum number of documents to return (optional)
...@@ -2567,10 +2468,6 @@ pub struct RerankRequest { ...@@ -2567,10 +2468,6 @@ pub struct RerankRequest {
pub user: Option<String>, pub user: Option<String>,
} }
pub fn default_model_name() -> String {
DEFAULT_MODEL_NAME.to_string()
}
fn default_return_documents() -> bool { fn default_return_documents() -> bool {
true true
} }
...@@ -2634,7 +2531,7 @@ impl From<V1RerankReqInput> for RerankRequest { ...@@ -2634,7 +2531,7 @@ impl From<V1RerankReqInput> for RerankRequest {
RerankRequest { RerankRequest {
query: v1.query, query: v1.query,
documents: v1.documents, documents: v1.documents,
model: default_model_name(), model: default_model(),
top_k: None, top_k: None,
return_documents: true, return_documents: true,
rid: None, rid: None,
......
...@@ -2156,7 +2156,7 @@ mod rerank_tests { ...@@ -2156,7 +2156,7 @@ mod rerank_tests {
assert!(body_json.get("model").is_some()); assert!(body_json.get("model").is_some());
// V1 API should use default model name // V1 API should use default model name
assert_eq!(body_json["model"], "default"); assert_eq!(body_json["model"], "unknown");
let results = body_json["results"].as_array().unwrap(); let results = body_json["results"].as_array().unwrap();
assert_eq!(results.len(), 3); // All documents should be returned assert_eq!(results.len(), 3); // All documents should be returned
......
...@@ -115,7 +115,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { ...@@ -115,7 +115,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
top_p: None, top_p: None,
truncation: Some(Truncation::Disabled), truncation: Some(Truncation::Disabled),
user: None, user: None,
request_id: "resp_test_mcp_e2e".to_string(), request_id: Some("resp_test_mcp_e2e".to_string()),
priority: 0, priority: 0,
frequency_penalty: Some(0.0), frequency_penalty: Some(0.0),
presence_penalty: Some(0.0), presence_penalty: Some(0.0),
...@@ -361,7 +361,7 @@ fn test_responses_request_creation() { ...@@ -361,7 +361,7 @@ fn test_responses_request_creation() {
top_p: Some(0.9), top_p: Some(0.9),
truncation: Some(Truncation::Disabled), truncation: Some(Truncation::Disabled),
user: Some("test-user".to_string()), user: Some("test-user".to_string()),
request_id: "resp_test123".to_string(), request_id: Some("resp_test123".to_string()),
priority: 0, priority: 0,
frequency_penalty: Some(0.0), frequency_penalty: Some(0.0),
presence_penalty: Some(0.0), presence_penalty: Some(0.0),
...@@ -379,7 +379,8 @@ fn test_responses_request_creation() { ...@@ -379,7 +379,8 @@ fn test_responses_request_creation() {
} }
#[test] #[test]
fn test_sampling_params_conversion() { fn test_responses_request_sglang_extensions() {
// Test that SGLang-specific sampling parameters are present and serializable
let request = ResponsesRequest { let request = ResponsesRequest {
background: Some(false), background: Some(false),
include: None, include: None,
...@@ -389,37 +390,44 @@ fn test_sampling_params_conversion() { ...@@ -389,37 +390,44 @@ fn test_sampling_params_conversion() {
max_tool_calls: None, max_tool_calls: None,
metadata: None, metadata: None,
model: Some("test-model".to_string()), model: Some("test-model".to_string()),
parallel_tool_calls: Some(true), // Use default true parallel_tool_calls: Some(true),
previous_response_id: None, previous_response_id: None,
reasoning: None, reasoning: None,
service_tier: Some(ServiceTier::Auto), service_tier: Some(ServiceTier::Auto),
store: Some(true), // Use default true store: Some(true),
stream: Some(false), stream: Some(false),
temperature: Some(0.8), temperature: Some(0.8),
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)), tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: Some(vec![]), tools: Some(vec![]),
top_logprobs: Some(0), // Use default 0 top_logprobs: Some(0),
top_p: Some(0.95), top_p: Some(0.95),
truncation: Some(Truncation::Auto), truncation: Some(Truncation::Auto),
user: None, user: None,
request_id: "resp_test456".to_string(), request_id: Some("resp_test456".to_string()),
priority: 0, priority: 0,
frequency_penalty: Some(0.1), frequency_penalty: Some(0.1),
presence_penalty: Some(0.2), presence_penalty: Some(0.2),
stop: None, stop: None,
// SGLang-specific extensions:
top_k: 10, top_k: 10,
min_p: 0.05, min_p: 0.05,
repetition_penalty: 1.1, repetition_penalty: 1.1,
conversation: None, conversation: None,
}; };
let params = request.to_sampling_params(1000, None); // Verify SGLang extensions are present
assert_eq!(request.top_k, 10);
assert_eq!(request.min_p, 0.05);
assert_eq!(request.repetition_penalty, 1.1);
// Check that parameters are converted correctly // Verify serialization works with SGLang extensions
assert!(params.contains_key("temperature")); let json = serde_json::to_string(&request).expect("Serialization should work");
assert!(params.contains_key("top_p")); let parsed: ResponsesRequest =
assert!(params.contains_key("frequency_penalty")); serde_json::from_str(&json).expect("Deserialization should work");
assert!(params.contains_key("max_new_tokens"));
assert_eq!(parsed.top_k, 10);
assert_eq!(parsed.min_p, 0.05);
assert_eq!(parsed.repetition_penalty, 1.1);
} }
#[test] #[test]
...@@ -516,7 +524,7 @@ fn test_json_serialization() { ...@@ -516,7 +524,7 @@ fn test_json_serialization() {
top_p: Some(0.8), top_p: Some(0.8),
truncation: Some(Truncation::Auto), truncation: Some(Truncation::Auto),
user: Some("test_user".to_string()), user: Some("test_user".to_string()),
request_id: "resp_comprehensive_test".to_string(), request_id: Some("resp_comprehensive_test".to_string()),
priority: 1, priority: 1,
frequency_penalty: Some(0.3), frequency_penalty: Some(0.3),
presence_penalty: Some(0.4), presence_penalty: Some(0.4),
...@@ -531,7 +539,10 @@ fn test_json_serialization() { ...@@ -531,7 +539,10 @@ fn test_json_serialization() {
let parsed: ResponsesRequest = let parsed: ResponsesRequest =
serde_json::from_str(&json).expect("Deserialization should work"); serde_json::from_str(&json).expect("Deserialization should work");
assert_eq!(parsed.request_id, "resp_comprehensive_test"); assert_eq!(
parsed.request_id,
Some("resp_comprehensive_test".to_string())
);
assert_eq!(parsed.model, Some("gpt-4".to_string())); assert_eq!(parsed.model, Some("gpt-4".to_string()));
assert_eq!(parsed.background, Some(true)); assert_eq!(parsed.background, Some(true));
assert_eq!(parsed.stream, Some(true)); assert_eq!(parsed.stream, Some(true));
...@@ -643,7 +654,7 @@ async fn test_multi_turn_loop_with_mcp() { ...@@ -643,7 +654,7 @@ async fn test_multi_turn_loop_with_mcp() {
top_p: Some(1.0), top_p: Some(1.0),
truncation: Some(Truncation::Disabled), truncation: Some(Truncation::Disabled),
user: None, user: None,
request_id: "resp_multi_turn_test".to_string(), request_id: Some("resp_multi_turn_test".to_string()),
priority: 0, priority: 0,
frequency_penalty: Some(0.0), frequency_penalty: Some(0.0),
presence_penalty: Some(0.0), presence_penalty: Some(0.0),
...@@ -816,7 +827,7 @@ async fn test_max_tool_calls_limit() { ...@@ -816,7 +827,7 @@ async fn test_max_tool_calls_limit() {
top_p: Some(1.0), top_p: Some(1.0),
truncation: Some(Truncation::Disabled), truncation: Some(Truncation::Disabled),
user: None, user: None,
request_id: "resp_max_calls_test".to_string(), request_id: Some("resp_max_calls_test".to_string()),
priority: 0, priority: 0,
frequency_penalty: Some(0.0), frequency_penalty: Some(0.0),
presence_penalty: Some(0.0), presence_penalty: Some(0.0),
...@@ -1011,7 +1022,7 @@ async fn test_streaming_with_mcp_tool_calls() { ...@@ -1011,7 +1022,7 @@ async fn test_streaming_with_mcp_tool_calls() {
top_p: Some(1.0), top_p: Some(1.0),
truncation: Some(Truncation::Disabled), truncation: Some(Truncation::Disabled),
user: None, user: None,
request_id: "resp_streaming_mcp_test".to_string(), request_id: Some("resp_streaming_mcp_test".to_string()),
priority: 0, priority: 0,
frequency_penalty: Some(0.0), frequency_penalty: Some(0.0),
presence_penalty: Some(0.0), presence_penalty: Some(0.0),
...@@ -1290,7 +1301,7 @@ async fn test_streaming_multi_turn_with_mcp() { ...@@ -1290,7 +1301,7 @@ async fn test_streaming_multi_turn_with_mcp() {
top_p: Some(1.0), top_p: Some(1.0),
truncation: Some(Truncation::Disabled), truncation: Some(Truncation::Disabled),
user: None, user: None,
request_id: "resp_streaming_multiturn_test".to_string(), request_id: Some("resp_streaming_multiturn_test".to_string()),
priority: 0, priority: 0,
frequency_penalty: Some(0.0), frequency_penalty: Some(0.0),
presence_penalty: Some(0.0), presence_penalty: Some(0.0),
......
use serde_json::{from_str, to_string, Number, Value}; use serde_json::{from_str, to_string, Number, Value};
use sglang_router_rs::protocols::spec::{ use sglang_router_rs::protocols::spec::{
default_model_name, GenerationRequest, RerankRequest, RerankResponse, RerankResult, GenerationRequest, RerankRequest, RerankResponse, RerankResult, StringOrArray, UsageInfo,
StringOrArray, UsageInfo, V1RerankReqInput, V1RerankReqInput,
}; };
use std::collections::HashMap; use std::collections::HashMap;
...@@ -40,7 +40,7 @@ fn test_rerank_request_deserialization_with_defaults() { ...@@ -40,7 +40,7 @@ fn test_rerank_request_deserialization_with_defaults() {
assert_eq!(request.query, "test query"); assert_eq!(request.query, "test query");
assert_eq!(request.documents, vec!["doc1", "doc2"]); assert_eq!(request.documents, vec!["doc1", "doc2"]);
assert_eq!(request.model, default_model_name()); assert_eq!(request.model, "unknown");
assert_eq!(request.top_k, None); assert_eq!(request.top_k, None);
assert!(request.return_documents); assert!(request.return_documents);
assert_eq!(request.rid, None); assert_eq!(request.rid, None);
...@@ -414,7 +414,7 @@ fn test_v1_to_rerank_request_conversion() { ...@@ -414,7 +414,7 @@ fn test_v1_to_rerank_request_conversion() {
assert_eq!(request.query, "test query"); assert_eq!(request.query, "test query");
assert_eq!(request.documents, vec!["doc1", "doc2"]); assert_eq!(request.documents, vec!["doc1", "doc2"]);
assert_eq!(request.model, default_model_name()); assert_eq!(request.model, "unknown");
assert_eq!(request.top_k, None); assert_eq!(request.top_k, None);
assert!(request.return_documents); assert!(request.return_documents);
assert_eq!(request.rid, None); assert_eq!(request.rid, None);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment