"vscode:/vscode.git/clone" did not exist on "782527d48195606e1045bdfa9e143e8793e29c9d"
Unverified Commit a69b6370 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] fix req handling order, improve serialization, remove retry (#8888)

parent 2d120f8b
...@@ -2,12 +2,12 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri ...@@ -2,12 +2,12 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri
use serde_json::{from_str, to_string, to_value, to_vec}; use serde_json::{from_str, to_string, to_value, to_vec};
use std::time::Instant; use std::time::Instant;
use sglang_router_rs::core::{BasicWorker, WorkerType}; use sglang_router_rs::core::{BasicWorker, Worker, WorkerType};
use sglang_router_rs::openai_api_types::{ use sglang_router_rs::openai_api_types::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
SamplingParams, StringOrArray, UserMessageContent, SamplingParams, StringOrArray, UserMessageContent,
}; };
use sglang_router_rs::routers::bootstrap_injector::inject_bootstrap_fields; use sglang_router_rs::routers::pd_types::{generate_room_id, get_hostname, RequestWithBootstrap};
fn create_test_worker() -> BasicWorker { fn create_test_worker() -> BasicWorker {
BasicWorker::new( BasicWorker::new(
...@@ -18,6 +18,16 @@ fn create_test_worker() -> BasicWorker { ...@@ -18,6 +18,16 @@ fn create_test_worker() -> BasicWorker {
) )
} }
// Helper function to get bootstrap info from worker
fn get_bootstrap_info(worker: &BasicWorker) -> (String, Option<u16>) {
let hostname = get_hostname(worker.url());
let bootstrap_port = match worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port.clone(),
_ => None,
};
(hostname, bootstrap_port)
}
/// Create a default GenerateRequest for benchmarks with minimal fields set /// Create a default GenerateRequest for benchmarks with minimal fields set
fn default_generate_request() -> GenerateRequest { fn default_generate_request() -> GenerateRequest {
GenerateRequest { GenerateRequest {
...@@ -331,35 +341,56 @@ fn bench_bootstrap_injection(c: &mut Criterion) { ...@@ -331,35 +341,56 @@ fn bench_bootstrap_injection(c: &mut Criterion) {
let completion_req = create_sample_completion_request(); let completion_req = create_sample_completion_request();
let large_chat_req = create_large_chat_completion_request(); let large_chat_req = create_large_chat_completion_request();
let worker = create_test_worker(); let worker = create_test_worker();
let (hostname, bootstrap_port) = get_bootstrap_info(&worker);
group.bench_function("generate_bootstrap_injection", |b| { group.bench_function("generate_bootstrap_injection", |b| {
b.iter(|| { b.iter(|| {
let mut json = to_value(black_box(&generate_req)).unwrap(); let request_with_bootstrap = RequestWithBootstrap {
inject_bootstrap_fields(&mut json, &worker).unwrap(); original: &generate_req,
bootstrap_host: hostname.clone(),
bootstrap_port,
bootstrap_room: generate_room_id(),
};
let json = to_value(black_box(&request_with_bootstrap)).unwrap();
black_box(json); black_box(json);
}); });
}); });
group.bench_function("chat_completion_bootstrap_injection", |b| { group.bench_function("chat_completion_bootstrap_injection", |b| {
b.iter(|| { b.iter(|| {
let mut json = to_value(black_box(&chat_req)).unwrap(); let request_with_bootstrap = RequestWithBootstrap {
inject_bootstrap_fields(&mut json, &worker).unwrap(); original: &chat_req,
bootstrap_host: hostname.clone(),
bootstrap_port,
bootstrap_room: generate_room_id(),
};
let json = to_value(black_box(&request_with_bootstrap)).unwrap();
black_box(json); black_box(json);
}); });
}); });
group.bench_function("completion_bootstrap_injection", |b| { group.bench_function("completion_bootstrap_injection", |b| {
b.iter(|| { b.iter(|| {
let mut json = to_value(black_box(&completion_req)).unwrap(); let request_with_bootstrap = RequestWithBootstrap {
inject_bootstrap_fields(&mut json, &worker).unwrap(); original: &completion_req,
bootstrap_host: hostname.clone(),
bootstrap_port,
bootstrap_room: generate_room_id(),
};
let json = to_value(black_box(&request_with_bootstrap)).unwrap();
black_box(json); black_box(json);
}); });
}); });
group.bench_function("large_chat_completion_bootstrap_injection", |b| { group.bench_function("large_chat_completion_bootstrap_injection", |b| {
b.iter(|| { b.iter(|| {
let mut json = to_value(black_box(&large_chat_req)).unwrap(); let request_with_bootstrap = RequestWithBootstrap {
inject_bootstrap_fields(&mut json, &worker).unwrap(); original: &large_chat_req,
bootstrap_host: hostname.clone(),
bootstrap_port,
bootstrap_room: generate_room_id(),
};
let json = to_value(black_box(&request_with_bootstrap)).unwrap();
black_box(json); black_box(json);
}); });
}); });
...@@ -441,6 +472,7 @@ fn bench_throughput_by_size(c: &mut Criterion) { ...@@ -441,6 +472,7 @@ fn bench_throughput_by_size(c: &mut Criterion) {
}; };
let worker = create_test_worker(); let worker = create_test_worker();
let (hostname, bootstrap_port) = get_bootstrap_info(&worker);
for (name, req) in [ for (name, req) in [
("small", &small_generate), ("small", &small_generate),
...@@ -449,6 +481,7 @@ fn bench_throughput_by_size(c: &mut Criterion) { ...@@ -449,6 +481,7 @@ fn bench_throughput_by_size(c: &mut Criterion) {
] { ] {
let json = to_string(req).unwrap(); let json = to_string(req).unwrap();
let size_bytes = json.len(); let size_bytes = json.len();
let hostname_clone = hostname.clone();
group.throughput(Throughput::Bytes(size_bytes as u64)); group.throughput(Throughput::Bytes(size_bytes as u64));
group.bench_with_input(BenchmarkId::new("serialize", name), &req, |b, req| { group.bench_with_input(BenchmarkId::new("serialize", name), &req, |b, req| {
...@@ -472,10 +505,16 @@ fn bench_throughput_by_size(c: &mut Criterion) { ...@@ -472,10 +505,16 @@ fn bench_throughput_by_size(c: &mut Criterion) {
group.bench_with_input( group.bench_with_input(
BenchmarkId::new("bootstrap_inject", name), BenchmarkId::new("bootstrap_inject", name),
&req, &req,
|b, req| { move |b, req| {
let hostname = hostname_clone.clone();
b.iter(|| { b.iter(|| {
let mut json = to_value(req).unwrap(); let request_with_bootstrap = RequestWithBootstrap {
inject_bootstrap_fields(&mut json, &worker).unwrap(); original: req,
bootstrap_host: hostname.clone(),
bootstrap_port,
bootstrap_room: generate_room_id(),
};
let json = to_value(&request_with_bootstrap).unwrap();
black_box(json); black_box(json);
}); });
}, },
...@@ -493,17 +532,21 @@ fn bench_full_round_trip(c: &mut Criterion) { ...@@ -493,17 +532,21 @@ fn bench_full_round_trip(c: &mut Criterion) {
let chat_json = to_string(&create_sample_chat_completion_request()).unwrap(); let chat_json = to_string(&create_sample_chat_completion_request()).unwrap();
let completion_json = to_string(&create_sample_completion_request()).unwrap(); let completion_json = to_string(&create_sample_completion_request()).unwrap();
let worker = create_test_worker(); let worker = create_test_worker();
let (hostname, bootstrap_port) = get_bootstrap_info(&worker);
group.bench_function("generate_openai_to_pd_pipeline", |b| { group.bench_function("generate_openai_to_pd_pipeline", |b| {
b.iter(|| { b.iter(|| {
// Deserialize OpenAI request // Deserialize OpenAI request
let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap(); let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap();
// Convert to JSON Value // Create wrapper with bootstrap fields
let mut json = to_value(&req).unwrap(); let request_with_bootstrap = RequestWithBootstrap {
// Inject bootstrap fields original: &req,
inject_bootstrap_fields(&mut json, &worker).unwrap(); bootstrap_host: hostname.clone(),
bootstrap_port,
bootstrap_room: generate_room_id(),
};
// Serialize final request // Serialize final request
let pd_json = to_string(&json).unwrap(); let pd_json = to_string(&request_with_bootstrap).unwrap();
black_box(pd_json); black_box(pd_json);
}); });
}); });
...@@ -511,9 +554,13 @@ fn bench_full_round_trip(c: &mut Criterion) { ...@@ -511,9 +554,13 @@ fn bench_full_round_trip(c: &mut Criterion) {
group.bench_function("chat_completion_openai_to_pd_pipeline", |b| { group.bench_function("chat_completion_openai_to_pd_pipeline", |b| {
b.iter(|| { b.iter(|| {
let req: ChatCompletionRequest = from_str(black_box(&chat_json)).unwrap(); let req: ChatCompletionRequest = from_str(black_box(&chat_json)).unwrap();
let mut json = to_value(&req).unwrap(); let request_with_bootstrap = RequestWithBootstrap {
inject_bootstrap_fields(&mut json, &worker).unwrap(); original: &req,
let pd_json = to_string(&json).unwrap(); bootstrap_host: hostname.clone(),
bootstrap_port,
bootstrap_room: generate_room_id(),
};
let pd_json = to_string(&request_with_bootstrap).unwrap();
black_box(pd_json); black_box(pd_json);
}); });
}); });
...@@ -521,9 +568,13 @@ fn bench_full_round_trip(c: &mut Criterion) { ...@@ -521,9 +568,13 @@ fn bench_full_round_trip(c: &mut Criterion) {
group.bench_function("completion_openai_to_pd_pipeline", |b| { group.bench_function("completion_openai_to_pd_pipeline", |b| {
b.iter(|| { b.iter(|| {
let req: CompletionRequest = from_str(black_box(&completion_json)).unwrap(); let req: CompletionRequest = from_str(black_box(&completion_json)).unwrap();
let mut json = to_value(&req).unwrap(); let request_with_bootstrap = RequestWithBootstrap {
inject_bootstrap_fields(&mut json, &worker).unwrap(); original: &req,
let pd_json = to_string(&json).unwrap(); bootstrap_host: hostname.clone(),
bootstrap_port,
bootstrap_room: generate_room_id(),
};
let pd_json = to_string(&request_with_bootstrap).unwrap();
black_box(pd_json); black_box(pd_json);
}); });
}); });
...@@ -575,10 +626,16 @@ fn benchmark_summary(c: &mut Criterion) { ...@@ -575,10 +626,16 @@ fn benchmark_summary(c: &mut Criterion) {
); );
// Measure bootstrap injection (replaces adaptation) // Measure bootstrap injection (replaces adaptation)
let (hostname, bootstrap_port) = get_bootstrap_info(&worker);
let start = Instant::now(); let start = Instant::now();
for _ in 0..1000 { for _ in 0..1000 {
let mut json = to_value(&generate_req).unwrap(); let request_with_bootstrap = RequestWithBootstrap {
let _ = black_box(inject_bootstrap_fields(&mut json, &worker)); original: &generate_req,
bootstrap_host: hostname.clone(),
bootstrap_port,
bootstrap_room: generate_room_id(),
};
let _ = black_box(to_value(&request_with_bootstrap).unwrap());
} }
let inject_time = start.elapsed().as_nanos() / 1000; let inject_time = start.elapsed().as_nanos() / 1000;
println!(" * Bootstrap Injection (avg): {:>6} ns/req", inject_time); println!(" * Bootstrap Injection (avg): {:>6} ns/req", inject_time);
......
...@@ -121,6 +121,8 @@ class BenchmarkRunner: ...@@ -121,6 +121,8 @@ class BenchmarkRunner:
results["serialization_time"] = self._extract_time(line) results["serialization_time"] = self._extract_time(line)
elif "Deserialization (avg):" in line: elif "Deserialization (avg):" in line:
results["deserialization_time"] = self._extract_time(line) results["deserialization_time"] = self._extract_time(line)
elif "Bootstrap Injection (avg):" in line:
results["bootstrap_injection_time"] = self._extract_time(line)
elif "Total Pipeline (avg):" in line: elif "Total Pipeline (avg):" in line:
results["total_time"] = self._extract_time(line) results["total_time"] = self._extract_time(line)
...@@ -143,6 +145,7 @@ class BenchmarkRunner: ...@@ -143,6 +145,7 @@ class BenchmarkRunner:
thresholds = { thresholds = {
"serialization_time": 2000, # 2μs max "serialization_time": 2000, # 2μs max
"deserialization_time": 2000, # 2μs max "deserialization_time": 2000, # 2μs max
"bootstrap_injection_time": 5000, # 5μs max
"total_time": 10000, # 10μs max "total_time": 10000, # 10μs max
} }
......
...@@ -230,6 +230,10 @@ impl LoadBalancingPolicy for CacheAwarePolicy { ...@@ -230,6 +230,10 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
"cache_aware" "cache_aware"
} }
fn needs_request_text(&self) -> bool {
true // Cache-aware policy needs request text for cache affinity
}
fn on_request_complete(&self, worker_url: &str, success: bool) { fn on_request_complete(&self, worker_url: &str, success: bool) {
// Could track success rates per worker for more intelligent routing // Could track success rates per worker for more intelligent routing
if !success { if !success {
......
...@@ -59,6 +59,11 @@ pub trait LoadBalancingPolicy: Send + Sync + Debug { ...@@ -59,6 +59,11 @@ pub trait LoadBalancingPolicy: Send + Sync + Debug {
/// Get policy name for metrics and debugging /// Get policy name for metrics and debugging
fn name(&self) -> &'static str; fn name(&self) -> &'static str;
/// Check if this policy needs request text for routing decisions
fn needs_request_text(&self) -> bool {
false // Default: most policies don't need request text
}
/// Update worker load information /// Update worker load information
/// ///
/// This is called periodically with current load information for load-aware policies. /// This is called periodically with current load information for load-aware policies.
......
// Bootstrap field injection for PD routing
// Directly injects bootstrap fields into JSON requests without intermediate type conversions
use crate::core::{Worker, WorkerType};
use crate::routers::pd_types::get_hostname;
use serde_json::{json, Value};
/// Inject bootstrap fields directly into a JSON request
/// This replaces the complex ToPdRequest -> Bootstrap trait pattern
pub fn inject_bootstrap_fields(json: &mut Value, worker: &dyn Worker) -> Result<(), String> {
let batch_size = extract_batch_size(json)?;
// Extract bootstrap port from prefill worker if it's a prefill type
let bootstrap_port = match worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let hostname = get_hostname(worker.url());
if let Some(batch_size) = batch_size {
// Batch scenario - create arrays of bootstrap values
json["bootstrap_host"] = json!(vec![hostname; batch_size]);
json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
json["bootstrap_room"] = json!((0..batch_size)
.map(|_| {
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
rand::random::<u64>() & (i64::MAX as u64)
})
.collect::<Vec<_>>());
} else {
// Single scenario - create single bootstrap values
json["bootstrap_host"] = json!(hostname);
json["bootstrap_port"] = json!(bootstrap_port);
json["bootstrap_room"] = json!(rand::random::<u64>() & (i64::MAX as u64));
}
Ok(())
}
/// Extract batch size from various JSON request formats
/// Handles chat completions, completions, and generate requests
fn extract_batch_size(json: &Value) -> Result<Option<usize>, String> {
// Check for chat completions 'n' parameter (number of choices)
if let Some(n) = json.get("n").and_then(|v| v.as_u64()) {
if n > 1 {
return Ok(Some(n as usize));
}
}
// Check for array prompts (completions API)
if let Some(prompt) = json.get("prompt") {
if let Some(arr) = prompt.as_array() {
if arr.is_empty() {
return Err("Batch prompt array is empty".to_string());
}
return Ok(Some(arr.len()));
}
}
// Check for array texts (generate API)
if let Some(text) = json.get("text") {
if let Some(arr) = text.as_array() {
if arr.is_empty() {
return Err("Batch text array is empty".to_string());
}
return Ok(Some(arr.len()));
}
}
// Check for batch input_ids (generate API)
if let Some(input_ids) = json.get("input_ids") {
if let Some(arr) = input_ids.as_array() {
if arr.is_empty() {
return Err("Batch input_ids array is empty".to_string());
}
return Ok(Some(arr.len()));
}
}
// No batch indicators found - single request
Ok(None)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::BasicWorker;
use serde_json::json;
fn create_test_worker() -> BasicWorker {
BasicWorker::new(
"http://test-server:8000".to_string(),
WorkerType::Prefill {
bootstrap_port: Some(5678),
},
)
}
#[test]
fn test_inject_bootstrap_single_request() {
let worker = create_test_worker();
let mut json = json!({
"model": "test-model",
"prompt": "Hello world",
"max_tokens": 100
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify bootstrap fields were added
assert_eq!(json["bootstrap_host"], json!("test-server"));
assert_eq!(json["bootstrap_port"], json!(5678));
assert!(json["bootstrap_room"].is_number());
// Verify original fields preserved
assert_eq!(json["model"], json!("test-model"));
assert_eq!(json["prompt"], json!("Hello world"));
assert_eq!(json["max_tokens"], json!(100));
}
#[test]
fn test_inject_bootstrap_batch_prompt() {
let worker = create_test_worker();
let mut json = json!({
"model": "test-model",
"prompt": ["Hello", "World"],
"max_tokens": 100
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify batch bootstrap fields
assert_eq!(
json["bootstrap_host"],
json!(["test-server", "test-server"])
);
assert_eq!(json["bootstrap_port"], json!([5678, 5678]));
let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap();
assert_eq!(bootstrap_rooms.len(), 2);
for room in bootstrap_rooms {
assert!(room.is_number());
let room_val = room.as_u64().unwrap();
assert!(room_val <= i64::MAX as u64);
}
}
#[test]
fn test_inject_bootstrap_chat_n_parameter() {
let worker = create_test_worker();
let mut json = json!({
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}],
"n": 3
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify batch bootstrap fields for n=3
let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap();
assert_eq!(bootstrap_hosts.len(), 3);
assert_eq!(bootstrap_hosts[0], json!("test-server"));
let bootstrap_ports = json["bootstrap_port"].as_array().unwrap();
assert_eq!(bootstrap_ports.len(), 3);
assert_eq!(bootstrap_ports[0], json!(5678));
let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap();
assert_eq!(bootstrap_rooms.len(), 3);
}
#[test]
fn test_inject_bootstrap_generate_text_array() {
let worker = create_test_worker();
let mut json = json!({
"text": ["First prompt", "Second prompt"],
"stream": false
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify batch bootstrap fields
let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap();
assert_eq!(bootstrap_hosts.len(), 2);
let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap();
assert_eq!(bootstrap_rooms.len(), 2);
// Ensure room values are different (randomness)
assert_ne!(bootstrap_rooms[0], bootstrap_rooms[1]);
}
#[test]
fn test_inject_bootstrap_input_ids_array() {
let worker = create_test_worker();
let mut json = json!({
"input_ids": [[1, 2, 3], [4, 5, 6]],
"stream": false
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify batch bootstrap fields
let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap();
assert_eq!(bootstrap_hosts.len(), 2);
}
#[test]
fn test_extract_batch_size_empty_array_error() {
let json = json!({
"prompt": [],
"model": "test"
});
let result = extract_batch_size(&json);
assert!(result.is_err());
assert!(result.unwrap_err().contains("empty"));
}
#[test]
fn test_extract_batch_size_single_requests() {
// Single string prompt
let json = json!({
"prompt": "Hello world",
"model": "test"
});
assert_eq!(extract_batch_size(&json).unwrap(), None);
// Single text
let json = json!({
"text": "Hello world",
"stream": false
});
assert_eq!(extract_batch_size(&json).unwrap(), None);
// Chat with n=1 (default)
let json = json!({
"messages": [{"role": "user", "content": "Hello"}],
"n": 1
});
assert_eq!(extract_batch_size(&json).unwrap(), None);
// Chat without n parameter
let json = json!({
"messages": [{"role": "user", "content": "Hello"}]
});
assert_eq!(extract_batch_size(&json).unwrap(), None);
}
#[test]
fn test_inject_bootstrap_preserves_sglang_fields() {
let worker = create_test_worker();
let mut json = json!({
"model": "test-model",
"prompt": "Hello",
// SGLang extensions should be preserved
"top_k": 40,
"min_p": 0.05,
"repetition_penalty": 1.1,
"regex": "test_pattern",
"lora_path": "test.bin",
"no_stop_trim": true,
"ignore_eos": false
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify bootstrap fields added
assert!(json.get("bootstrap_host").is_some());
assert!(json.get("bootstrap_port").is_some());
assert!(json.get("bootstrap_room").is_some());
// Verify all SGLang fields preserved
assert_eq!(json["top_k"], json!(40));
assert_eq!(json["min_p"], json!(0.05));
assert_eq!(json["repetition_penalty"], json!(1.1));
assert_eq!(json["regex"], json!("test_pattern"));
assert_eq!(json["lora_path"], json!("test.bin"));
assert_eq!(json["no_stop_trim"], json!(true));
assert_eq!(json["ignore_eos"], json!(false));
}
#[test]
fn test_bootstrap_room_range() {
let worker = create_test_worker();
// Test single request room generation
for _ in 0..1000 {
let mut json = json!({"prompt": "test"});
inject_bootstrap_fields(&mut json, &worker).unwrap();
let room = json["bootstrap_room"].as_u64().unwrap();
assert!(room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room);
}
// Test batch request room generation
for _ in 0..100 {
let mut json = json!({"prompt": ["test1", "test2"]});
inject_bootstrap_fields(&mut json, &worker).unwrap();
let rooms = json["bootstrap_room"].as_array().unwrap();
for room_val in rooms {
let room = room_val.as_u64().unwrap();
assert!(room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room);
}
}
}
#[test]
fn test_worker_without_bootstrap_port() {
let worker = BasicWorker::new(
"http://decode-only:8000".to_string(),
WorkerType::Decode, // No bootstrap port
);
let mut json = json!({
"prompt": "Hello world"
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify bootstrap fields with null port
assert_eq!(json["bootstrap_host"], json!("decode-only"));
assert_eq!(json["bootstrap_port"], json!(null));
assert!(json["bootstrap_room"].is_number());
}
}
...@@ -11,7 +11,6 @@ use std::fmt::Debug; ...@@ -11,7 +11,6 @@ use std::fmt::Debug;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
pub mod bootstrap_injector;
pub mod factory; pub mod factory;
pub mod pd_router; pub mod pd_router;
pub mod pd_types; pub mod pd_types;
......
This diff is collapsed.
...@@ -40,6 +40,34 @@ pub fn get_hostname(url: &str) -> String { ...@@ -40,6 +40,34 @@ pub fn get_hostname(url: &str) -> String {
url.split(':').next().unwrap_or("localhost").to_string() url.split(':').next().unwrap_or("localhost").to_string()
} }
use serde::Serialize;
// Optimized bootstrap wrapper for single requests
#[derive(Serialize)]
pub struct RequestWithBootstrap<'a, T: Serialize> {
#[serde(flatten)]
pub original: &'a T,
pub bootstrap_host: String,
pub bootstrap_port: Option<u16>,
pub bootstrap_room: u64,
}
// Optimized bootstrap wrapper for batch requests
#[derive(Serialize)]
pub struct BatchRequestWithBootstrap<'a, T: Serialize> {
#[serde(flatten)]
pub original: &'a T,
pub bootstrap_host: Vec<String>,
pub bootstrap_port: Vec<Option<u16>>,
pub bootstrap_room: Vec<u64>,
}
// Helper to generate bootstrap room ID
pub fn generate_room_id() -> u64 {
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
rand::random::<u64>() & (i64::MAX as u64)
}
// PD-specific routing policies // PD-specific routing policies
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum PDSelectionPolicy { pub enum PDSelectionPolicy {
......
...@@ -269,7 +269,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -269,7 +269,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
let client = Client::builder() let client = Client::builder()
.pool_idle_timeout(Some(Duration::from_secs(50))) .pool_idle_timeout(Some(Duration::from_secs(50)))
.pool_max_idle_per_host(100) // Increase from default of 1 to allow more concurrent connections .pool_max_idle_per_host(500) // Increase to 500 connections per host
.timeout(Duration::from_secs(config.request_timeout_secs)) .timeout(Duration::from_secs(config.request_timeout_secs))
.connect_timeout(Duration::from_secs(10)) // Separate connection timeout .connect_timeout(Duration::from_secs(10)) // Separate connection timeout
.tcp_nodelay(true) .tcp_nodelay(true)
......
...@@ -9,7 +9,6 @@ use sglang_router_rs::openai_api_types::{ ...@@ -9,7 +9,6 @@ use sglang_router_rs::openai_api_types::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
SamplingParams, StringOrArray, UserMessageContent, SamplingParams, StringOrArray, UserMessageContent,
}; };
use sglang_router_rs::routers::bootstrap_injector::inject_bootstrap_fields;
/// Create a default GenerateRequest for benchmarks with minimal fields set /// Create a default GenerateRequest for benchmarks with minimal fields set
fn default_generate_request() -> GenerateRequest { fn default_generate_request() -> GenerateRequest {
...@@ -208,63 +207,6 @@ fn test_benchmark_serialization_roundtrip() { ...@@ -208,63 +207,6 @@ fn test_benchmark_serialization_roundtrip() {
assert_eq!(generate_req.return_logprob, deserialized.return_logprob); assert_eq!(generate_req.return_logprob, deserialized.return_logprob);
} }
#[test]
fn test_benchmark_bootstrap_injection() {
// Test that bootstrap injection works for benchmark types (replaces PD request adaptation)
let generate_req = GenerateRequest {
text: Some("Test prompt".to_string()),
..default_generate_request()
};
let chat_req = ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Test message".to_string()),
name: None,
}],
max_tokens: Some(150),
max_completion_tokens: Some(150),
temperature: Some(0.7),
top_p: Some(1.0),
n: Some(1),
presence_penalty: Some(0.0),
frequency_penalty: Some(0.0),
parallel_tool_calls: Some(true),
..default_chat_completion_request()
};
let completion_req = CompletionRequest {
model: "test-model".to_string(),
prompt: StringOrArray::String("Test prompt".to_string()),
max_tokens: Some(50),
temperature: Some(0.8),
top_p: Some(1.0),
n: Some(1),
presence_penalty: Some(0.0),
frequency_penalty: Some(0.0),
best_of: Some(1),
..default_completion_request()
};
let worker = create_test_worker();
// Test bootstrap injection (should not panic)
let mut generate_json = to_value(&generate_req).unwrap();
let mut chat_json = to_value(&chat_req).unwrap();
let mut completion_json = to_value(&completion_req).unwrap();
assert!(inject_bootstrap_fields(&mut generate_json, &worker).is_ok());
assert!(inject_bootstrap_fields(&mut chat_json, &worker).is_ok());
assert!(inject_bootstrap_fields(&mut completion_json, &worker).is_ok());
// Verify bootstrap fields were added
assert!(generate_json.get("bootstrap_host").is_some());
assert!(generate_json.get("bootstrap_port").is_some());
assert!(generate_json.get("bootstrap_room").is_some());
}
#[test] #[test]
fn test_benchmark_direct_json_routing() { fn test_benchmark_direct_json_routing() {
// Test direct JSON routing functionality for benchmark types (replaces regular routing) // Test direct JSON routing functionality for benchmark types (replaces regular routing)
...@@ -283,47 +225,3 @@ fn test_benchmark_direct_json_routing() { ...@@ -283,47 +225,3 @@ fn test_benchmark_direct_json_routing() {
assert!(!json_string.is_empty()); assert!(!json_string.is_empty());
assert!(!bytes.is_empty()); assert!(!bytes.is_empty());
} }
#[test]
fn test_benchmark_performance_baseline() {
// Basic performance sanity check - ensure operations complete quickly
use std::time::Instant;
let generate_req = GenerateRequest {
text: Some("Short test prompt".to_string()),
..default_generate_request()
};
// Test the actual simplified pipeline: to_value + bootstrap injection
let start = Instant::now();
let worker = create_test_worker();
// This mirrors the actual router pipeline
let mut json = to_value(&generate_req).unwrap();
let _ = inject_bootstrap_fields(&mut json, &worker);
let total_duration = start.elapsed();
assert!(
total_duration.as_millis() < 5,
"Simplified pipeline took too long: {:?} (should be faster than old adapter approach)",
total_duration
);
// Individual components should also be fast
let start = Instant::now();
let _json = to_value(&generate_req).unwrap();
let to_value_duration = start.elapsed();
let start = Instant::now();
let mut json = to_value(&generate_req).unwrap();
let _ = inject_bootstrap_fields(&mut json, &worker);
let inject_duration = start.elapsed();
// Bootstrap injection should be faster than the JSON conversion
assert!(
inject_duration <= to_value_duration * 3,
"Bootstrap injection ({:?}) should not be much slower than JSON conversion ({:?})",
inject_duration,
to_value_duration
);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment