Unverified Commit 5d62b56f authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] complete router oai spec (#8828)

parent 3ae8e3ea
...@@ -8,12 +8,116 @@ use sglang_router_rs::openai_api_types::{ ...@@ -8,12 +8,116 @@ use sglang_router_rs::openai_api_types::{
}; };
use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest}; use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest};
/// Create a default GenerateRequest for benchmarks with minimal fields set
fn default_generate_request() -> GenerateRequest {
GenerateRequest {
text: None,
prompt: None,
input_ids: None,
stream: false,
parameters: None,
sampling_params: None,
return_logprob: false,
// SGLang Extensions
lora_path: None,
session_params: None,
return_hidden_states: false,
rid: None,
}
}
/// Create a default ChatCompletionRequest for benchmarks with minimal fields set
fn default_chat_completion_request() -> ChatCompletionRequest {
ChatCompletionRequest {
model: String::new(),
messages: vec![],
max_tokens: None,
max_completion_tokens: None,
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
stop: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
logprobs: false,
top_logprobs: None,
user: None,
response_format: None,
seed: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
function_call: None,
functions: None,
// SGLang Extensions
top_k: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
stop_token_ids: None,
no_stop_trim: false,
ignore_eos: false,
continue_final_message: false,
skip_special_tokens: true,
// SGLang Extensions
lora_path: None,
session_params: None,
separate_reasoning: true,
stream_reasoning: true,
return_hidden_states: false,
}
}
/// Create a default CompletionRequest for benchmarks with minimal fields set
fn default_completion_request() -> CompletionRequest {
CompletionRequest {
model: String::new(),
prompt: StringOrArray::String(String::new()),
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
// SGLang Extensions
top_k: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
json_schema: None,
stop_token_ids: None,
no_stop_trim: false,
ignore_eos: false,
skip_special_tokens: true,
// SGLang Extensions
lora_path: None,
session_params: None,
return_hidden_states: false,
other: serde_json::Map::new(),
}
}
// Sample request data for benchmarks // Sample request data for benchmarks
fn create_sample_generate_request() -> GenerateRequest { fn create_sample_generate_request() -> GenerateRequest {
GenerateRequest { GenerateRequest {
text: Some("Write a story about artificial intelligence".to_string()), text: Some("Write a story about artificial intelligence".to_string()),
input_ids: None,
prompt: None,
parameters: Some(GenerateParameters { parameters: Some(GenerateParameters {
max_new_tokens: Some(100), max_new_tokens: Some(100),
temperature: Some(0.8), temperature: Some(0.8),
...@@ -31,8 +135,7 @@ fn create_sample_generate_request() -> GenerateRequest { ...@@ -31,8 +135,7 @@ fn create_sample_generate_request() -> GenerateRequest {
repetition_penalty: Some(1.0), repetition_penalty: Some(1.0),
..Default::default() ..Default::default()
}), }),
stream: false, ..default_generate_request()
return_logprob: false,
} }
} }
...@@ -58,22 +161,10 @@ fn create_sample_chat_completion_request() -> ChatCompletionRequest { ...@@ -58,22 +161,10 @@ fn create_sample_chat_completion_request() -> ChatCompletionRequest {
temperature: Some(0.7), temperature: Some(0.7),
top_p: Some(1.0), top_p: Some(1.0),
n: Some(1), n: Some(1),
stream: false,
stream_options: None,
stop: None,
presence_penalty: Some(0.0), presence_penalty: Some(0.0),
frequency_penalty: Some(0.0), frequency_penalty: Some(0.0),
logit_bias: None,
logprobs: false,
top_logprobs: None,
user: None,
response_format: None,
seed: None,
tools: None,
tool_choice: None,
parallel_tool_calls: Some(true), parallel_tool_calls: Some(true),
function_call: None, ..default_chat_completion_request()
functions: None,
} }
} }
...@@ -81,23 +172,14 @@ fn create_sample_completion_request() -> CompletionRequest { ...@@ -81,23 +172,14 @@ fn create_sample_completion_request() -> CompletionRequest {
CompletionRequest { CompletionRequest {
model: "text-davinci-003".to_string(), model: "text-davinci-003".to_string(),
prompt: StringOrArray::String("Complete this sentence: The future of AI is".to_string()), prompt: StringOrArray::String("Complete this sentence: The future of AI is".to_string()),
suffix: None,
max_tokens: Some(50), max_tokens: Some(50),
temperature: Some(0.8), temperature: Some(0.8),
top_p: Some(1.0), top_p: Some(1.0),
n: Some(1), n: Some(1),
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: Some(0.0), presence_penalty: Some(0.0),
frequency_penalty: Some(0.0), frequency_penalty: Some(0.0),
best_of: Some(1), best_of: Some(1),
logit_bias: None, ..default_completion_request()
user: None,
seed: None,
other: serde_json::Map::new(),
} }
} }
...@@ -121,6 +203,7 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest { ...@@ -121,6 +203,7 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest {
name: None, name: None,
tool_calls: None, tool_calls: None,
function_call: None, function_call: None,
reasoning_content: None,
}); });
} }
...@@ -132,22 +215,13 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest { ...@@ -132,22 +215,13 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest {
temperature: Some(0.7), temperature: Some(0.7),
top_p: Some(0.95), top_p: Some(0.95),
n: Some(1), n: Some(1),
stream: false,
stream_options: None,
stop: None,
presence_penalty: Some(0.1), presence_penalty: Some(0.1),
frequency_penalty: Some(0.1), frequency_penalty: Some(0.1),
logit_bias: None,
logprobs: false,
top_logprobs: Some(5), top_logprobs: Some(5),
user: Some("benchmark_user".to_string()), user: Some("benchmark_user".to_string()),
response_format: None,
seed: Some(42), seed: Some(42),
tools: None,
tool_choice: None,
parallel_tool_calls: Some(true), parallel_tool_calls: Some(true),
function_call: None, ..default_chat_completion_request()
functions: None,
} }
} }
...@@ -331,32 +405,17 @@ fn bench_throughput_by_size(c: &mut Criterion) { ...@@ -331,32 +405,17 @@ fn bench_throughput_by_size(c: &mut Criterion) {
// Create requests of different sizes // Create requests of different sizes
let small_generate = GenerateRequest { let small_generate = GenerateRequest {
text: Some("Hi".to_string()), text: Some("Hi".to_string()),
input_ids: None, ..default_generate_request()
prompt: None,
parameters: None,
sampling_params: None,
stream: false,
return_logprob: false,
}; };
let medium_generate = GenerateRequest { let medium_generate = GenerateRequest {
text: Some("Write a medium length story about AI".repeat(10)), text: Some("Write a medium length story about AI".repeat(10)),
input_ids: None, ..default_generate_request()
prompt: None,
parameters: None,
sampling_params: None,
stream: false,
return_logprob: false,
}; };
let large_generate = GenerateRequest { let large_generate = GenerateRequest {
text: Some("Write a very long and detailed story about artificial intelligence and its impact on society".repeat(100)), text: Some("Write a very long and detailed story about artificial intelligence and its impact on society".repeat(100)),
input_ids: None, ..default_generate_request()
prompt: None,
parameters: None,
sampling_params: None,
stream: false,
return_logprob: false,
}; };
for (name, req) in [ for (name, req) in [
......
...@@ -6,6 +6,21 @@ use serde::{Deserialize, Serialize}; ...@@ -6,6 +6,21 @@ use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
/// Helper function for serde default value
fn default_true() -> bool {
true
}
// ============= SGLang-Specific Types =============
/// LoRA adapter path - can be single path or batch of paths
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum LoRAPath {
Single(Option<String>),
Batch(Vec<Option<String>>),
}
/// Common trait for all generation requests /// Common trait for all generation requests
pub trait GenerationRequest: Send + Sync { pub trait GenerationRequest: Send + Sync {
/// Check if the request is for streaming /// Check if the request is for streaming
...@@ -92,6 +107,64 @@ pub struct CompletionRequest { ...@@ -92,6 +107,64 @@ pub struct CompletionRequest {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>, pub seed: Option<i64>,
// ============= SGLang Extensions =============
/// Top-k sampling parameter (-1 to disable)
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
/// Min-p nucleus sampling parameter
#[serde(skip_serializing_if = "Option::is_none")]
pub min_p: Option<f32>,
/// Minimum number of tokens to generate
#[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
/// Repetition penalty for reducing repetitive text
#[serde(skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f32>,
/// Regex constraint for output generation
#[serde(skip_serializing_if = "Option::is_none")]
pub regex: Option<String>,
/// EBNF grammar constraint for structured output
#[serde(skip_serializing_if = "Option::is_none")]
pub ebnf: Option<String>,
/// JSON schema constraint for structured output
#[serde(skip_serializing_if = "Option::is_none")]
pub json_schema: Option<String>,
/// Specific token IDs to use as stop conditions
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_token_ids: Option<Vec<i32>>,
/// Skip trimming stop tokens from output
#[serde(default)]
pub no_stop_trim: bool,
/// Ignore end-of-sequence tokens during generation
#[serde(default)]
pub ignore_eos: bool,
/// Skip special tokens during detokenization
#[serde(default = "default_true")]
pub skip_special_tokens: bool,
// ============= SGLang Extensions =============
/// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if = "Option::is_none")]
pub lora_path: Option<LoRAPath>,
/// Session parameters for continual prompting
#[serde(skip_serializing_if = "Option::is_none")]
pub session_params: Option<HashMap<String, serde_json::Value>>,
/// Return model hidden states
#[serde(default)]
pub return_hidden_states: bool,
/// Additional fields including bootstrap info for PD routing /// Additional fields including bootstrap info for PD routing
#[serde(flatten)] #[serde(flatten)]
pub other: serde_json::Map<String, serde_json::Value>, pub other: serde_json::Map<String, serde_json::Value>,
...@@ -166,7 +239,7 @@ pub struct ChatCompletionRequest { ...@@ -166,7 +239,7 @@ pub struct ChatCompletionRequest {
/// Modify the likelihood of specified tokens appearing in the completion /// Modify the likelihood of specified tokens appearing in the completion
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, i32>>, pub logit_bias: Option<HashMap<String, f32>>,
/// A unique identifier representing your end-user /// A unique identifier representing your end-user
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
...@@ -207,6 +280,72 @@ pub struct ChatCompletionRequest { ...@@ -207,6 +280,72 @@ pub struct ChatCompletionRequest {
/// Deprecated: use tool_choice instead /// Deprecated: use tool_choice instead
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>, pub function_call: Option<FunctionCall>,
// ============= SGLang Extensions =============
/// Top-k sampling parameter (-1 to disable)
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
/// Min-p nucleus sampling parameter
#[serde(skip_serializing_if = "Option::is_none")]
pub min_p: Option<f32>,
/// Minimum number of tokens to generate
#[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
/// Repetition penalty for reducing repetitive text
#[serde(skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f32>,
/// Regex constraint for output generation
#[serde(skip_serializing_if = "Option::is_none")]
pub regex: Option<String>,
/// EBNF grammar constraint for structured output
#[serde(skip_serializing_if = "Option::is_none")]
pub ebnf: Option<String>,
/// Specific token IDs to use as stop conditions
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_token_ids: Option<Vec<i32>>,
/// Skip trimming stop tokens from output
#[serde(default)]
pub no_stop_trim: bool,
/// Ignore end-of-sequence tokens during generation
#[serde(default)]
pub ignore_eos: bool,
/// Continue generating from final assistant message
#[serde(default)]
pub continue_final_message: bool,
/// Skip special tokens during detokenization
#[serde(default = "default_true")]
pub skip_special_tokens: bool,
// ============= SGLang Extensions =============
/// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if = "Option::is_none")]
pub lora_path: Option<LoRAPath>,
/// Session parameters for continual prompting
#[serde(skip_serializing_if = "Option::is_none")]
pub session_params: Option<HashMap<String, serde_json::Value>>,
/// Separate reasoning content from final answer (O1-style models)
#[serde(default = "default_true")]
pub separate_reasoning: bool,
/// Stream reasoning tokens during generation
#[serde(default = "default_true")]
pub stream_reasoning: bool,
/// Return model hidden states
#[serde(default)]
pub return_hidden_states: bool,
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
...@@ -234,6 +373,9 @@ pub enum ChatMessage { ...@@ -234,6 +373,9 @@ pub enum ChatMessage {
tool_calls: Option<Vec<ToolCall>>, tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
function_call: Option<FunctionCallResponse>, function_call: Option<FunctionCallResponse>,
/// Reasoning content for O1-style models (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_content: Option<String>,
}, },
Tool { Tool {
role: String, // "tool" role: String, // "tool"
...@@ -378,7 +520,20 @@ impl GenerationRequest for ChatCompletionRequest { ...@@ -378,7 +520,20 @@ impl GenerationRequest for ChatCompletionRequest {
Some(texts.join(" ")) Some(texts.join(" "))
} }
}, },
ChatMessage::Assistant { content, .. } => content.clone(), ChatMessage::Assistant {
content,
reasoning_content,
..
} => {
// Combine content and reasoning content for routing decisions
let main_content = content.clone().unwrap_or_default();
let reasoning = reasoning_content.clone().unwrap_or_default();
if main_content.is_empty() && reasoning.is_empty() {
None
} else {
Some(format!("{} {}", main_content, reasoning).trim().to_string())
}
}
ChatMessage::Tool { content, .. } => Some(content.clone()), ChatMessage::Tool { content, .. } => Some(content.clone()),
ChatMessage::Function { content, .. } => Some(content.clone()), ChatMessage::Function { content, .. } => Some(content.clone()),
}) })
...@@ -418,6 +573,23 @@ pub struct GenerateRequest { ...@@ -418,6 +573,23 @@ pub struct GenerateRequest {
/// Whether to return logprobs /// Whether to return logprobs
#[serde(default)] #[serde(default)]
pub return_logprob: bool, pub return_logprob: bool,
// ============= SGLang Extensions =============
/// Path to LoRA adapter(s) for model customization
#[serde(skip_serializing_if = "Option::is_none")]
pub lora_path: Option<LoRAPath>,
/// Session parameters for continual prompting
#[serde(skip_serializing_if = "Option::is_none")]
pub session_params: Option<HashMap<String, serde_json::Value>>,
/// Return model hidden states
#[serde(default)]
pub return_hidden_states: bool,
/// Request ID for tracking
#[serde(skip_serializing_if = "Option::is_none")]
pub rid: Option<String>,
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
...@@ -485,6 +657,18 @@ pub struct SamplingParams { ...@@ -485,6 +657,18 @@ pub struct SamplingParams {
pub skip_special_tokens: Option<bool>, pub skip_special_tokens: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub json_schema: Option<String>, pub json_schema: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub regex: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ebnf: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_token_ids: Option<Vec<i32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub no_stop_trim: Option<bool>,
} }
impl GenerationRequest for GenerateRequest { impl GenerationRequest for GenerateRequest {
...@@ -561,6 +745,12 @@ pub struct CompletionChoice { ...@@ -561,6 +745,12 @@ pub struct CompletionChoice {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<LogProbs>, pub logprobs: Option<LogProbs>,
pub finish_reason: Option<String>, // "stop", "length", "content_filter", etc. pub finish_reason: Option<String>, // "stop", "length", "content_filter", etc.
/// Information about which stop condition was matched
#[serde(skip_serializing_if = "Option::is_none")]
pub matched_stop: Option<serde_json::Value>, // Can be string or integer
/// Hidden states from the model (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
pub hidden_states: Option<Vec<f32>>,
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
...@@ -591,6 +781,12 @@ pub struct ChatChoice { ...@@ -591,6 +781,12 @@ pub struct ChatChoice {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>, pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call" pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"
/// Information about which stop condition was matched
#[serde(skip_serializing_if = "Option::is_none")]
pub matched_stop: Option<serde_json::Value>, // Can be string or integer
/// Hidden states from the model (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
pub hidden_states: Option<Vec<f32>>,
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
...@@ -681,6 +877,9 @@ pub struct ChatMessageDelta { ...@@ -681,6 +877,9 @@ pub struct ChatMessageDelta {
pub tool_calls: Option<Vec<ToolCallDelta>>, pub tool_calls: Option<Vec<ToolCallDelta>>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCallDelta>, pub function_call: Option<FunctionCallDelta>,
/// Reasoning content delta for O1-style models (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
......
...@@ -278,11 +278,11 @@ mod bootstrap_tests { ...@@ -278,11 +278,11 @@ mod bootstrap_tests {
use crate::core::BasicWorker; use crate::core::BasicWorker;
use crate::openai_api_types::StringOrArray; use crate::openai_api_types::StringOrArray;
#[test] /// Create a default CompletionRequest for testing with minimal fields set
fn test_completion_batch_size_with_array_prompt() { fn default_completion_request() -> CompletionRequest {
let req = CompletionRequest { CompletionRequest {
model: "test".to_string(), model: String::new(),
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]), prompt: StringOrArray::String(String::new()),
n: None, n: None,
other: serde_json::Map::new(), other: serde_json::Map::new(),
suffix: None, suffix: None,
...@@ -300,6 +300,31 @@ mod bootstrap_tests { ...@@ -300,6 +300,31 @@ mod bootstrap_tests {
logit_bias: None, logit_bias: None,
user: None, user: None,
seed: None, seed: None,
// SGLang Extensions
top_k: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
json_schema: None,
stop_token_ids: None,
no_stop_trim: false,
ignore_eos: false,
skip_special_tokens: true,
// SGLang Extensions
lora_path: None,
session_params: None,
return_hidden_states: false,
}
}
#[test]
fn test_completion_batch_size_with_array_prompt() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
..default_completion_request()
}; };
// Should return batch size for array prompt // Should return batch size for array prompt
...@@ -311,23 +336,7 @@ mod bootstrap_tests { ...@@ -311,23 +336,7 @@ mod bootstrap_tests {
let req = CompletionRequest { let req = CompletionRequest {
model: "test".to_string(), model: "test".to_string(),
prompt: StringOrArray::String("single prompt".to_string()), prompt: StringOrArray::String("single prompt".to_string()),
n: None, ..default_completion_request()
other: serde_json::Map::new(),
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
}; };
// Should return None for single prompt // Should return None for single prompt
...@@ -340,22 +349,7 @@ mod bootstrap_tests { ...@@ -340,22 +349,7 @@ mod bootstrap_tests {
model: "test".to_string(), model: "test".to_string(),
prompt: StringOrArray::String("single prompt".to_string()), prompt: StringOrArray::String("single prompt".to_string()),
n: Some(3), n: Some(3),
other: serde_json::Map::new(), ..default_completion_request()
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
}; };
// Should return None for single string prompt, even with n > 1 // Should return None for single string prompt, even with n > 1
...@@ -368,23 +362,7 @@ mod bootstrap_tests { ...@@ -368,23 +362,7 @@ mod bootstrap_tests {
let mut req = CompletionRequest { let mut req = CompletionRequest {
model: "test".to_string(), model: "test".to_string(),
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]), prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
n: None, ..default_completion_request()
other: serde_json::Map::new(),
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
}; };
// Set bootstrap info - should always use single values // Set bootstrap info - should always use single values
...@@ -418,23 +396,7 @@ mod bootstrap_tests { ...@@ -418,23 +396,7 @@ mod bootstrap_tests {
let mut req = CompletionRequest { let mut req = CompletionRequest {
model: "test".to_string(), model: "test".to_string(),
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]), prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
n: None, ..default_completion_request()
other: serde_json::Map::new(),
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
}; };
// Set bootstrap info with arrays // Set bootstrap info with arrays
......
...@@ -176,6 +176,33 @@ impl ToPdRequest for CompletionRequest { ...@@ -176,6 +176,33 @@ impl ToPdRequest for CompletionRequest {
self.stream => "stream" self.stream => "stream"
); );
// Add SGLang extension fields
insert_if_some!(other,
// SGLang Extensions - Priority 1
self.top_k => "top_k",
self.min_p => "min_p",
self.min_tokens => "min_tokens",
self.repetition_penalty => "repetition_penalty",
self.regex => "regex",
self.ebnf => "ebnf",
self.stop_token_ids => "stop_token_ids",
// SGLang Extensions - Priority 2
self.lora_path => "lora_path",
self.session_params => "session_params"
);
// SGLang boolean extensions (CompletionRequest has these as bool, not Option<bool>)
other.insert("no_stop_trim".to_string(), self.no_stop_trim.into());
other.insert("ignore_eos".to_string(), self.ignore_eos.into());
other.insert(
"skip_special_tokens".to_string(),
self.skip_special_tokens.into(),
);
other.insert(
"return_hidden_states".to_string(),
self.return_hidden_states.into(),
);
GenerateReqInput { GenerateReqInput {
text, text,
input_ids: None, input_ids: None,
...@@ -226,14 +253,46 @@ impl ToPdRequest for ChatCompletionRequest { ...@@ -226,14 +253,46 @@ impl ToPdRequest for ChatCompletionRequest {
self.tool_choice => "tool_choice", self.tool_choice => "tool_choice",
self.parallel_tool_calls => "parallel_tool_calls", self.parallel_tool_calls => "parallel_tool_calls",
self.functions => "functions", self.functions => "functions",
self.function_call => "function_call" self.function_call => "function_call",
// SGLang Extensions - Priority 1
self.top_k => "top_k",
self.min_p => "min_p",
self.min_tokens => "min_tokens",
self.repetition_penalty => "repetition_penalty",
self.regex => "regex",
self.ebnf => "ebnf",
self.stop_token_ids => "stop_token_ids",
// SGLang Extensions - Priority 2
self.lora_path => "lora_path",
self.session_params => "session_params"
); );
// Handle boolean logprobs flag // Handle boolean flags
if self.logprobs { if self.logprobs {
other.insert("logprobs".to_string(), true.into()); other.insert("logprobs".to_string(), true.into());
} }
// SGLang boolean extensions (ChatCompletionRequest has these as bool, not Option<bool>)
other.insert("no_stop_trim".to_string(), self.no_stop_trim.into());
other.insert("ignore_eos".to_string(), self.ignore_eos.into());
other.insert(
"continue_final_message".to_string(),
self.continue_final_message.into(),
);
other.insert(
"skip_special_tokens".to_string(),
self.skip_special_tokens.into(),
);
other.insert(
"separate_reasoning".to_string(),
self.separate_reasoning.into(),
);
other.insert("stream_reasoning".to_string(), self.stream_reasoning.into());
other.insert(
"return_hidden_states".to_string(),
self.return_hidden_states.into(),
);
ChatReqInput { ChatReqInput {
stream: self.stream, stream: self.stream,
bootstrap_host: None, bootstrap_host: None,
...@@ -271,18 +330,136 @@ mod tests { ...@@ -271,18 +330,136 @@ mod tests {
use serde_json::json; use serde_json::json;
use std::collections::HashMap; use std::collections::HashMap;
// ============= GenerateRequest to_pd_request Tests ============= // ============= Test Helper Functions =============
//
#[test] // These helper functions create default request instances with all required SGLang extension fields
fn test_generate_to_pd_request_with_text_only() { // properly initialized. Use the struct spread operator `..default_*_request()` to override only
let req = GenerateRequest { // the fields you need for specific tests, avoiding repetitive boilerplate code.
text: Some("Hello world".to_string()), //
// Example usage:
// let req = GenerateRequest {
// text: Some("Custom text".to_string()),
// stream: true,
// ..default_generate_request()
// };
/// Create a default GenerateRequest with minimal fields set
fn default_generate_request() -> GenerateRequest {
GenerateRequest {
text: None,
prompt: None, prompt: None,
input_ids: None, input_ids: None,
stream: false, stream: false,
parameters: None, parameters: None,
sampling_params: None, sampling_params: None,
return_logprob: false, return_logprob: false,
// SGLang Extensions
lora_path: None,
session_params: None,
return_hidden_states: false,
rid: None,
}
}
/// Create a default CompletionRequest with minimal fields set
fn default_completion_request() -> CompletionRequest {
CompletionRequest {
model: "test-model".to_string(),
prompt: StringOrArray::String("test prompt".to_string()),
max_tokens: None,
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
suffix: None,
// SGLang Extensions
top_k: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
json_schema: None,
stop_token_ids: None,
no_stop_trim: false,
ignore_eos: false,
skip_special_tokens: true,
// SGLang Extensions
lora_path: None,
session_params: None,
return_hidden_states: false,
other: serde_json::Map::new(),
}
}
/// Create a default ChatCompletionRequest with minimal fields set
fn default_chat_completion_request() -> ChatCompletionRequest {
ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("test message".to_string()),
name: None,
}],
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
stop: None,
max_tokens: None,
max_completion_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
logprobs: false,
top_logprobs: None,
user: None,
seed: None,
response_format: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
functions: None,
function_call: None,
// SGLang Extensions
top_k: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
stop_token_ids: None,
no_stop_trim: false,
ignore_eos: false,
continue_final_message: false,
skip_special_tokens: true,
// SGLang Extensions
lora_path: None,
session_params: None,
separate_reasoning: true,
stream_reasoning: true,
return_hidden_states: false,
}
}
// ============= GenerateRequest to_pd_request Tests =============
#[test]
fn test_generate_to_pd_request_with_text_only() {
let req = GenerateRequest {
text: Some("Hello world".to_string()),
..default_generate_request()
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -308,13 +485,10 @@ mod tests { ...@@ -308,13 +485,10 @@ mod tests {
#[test] #[test]
fn test_generate_to_pd_request_with_prompt_string() { fn test_generate_to_pd_request_with_prompt_string() {
let req = GenerateRequest { let req = GenerateRequest {
text: None,
prompt: Some(StringOrArray::String("Test prompt".to_string())), prompt: Some(StringOrArray::String("Test prompt".to_string())),
input_ids: None,
stream: true, stream: true,
parameters: None,
sampling_params: None,
return_logprob: true, return_logprob: true,
..default_generate_request()
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -342,6 +516,7 @@ mod tests { ...@@ -342,6 +516,7 @@ mod tests {
parameters: None, parameters: None,
sampling_params: None, sampling_params: None,
return_logprob: false, return_logprob: false,
..default_generate_request()
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -360,13 +535,8 @@ mod tests { ...@@ -360,13 +535,8 @@ mod tests {
#[test] #[test]
fn test_generate_to_pd_request_with_single_input_ids() { fn test_generate_to_pd_request_with_single_input_ids() {
let req = GenerateRequest { let req = GenerateRequest {
text: None,
prompt: None,
input_ids: Some(InputIds::Single(vec![100, 200, 300, 400])), input_ids: Some(InputIds::Single(vec![100, 200, 300, 400])),
stream: false, ..default_generate_request()
parameters: None,
sampling_params: None,
return_logprob: false,
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -381,17 +551,12 @@ mod tests { ...@@ -381,17 +551,12 @@ mod tests {
#[test] #[test]
fn test_generate_to_pd_request_with_batch_input_ids() { fn test_generate_to_pd_request_with_batch_input_ids() {
let req = GenerateRequest { let req = GenerateRequest {
text: None,
prompt: None,
input_ids: Some(InputIds::Batch(vec![ input_ids: Some(InputIds::Batch(vec![
vec![1, 2, 3], vec![1, 2, 3],
vec![4, 5, 6, 7], vec![4, 5, 6, 7],
vec![8, 9], vec![8, 9],
])), ])),
stream: false, ..default_generate_request()
parameters: None,
sampling_params: None,
return_logprob: false,
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -413,10 +578,7 @@ mod tests { ...@@ -413,10 +578,7 @@ mod tests {
text: Some("SGLang text".to_string()), text: Some("SGLang text".to_string()),
prompt: Some(StringOrArray::String("OpenAI prompt".to_string())), prompt: Some(StringOrArray::String("OpenAI prompt".to_string())),
input_ids: Some(InputIds::Single(vec![1, 2, 3])), input_ids: Some(InputIds::Single(vec![1, 2, 3])),
stream: false, ..default_generate_request()
parameters: None,
sampling_params: None,
return_logprob: false,
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -429,13 +591,9 @@ mod tests { ...@@ -429,13 +591,9 @@ mod tests {
#[test] #[test]
fn test_generate_to_pd_request_priority_prompt_over_input_ids() { fn test_generate_to_pd_request_priority_prompt_over_input_ids() {
let req = GenerateRequest { let req = GenerateRequest {
text: None,
prompt: Some(StringOrArray::String("OpenAI prompt".to_string())), prompt: Some(StringOrArray::String("OpenAI prompt".to_string())),
input_ids: Some(InputIds::Single(vec![1, 2, 3])), input_ids: Some(InputIds::Single(vec![1, 2, 3])),
stream: false, ..default_generate_request()
parameters: None,
sampling_params: None,
return_logprob: false,
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -459,12 +617,8 @@ mod tests { ...@@ -459,12 +617,8 @@ mod tests {
let req = GenerateRequest { let req = GenerateRequest {
text: Some("test".to_string()), text: Some("test".to_string()),
prompt: None,
input_ids: None,
stream: false,
parameters: Some(params), parameters: Some(params),
sampling_params: None, ..default_generate_request()
return_logprob: false,
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -497,12 +651,8 @@ mod tests { ...@@ -497,12 +651,8 @@ mod tests {
let req = GenerateRequest { let req = GenerateRequest {
text: Some("test".to_string()), text: Some("test".to_string()),
prompt: None,
input_ids: None,
stream: false,
parameters: None,
sampling_params: Some(sampling), sampling_params: Some(sampling),
return_logprob: false, ..default_generate_request()
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -546,6 +696,7 @@ mod tests { ...@@ -546,6 +696,7 @@ mod tests {
parameters: Some(params), parameters: Some(params),
sampling_params: Some(sampling), sampling_params: Some(sampling),
return_logprob: false, return_logprob: false,
..default_generate_request()
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -568,6 +719,7 @@ mod tests { ...@@ -568,6 +719,7 @@ mod tests {
parameters: Some(params), parameters: Some(params),
sampling_params: None, sampling_params: None,
return_logprob: false, return_logprob: false,
..default_generate_request()
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -603,6 +755,7 @@ mod tests { ...@@ -603,6 +755,7 @@ mod tests {
parameters: Some(params), parameters: Some(params),
sampling_params: Some(sampling), sampling_params: Some(sampling),
return_logprob: true, return_logprob: true,
..default_generate_request()
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -632,23 +785,7 @@ mod tests { ...@@ -632,23 +785,7 @@ mod tests {
let req = CompletionRequest { let req = CompletionRequest {
model: "gpt-3.5-turbo".to_string(), model: "gpt-3.5-turbo".to_string(),
prompt: StringOrArray::String("Complete this sentence".to_string()), prompt: StringOrArray::String("Complete this sentence".to_string()),
max_tokens: None, ..default_completion_request()
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
suffix: None,
other: serde_json::Map::new(),
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -672,23 +809,7 @@ mod tests { ...@@ -672,23 +809,7 @@ mod tests {
"First prompt".to_string(), "First prompt".to_string(),
"Second prompt".to_string(), "Second prompt".to_string(),
]), ]),
max_tokens: None, ..default_completion_request()
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
suffix: None,
other: serde_json::Map::new(),
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -727,7 +848,7 @@ mod tests { ...@@ -727,7 +848,7 @@ mod tests {
user: Some("user123".to_string()), user: Some("user123".to_string()),
seed: Some(42), seed: Some(42),
suffix: Some("...".to_string()), suffix: Some("...".to_string()),
other: serde_json::Map::new(), ..default_completion_request()
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -771,7 +892,7 @@ mod tests { ...@@ -771,7 +892,7 @@ mod tests {
user: None, user: None,
seed: None, seed: None,
suffix: None, suffix: None,
other: serde_json::Map::new(), ..default_completion_request()
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -803,7 +924,7 @@ mod tests { ...@@ -803,7 +924,7 @@ mod tests {
user: None, user: None,
seed: None, seed: None,
suffix: None, suffix: None,
other: serde_json::Map::new(), ..default_completion_request()
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -834,27 +955,7 @@ mod tests { ...@@ -834,27 +955,7 @@ mod tests {
let req = ChatCompletionRequest { let req = ChatCompletionRequest {
messages, messages,
model: "gpt-4".to_string(), model: "gpt-4".to_string(),
temperature: None, ..default_chat_completion_request()
top_p: None,
n: None,
stream: false,
stream_options: None,
stop: None,
max_tokens: None,
max_completion_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
logprobs: false,
top_logprobs: None,
user: None,
seed: None,
response_format: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
functions: None,
function_call: None,
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -883,7 +984,7 @@ mod tests { ...@@ -883,7 +984,7 @@ mod tests {
}]; }];
let mut logit_bias = HashMap::new(); let mut logit_bias = HashMap::new();
logit_bias.insert("50256".to_string(), -100); logit_bias.insert("50256".to_string(), -100.0f32);
let tool = Tool { let tool = Tool {
tool_type: "function".to_string(), tool_type: "function".to_string(),
...@@ -920,6 +1021,7 @@ mod tests { ...@@ -920,6 +1021,7 @@ mod tests {
parallel_tool_calls: Some(false), parallel_tool_calls: Some(false),
functions: None, functions: None,
function_call: None, function_call: None,
..default_chat_completion_request()
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -968,27 +1070,7 @@ mod tests { ...@@ -968,27 +1070,7 @@ mod tests {
let req = ChatCompletionRequest { let req = ChatCompletionRequest {
messages, messages,
model: "gpt-4-vision".to_string(), model: "gpt-4-vision".to_string(),
temperature: None, ..default_chat_completion_request()
top_p: None,
n: None,
stream: false,
stream_options: None,
stop: None,
max_tokens: None,
max_completion_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
logprobs: false,
top_logprobs: None,
user: None,
seed: None,
response_format: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
functions: None,
function_call: None,
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -1037,6 +1119,7 @@ mod tests { ...@@ -1037,6 +1119,7 @@ mod tests {
parallel_tool_calls: None, parallel_tool_calls: None,
functions: None, functions: None,
function_call: None, function_call: None,
..default_chat_completion_request()
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -1054,32 +1137,13 @@ mod tests { ...@@ -1054,32 +1137,13 @@ mod tests {
name: None, name: None,
tool_calls: None, tool_calls: None,
function_call: None, function_call: None,
reasoning_content: None,
}]; }];
let req = ChatCompletionRequest { let req = ChatCompletionRequest {
messages, messages,
model: "gpt-3.5-turbo".to_string(), model: "gpt-3.5-turbo".to_string(),
temperature: None, ..default_chat_completion_request()
top_p: None,
n: None,
stream: false,
stream_options: None,
stop: None,
max_tokens: None,
max_completion_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
logprobs: false,
top_logprobs: None,
user: None,
seed: None,
response_format: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
functions: None,
function_call: None,
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -1101,12 +1165,7 @@ mod tests { ...@@ -1101,12 +1165,7 @@ mod tests {
fn test_routeable_request_to_json() { fn test_routeable_request_to_json() {
let req = GenerateRequest { let req = GenerateRequest {
text: Some("test".to_string()), text: Some("test".to_string()),
prompt: None, ..default_generate_request()
input_ids: None,
stream: false,
parameters: None,
sampling_params: None,
return_logprob: false,
}; };
let json = req.to_json().unwrap(); let json = req.to_json().unwrap();
...@@ -1166,6 +1225,7 @@ mod tests { ...@@ -1166,6 +1225,7 @@ mod tests {
parameters: Some(params), parameters: Some(params),
sampling_params: None, sampling_params: None,
return_logprob: false, return_logprob: false,
..default_generate_request()
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -1187,6 +1247,7 @@ mod tests { ...@@ -1187,6 +1247,7 @@ mod tests {
parameters: None, parameters: None,
sampling_params: None, sampling_params: None,
return_logprob: false, return_logprob: false,
..default_generate_request()
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -1206,12 +1267,7 @@ mod tests { ...@@ -1206,12 +1267,7 @@ mod tests {
let req = GenerateRequest { let req = GenerateRequest {
text: Some(unicode_text.clone()), text: Some(unicode_text.clone()),
prompt: None, ..default_generate_request()
input_ids: None,
stream: false,
parameters: None,
sampling_params: None,
return_logprob: false,
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -1250,6 +1306,7 @@ mod tests { ...@@ -1250,6 +1306,7 @@ mod tests {
parameters: Some(params), parameters: Some(params),
sampling_params: None, sampling_params: None,
return_logprob: false, return_logprob: false,
..default_generate_request()
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -1265,12 +1322,7 @@ mod tests { ...@@ -1265,12 +1322,7 @@ mod tests {
fn test_bootstrap_fields_none() { fn test_bootstrap_fields_none() {
let req = GenerateRequest { let req = GenerateRequest {
text: Some("test".to_string()), text: Some("test".to_string()),
prompt: None, ..default_generate_request()
input_ids: None,
stream: false,
parameters: None,
sampling_params: None,
return_logprob: false,
}; };
let pd_req = req.to_pd_request(); let pd_req = req.to_pd_request();
...@@ -1279,4 +1331,182 @@ mod tests { ...@@ -1279,4 +1331,182 @@ mod tests {
assert_eq!(pd_req.bootstrap_port, None); assert_eq!(pd_req.bootstrap_port, None);
assert_eq!(pd_req.bootstrap_room, None); assert_eq!(pd_req.bootstrap_room, None);
} }
// ============= SGLang Extension Field Pass-Through Tests =============
#[test]
fn test_chat_completion_sglang_extensions_passed_through() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Test".to_string()),
name: None,
}];
let mut session_params = std::collections::HashMap::new();
session_params.insert(
"key".to_string(),
serde_json::Value::String("value".to_string()),
);
let req = ChatCompletionRequest {
messages,
model: "test-model".to_string(),
// SGLang Extensions - Priority 1
top_k: Some(40),
min_p: Some(0.05),
min_tokens: Some(10),
repetition_penalty: Some(1.1),
regex: Some("test_regex".to_string()),
ebnf: Some("test_ebnf".to_string()),
stop_token_ids: Some(vec![1, 2, 3]),
// SGLang Extensions - Priority 2
lora_path: Some(LoRAPath::Single(Some("test_lora.bin".to_string()))),
session_params: Some(session_params.clone()),
// Boolean extensions (ChatCompletionRequest has these as bool, not Option<bool>)
no_stop_trim: true,
ignore_eos: false,
continue_final_message: true,
skip_special_tokens: false,
separate_reasoning: true,
stream_reasoning: false,
return_hidden_states: true,
..default_chat_completion_request()
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Verify SGLang extensions are passed through
assert_eq!(other.get("top_k"), Some(&json!(40)));
assert!((other.get("min_p").unwrap().as_f64().unwrap() - 0.05).abs() < 0.0001);
assert_eq!(other.get("min_tokens"), Some(&json!(10)));
assert!((other.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.1).abs() < 0.0001);
assert_eq!(other.get("regex"), Some(&json!("test_regex")));
assert_eq!(other.get("ebnf"), Some(&json!("test_ebnf")));
assert_eq!(other.get("stop_token_ids"), Some(&json!(vec![1, 2, 3])));
assert_eq!(other.get("lora_path"), Some(&json!("test_lora.bin")));
assert_eq!(
other.get("session_params"),
Some(&serde_json::to_value(&session_params).unwrap())
);
// Verify boolean extensions
assert_eq!(other.get("no_stop_trim"), Some(&json!(true)));
assert_eq!(other.get("ignore_eos"), Some(&json!(false)));
assert_eq!(other.get("continue_final_message"), Some(&json!(true)));
assert_eq!(other.get("skip_special_tokens"), Some(&json!(false)));
assert_eq!(other.get("separate_reasoning"), Some(&json!(true)));
assert_eq!(other.get("stream_reasoning"), Some(&json!(false)));
assert_eq!(other.get("return_hidden_states"), Some(&json!(true)));
}
#[test]
fn test_completion_request_sglang_extensions_passed_through() {
let mut session_params = std::collections::HashMap::new();
session_params.insert(
"key".to_string(),
serde_json::Value::String("value".to_string()),
);
let req = CompletionRequest {
prompt: StringOrArray::String("Test prompt".to_string()),
model: "test-model".to_string(),
// SGLang Extensions - Priority 1
top_k: Some(40),
min_p: Some(0.05),
min_tokens: Some(10),
repetition_penalty: Some(1.1),
regex: Some("test_regex".to_string()),
ebnf: Some("test_ebnf".to_string()),
stop_token_ids: Some(vec![1, 2, 3]),
// SGLang Extensions - Priority 2
lora_path: Some(LoRAPath::Single(Some("test_lora.bin".to_string()))),
session_params: Some(session_params.clone()),
// Boolean extensions (CompletionRequest only has these 4 boolean fields)
no_stop_trim: true,
ignore_eos: false,
skip_special_tokens: false,
return_hidden_states: true,
..default_completion_request()
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Verify SGLang extensions are passed through
assert_eq!(other.get("top_k"), Some(&json!(40)));
assert!((other.get("min_p").unwrap().as_f64().unwrap() - 0.05).abs() < 0.0001);
assert_eq!(other.get("min_tokens"), Some(&json!(10)));
assert!((other.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.1).abs() < 0.0001);
assert_eq!(other.get("regex"), Some(&json!("test_regex")));
assert_eq!(other.get("ebnf"), Some(&json!("test_ebnf")));
assert_eq!(other.get("stop_token_ids"), Some(&json!(vec![1, 2, 3])));
assert_eq!(other.get("lora_path"), Some(&json!("test_lora.bin")));
assert_eq!(
other.get("session_params"),
Some(&serde_json::to_value(&session_params).unwrap())
);
// Verify boolean extensions (only the ones CompletionRequest has)
assert_eq!(other.get("no_stop_trim"), Some(&json!(true)));
assert_eq!(other.get("ignore_eos"), Some(&json!(false)));
assert_eq!(other.get("skip_special_tokens"), Some(&json!(false)));
assert_eq!(other.get("return_hidden_states"), Some(&json!(true)));
}
#[test]
fn test_sglang_extensions_none_values_not_passed_through() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Test".to_string()),
name: None,
}];
let req = ChatCompletionRequest {
messages,
model: "test-model".to_string(),
// All SGLang extensions as None/default - Optional fields won't appear, bools will use defaults
top_k: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
stop_token_ids: None,
lora_path: None,
session_params: None,
// Boolean fields use defaults (false for most, true for some with default_true)
no_stop_trim: false,
ignore_eos: false,
continue_final_message: false,
skip_special_tokens: true, // This has default_true
separate_reasoning: true, // This has default_true
stream_reasoning: true, // This has default_true
return_hidden_states: false,
..default_chat_completion_request()
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Verify None values are not included
assert!(!other.contains_key("top_k"));
assert!(!other.contains_key("min_p"));
assert!(!other.contains_key("min_tokens"));
assert!(!other.contains_key("repetition_penalty"));
assert!(!other.contains_key("regex"));
assert!(!other.contains_key("ebnf"));
assert!(!other.contains_key("stop_token_ids"));
assert!(!other.contains_key("lora_path"));
assert!(!other.contains_key("session_params"));
// Boolean fields are always present with their values (can't be None)
assert_eq!(other.get("no_stop_trim"), Some(&json!(false)));
assert_eq!(other.get("ignore_eos"), Some(&json!(false)));
assert_eq!(other.get("continue_final_message"), Some(&json!(false)));
assert_eq!(other.get("skip_special_tokens"), Some(&json!(true))); // default_true
assert_eq!(other.get("separate_reasoning"), Some(&json!(true))); // default_true
assert_eq!(other.get("stream_reasoning"), Some(&json!(true))); // default_true
assert_eq!(other.get("return_hidden_states"), Some(&json!(false)));
}
} }
...@@ -8,14 +8,118 @@ use sglang_router_rs::openai_api_types::{ ...@@ -8,14 +8,118 @@ use sglang_router_rs::openai_api_types::{
}; };
use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest}; use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest};
/// Create a default GenerateRequest for benchmarks with minimal fields set
fn default_generate_request() -> GenerateRequest {
GenerateRequest {
text: None,
prompt: None,
input_ids: None,
stream: false,
parameters: None,
sampling_params: None,
return_logprob: false,
// SGLang Extensions
lora_path: None,
session_params: None,
return_hidden_states: false,
rid: None,
}
}
/// Create a default ChatCompletionRequest for benchmarks with minimal fields set
fn default_chat_completion_request() -> ChatCompletionRequest {
ChatCompletionRequest {
model: String::new(),
messages: vec![],
max_tokens: None,
max_completion_tokens: None,
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
stop: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
logprobs: false,
top_logprobs: None,
user: None,
response_format: None,
seed: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
function_call: None,
functions: None,
// SGLang Extensions
top_k: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
stop_token_ids: None,
no_stop_trim: false,
ignore_eos: false,
continue_final_message: false,
skip_special_tokens: true,
// SGLang Extensions
lora_path: None,
session_params: None,
separate_reasoning: true,
stream_reasoning: true,
return_hidden_states: false,
}
}
/// Create a default CompletionRequest for benchmarks with minimal fields set
fn default_completion_request() -> CompletionRequest {
CompletionRequest {
model: String::new(),
prompt: StringOrArray::String(String::new()),
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
// SGLang Extensions
top_k: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
json_schema: None,
stop_token_ids: None,
no_stop_trim: false,
ignore_eos: false,
skip_special_tokens: true,
// SGLang Extensions
lora_path: None,
session_params: None,
return_hidden_states: false,
other: serde_json::Map::new(),
}
}
#[test] #[test]
fn test_benchmark_request_creation() { fn test_benchmark_request_creation() {
// Ensure all benchmark request types can be created without panicking // Ensure all benchmark request types can be created without panicking
let generate_req = GenerateRequest { let generate_req = GenerateRequest {
text: Some("Test prompt".to_string()), text: Some("Test prompt".to_string()),
input_ids: None,
prompt: None,
parameters: Some(GenerateParameters { parameters: Some(GenerateParameters {
max_new_tokens: Some(100), max_new_tokens: Some(100),
temperature: Some(0.8), temperature: Some(0.8),
...@@ -33,8 +137,7 @@ fn test_benchmark_request_creation() { ...@@ -33,8 +137,7 @@ fn test_benchmark_request_creation() {
repetition_penalty: Some(1.0), repetition_penalty: Some(1.0),
..Default::default() ..Default::default()
}), }),
stream: false, ..default_generate_request()
return_logprob: false,
}; };
let chat_req = ChatCompletionRequest { let chat_req = ChatCompletionRequest {
...@@ -49,44 +152,23 @@ fn test_benchmark_request_creation() { ...@@ -49,44 +152,23 @@ fn test_benchmark_request_creation() {
temperature: Some(0.7), temperature: Some(0.7),
top_p: Some(1.0), top_p: Some(1.0),
n: Some(1), n: Some(1),
stream: false,
stream_options: None,
stop: None,
presence_penalty: Some(0.0), presence_penalty: Some(0.0),
frequency_penalty: Some(0.0), frequency_penalty: Some(0.0),
logit_bias: None,
logprobs: false,
top_logprobs: None,
user: None,
response_format: None,
seed: None,
tools: None,
tool_choice: None,
parallel_tool_calls: Some(true), parallel_tool_calls: Some(true),
function_call: None, ..default_chat_completion_request()
functions: None,
}; };
let completion_req = CompletionRequest { let completion_req = CompletionRequest {
model: "test-model".to_string(), model: "test-model".to_string(),
prompt: StringOrArray::String("Test prompt".to_string()), prompt: StringOrArray::String("Test prompt".to_string()),
suffix: None,
max_tokens: Some(50), max_tokens: Some(50),
temperature: Some(0.8), temperature: Some(0.8),
top_p: Some(1.0), top_p: Some(1.0),
n: Some(1), n: Some(1),
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: Some(0.0), presence_penalty: Some(0.0),
frequency_penalty: Some(0.0), frequency_penalty: Some(0.0),
best_of: Some(1), best_of: Some(1),
logit_bias: None, ..default_completion_request()
user: None,
seed: None,
other: serde_json::Map::new(),
}; };
// Test serialization works // Test serialization works
...@@ -101,12 +183,7 @@ fn test_benchmark_serialization_roundtrip() { ...@@ -101,12 +183,7 @@ fn test_benchmark_serialization_roundtrip() {
let generate_req = GenerateRequest { let generate_req = GenerateRequest {
text: Some("Test prompt".to_string()), text: Some("Test prompt".to_string()),
input_ids: None, ..default_generate_request()
prompt: None,
parameters: None,
sampling_params: None,
stream: false,
return_logprob: false,
}; };
// Serialize and deserialize // Serialize and deserialize
...@@ -125,12 +202,7 @@ fn test_benchmark_request_adaptation() { ...@@ -125,12 +202,7 @@ fn test_benchmark_request_adaptation() {
let generate_req = GenerateRequest { let generate_req = GenerateRequest {
text: Some("Test prompt".to_string()), text: Some("Test prompt".to_string()),
input_ids: None, ..default_generate_request()
prompt: None,
parameters: None,
sampling_params: None,
stream: false,
return_logprob: false,
}; };
let chat_req = ChatCompletionRequest { let chat_req = ChatCompletionRequest {
...@@ -145,44 +217,23 @@ fn test_benchmark_request_adaptation() { ...@@ -145,44 +217,23 @@ fn test_benchmark_request_adaptation() {
temperature: Some(0.7), temperature: Some(0.7),
top_p: Some(1.0), top_p: Some(1.0),
n: Some(1), n: Some(1),
stream: false,
stream_options: None,
stop: None,
presence_penalty: Some(0.0), presence_penalty: Some(0.0),
frequency_penalty: Some(0.0), frequency_penalty: Some(0.0),
logit_bias: None,
logprobs: false,
top_logprobs: None,
user: None,
response_format: None,
seed: None,
tools: None,
tool_choice: None,
parallel_tool_calls: Some(true), parallel_tool_calls: Some(true),
function_call: None, ..default_chat_completion_request()
functions: None,
}; };
let completion_req = CompletionRequest { let completion_req = CompletionRequest {
model: "test-model".to_string(), model: "test-model".to_string(),
prompt: StringOrArray::String("Test prompt".to_string()), prompt: StringOrArray::String("Test prompt".to_string()),
suffix: None,
max_tokens: Some(50), max_tokens: Some(50),
temperature: Some(0.8), temperature: Some(0.8),
top_p: Some(1.0), top_p: Some(1.0),
n: Some(1), n: Some(1),
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: Some(0.0), presence_penalty: Some(0.0),
frequency_penalty: Some(0.0), frequency_penalty: Some(0.0),
best_of: Some(1), best_of: Some(1),
logit_bias: None, ..default_completion_request()
user: None,
seed: None,
other: serde_json::Map::new(),
}; };
// Test PD adaptation (should not panic) // Test PD adaptation (should not panic)
...@@ -197,12 +248,7 @@ fn test_benchmark_regular_routing() { ...@@ -197,12 +248,7 @@ fn test_benchmark_regular_routing() {
let generate_req = GenerateRequest { let generate_req = GenerateRequest {
text: Some("Test prompt".to_string()), text: Some("Test prompt".to_string()),
input_ids: None, ..default_generate_request()
prompt: None,
parameters: None,
sampling_params: None,
stream: false,
return_logprob: false,
}; };
// Test regular routing methods (should not panic) // Test regular routing methods (should not panic)
...@@ -217,12 +263,7 @@ fn test_benchmark_performance_baseline() { ...@@ -217,12 +263,7 @@ fn test_benchmark_performance_baseline() {
let generate_req = GenerateRequest { let generate_req = GenerateRequest {
text: Some("Short test prompt".to_string()), text: Some("Short test prompt".to_string()),
input_ids: None, ..default_generate_request()
prompt: None,
parameters: None,
sampling_params: None,
stream: false,
return_logprob: false,
}; };
// Serialization should be fast (< 1ms for simple requests) // Serialization should be fast (< 1ms for simple requests)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment