Unverified Commit 8c7bb39d authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] PD Router Simplification and Reorganization (#8838)

parent ca47e24f
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use serde_json::{from_str, to_string, to_vec};
use serde_json::{from_str, to_string, to_value, to_vec};
use std::time::Instant;
use sglang_router_rs::core::{BasicWorker, WorkerType};
use sglang_router_rs::openai_api_types::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
SamplingParams, StringOrArray, UserMessageContent,
};
use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest};
use sglang_router_rs::routers::bootstrap_injector::inject_bootstrap_fields;
fn create_test_worker() -> BasicWorker {
BasicWorker::new(
"http://test-server:8000".to_string(),
WorkerType::Prefill {
bootstrap_port: Some(5678),
},
)
}
/// Create a default GenerateRequest for benchmarks with minimal fields set
fn default_generate_request() -> GenerateRequest {
......@@ -312,49 +322,54 @@ fn bench_json_deserialization(c: &mut Criterion) {
group.finish();
}
// Benchmark request adaptation from OpenAI to PD format
fn bench_request_adaptation(c: &mut Criterion) {
let mut group = c.benchmark_group("request_adaptation");
// Benchmark bootstrap injection (replaces request adaptation)
fn bench_bootstrap_injection(c: &mut Criterion) {
let mut group = c.benchmark_group("bootstrap_injection");
let generate_req = create_sample_generate_request();
let chat_req = create_sample_chat_completion_request();
let completion_req = create_sample_completion_request();
let large_chat_req = create_large_chat_completion_request();
let worker = create_test_worker();
group.bench_function("generate_to_pd", |b| {
group.bench_function("generate_bootstrap_injection", |b| {
b.iter(|| {
let pd_req = black_box(generate_req.clone()).to_pd_request();
black_box(pd_req);
let mut json = to_value(black_box(&generate_req)).unwrap();
inject_bootstrap_fields(&mut json, &worker).unwrap();
black_box(json);
});
});
group.bench_function("chat_completion_to_pd", |b| {
group.bench_function("chat_completion_bootstrap_injection", |b| {
b.iter(|| {
let pd_req = black_box(chat_req.clone()).to_pd_request();
black_box(pd_req);
let mut json = to_value(black_box(&chat_req)).unwrap();
inject_bootstrap_fields(&mut json, &worker).unwrap();
black_box(json);
});
});
group.bench_function("completion_to_pd", |b| {
group.bench_function("completion_bootstrap_injection", |b| {
b.iter(|| {
let pd_req = black_box(completion_req.clone()).to_pd_request();
black_box(pd_req);
let mut json = to_value(black_box(&completion_req)).unwrap();
inject_bootstrap_fields(&mut json, &worker).unwrap();
black_box(json);
});
});
group.bench_function("large_chat_completion_to_pd", |b| {
group.bench_function("large_chat_completion_bootstrap_injection", |b| {
b.iter(|| {
let pd_req = black_box(large_chat_req.clone()).to_pd_request();
black_box(pd_req);
let mut json = to_value(black_box(&large_chat_req)).unwrap();
inject_bootstrap_fields(&mut json, &worker).unwrap();
black_box(json);
});
});
group.finish();
}
// Benchmark regular routing (RouteableRequest methods)
fn bench_regular_routing(c: &mut Criterion) {
let mut group = c.benchmark_group("regular_routing");
// Benchmark direct JSON routing (replaces regular routing)
fn bench_direct_json_routing(c: &mut Criterion) {
let mut group = c.benchmark_group("direct_json_routing");
let generate_req = create_sample_generate_request();
let chat_req = create_sample_chat_completion_request();
......@@ -362,35 +377,42 @@ fn bench_regular_routing(c: &mut Criterion) {
group.bench_function("generate_to_json", |b| {
b.iter(|| {
let json = black_box(&generate_req).to_json().unwrap();
let json = to_value(black_box(&generate_req)).unwrap();
black_box(json);
});
});
group.bench_function("generate_to_json_string", |b| {
b.iter(|| {
let json = to_string(black_box(&generate_req)).unwrap();
black_box(json);
});
});
group.bench_function("generate_to_bytes", |b| {
b.iter(|| {
let bytes = black_box(&generate_req).to_bytes().unwrap();
let bytes = to_vec(black_box(&generate_req)).unwrap();
black_box(bytes);
});
});
group.bench_function("chat_completion_to_json", |b| {
b.iter(|| {
let json = black_box(&chat_req).to_json().unwrap();
let json = to_value(black_box(&chat_req)).unwrap();
black_box(json);
});
});
group.bench_function("chat_completion_to_bytes", |b| {
group.bench_function("chat_completion_to_json_string", |b| {
b.iter(|| {
let bytes = black_box(&chat_req).to_bytes().unwrap();
black_box(bytes);
let json = to_string(black_box(&chat_req)).unwrap();
black_box(json);
});
});
group.bench_function("completion_to_json", |b| {
b.iter(|| {
let json = black_box(&completion_req).to_json().unwrap();
let json = to_value(black_box(&completion_req)).unwrap();
black_box(json);
});
});
......@@ -418,6 +440,8 @@ fn bench_throughput_by_size(c: &mut Criterion) {
..default_generate_request()
};
let worker = create_test_worker();
for (name, req) in [
("small", &small_generate),
("medium", &medium_generate),
......@@ -445,33 +469,41 @@ fn bench_throughput_by_size(c: &mut Criterion) {
},
);
group.bench_with_input(BenchmarkId::new("adapt_to_pd", name), &req, |b, req| {
b.iter(|| {
let pd_req = (*req).clone().to_pd_request();
black_box(pd_req);
});
});
group.bench_with_input(
BenchmarkId::new("bootstrap_inject", name),
&req,
|b, req| {
b.iter(|| {
let mut json = to_value(req).unwrap();
inject_bootstrap_fields(&mut json, &worker).unwrap();
black_box(json);
});
},
);
}
group.finish();
}
// Benchmark full round-trip: deserialize -> adapt -> serialize
// Benchmark full round-trip: deserialize -> inject bootstrap -> serialize
fn bench_full_round_trip(c: &mut Criterion) {
let mut group = c.benchmark_group("full_round_trip");
let generate_json = to_string(&create_sample_generate_request()).unwrap();
let chat_json = to_string(&create_sample_chat_completion_request()).unwrap();
let completion_json = to_string(&create_sample_completion_request()).unwrap();
let worker = create_test_worker();
group.bench_function("generate_openai_to_pd_pipeline", |b| {
b.iter(|| {
// Deserialize OpenAI request
let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap();
// Adapt to PD format
let pd_req = req.to_pd_request();
// Serialize PD request
let pd_json = to_string(&pd_req).unwrap();
// Convert to JSON Value
let mut json = to_value(&req).unwrap();
// Inject bootstrap fields
inject_bootstrap_fields(&mut json, &worker).unwrap();
// Serialize final request
let pd_json = to_string(&json).unwrap();
black_box(pd_json);
});
});
......@@ -479,8 +511,9 @@ fn bench_full_round_trip(c: &mut Criterion) {
group.bench_function("chat_completion_openai_to_pd_pipeline", |b| {
b.iter(|| {
let req: ChatCompletionRequest = from_str(black_box(&chat_json)).unwrap();
let pd_req = req.to_pd_request();
let pd_json = to_string(&pd_req).unwrap();
let mut json = to_value(&req).unwrap();
inject_bootstrap_fields(&mut json, &worker).unwrap();
let pd_json = to_string(&json).unwrap();
black_box(pd_json);
});
});
......@@ -488,19 +521,21 @@ fn bench_full_round_trip(c: &mut Criterion) {
group.bench_function("completion_openai_to_pd_pipeline", |b| {
b.iter(|| {
let req: CompletionRequest = from_str(black_box(&completion_json)).unwrap();
let pd_req = req.to_pd_request();
let pd_json = to_string(&pd_req).unwrap();
let mut json = to_value(&req).unwrap();
inject_bootstrap_fields(&mut json, &worker).unwrap();
let pd_json = to_string(&json).unwrap();
black_box(pd_json);
});
});
group.bench_function("generate_regular_routing_pipeline", |b| {
group.bench_function("generate_direct_json_pipeline", |b| {
b.iter(|| {
// Deserialize OpenAI request
let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap();
// Convert to JSON for regular routing
let routing_json = req.to_json().unwrap();
black_box(routing_json);
// Convert to JSON for direct routing (no bootstrap injection)
let routing_json = to_value(&req).unwrap();
let json_string = to_string(&routing_json).unwrap();
black_box(json_string);
});
});
......@@ -515,6 +550,7 @@ fn benchmark_summary(c: &mut Criterion) {
// Quick performance overview
let generate_req = create_sample_generate_request();
let worker = create_test_worker();
println!("\nQuick Performance Overview:");
......@@ -538,32 +574,39 @@ fn benchmark_summary(c: &mut Criterion) {
deserialize_time
);
// Measure adaptation
// Measure bootstrap injection (replaces adaptation)
let start = Instant::now();
for _ in 0..1000 {
let _ = black_box(generate_req.clone().to_pd_request());
let mut json = to_value(&generate_req).unwrap();
let _ = black_box(inject_bootstrap_fields(&mut json, &worker));
}
let adapt_time = start.elapsed().as_nanos() / 1000;
println!(" * PD Adaptation (avg): {:>8} ns/req", adapt_time);
let inject_time = start.elapsed().as_nanos() / 1000;
println!(" * Bootstrap Injection (avg): {:>6} ns/req", inject_time);
// Calculate ratios
let total_pipeline = serialize_time + deserialize_time + adapt_time;
let total_pipeline = serialize_time + deserialize_time + inject_time;
println!(" * Total Pipeline (avg): {:>8} ns/req", total_pipeline);
println!("\nPerformance Insights:");
if deserialize_time > serialize_time * 2 {
println!(" • Deserialization is significantly faster than serialization");
}
if adapt_time < serialize_time / 10 {
if inject_time < serialize_time / 10 {
println!(
" • PD adaptation overhead is negligible ({:.1}% of serialization)",
(adapt_time as f64 / serialize_time as f64) * 100.0
" • Bootstrap injection overhead is negligible ({:.1}% of serialization)",
(inject_time as f64 / serialize_time as f64) * 100.0
);
}
if total_pipeline < 10_000 {
println!(" • Total pipeline latency is excellent (< 10μs)");
if total_pipeline < 100_000 {
println!(" • Total pipeline latency is excellent (< 100μs)");
}
println!("\nSimplification Benefits:");
println!(" • Eliminated complex type conversion layer");
println!(" • Reduced memory allocations");
println!(" • Automatic field preservation (no manual mapping)");
println!(" • Direct JSON manipulation improves performance");
println!("\nRecommendations:");
if serialize_time > deserialize_time {
println!(" • Focus optimization efforts on serialization rather than deserialization");
......@@ -581,8 +624,8 @@ criterion_group!(
benchmark_summary,
bench_json_serialization,
bench_json_deserialization,
bench_request_adaptation,
bench_regular_routing,
bench_bootstrap_injection,
bench_direct_json_routing,
bench_throughput_by_size,
bench_full_round_trip
);
......
......@@ -121,8 +121,6 @@ class BenchmarkRunner:
results["serialization_time"] = self._extract_time(line)
elif "Deserialization (avg):" in line:
results["deserialization_time"] = self._extract_time(line)
elif "PD Adaptation (avg):" in line:
results["adaptation_time"] = self._extract_time(line)
elif "Total Pipeline (avg):" in line:
results["total_time"] = self._extract_time(line)
......@@ -145,7 +143,6 @@ class BenchmarkRunner:
thresholds = {
"serialization_time": 2000, # 2μs max
"deserialization_time": 2000, # 2μs max
"adaptation_time": 5000, # 5μs max
"total_time": 10000, # 10μs max
}
......
// Bootstrap field injection for PD routing
// Directly injects bootstrap fields into JSON requests without intermediate type conversions
use crate::core::{Worker, WorkerType};
use crate::routers::pd_types::get_hostname;
use serde_json::{json, Value};
/// Inject bootstrap fields directly into a JSON request
/// This replaces the complex ToPdRequest -> Bootstrap trait pattern
pub fn inject_bootstrap_fields(json: &mut Value, worker: &dyn Worker) -> Result<(), String> {
let batch_size = extract_batch_size(json)?;
// Extract bootstrap port from prefill worker if it's a prefill type
let bootstrap_port = match worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let hostname = get_hostname(worker.url());
if let Some(batch_size) = batch_size {
// Batch scenario - create arrays of bootstrap values
json["bootstrap_host"] = json!(vec![hostname; batch_size]);
json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
json["bootstrap_room"] = json!((0..batch_size)
.map(|_| {
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
rand::random::<u64>() & (i64::MAX as u64)
})
.collect::<Vec<_>>());
} else {
// Single scenario - create single bootstrap values
json["bootstrap_host"] = json!(hostname);
json["bootstrap_port"] = json!(bootstrap_port);
json["bootstrap_room"] = json!(rand::random::<u64>() & (i64::MAX as u64));
}
Ok(())
}
/// Extract batch size from various JSON request formats
/// Handles chat completions, completions, and generate requests
fn extract_batch_size(json: &Value) -> Result<Option<usize>, String> {
// Check for chat completions 'n' parameter (number of choices)
if let Some(n) = json.get("n").and_then(|v| v.as_u64()) {
if n > 1 {
return Ok(Some(n as usize));
}
}
// Check for array prompts (completions API)
if let Some(prompt) = json.get("prompt") {
if let Some(arr) = prompt.as_array() {
if arr.is_empty() {
return Err("Batch prompt array is empty".to_string());
}
return Ok(Some(arr.len()));
}
}
// Check for array texts (generate API)
if let Some(text) = json.get("text") {
if let Some(arr) = text.as_array() {
if arr.is_empty() {
return Err("Batch text array is empty".to_string());
}
return Ok(Some(arr.len()));
}
}
// Check for batch input_ids (generate API)
if let Some(input_ids) = json.get("input_ids") {
if let Some(arr) = input_ids.as_array() {
if arr.is_empty() {
return Err("Batch input_ids array is empty".to_string());
}
return Ok(Some(arr.len()));
}
}
// No batch indicators found - single request
Ok(None)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::BasicWorker;
use serde_json::json;
fn create_test_worker() -> BasicWorker {
BasicWorker::new(
"http://test-server:8000".to_string(),
WorkerType::Prefill {
bootstrap_port: Some(5678),
},
)
}
#[test]
fn test_inject_bootstrap_single_request() {
let worker = create_test_worker();
let mut json = json!({
"model": "test-model",
"prompt": "Hello world",
"max_tokens": 100
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify bootstrap fields were added
assert_eq!(json["bootstrap_host"], json!("test-server"));
assert_eq!(json["bootstrap_port"], json!(5678));
assert!(json["bootstrap_room"].is_number());
// Verify original fields preserved
assert_eq!(json["model"], json!("test-model"));
assert_eq!(json["prompt"], json!("Hello world"));
assert_eq!(json["max_tokens"], json!(100));
}
#[test]
fn test_inject_bootstrap_batch_prompt() {
let worker = create_test_worker();
let mut json = json!({
"model": "test-model",
"prompt": ["Hello", "World"],
"max_tokens": 100
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify batch bootstrap fields
assert_eq!(
json["bootstrap_host"],
json!(["test-server", "test-server"])
);
assert_eq!(json["bootstrap_port"], json!([5678, 5678]));
let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap();
assert_eq!(bootstrap_rooms.len(), 2);
for room in bootstrap_rooms {
assert!(room.is_number());
let room_val = room.as_u64().unwrap();
assert!(room_val <= i64::MAX as u64);
}
}
#[test]
fn test_inject_bootstrap_chat_n_parameter() {
let worker = create_test_worker();
let mut json = json!({
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}],
"n": 3
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify batch bootstrap fields for n=3
let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap();
assert_eq!(bootstrap_hosts.len(), 3);
assert_eq!(bootstrap_hosts[0], json!("test-server"));
let bootstrap_ports = json["bootstrap_port"].as_array().unwrap();
assert_eq!(bootstrap_ports.len(), 3);
assert_eq!(bootstrap_ports[0], json!(5678));
let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap();
assert_eq!(bootstrap_rooms.len(), 3);
}
#[test]
fn test_inject_bootstrap_generate_text_array() {
let worker = create_test_worker();
let mut json = json!({
"text": ["First prompt", "Second prompt"],
"stream": false
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify batch bootstrap fields
let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap();
assert_eq!(bootstrap_hosts.len(), 2);
let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap();
assert_eq!(bootstrap_rooms.len(), 2);
// Ensure room values are different (randomness)
assert_ne!(bootstrap_rooms[0], bootstrap_rooms[1]);
}
#[test]
fn test_inject_bootstrap_input_ids_array() {
let worker = create_test_worker();
let mut json = json!({
"input_ids": [[1, 2, 3], [4, 5, 6]],
"stream": false
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify batch bootstrap fields
let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap();
assert_eq!(bootstrap_hosts.len(), 2);
}
#[test]
fn test_extract_batch_size_empty_array_error() {
let json = json!({
"prompt": [],
"model": "test"
});
let result = extract_batch_size(&json);
assert!(result.is_err());
assert!(result.unwrap_err().contains("empty"));
}
#[test]
fn test_extract_batch_size_single_requests() {
// Single string prompt
let json = json!({
"prompt": "Hello world",
"model": "test"
});
assert_eq!(extract_batch_size(&json).unwrap(), None);
// Single text
let json = json!({
"text": "Hello world",
"stream": false
});
assert_eq!(extract_batch_size(&json).unwrap(), None);
// Chat with n=1 (default)
let json = json!({
"messages": [{"role": "user", "content": "Hello"}],
"n": 1
});
assert_eq!(extract_batch_size(&json).unwrap(), None);
// Chat without n parameter
let json = json!({
"messages": [{"role": "user", "content": "Hello"}]
});
assert_eq!(extract_batch_size(&json).unwrap(), None);
}
#[test]
fn test_inject_bootstrap_preserves_sglang_fields() {
let worker = create_test_worker();
let mut json = json!({
"model": "test-model",
"prompt": "Hello",
// SGLang extensions should be preserved
"top_k": 40,
"min_p": 0.05,
"repetition_penalty": 1.1,
"regex": "test_pattern",
"lora_path": "test.bin",
"no_stop_trim": true,
"ignore_eos": false
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify bootstrap fields added
assert!(json.get("bootstrap_host").is_some());
assert!(json.get("bootstrap_port").is_some());
assert!(json.get("bootstrap_room").is_some());
// Verify all SGLang fields preserved
assert_eq!(json["top_k"], json!(40));
assert_eq!(json["min_p"], json!(0.05));
assert_eq!(json["repetition_penalty"], json!(1.1));
assert_eq!(json["regex"], json!("test_pattern"));
assert_eq!(json["lora_path"], json!("test.bin"));
assert_eq!(json["no_stop_trim"], json!(true));
assert_eq!(json["ignore_eos"], json!(false));
}
#[test]
fn test_bootstrap_room_range() {
let worker = create_test_worker();
// Test single request room generation
for _ in 0..1000 {
let mut json = json!({"prompt": "test"});
inject_bootstrap_fields(&mut json, &worker).unwrap();
let room = json["bootstrap_room"].as_u64().unwrap();
assert!(room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room);
}
// Test batch request room generation
for _ in 0..100 {
let mut json = json!({"prompt": ["test1", "test2"]});
inject_bootstrap_fields(&mut json, &worker).unwrap();
let rooms = json["bootstrap_room"].as_array().unwrap();
for room_val in rooms {
let room = room_val.as_u64().unwrap();
assert!(room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room);
}
}
}
#[test]
fn test_worker_without_bootstrap_port() {
let worker = BasicWorker::new(
"http://decode-only:8000".to_string(),
WorkerType::Decode, // No bootstrap port
);
let mut json = json!({
"prompt": "Hello world"
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify bootstrap fields with null port
assert_eq!(json["bootstrap_host"], json!("decode-only"));
assert_eq!(json["bootstrap_port"], json!(null));
assert!(json["bootstrap_room"].is_number());
}
}
......@@ -11,10 +11,10 @@ use std::fmt::Debug;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
pub mod bootstrap_injector;
pub mod factory;
pub mod pd_router;
pub mod pd_types;
pub mod request_adapter;
pub mod router;
pub use factory::RouterFactory;
......
// PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems
use super::pd_types::{api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRouterError};
use super::request_adapter::ToPdRequest;
use super::bootstrap_injector::inject_bootstrap_fields;
use super::pd_types::{api_path, PDRouterError};
use crate::config::types::RetryConfig;
use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard};
use crate::metrics::RouterMetrics;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::policies::LoadBalancingPolicy;
use crate::routers::{RouterTrait, WorkerManagement};
use crate::tree::Tree;
use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
......@@ -46,18 +48,26 @@ pub struct PDRouter {
impl PDRouter {
// Dynamic worker management methods for service discovery
// Private helper method to perform health check on a new server
async fn wait_for_server_health(&self, url: &str) -> Result<(), PDRouterError> {
crate::routers::router::Router::wait_for_healthy_workers(
&[url.to_string()],
self.timeout_secs,
self.interval_secs,
)
.map_err(|_| PDRouterError::HealthCheckFailed {
url: url.to_string(),
})
}
pub async fn add_prefill_server(
&self,
url: String,
bootstrap_port: Option<u16>,
) -> Result<String, PDRouterError> {
// Wait for the new server to be healthy
crate::routers::router::Router::wait_for_healthy_workers(
&[url.clone()],
self.timeout_secs,
self.interval_secs,
)
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
self.wait_for_server_health(&url).await?;
// Create Worker for the new prefill server
let worker = WorkerFactory::create_prefill(url.clone(), bootstrap_port);
......@@ -88,12 +98,7 @@ impl PDRouter {
pub async fn add_decode_server(&self, url: String) -> Result<String, PDRouterError> {
// Wait for the new server to be healthy
crate::routers::router::Router::wait_for_healthy_workers(
&[url.clone()],
self.timeout_secs,
self.interval_secs,
)
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
self.wait_for_server_health(&url).await?;
// Create Worker for the new decode server
let worker = WorkerFactory::create_decode(url.clone());
......@@ -332,189 +337,6 @@ impl PDRouter {
.into_response()
}
// Route a typed generate request
pub async fn route_generate(
&self,
headers: Option<&HeaderMap>,
mut typed_req: GenerateReqInput,
route: &str,
) -> Response {
let start = Instant::now();
// Get stream flag and return_logprob flag before moving the request
let is_stream = typed_req.stream;
let return_logprob = typed_req
.other
.get("return_logprob")
.and_then(|v| v.as_bool())
.unwrap_or(false);
// Extract text for cache-aware routing from the typed request
let request_text = typed_req.text.as_ref().and_then(|t| match t {
super::pd_types::InputText::Single(s) => Some(s.as_str()),
super::pd_types::InputText::Batch(v) => v.first().map(|s| s.as_str()),
});
// Select servers
let (prefill, decode) = match self.select_pd_pair(request_text).await {
Ok(pair) => pair,
Err(e) => return Self::handle_server_selection_error(e),
};
// Log routing decision
info!(
"PD routing decision route={} prefill_url={} decode_url={}",
route,
prefill.url(),
decode.url()
);
// Add bootstrap info using the trait method
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
return Self::handle_bootstrap_error(e);
}
// Convert to JSON after bootstrap injection
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
Ok(json) => json,
Err(e) => return Self::handle_serialization_error(e),
};
// Execute dual dispatch
self.execute_dual_dispatch(
headers,
json_with_bootstrap,
route,
prefill.as_ref(),
decode.as_ref(),
is_stream,
return_logprob,
start,
)
.await
}
// Route a typed chat request
pub async fn route_chat(
&self,
headers: Option<&HeaderMap>,
mut typed_req: ChatReqInput,
route: &str,
) -> Response {
let start = Instant::now();
// Get stream flag and return_logprob flag before moving the request
let is_stream = typed_req.stream;
let return_logprob = typed_req
.other
.get("return_logprob")
.and_then(|v| v.as_bool())
.unwrap_or(false);
// Extract text for cache-aware routing from chat messages
let request_text = typed_req
.other
.get("messages")
.and_then(|messages| messages.as_array())
.and_then(|arr| arr.first())
.and_then(|msg| msg.get("content"))
.and_then(|content| content.as_str());
// Select servers
let (prefill, decode) = match self.select_pd_pair(request_text).await {
Ok(pair) => pair,
Err(e) => return Self::handle_server_selection_error(e),
};
// Log routing decision
info!(
"PD routing decision route={} prefill_url={} decode_url={}",
route,
prefill.url(),
decode.url()
);
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
return Self::handle_bootstrap_error(e);
}
// Convert to JSON after bootstrap injection
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
Ok(json) => json,
Err(e) => return Self::handle_serialization_error(e),
};
// Execute dual dispatch
self.execute_dual_dispatch(
headers,
json_with_bootstrap,
route,
prefill.as_ref(),
decode.as_ref(),
is_stream,
return_logprob,
start,
)
.await
}
// Route a completion request while preserving OpenAI format
pub async fn route_completion(
&self,
headers: Option<&HeaderMap>,
mut typed_req: CompletionRequest,
route: &str,
) -> Response {
let start = Instant::now();
// Get stream flag and return_logprob flag before moving the request
let is_stream = typed_req.stream;
let return_logprob = typed_req.logprobs.is_some();
// Extract text for cache-aware routing from the typed request
let request_text = match &typed_req.prompt {
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
crate::openai_api_types::StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()),
};
// Select servers
let (prefill, decode) = match self.select_pd_pair(request_text).await {
Ok(pair) => pair,
Err(e) => return Self::handle_server_selection_error(e),
};
// Log routing decision
info!(
"PD routing decision route={} prefill_url={} decode_url={}",
route,
prefill.url(),
decode.url()
);
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
return Self::handle_bootstrap_error(e);
}
// Convert to JSON after bootstrap injection
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
Ok(json) => json,
Err(e) => return Self::handle_serialization_error(e),
};
// Execute dual dispatch
self.execute_dual_dispatch(
headers,
json_with_bootstrap,
route,
prefill.as_ref(),
decode.as_ref(),
is_stream,
return_logprob,
start,
)
.await
}
// Execute the dual dispatch to prefill and decode servers with retry logic
async fn execute_dual_dispatch(
&self,
......@@ -1090,7 +912,7 @@ impl PDRouter {
// Helper functions
async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option<isize> {
async fn get_worker_load(client: &Client, worker_url: &str) -> Option<isize> {
match client.get(format!("{}/get_load", worker_url)).send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<Value>(&bytes) {
......@@ -1123,9 +945,96 @@ async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option<i
}
}
// PD-specific endpoints
impl PDRouter {
pub async fn health_generate(&self) -> Response {
#[async_trait]
impl WorkerManagement for PDRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
// For PD router, we don't support adding workers via this generic method
Err(
"PD router requires specific add_prefill_server or add_decode_server methods"
.to_string(),
)
}
fn remove_worker(&self, worker_url: &str) {
// For PD router, we would need to know if it's a prefill or decode server
// For now, try both
if let Ok(mut workers) = self.prefill_workers.write() {
if let Some(index) = workers.iter().position(|w| w.url() == worker_url) {
workers.remove(index);
info!("Removed prefill worker: {}", worker_url);
return;
}
}
if let Ok(mut workers) = self.decode_workers.write() {
if let Some(index) = workers.iter().position(|w| w.url() == worker_url) {
workers.remove(index);
info!("Removed decode worker: {}", worker_url);
}
}
}
fn get_worker_urls(&self) -> Vec<String> {
let mut urls = Vec::new();
// Add prefill worker URLs
if let Ok(workers) = self.prefill_workers.read() {
for worker in workers.iter() {
urls.push(worker.url().to_string());
}
}
// Add decode worker URLs
if let Ok(workers) = self.decode_workers.read() {
for worker in workers.iter() {
urls.push(worker.url().to_string());
}
}
urls
}
}
#[async_trait]
impl RouterTrait for PDRouter {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn health(&self, _req: Request<Body>) -> Response {
// This is a server readiness check - checking if we have healthy workers
// Workers handle their own health checks in the background
let mut all_healthy = true;
let mut unhealthy_servers = Vec::new();
// Check prefill servers
for worker in self.prefill_workers.read().unwrap().iter() {
if !worker.is_healthy() {
all_healthy = false;
unhealthy_servers.push(format!("Prefill: {}", worker.url()));
}
}
// Check decode servers
for worker in self.decode_workers.read().unwrap().iter() {
if !worker.is_healthy() {
all_healthy = false;
unhealthy_servers.push(format!("Decode: {}", worker.url()));
}
}
if all_healthy {
(StatusCode::OK, "All servers healthy").into_response()
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
format!("Unhealthy servers: {:?}", unhealthy_servers),
)
.into_response()
}
}
async fn health_generate(&self, _req: Request<Body>) -> Response {
// Test model generation capability by selecting a random pair and testing them
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
......@@ -1206,7 +1115,7 @@ impl PDRouter {
}
}
pub async fn get_server_info(&self) -> Response {
async fn get_server_info(&self, _req: Request<Body>) -> Response {
// Get info from the first decode server to match sglang's server info format
let first_decode_url = if let Ok(workers) = self.decode_workers.read() {
workers.first().map(|w| w.url().to_string())
......@@ -1269,7 +1178,7 @@ impl PDRouter {
}
}
pub async fn get_models(&self, req: Request<Body>) -> Response {
async fn get_models(&self, req: Request<Body>) -> Response {
// Extract headers first to avoid Send issues
let headers = crate::routers::router::copy_request_headers(&req);
......@@ -1285,32 +1194,43 @@ impl PDRouter {
};
if let Some(worker_url) = first_worker_url {
// Send request directly without going through Router
let mut request_builder = self.client.get(format!("{}/v1/models", worker_url));
let url = format!("{}/v1/models", worker_url);
let mut request_builder = self.client.get(&url);
// Add headers
for (name, value) in headers {
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
{
request_builder = request_builder.header(name, value);
}
request_builder = request_builder.header(name, value);
}
match request_builder.send().await {
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
match res.bytes().await {
Ok(body) => (status, body).into_response(),
Err(e) => (
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(body) => (StatusCode::OK, body).into_response(),
Err(e) => {
error!("Failed to read response body: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response body: {}", e),
)
.into_response(),
.into_response()
}
},
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
(
status,
format!("Prefill server returned status: {}", res.status()),
)
.into_response()
}
Err(e) => {
error!("Failed to get models: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to get models: {}", e),
)
.into_response()
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to send request: {}", e),
)
.into_response(),
}
} else {
(
......@@ -1321,53 +1241,10 @@ impl PDRouter {
}
}
pub async fn get_loads(&self, client: &reqwest::Client) -> Response {
let p_urls: Vec<_> = self
.prefill_workers
.read()
.unwrap()
.iter()
.map(|w| w.url().to_string())
.collect();
let d_urls: Vec<_> = self
.decode_workers
.read()
.unwrap()
.iter()
.map(|w| w.url().to_string())
.collect();
let mut prefill_loads = Vec::new();
let mut decode_loads = Vec::new();
for url in &p_urls {
let load = get_worker_load(client, url).await.unwrap_or(-1);
prefill_loads.push(serde_json::json!({
"engine": format!("(Prefill@{})", url),
"load": load as i64
}));
}
for url in &d_urls {
let load = get_worker_load(client, url).await.unwrap_or(-1);
decode_loads.push(serde_json::json!({
"engine": format!("(Decode@{})", url),
"load": load as i64
}));
}
Json(serde_json::json!({
"prefill": prefill_loads,
"decode": decode_loads
}))
.into_response()
}
pub async fn get_model_info(&self, req: Request<Body>) -> Response {
async fn get_model_info(&self, req: Request<Body>) -> Response {
// Extract headers first to avoid Send issues
let headers = crate::routers::router::copy_request_headers(&req);
// Get model info from the first prefill server (matches original Rust PDLB behavior)
// Get first prefill worker URL to avoid holding lock across await
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
workers.first().map(|w| w.url().to_string())
......@@ -1380,31 +1257,43 @@ impl PDRouter {
};
if let Some(worker_url) = first_worker_url {
let mut request_builder = self.client.get(format!("{}/get_model_info", worker_url));
let url = format!("{}/get_model_info", worker_url);
let mut request_builder = self.client.get(&url);
// Add headers
for (name, value) in headers {
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
{
request_builder = request_builder.header(name, value);
}
request_builder = request_builder.header(name, value);
}
match request_builder.send().await {
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
match res.bytes().await {
Ok(body) => (status, body).into_response(),
Err(e) => (
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(body) => (StatusCode::OK, body).into_response(),
Err(e) => {
error!("Failed to read response body: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response body: {}", e),
)
.into_response(),
.into_response()
}
},
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
(
status,
format!("Prefill server returned status: {}", res.status()),
)
.into_response()
}
Err(e) => {
error!("Failed to get model info: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to get model info: {}", e),
)
.into_response()
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to send request: {}", e),
)
.into_response(),
}
} else {
(
......@@ -1415,205 +1304,319 @@ impl PDRouter {
}
}
pub async fn flush_cache(&self, client: &reqwest::Client) -> Response {
let mut tasks = Vec::new();
async fn route_generate(
&self,
headers: Option<&HeaderMap>,
body: &GenerateRequest,
) -> Response {
let start = Instant::now();
// Flush cache on all prefill servers
for worker in self.prefill_workers.read().unwrap().iter() {
let url = format!("{}/flush_cache", worker.url());
tasks.push(client.post(&url).send());
}
// Convert directly to JSON to preserve all fields automatically
let mut json = match serde_json::to_value(body) {
Ok(json) => json,
Err(e) => return Self::handle_serialization_error(e),
};
// Flush cache on all decode servers
for worker in self.decode_workers.read().unwrap().iter() {
let url = format!("{}/flush_cache", worker.url());
tasks.push(client.post(&url).send());
}
// Extract flags for routing logic
let is_stream = body.stream;
let return_logprob = body.return_logprob;
let results = futures_util::future::join_all(tasks).await;
// Extract text for cache-aware routing
let request_text = body.text.as_deref().or_else(|| {
body.prompt.as_ref().and_then(|p| match p {
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
crate::openai_api_types::StringOrArray::Array(v) => v.first().map(|s| s.as_str()),
})
});
let mut all_success = true;
for (i, result) in results.into_iter().enumerate() {
match result {
Ok(res) if res.status().is_success() => {}
Ok(res) => {
all_success = false;
warn!(
"Server {} returned status {} for flush_cache",
i,
res.status()
);
}
Err(e) => {
all_success = false;
error!("Server {} error during flush_cache: {}", i, e);
}
}
}
// Select servers
let (prefill, decode) = match self.select_pd_pair(request_text).await {
Ok(pair) => pair,
Err(e) => return Self::handle_server_selection_error(e),
};
if all_success {
(StatusCode::OK, "Cache flushed on all servers").into_response()
} else {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Cache flush failed on one or more servers",
)
.into_response()
}
}
}
// Log routing decision
info!(
"PD routing decision route=/generate prefill_url={} decode_url={}",
prefill.url(),
decode.url()
);
use crate::routers::{RouterTrait, WorkerManagement};
use async_trait::async_trait;
// Inject bootstrap fields directly into JSON
if let Err(e) = inject_bootstrap_fields(&mut json, prefill.as_ref()) {
return Self::handle_bootstrap_error(e);
}
#[async_trait]
impl WorkerManagement for PDRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
// For PD router, we don't support adding workers via this generic method
Err(
"PD router requires specific add_prefill_server or add_decode_server methods"
.to_string(),
// Execute dual dispatch
self.execute_dual_dispatch(
headers,
json,
"/generate",
prefill.as_ref(),
decode.as_ref(),
is_stream,
return_logprob,
start,
)
.await
}
fn remove_worker(&self, worker_url: &str) {
// For PD router, we would need to know if it's a prefill or decode server
// For now, try both
if let Ok(mut workers) = self.prefill_workers.write() {
if let Some(index) = workers.iter().position(|w| w.url() == worker_url) {
workers.remove(index);
info!("Removed prefill worker: {}", worker_url);
return;
}
}
async fn route_chat(
&self,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
) -> Response {
let start = Instant::now();
if let Ok(mut workers) = self.decode_workers.write() {
if let Some(index) = workers.iter().position(|w| w.url() == worker_url) {
workers.remove(index);
info!("Removed decode worker: {}", worker_url);
// Convert directly to JSON to preserve all fields automatically
let mut json = match serde_json::to_value(body) {
Ok(json) => json,
Err(e) => return Self::handle_serialization_error(e),
};
// Extract flags for routing logic
let is_stream = body.stream;
let return_logprob = body.logprobs;
// Extract text for cache-aware routing from chat messages
let request_text = body.messages.first().and_then(|msg| match msg {
crate::openai_api_types::ChatMessage::User { content, .. } => {
match content {
crate::openai_api_types::UserMessageContent::Text(text) => Some(text.as_str()),
crate::openai_api_types::UserMessageContent::Parts(_) => None, // Skip complex content
}
}
crate::openai_api_types::ChatMessage::System { content, .. } => Some(content.as_str()),
_ => None,
});
// Select servers
let (prefill, decode) = match self.select_pd_pair(request_text).await {
Ok(pair) => pair,
Err(e) => return Self::handle_server_selection_error(e),
};
// Log routing decision
info!(
"PD routing decision route=/v1/chat/completions prefill_url={} decode_url={}",
prefill.url(),
decode.url()
);
// Inject bootstrap fields directly into JSON
if let Err(e) = inject_bootstrap_fields(&mut json, prefill.as_ref()) {
return Self::handle_bootstrap_error(e);
}
// Execute dual dispatch
self.execute_dual_dispatch(
headers,
json,
"/v1/chat/completions",
prefill.as_ref(),
decode.as_ref(),
is_stream,
return_logprob,
start,
)
.await
}
fn get_worker_urls(&self) -> Vec<String> {
let mut urls = Vec::new();
async fn route_completion(
&self,
headers: Option<&HeaderMap>,
body: &CompletionRequest,
) -> Response {
let start = Instant::now();
// Add prefill worker URLs
if let Ok(workers) = self.prefill_workers.read() {
for worker in workers.iter() {
urls.push(worker.url().to_string());
}
}
// Convert directly to JSON to preserve all fields automatically
let mut json = match serde_json::to_value(body) {
Ok(json) => json,
Err(e) => return Self::handle_serialization_error(e),
};
// Add decode worker URLs
if let Ok(workers) = self.decode_workers.read() {
for worker in workers.iter() {
urls.push(worker.url().to_string());
}
// Extract flags for routing logic
let is_stream = body.stream;
let return_logprob = body.logprobs.is_some();
// Extract text for cache-aware routing
let request_text = match &body.prompt {
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
crate::openai_api_types::StringOrArray::Array(v) => v.first().map(|s| s.as_str()),
};
// Select servers
let (prefill, decode) = match self.select_pd_pair(request_text).await {
Ok(pair) => pair,
Err(e) => return Self::handle_server_selection_error(e),
};
// Log routing decision
info!(
"PD routing decision route=/v1/completions prefill_url={} decode_url={}",
prefill.url(),
decode.url()
);
// Inject bootstrap fields directly into JSON
if let Err(e) = inject_bootstrap_fields(&mut json, prefill.as_ref()) {
return Self::handle_bootstrap_error(e);
}
urls
// Execute dual dispatch
self.execute_dual_dispatch(
headers,
json,
"/v1/completions",
prefill.as_ref(),
decode.as_ref(),
is_stream,
return_logprob,
start,
)
.await
}
}
#[async_trait]
impl RouterTrait for PDRouter {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn flush_cache(&self) -> Response {
let mut results = Vec::new();
let mut errors = Vec::new();
async fn health(&self, _req: Request<Body>) -> Response {
// This is a server readiness check - checking if we have healthy workers
// Workers handle their own health checks in the background
let mut all_healthy = true;
let mut unhealthy_servers = Vec::new();
// Get prefill worker URLs first to avoid holding lock across await
let prefill_urls = if let Ok(workers) = self.prefill_workers.read() {
workers
.iter()
.map(|w| w.url().to_string())
.collect::<Vec<_>>()
} else {
errors.push("Failed to access prefill workers".to_string());
Vec::new()
};
// Check prefill servers
for worker in self.prefill_workers.read().unwrap().iter() {
if !worker.is_healthy() {
all_healthy = false;
unhealthy_servers.push(format!("Prefill: {}", worker.url()));
// Flush prefill workers
for worker_url in prefill_urls {
let url = format!("{}/flush_cache", worker_url);
match self.client.post(&url).send().await {
Ok(res) if res.status().is_success() => {
results.push(format!("Prefill {}: OK", worker_url));
}
Ok(res) => {
errors.push(format!(
"Prefill {} returned status: {}",
worker_url,
res.status()
));
}
Err(e) => {
errors.push(format!("Prefill {} error: {}", worker_url, e));
}
}
}
// Check decode servers
for worker in self.decode_workers.read().unwrap().iter() {
if !worker.is_healthy() {
all_healthy = false;
unhealthy_servers.push(format!("Decode: {}", worker.url()));
// Get decode worker URLs first to avoid holding lock across await
let decode_urls = if let Ok(workers) = self.decode_workers.read() {
workers
.iter()
.map(|w| w.url().to_string())
.collect::<Vec<_>>()
} else {
errors.push("Failed to access decode workers".to_string());
Vec::new()
};
// Flush decode workers
for worker_url in decode_urls {
let url = format!("{}/flush_cache", worker_url);
match self.client.post(&url).send().await {
Ok(res) if res.status().is_success() => {
results.push(format!("Decode {}: OK", worker_url));
}
Ok(res) => {
errors.push(format!(
"Decode {} returned status: {}",
worker_url,
res.status()
));
}
Err(e) => {
errors.push(format!("Decode {} error: {}", worker_url, e));
}
}
}
if all_healthy {
(StatusCode::OK, "All servers healthy").into_response()
if errors.is_empty() {
(
StatusCode::OK,
format!("Cache flushed successfully: {:?}", results),
)
.into_response()
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
format!("Unhealthy servers: {:?}", unhealthy_servers),
StatusCode::PARTIAL_CONTENT,
format!(
"Partial success. Results: {:?}, Errors: {:?}",
results, errors
),
)
.into_response()
}
}
async fn health_generate(&self, _req: Request<Body>) -> Response {
// Use the existing PDRouter health_generate method
PDRouter::health_generate(self).await
}
async fn get_server_info(&self, _req: Request<Body>) -> Response {
// Use the existing PDRouter get_server_info method
PDRouter::get_server_info(self).await
}
async fn get_models(&self, req: Request<Body>) -> Response {
// Use the existing PDRouter get_models method
PDRouter::get_models(self, req).await
}
async fn get_model_info(&self, req: Request<Body>) -> Response {
// Use the existing PDRouter get_model_info method
PDRouter::get_model_info(self, req).await
}
async fn route_generate(
&self,
headers: Option<&HeaderMap>,
body: &GenerateRequest,
) -> Response {
// Convert OpenAI format to PD format
let pd_req = body.clone().to_pd_request();
async fn get_worker_loads(&self) -> Response {
let mut loads = HashMap::new();
let mut errors = Vec::new();
PDRouter::route_generate(self, headers, pd_req, "/generate").await
}
// Get prefill worker URLs first to avoid holding lock across await
let prefill_urls = if let Ok(workers) = self.prefill_workers.read() {
workers
.iter()
.map(|w| w.url().to_string())
.collect::<Vec<_>>()
} else {
errors.push("Failed to access prefill workers".to_string());
Vec::new()
};
async fn route_chat(
&self,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
) -> Response {
// Convert OpenAI format to PD format
let pd_req = body.clone().to_pd_request();
// Get loads from prefill workers
for worker_url in prefill_urls {
match get_worker_load(&self.client, &worker_url).await {
Some(load) => {
loads.insert(format!("prefill_{}", worker_url), load);
}
None => {
errors.push(format!("Failed to get load from prefill {}", worker_url));
}
}
}
PDRouter::route_chat(self, headers, pd_req, "/v1/chat/completions").await
}
// Get decode worker URLs first to avoid holding lock across await
let decode_urls = if let Ok(workers) = self.decode_workers.read() {
workers
.iter()
.map(|w| w.url().to_string())
.collect::<Vec<_>>()
} else {
errors.push("Failed to access decode workers".to_string());
Vec::new()
};
async fn route_completion(
&self,
headers: Option<&HeaderMap>,
body: &CompletionRequest,
) -> Response {
// Use the new method that preserves OpenAI format
PDRouter::route_completion(self, headers, body.clone(), "/v1/completions").await
}
// Get loads from decode workers
for worker_url in decode_urls {
match get_worker_load(&self.client, &worker_url).await {
Some(load) => {
loads.insert(format!("decode_{}", worker_url), load);
}
None => {
errors.push(format!("Failed to get load from decode {}", worker_url));
}
}
}
async fn flush_cache(&self) -> Response {
// Use the existing PDRouter flush_cache method
PDRouter::flush_cache(self, &self.client).await
}
let response_data = serde_json::json!({
"loads": loads,
"errors": errors
});
async fn get_worker_loads(&self) -> Response {
// Use the existing PDRouter get_loads method
PDRouter::get_loads(self, &self.client).await
(StatusCode::OK, Json(response_data)).into_response()
}
fn router_type(&self) -> &'static str {
......@@ -1688,7 +1691,6 @@ mod tests {
use super::*;
use crate::core::{BasicWorker, WorkerType};
use crate::policies::{CacheAwarePolicy, RandomPolicy};
use crate::routers::pd_types::SingleOrBatch;
fn create_test_pd_router() -> PDRouter {
let prefill_policy = Arc::new(RandomPolicy::new());
......@@ -1935,90 +1937,6 @@ mod tests {
assert!(result.is_ok());
}
// ============= Bootstrap Injection Tests =============
#[test]
fn test_bootstrap_injection_with_existing_fields() {
let mut req = GenerateReqInput {
text: Some(SingleOrBatch::Single("Test".to_string())),
input_ids: None,
stream: false,
bootstrap_host: Some(SingleOrBatch::Single("existing-host".to_string())),
bootstrap_port: Some(SingleOrBatch::Single(Some(9999))),
bootstrap_room: Some(SingleOrBatch::Single(12345)),
other: Value::Object(serde_json::Map::new()),
};
let prefill_worker = create_test_worker(
"http://new-host:8000".to_string(),
WorkerType::Prefill {
bootstrap_port: Some(8080),
},
true,
);
// Bootstrap info is added regardless of existing fields
let result = req.add_bootstrap_info(prefill_worker.as_ref());
assert!(result.is_ok());
// Bootstrap info should be updated with new values
assert_eq!(
req.bootstrap_host,
Some(SingleOrBatch::Single("new-host".to_string()))
);
assert_eq!(req.bootstrap_port, Some(SingleOrBatch::Single(Some(8080))));
// Room should be regenerated (different from original)
if let Some(SingleOrBatch::Single(room)) = req.bootstrap_room {
assert_ne!(room, 12345);
} else {
panic!("Expected single room ID");
}
}
#[test]
fn test_bootstrap_room_generation() {
let mut req1 = GenerateReqInput {
text: Some(SingleOrBatch::Single("Test".to_string())),
input_ids: None,
stream: false,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(serde_json::Map::new()),
};
let mut req2 = GenerateReqInput {
text: Some(SingleOrBatch::Single("Test".to_string())),
input_ids: None,
stream: false,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(serde_json::Map::new()),
};
let prefill_worker = create_test_worker(
"http://host:8000".to_string(),
WorkerType::Prefill {
bootstrap_port: Some(8080),
},
true,
);
// Add bootstrap info to both requests
let _ = req1.add_bootstrap_info(prefill_worker.as_ref());
let _ = req2.add_bootstrap_info(prefill_worker.as_ref());
// Room IDs should be different
if let (Some(SingleOrBatch::Single(room1)), Some(SingleOrBatch::Single(room2))) =
(req1.bootstrap_room, req2.bootstrap_room)
{
assert_ne!(room1, room2, "Room IDs should be unique");
} else {
panic!("Expected single room IDs");
}
}
// ============= Worker Selection Tests =============
#[tokio::test]
......@@ -2196,4 +2114,158 @@ mod tests {
let workers = router.prefill_workers.read().unwrap();
assert_eq!(workers.len(), 5);
}
#[tokio::test]
async fn test_simplified_routing_preserves_sglang_fields() {
use crate::openai_api_types::GenerateRequest;
use crate::routers::bootstrap_injector::inject_bootstrap_fields;
// Create a test worker
let worker = BasicWorker::new(
"http://test-server:8000".to_string(),
WorkerType::Prefill {
bootstrap_port: Some(5678),
},
);
// Create a GenerateRequest with SGLang extensions
let mut session_params = std::collections::HashMap::new();
session_params.insert("test_key".to_string(), serde_json::json!("test_value"));
let request = GenerateRequest {
text: Some("Test prompt".to_string()),
stream: false,
return_logprob: true,
// SGLang extensions
lora_path: Some(crate::openai_api_types::LoRAPath::Single(Some(
"test.bin".to_string(),
))),
session_params: Some(session_params.clone()),
return_hidden_states: true,
rid: Some("test-request-id".to_string()),
// Other fields default to None/false
prompt: None,
input_ids: None,
parameters: None,
sampling_params: None,
};
// Convert to JSON (simulating the simplified routing path)
let mut json = serde_json::to_value(&request).unwrap();
// Inject bootstrap fields
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify all SGLang fields are preserved
assert_eq!(json["text"], serde_json::json!("Test prompt"));
assert_eq!(json["stream"], serde_json::json!(false));
assert_eq!(json["return_logprob"], serde_json::json!(true));
assert_eq!(json["lora_path"], serde_json::json!("test.bin")); // LoRAPath::Single serializes as just the inner value
assert_eq!(
json["session_params"],
serde_json::to_value(&session_params).unwrap()
);
assert_eq!(json["return_hidden_states"], serde_json::json!(true));
assert_eq!(json["rid"], serde_json::json!("test-request-id"));
// Verify bootstrap fields were added
assert_eq!(json["bootstrap_host"], serde_json::json!("test-server"));
assert_eq!(json["bootstrap_port"], serde_json::json!(5678));
assert!(json["bootstrap_room"].is_number());
}
#[tokio::test]
async fn test_simplified_routing_chat_completion() {
use crate::openai_api_types::{ChatCompletionRequest, ChatMessage, UserMessageContent};
use crate::routers::bootstrap_injector::inject_bootstrap_fields;
// Create a test worker
let worker = BasicWorker::new(
"http://chat-server:8000".to_string(),
WorkerType::Prefill {
bootstrap_port: Some(9999),
},
);
// Create a ChatCompletionRequest with SGLang extensions
let request = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Hello world!".to_string()),
name: None,
}],
stream: false,
n: Some(2), // This should create batch bootstrap
// SGLang extensions
top_k: Some(50),
separate_reasoning: false,
stream_reasoning: true,
// Set all other fields to defaults
temperature: None,
top_p: None,
stream_options: None,
stop: None,
max_tokens: None,
max_completion_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
seed: None,
logprobs: false,
top_logprobs: None,
response_format: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
functions: None,
function_call: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
stop_token_ids: None,
no_stop_trim: false,
ignore_eos: false,
continue_final_message: false,
skip_special_tokens: true,
lora_path: None,
session_params: None,
return_hidden_states: false,
};
// Convert to JSON (simulating the simplified routing path)
let mut json = serde_json::to_value(&request).unwrap();
// Inject bootstrap fields
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify original fields preserved
assert_eq!(json["model"], serde_json::json!("gpt-4"));
assert_eq!(json["stream"], serde_json::json!(false));
assert_eq!(json["n"], serde_json::json!(2));
assert_eq!(json["top_k"], serde_json::json!(50));
assert_eq!(json["separate_reasoning"], serde_json::json!(false));
assert_eq!(json["stream_reasoning"], serde_json::json!(true));
// Verify batch bootstrap fields for n=2
let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap();
assert_eq!(bootstrap_hosts.len(), 2);
assert_eq!(bootstrap_hosts[0], serde_json::json!("chat-server"));
assert_eq!(bootstrap_hosts[1], serde_json::json!("chat-server"));
let bootstrap_ports = json["bootstrap_port"].as_array().unwrap();
assert_eq!(bootstrap_ports.len(), 2);
assert_eq!(bootstrap_ports[0], serde_json::json!(9999));
assert_eq!(bootstrap_ports[1], serde_json::json!(9999));
let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap();
assert_eq!(bootstrap_rooms.len(), 2);
// Rooms should be different (randomness)
assert_ne!(bootstrap_rooms[0], bootstrap_rooms[1]);
}
}
// Essential PDLB types extracted for PD routing
use crate::core::{Worker, WorkerType};
use crate::openai_api_types::{CompletionRequest, StringOrArray};
use serde::{Deserialize, Serialize};
use serde_json::Value;
// Custom error type for PD router operations
#[derive(Debug, thiserror::Error)]
pub enum PDRouterError {
......@@ -58,428 +51,3 @@ pub enum PDSelectionPolicy {
balance_rel_threshold: f32,
},
}
// Bootstrap types from PDLB
#[derive(Debug, Deserialize, Serialize, PartialEq)]
#[serde(untagged)]
pub enum SingleOrBatch<T> {
Single(T),
Batch(Vec<T>),
}
pub type InputIds = SingleOrBatch<Vec<i32>>;
pub type InputText = SingleOrBatch<String>;
pub type BootstrapHost = SingleOrBatch<String>;
pub type BootstrapPort = SingleOrBatch<Option<u16>>;
pub type BootstrapRoom = SingleOrBatch<u64>;
// Bootstrap trait for request handling
pub trait Bootstrap: Send + Sync {
fn is_stream(&self) -> bool;
fn get_batch_size(&self) -> Result<Option<usize>, String>;
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
);
fn add_bootstrap_info(&mut self, prefill_worker: &dyn Worker) -> Result<(), String> {
let batch_size = self.get_batch_size()?;
// Extract bootstrap port from prefill worker if it's a prefill type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let hostname = get_hostname(prefill_worker.url());
if let Some(batch_size) = batch_size {
self.set_bootstrap_info(
BootstrapHost::Batch(vec![hostname; batch_size]),
BootstrapPort::Batch(vec![bootstrap_port; batch_size]),
// Use high-quality random numbers to minimize collision risk
BootstrapRoom::Batch(
(0..batch_size)
.map(|_| {
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
rand::random::<u64>() & (i64::MAX as u64)
})
.collect(),
),
);
} else {
self.set_bootstrap_info(
BootstrapHost::Single(hostname),
BootstrapPort::Single(bootstrap_port),
BootstrapRoom::Single(
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
rand::random::<u64>() & (i64::MAX as u64),
),
);
}
Ok(())
}
}
// Request types
#[derive(Debug, Deserialize, Serialize)]
pub struct GenerateReqInput {
pub text: Option<InputText>,
pub input_ids: Option<InputIds>,
#[serde(default)]
pub stream: bool,
pub bootstrap_host: Option<BootstrapHost>,
pub bootstrap_port: Option<BootstrapPort>,
pub bootstrap_room: Option<BootstrapRoom>,
#[serde(flatten)]
pub other: Value,
}
impl GenerateReqInput {
pub fn get_batch_size(&self) -> Result<Option<usize>, String> {
if self.text.is_some() && self.input_ids.is_some() {
return Err("Both text and input_ids are present in the request".to_string());
}
// Check text batch
if let Some(InputText::Batch(texts)) = &self.text {
if texts.is_empty() {
return Err("Batch text array is empty".to_string());
}
return Ok(Some(texts.len()));
}
// Check input_ids batch
if let Some(InputIds::Batch(ids)) = &self.input_ids {
if ids.is_empty() {
return Err("Batch input_ids array is empty".to_string());
}
// Validate each sequence is not empty
for (i, seq) in ids.iter().enumerate() {
if seq.is_empty() {
return Err(format!("Input sequence at index {} is empty", i));
}
}
return Ok(Some(ids.len()));
}
Ok(None)
}
}
impl Bootstrap for GenerateReqInput {
fn is_stream(&self) -> bool {
self.stream
}
fn get_batch_size(&self) -> Result<Option<usize>, String> {
self.get_batch_size()
}
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
) {
self.bootstrap_host = Some(bootstrap_host);
self.bootstrap_port = Some(bootstrap_port);
self.bootstrap_room = Some(bootstrap_room);
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ChatReqInput {
#[serde(default)]
pub stream: bool,
pub bootstrap_host: Option<BootstrapHost>,
pub bootstrap_port: Option<BootstrapPort>,
pub bootstrap_room: Option<BootstrapRoom>,
#[serde(flatten)]
pub other: Value,
}
impl Bootstrap for ChatReqInput {
fn is_stream(&self) -> bool {
self.stream
}
fn get_batch_size(&self) -> Result<Option<usize>, String> {
// Check if 'n' parameter is present and > 1
if let Some(n_value) = self.other.get("n") {
if let Some(n) = n_value.as_u64() {
if n > 1 {
return Ok(Some(n as usize));
}
}
}
Ok(None)
}
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
) {
self.bootstrap_host = Some(bootstrap_host);
self.bootstrap_port = Some(bootstrap_port);
self.bootstrap_room = Some(bootstrap_room);
}
}
// Bootstrap implementation for CompletionRequest to preserve OpenAI format
impl Bootstrap for CompletionRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_batch_size(&self) -> Result<Option<usize>, String> {
if let StringOrArray::Array(prompts) = &self.prompt {
if prompts.is_empty() {
return Err("Batch prompt array is empty".to_string());
}
return Ok(Some(prompts.len()));
}
// Single string prompt
Ok(None)
}
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
) {
// Insert bootstrap_host - it serializes correctly whether Single or Batch
if let Ok(host_value) = serde_json::to_value(&bootstrap_host) {
self.other.insert("bootstrap_host".to_string(), host_value);
}
// Insert bootstrap_port - it serializes correctly whether Single or Batch
if let Ok(port_value) = serde_json::to_value(&bootstrap_port) {
self.other.insert("bootstrap_port".to_string(), port_value);
}
// Insert bootstrap_room - it serializes correctly whether Single or Batch
if let Ok(room_value) = serde_json::to_value(&bootstrap_room) {
self.other.insert("bootstrap_room".to_string(), room_value);
}
}
}
#[cfg(test)]
mod bootstrap_tests {
use super::*;
use crate::core::BasicWorker;
use crate::openai_api_types::StringOrArray;
/// Create a default CompletionRequest for testing with minimal fields set
fn default_completion_request() -> CompletionRequest {
CompletionRequest {
model: String::new(),
prompt: StringOrArray::String(String::new()),
n: None,
other: serde_json::Map::new(),
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
// SGLang Extensions
top_k: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
json_schema: None,
stop_token_ids: None,
no_stop_trim: false,
ignore_eos: false,
skip_special_tokens: true,
// SGLang Extensions
lora_path: None,
session_params: None,
return_hidden_states: false,
}
}
#[test]
fn test_completion_batch_size_with_array_prompt() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
..default_completion_request()
};
// Should return batch size for array prompt
assert_eq!(req.get_batch_size().unwrap(), Some(2));
}
#[test]
fn test_completion_batch_size_with_single_prompt() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::String("single prompt".to_string()),
..default_completion_request()
};
// Should return None for single prompt
assert_eq!(req.get_batch_size().unwrap(), None);
}
#[test]
fn test_completion_batch_size_with_n_parameter() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::String("single prompt".to_string()),
n: Some(3),
..default_completion_request()
};
// Should return None for single string prompt, even with n > 1
// SGLang handles n parameter differently than batch requests
assert_eq!(req.get_batch_size().unwrap(), None);
}
#[test]
fn test_completion_bootstrap_single_values() {
let mut req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
..default_completion_request()
};
// Set bootstrap info - should always use single values
req.set_bootstrap_info(
BootstrapHost::Single("test-server".to_string()),
BootstrapPort::Single(Some(5678)),
BootstrapRoom::Single(12345),
);
// Verify single values were created
assert!(req.other.get("bootstrap_host").unwrap().is_string());
assert!(req.other.get("bootstrap_port").unwrap().is_number());
assert!(req.other.get("bootstrap_room").unwrap().is_number());
assert_eq!(
req.other.get("bootstrap_host").unwrap().as_str().unwrap(),
"test-server"
);
assert_eq!(
req.other.get("bootstrap_port").unwrap().as_u64().unwrap(),
5678
);
assert_eq!(
req.other.get("bootstrap_room").unwrap().as_u64().unwrap(),
12345
);
}
#[test]
fn test_completion_bootstrap_array_values() {
let mut req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
..default_completion_request()
};
// Set bootstrap info with arrays
req.set_bootstrap_info(
BootstrapHost::Batch(vec!["test-server".to_string(); 2]),
BootstrapPort::Batch(vec![Some(5678); 2]),
BootstrapRoom::Batch(vec![12345, 67890]),
);
// Verify arrays were created correctly
assert!(req.other.get("bootstrap_host").unwrap().is_array());
assert!(req.other.get("bootstrap_port").unwrap().is_array());
assert!(req.other.get("bootstrap_room").unwrap().is_array());
let hosts = req.other.get("bootstrap_host").unwrap().as_array().unwrap();
assert_eq!(hosts.len(), 2);
assert_eq!(hosts[0].as_str().unwrap(), "test-server");
let ports = req.other.get("bootstrap_port").unwrap().as_array().unwrap();
assert_eq!(ports.len(), 2);
assert_eq!(ports[0].as_u64().unwrap(), 5678);
let rooms = req.other.get("bootstrap_room").unwrap().as_array().unwrap();
assert_eq!(rooms.len(), 2);
assert_eq!(rooms[0].as_u64().unwrap(), 12345);
assert_eq!(rooms[1].as_u64().unwrap(), 67890);
}
#[test]
fn test_bootstrap_room_range() {
// Test that bootstrap_room values are within the expected range [0, 2^63 - 1]
let worker = BasicWorker::new(
"http://test:8000".to_string(),
WorkerType::Prefill {
bootstrap_port: Some(8080),
},
);
// Test single request
let mut single_req = GenerateReqInput {
text: Some(InputText::Single("test".to_string())),
input_ids: None,
stream: false,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(serde_json::Map::new()),
};
for _ in 0..200000 {
single_req.add_bootstrap_info(&worker).unwrap();
if let Some(BootstrapRoom::Single(room)) = single_req.bootstrap_room {
// Verify the room value is within signed 64-bit range
assert!(room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room);
} else {
panic!("Expected single bootstrap room");
}
}
// Test batch request
let mut batch_req = GenerateReqInput {
text: Some(InputText::Batch(vec![
"test1".to_string(),
"test2".to_string(),
])),
input_ids: None,
stream: false,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(serde_json::Map::new()),
};
for _ in 0..200000 {
batch_req.add_bootstrap_info(&worker).unwrap();
if let Some(BootstrapRoom::Batch(rooms)) = &batch_req.bootstrap_room {
for room in rooms {
// Verify each room value is within signed 64-bit range
assert!(*room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room);
}
} else {
panic!("Expected batch bootstrap rooms");
}
}
}
}
// Request adapter to bridge OpenAI API types with PD routing requirements
use super::pd_types::{Bootstrap, ChatReqInput, GenerateReqInput, SingleOrBatch};
use crate::openai_api_types::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, StringOrArray,
};
use serde_json::Value;
/// Adapter trait to convert OpenAI requests to PD-compatible requests
pub trait ToPdRequest {
type Output: Bootstrap;
fn to_pd_request(self) -> Self::Output;
}
// Helper macro to insert optional fields into a map
macro_rules! insert_if_some {
($map:expr, $($field:expr => $key:expr),* $(,)?) => {
$(
if let Some(value) = $field {
$map.insert($key.to_string(), serde_json::to_value(value).unwrap_or(Value::Null));
}
)*
};
}
// Helper macro for simple value insertions
macro_rules! insert_value {
($map:expr, $($field:expr => $key:expr),* $(,)?) => {
$(
$map.insert($key.to_string(), $field.into());
)*
};
}
// ============= Generate Request Adapter =============
impl ToPdRequest for GenerateRequest {
type Output = GenerateReqInput;
fn to_pd_request(self) -> Self::Output {
// Build the other fields first
let mut other = serde_json::Map::new();
// Handle text input - check in priority order: text (SGLang), prompt (OpenAI)
let (text, input_ids) = if let Some(text_str) = self.text {
// SGLang native format
(Some(SingleOrBatch::Single(text_str)), None)
} else if let Some(prompt) = self.prompt {
// OpenAI style prompt
let text = match prompt {
StringOrArray::String(s) => Some(SingleOrBatch::Single(s)),
StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)),
};
(text, None)
} else if let Some(ids) = self.input_ids {
// Input IDs case
let input_ids = match ids {
crate::openai_api_types::InputIds::Single(ids) => Some(SingleOrBatch::Single(ids)),
crate::openai_api_types::InputIds::Batch(ids) => Some(SingleOrBatch::Batch(ids)),
};
(None, input_ids)
} else {
// No input provided
(None, None)
};
// Add parameters to other - handle both old and new style
if let Some(params) = self.parameters {
// For generate endpoint, extract max_new_tokens to top level if present
let mut params_value = serde_json::to_value(&params).unwrap_or(Value::Null);
if let Value::Object(ref mut params_map) = params_value {
// Move max_new_tokens to top level if it exists
if let Some(max_new_tokens) = params_map.remove("max_new_tokens") {
other.insert("max_new_tokens".to_string(), max_new_tokens);
}
// Move temperature to top level if it exists
if let Some(temperature) = params_map.remove("temperature") {
other.insert("temperature".to_string(), temperature);
}
}
// Only add parameters if there are remaining fields
if !params_value.is_null() && params_value.as_object().map_or(false, |m| !m.is_empty())
{
other.insert("parameters".to_string(), params_value);
}
}
// Add sampling_params if present
if let Some(sampling_params) = self.sampling_params {
let params_value = serde_json::to_value(&sampling_params).unwrap_or(Value::Null);
if !params_value.is_null() {
// Extract commonly used fields to top level
if let Value::Object(ref params_map) = params_value {
if let Some(max_new_tokens) = params_map.get("max_new_tokens") {
other.insert("max_new_tokens".to_string(), max_new_tokens.clone());
}
if let Some(temperature) = params_map.get("temperature") {
other.insert("temperature".to_string(), temperature.clone());
}
}
other.insert("sampling_params".to_string(), params_value);
}
}
// Add other fields
insert_value!(other,
self.stream => "stream",
self.return_logprob => "return_logprob"
);
GenerateReqInput {
text,
input_ids,
stream: self.stream,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(other),
}
}
}
// ============= Completion Request Adapter =============
impl ToPdRequest for CompletionRequest {
type Output = GenerateReqInput;
fn to_pd_request(self) -> Self::Output {
// Convert CompletionRequest to GenerateReqInput
let text = match self.prompt {
StringOrArray::String(s) => Some(SingleOrBatch::Single(s)),
StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)),
};
// Map OpenAI parameters to generate parameters
let mut other = serde_json::Map::new();
// Create parameters object
let mut params = serde_json::Map::new();
// Map OpenAI fields to internal parameter names
insert_if_some!(params,
self.max_tokens => "max_new_tokens",
self.temperature => "temperature",
self.top_p => "top_p",
self.n => "best_of",
self.logprobs => "top_n_tokens",
self.seed => "seed"
);
// Special handling for fields that need transformation
if let Some(presence_penalty) = self.presence_penalty {
params.insert(
"repetition_penalty".to_string(),
(1.0 + presence_penalty).into(),
);
}
if let Some(stop) = self.stop {
let stop_sequences = match stop {
StringOrArray::String(s) => vec![s],
StringOrArray::Array(v) => v,
};
params.insert("stop".to_string(), stop_sequences.into());
}
if self.echo {
params.insert("return_full_text".to_string(), true.into());
}
other.insert("parameters".to_string(), Value::Object(params));
// Store original model and stream flag
insert_value!(other,
self.model => "model",
self.stream => "stream"
);
// Add SGLang extension fields
insert_if_some!(other,
// SGLang Extensions - Priority 1
self.top_k => "top_k",
self.min_p => "min_p",
self.min_tokens => "min_tokens",
self.repetition_penalty => "repetition_penalty",
self.regex => "regex",
self.ebnf => "ebnf",
self.stop_token_ids => "stop_token_ids",
// SGLang Extensions - Priority 2
self.lora_path => "lora_path",
self.session_params => "session_params"
);
// SGLang boolean extensions (CompletionRequest has these as bool, not Option<bool>)
other.insert("no_stop_trim".to_string(), self.no_stop_trim.into());
other.insert("ignore_eos".to_string(), self.ignore_eos.into());
other.insert(
"skip_special_tokens".to_string(),
self.skip_special_tokens.into(),
);
other.insert(
"return_hidden_states".to_string(),
self.return_hidden_states.into(),
);
GenerateReqInput {
text,
input_ids: None,
stream: self.stream,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(other),
}
}
}
// ============= Chat Completion Request Adapter =============
impl ToPdRequest for ChatCompletionRequest {
type Output = ChatReqInput;
fn to_pd_request(self) -> Self::Output {
let mut other = serde_json::Map::new();
// Add required fields
insert_if_some!(other,
Some(&self.messages) => "messages"
);
insert_value!(other,
self.model => "model",
self.stream => "stream"
);
// Add all optional fields
insert_if_some!(other,
self.temperature => "temperature",
self.top_p => "top_p",
self.n => "n",
self.stream_options => "stream_options",
self.stop => "stop",
self.max_tokens => "max_tokens",
self.max_completion_tokens => "max_completion_tokens",
self.presence_penalty => "presence_penalty",
self.frequency_penalty => "frequency_penalty",
self.logit_bias => "logit_bias",
self.user => "user",
self.seed => "seed",
self.top_logprobs => "top_logprobs",
self.response_format => "response_format",
self.tools => "tools",
self.tool_choice => "tool_choice",
self.parallel_tool_calls => "parallel_tool_calls",
self.functions => "functions",
self.function_call => "function_call",
// SGLang Extensions - Priority 1
self.top_k => "top_k",
self.min_p => "min_p",
self.min_tokens => "min_tokens",
self.repetition_penalty => "repetition_penalty",
self.regex => "regex",
self.ebnf => "ebnf",
self.stop_token_ids => "stop_token_ids",
// SGLang Extensions - Priority 2
self.lora_path => "lora_path",
self.session_params => "session_params"
);
// Handle boolean flags
if self.logprobs {
other.insert("logprobs".to_string(), true.into());
}
// SGLang boolean extensions (ChatCompletionRequest has these as bool, not Option<bool>)
other.insert("no_stop_trim".to_string(), self.no_stop_trim.into());
other.insert("ignore_eos".to_string(), self.ignore_eos.into());
other.insert(
"continue_final_message".to_string(),
self.continue_final_message.into(),
);
other.insert(
"skip_special_tokens".to_string(),
self.skip_special_tokens.into(),
);
other.insert(
"separate_reasoning".to_string(),
self.separate_reasoning.into(),
);
other.insert("stream_reasoning".to_string(), self.stream_reasoning.into());
other.insert(
"return_hidden_states".to_string(),
self.return_hidden_states.into(),
);
ChatReqInput {
stream: self.stream,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(other),
}
}
}
// ============= Direct routing support for regular router =============
/// Extension trait for routing without PD conversion
pub trait RouteableRequest: GenerationRequest + serde::Serialize + Clone {
/// Convert to JSON for sending to backend
fn to_json(&self) -> Result<Value, serde_json::Error> {
serde_json::to_value(self)
}
/// Convert to bytes for legacy routing
fn to_bytes(&self) -> Result<bytes::Bytes, serde_json::Error> {
let json = serde_json::to_vec(self)?;
Ok(bytes::Bytes::from(json))
}
}
impl RouteableRequest for GenerateRequest {}
impl RouteableRequest for CompletionRequest {}
impl RouteableRequest for ChatCompletionRequest {}
#[cfg(test)]
mod tests {
use super::*;
use crate::openai_api_types::*;
use serde_json::json;
use std::collections::HashMap;
// ============= Test Helper Functions =============
//
// These helper functions create default request instances with all required SGLang extension fields
// properly initialized. Use the struct spread operator `..default_*_request()` to override only
// the fields you need for specific tests, avoiding repetitive boilerplate code.
//
// Example usage:
// let req = GenerateRequest {
// text: Some("Custom text".to_string()),
// stream: true,
// ..default_generate_request()
// };
/// Create a default GenerateRequest with minimal fields set
fn default_generate_request() -> GenerateRequest {
GenerateRequest {
text: None,
prompt: None,
input_ids: None,
stream: false,
parameters: None,
sampling_params: None,
return_logprob: false,
// SGLang Extensions
lora_path: None,
session_params: None,
return_hidden_states: false,
rid: None,
}
}
/// Create a default CompletionRequest with minimal fields set
fn default_completion_request() -> CompletionRequest {
CompletionRequest {
model: "test-model".to_string(),
prompt: StringOrArray::String("test prompt".to_string()),
max_tokens: None,
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
suffix: None,
// SGLang Extensions
top_k: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
json_schema: None,
stop_token_ids: None,
no_stop_trim: false,
ignore_eos: false,
skip_special_tokens: true,
// SGLang Extensions
lora_path: None,
session_params: None,
return_hidden_states: false,
other: serde_json::Map::new(),
}
}
/// Create a default ChatCompletionRequest with minimal fields set
fn default_chat_completion_request() -> ChatCompletionRequest {
ChatCompletionRequest {
model: "test-model".to_string(),
messages: vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("test message".to_string()),
name: None,
}],
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
stop: None,
max_tokens: None,
max_completion_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
logprobs: false,
top_logprobs: None,
user: None,
seed: None,
response_format: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
functions: None,
function_call: None,
// SGLang Extensions
top_k: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
stop_token_ids: None,
no_stop_trim: false,
ignore_eos: false,
continue_final_message: false,
skip_special_tokens: true,
// SGLang Extensions
lora_path: None,
session_params: None,
separate_reasoning: true,
stream_reasoning: true,
return_hidden_states: false,
}
}
// ============= GenerateRequest to_pd_request Tests =============
#[test]
fn test_generate_to_pd_request_with_text_only() {
let req = GenerateRequest {
text: Some("Hello world".to_string()),
..default_generate_request()
};
let pd_req = req.to_pd_request();
// Check text field conversion
assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Hello world"));
assert!(pd_req.input_ids.is_none());
// Check bootstrap fields are None
assert!(pd_req.bootstrap_host.is_none());
assert!(pd_req.bootstrap_port.is_none());
assert!(pd_req.bootstrap_room.is_none());
// Check stream flag
assert_eq!(pd_req.stream, false);
// Check other fields
let other = pd_req.other.as_object().unwrap();
assert_eq!(other.get("stream"), Some(&json!(false)));
assert_eq!(other.get("return_logprob"), Some(&json!(false)));
}
#[test]
fn test_generate_to_pd_request_with_prompt_string() {
let req = GenerateRequest {
prompt: Some(StringOrArray::String("Test prompt".to_string())),
stream: true,
return_logprob: true,
..default_generate_request()
};
let pd_req = req.to_pd_request();
assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Test prompt"));
assert!(pd_req.input_ids.is_none());
assert_eq!(pd_req.stream, true);
let other = pd_req.other.as_object().unwrap();
assert_eq!(other.get("stream"), Some(&json!(true)));
assert_eq!(other.get("return_logprob"), Some(&json!(true)));
}
#[test]
fn test_generate_to_pd_request_with_prompt_array() {
let req = GenerateRequest {
text: None,
prompt: Some(StringOrArray::Array(vec![
"Prompt 1".to_string(),
"Prompt 2".to_string(),
"Prompt 3".to_string(),
])),
input_ids: None,
stream: false,
parameters: None,
sampling_params: None,
return_logprob: false,
..default_generate_request()
};
let pd_req = req.to_pd_request();
match pd_req.text {
Some(SingleOrBatch::Batch(ref batch)) => {
assert_eq!(batch.len(), 3);
assert_eq!(batch[0], "Prompt 1");
assert_eq!(batch[1], "Prompt 2");
assert_eq!(batch[2], "Prompt 3");
}
_ => panic!("Expected batch text"),
}
}
#[test]
fn test_generate_to_pd_request_with_single_input_ids() {
let req = GenerateRequest {
input_ids: Some(InputIds::Single(vec![100, 200, 300, 400])),
..default_generate_request()
};
let pd_req = req.to_pd_request();
assert!(pd_req.text.is_none());
assert!(matches!(
pd_req.input_ids,
Some(SingleOrBatch::Single(ref ids)) if ids == &vec![100, 200, 300, 400]
));
}
#[test]
fn test_generate_to_pd_request_with_batch_input_ids() {
let req = GenerateRequest {
input_ids: Some(InputIds::Batch(vec![
vec![1, 2, 3],
vec![4, 5, 6, 7],
vec![8, 9],
])),
..default_generate_request()
};
let pd_req = req.to_pd_request();
match pd_req.input_ids {
Some(SingleOrBatch::Batch(ref batch)) => {
assert_eq!(batch.len(), 3);
assert_eq!(batch[0], vec![1, 2, 3]);
assert_eq!(batch[1], vec![4, 5, 6, 7]);
assert_eq!(batch[2], vec![8, 9]);
}
_ => panic!("Expected batch input_ids"),
}
}
#[test]
fn test_generate_to_pd_request_priority_text_over_prompt() {
let req = GenerateRequest {
text: Some("SGLang text".to_string()),
prompt: Some(StringOrArray::String("OpenAI prompt".to_string())),
input_ids: Some(InputIds::Single(vec![1, 2, 3])),
..default_generate_request()
};
let pd_req = req.to_pd_request();
// text should take priority
assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "SGLang text"));
assert!(pd_req.input_ids.is_none());
}
#[test]
fn test_generate_to_pd_request_priority_prompt_over_input_ids() {
let req = GenerateRequest {
prompt: Some(StringOrArray::String("OpenAI prompt".to_string())),
input_ids: Some(InputIds::Single(vec![1, 2, 3])),
..default_generate_request()
};
let pd_req = req.to_pd_request();
// prompt should take priority over input_ids
assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "OpenAI prompt"));
assert!(pd_req.input_ids.is_none());
}
#[test]
fn test_generate_to_pd_request_with_parameters() {
let params = GenerateParameters {
max_new_tokens: Some(100),
temperature: Some(0.8),
top_p: Some(0.95),
seed: Some(12345),
stop: Some(vec!["END".to_string(), "STOP".to_string()]),
repetition_penalty: Some(1.1),
..Default::default()
};
let req = GenerateRequest {
text: Some("test".to_string()),
parameters: Some(params),
..default_generate_request()
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Check that max_new_tokens and temperature were extracted to top level
assert_eq!(other.get("max_new_tokens"), Some(&json!(100)));
assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.8 < 0.0001);
// Check that other parameters remain under "parameters"
let params = other.get("parameters").unwrap().as_object().unwrap();
assert!(params.get("top_p").unwrap().as_f64().unwrap() - 0.95 < 0.0001);
assert_eq!(params.get("seed"), Some(&json!(12345)));
assert_eq!(params.get("stop"), Some(&json!(vec!["END", "STOP"])));
assert!(params.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.1 < 0.0001);
}
#[test]
fn test_generate_to_pd_request_with_sampling_params() {
let sampling = SamplingParams {
max_new_tokens: Some(200),
temperature: Some(0.7),
top_p: Some(0.9),
top_k: Some(50),
frequency_penalty: Some(0.1),
presence_penalty: Some(0.2),
repetition_penalty: Some(1.05),
..Default::default()
};
let req = GenerateRequest {
text: Some("test".to_string()),
sampling_params: Some(sampling),
..default_generate_request()
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Check extracted top-level fields
assert_eq!(other.get("max_new_tokens"), Some(&json!(200)));
assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.7 < 0.0001);
// Check full sampling_params is preserved
let sampling = other.get("sampling_params").unwrap().as_object().unwrap();
assert_eq!(sampling.get("max_new_tokens"), Some(&json!(200)));
assert!(sampling.get("temperature").unwrap().as_f64().unwrap() - 0.7 < 0.0001);
assert!(sampling.get("top_p").unwrap().as_f64().unwrap() - 0.9 < 0.0001);
assert_eq!(sampling.get("top_k"), Some(&json!(50)));
assert!(sampling.get("frequency_penalty").unwrap().as_f64().unwrap() - 0.1 < 0.0001);
assert!(sampling.get("presence_penalty").unwrap().as_f64().unwrap() - 0.2 < 0.0001);
}
#[test]
fn test_generate_to_pd_request_sampling_params_override_parameters() {
// When both parameters and sampling_params have max_new_tokens/temperature,
// sampling_params should take precedence (processed last)
let params = GenerateParameters {
max_new_tokens: Some(100),
temperature: Some(0.5),
..Default::default()
};
let sampling = SamplingParams {
max_new_tokens: Some(200),
temperature: Some(0.9),
..Default::default()
};
let req = GenerateRequest {
text: Some("test".to_string()),
prompt: None,
input_ids: None,
stream: false,
parameters: Some(params),
sampling_params: Some(sampling),
return_logprob: false,
..default_generate_request()
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Should use values from sampling_params since they're processed last
assert_eq!(other.get("max_new_tokens"), Some(&json!(200)));
assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.9 < 0.0001);
}
#[test]
fn test_generate_to_pd_request_empty_parameters() {
let params = GenerateParameters::default();
let req = GenerateRequest {
text: Some("test".to_string()),
prompt: None,
input_ids: None,
stream: false,
parameters: Some(params),
sampling_params: None,
return_logprob: false,
..default_generate_request()
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Should not have parameters field if all values are None/default
assert!(!other.contains_key("parameters"));
assert!(!other.contains_key("max_new_tokens"));
assert!(!other.contains_key("temperature"));
}
#[test]
fn test_generate_to_pd_request_all_fields() {
let params = GenerateParameters {
max_new_tokens: Some(150),
temperature: Some(0.6),
top_k: Some(40),
..Default::default()
};
let sampling = SamplingParams {
max_new_tokens: Some(250), // Will override parameters
temperature: Some(0.8), // Will override parameters
presence_penalty: Some(0.1),
..Default::default()
};
let req = GenerateRequest {
text: Some("Complex test".to_string()),
prompt: Some(StringOrArray::String("Ignored prompt".to_string())),
input_ids: None,
stream: true,
parameters: Some(params),
sampling_params: Some(sampling),
return_logprob: true,
..default_generate_request()
};
let pd_req = req.to_pd_request();
// Verify all fields
assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Complex test"));
assert!(pd_req.input_ids.is_none());
assert_eq!(pd_req.stream, true);
assert!(pd_req.bootstrap_host.is_none());
assert!(pd_req.bootstrap_port.is_none());
assert!(pd_req.bootstrap_room.is_none());
let other = pd_req.other.as_object().unwrap();
assert_eq!(other.get("stream"), Some(&json!(true)));
assert_eq!(other.get("return_logprob"), Some(&json!(true)));
// Sampling params override parameters
assert_eq!(other.get("max_new_tokens"), Some(&json!(250)));
assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.8 < 0.0001);
assert!(other.contains_key("parameters"));
assert!(other.contains_key("sampling_params"));
}
// ============= CompletionRequest to_pd_request Tests =============
#[test]
fn test_completion_to_pd_request_basic() {
let req = CompletionRequest {
model: "gpt-3.5-turbo".to_string(),
prompt: StringOrArray::String("Complete this sentence".to_string()),
..default_completion_request()
};
let pd_req = req.to_pd_request();
assert!(
matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Complete this sentence")
);
assert!(pd_req.input_ids.is_none());
assert_eq!(pd_req.stream, false);
let other = pd_req.other.as_object().unwrap();
assert_eq!(other.get("model"), Some(&json!("gpt-3.5-turbo")));
assert_eq!(other.get("stream"), Some(&json!(false)));
}
#[test]
fn test_completion_to_pd_request_array_prompt() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::Array(vec![
"First prompt".to_string(),
"Second prompt".to_string(),
]),
..default_completion_request()
};
let pd_req = req.to_pd_request();
match pd_req.text {
Some(SingleOrBatch::Batch(ref batch)) => {
assert_eq!(batch.len(), 2);
assert_eq!(batch[0], "First prompt");
assert_eq!(batch[1], "Second prompt");
}
_ => panic!("Expected batch text"),
}
}
#[test]
fn test_completion_to_pd_request_parameter_mapping() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::String("test".to_string()),
max_tokens: Some(150), // -> max_new_tokens
temperature: Some(0.75),
top_p: Some(0.92),
n: Some(3), // -> best_of
stream: true,
stream_options: None,
logprobs: Some(10), // -> top_n_tokens
echo: true, // -> return_full_text
stop: Some(StringOrArray::Array(vec![
"\\n".to_string(),
"END".to_string(),
])),
presence_penalty: Some(0.5), // -> repetition_penalty = 1.5
frequency_penalty: Some(0.2),
best_of: Some(5),
logit_bias: None,
user: Some("user123".to_string()),
seed: Some(42),
suffix: Some("...".to_string()),
..default_completion_request()
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
let params = other.get("parameters").unwrap().as_object().unwrap();
// Check parameter mappings
assert_eq!(params.get("max_new_tokens"), Some(&json!(150)));
assert!(params.get("temperature").unwrap().as_f64().unwrap() - 0.75 < 0.0001);
assert!(params.get("top_p").unwrap().as_f64().unwrap() - 0.92 < 0.0001);
assert_eq!(params.get("best_of"), Some(&json!(3)));
assert_eq!(params.get("top_n_tokens"), Some(&json!(10)));
assert_eq!(params.get("return_full_text"), Some(&json!(true)));
assert_eq!(params.get("stop"), Some(&json!(vec!["\\n", "END"])));
assert!(params.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.5 < 0.0001);
assert_eq!(params.get("seed"), Some(&json!(42)));
// Check other fields
assert_eq!(other.get("model"), Some(&json!("test")));
assert_eq!(other.get("stream"), Some(&json!(true)));
}
#[test]
fn test_completion_to_pd_request_stop_string() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::String("test".to_string()),
stop: Some(StringOrArray::String("STOP".to_string())),
max_tokens: None,
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
suffix: None,
..default_completion_request()
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
let params = other.get("parameters").unwrap().as_object().unwrap();
// Single string stop should be converted to array
assert_eq!(params.get("stop"), Some(&json!(vec!["STOP"])));
}
#[test]
fn test_completion_to_pd_request_no_presence_penalty() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::String("test".to_string()),
presence_penalty: None,
max_tokens: None,
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
suffix: None,
..default_completion_request()
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
let params = other.get("parameters").unwrap().as_object().unwrap();
// Should not have repetition_penalty if presence_penalty is None
assert!(!params.contains_key("repetition_penalty"));
}
// ============= ChatCompletionRequest to_pd_request Tests =============
#[test]
fn test_chat_to_pd_request_basic() {
let messages = vec![
ChatMessage::System {
role: "system".to_string(),
content: "You are a helpful assistant".to_string(),
name: None,
},
ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Hello!".to_string()),
name: None,
},
];
let req = ChatCompletionRequest {
messages,
model: "gpt-4".to_string(),
..default_chat_completion_request()
};
let pd_req = req.to_pd_request();
assert_eq!(pd_req.stream, false);
assert!(pd_req.bootstrap_host.is_none());
assert!(pd_req.bootstrap_port.is_none());
assert!(pd_req.bootstrap_room.is_none());
let other = pd_req.other.as_object().unwrap();
assert!(other.contains_key("messages"));
assert_eq!(other.get("model"), Some(&json!("gpt-4")));
assert_eq!(other.get("stream"), Some(&json!(false)));
// Check messages are preserved
let messages = other.get("messages").unwrap().as_array().unwrap();
assert_eq!(messages.len(), 2);
}
#[test]
fn test_chat_to_pd_request_with_all_optional_fields() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Test".to_string()),
name: Some("test_user".to_string()),
}];
let mut logit_bias = HashMap::new();
logit_bias.insert("50256".to_string(), -100.0f32);
let tool = Tool {
tool_type: "function".to_string(),
function: Function {
name: "get_weather".to_string(),
description: Some("Get weather info".to_string()),
parameters: json!({"type": "object"}),
},
};
let req = ChatCompletionRequest {
messages,
model: "gpt-4".to_string(),
temperature: Some(0.8),
top_p: Some(0.95),
n: Some(2),
stream: true,
stream_options: Some(StreamOptions {
include_usage: Some(true),
}),
stop: Some(StringOrArray::String("\\n\\n".to_string())),
max_tokens: Some(200),
max_completion_tokens: Some(150),
presence_penalty: Some(0.1),
frequency_penalty: Some(0.2),
logit_bias: Some(logit_bias),
logprobs: true,
top_logprobs: Some(5),
user: Some("user456".to_string()),
seed: Some(12345),
response_format: Some(ResponseFormat::JsonObject),
tools: Some(vec![tool]),
tool_choice: Some(ToolChoice::Auto),
parallel_tool_calls: Some(false),
functions: None,
function_call: None,
..default_chat_completion_request()
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Check all fields are preserved
assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.8 < 0.0001);
assert!(other.get("top_p").unwrap().as_f64().unwrap() - 0.95 < 0.0001);
assert_eq!(other.get("n"), Some(&json!(2)));
assert_eq!(other.get("stream"), Some(&json!(true)));
assert!(other.contains_key("stream_options"));
assert!(other.contains_key("stop"));
assert_eq!(other.get("max_tokens"), Some(&json!(200)));
assert_eq!(other.get("max_completion_tokens"), Some(&json!(150)));
assert!(other.get("presence_penalty").unwrap().as_f64().unwrap() - 0.1 < 0.0001);
assert!(other.get("frequency_penalty").unwrap().as_f64().unwrap() - 0.2 < 0.0001);
assert!(other.contains_key("logit_bias"));
assert_eq!(other.get("logprobs"), Some(&json!(true)));
assert_eq!(other.get("top_logprobs"), Some(&json!(5)));
assert_eq!(other.get("user"), Some(&json!("user456")));
assert_eq!(other.get("seed"), Some(&json!(12345)));
assert!(other.contains_key("response_format"));
assert!(other.contains_key("tools"));
assert!(other.contains_key("tool_choice"));
assert_eq!(other.get("parallel_tool_calls"), Some(&json!(false)));
}
#[test]
fn test_chat_to_pd_request_multimodal_content() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![
ContentPart::Text {
text: "What's in this image?".to_string(),
},
ContentPart::ImageUrl {
image_url: ImageUrl {
url: "https://example.com/image.jpg".to_string(),
detail: Some("high".to_string()),
},
},
]),
name: None,
}];
let req = ChatCompletionRequest {
messages,
model: "gpt-4-vision".to_string(),
..default_chat_completion_request()
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Messages with multimodal content should be preserved
assert!(other.contains_key("messages"));
let messages = other.get("messages").unwrap().as_array().unwrap();
assert_eq!(messages.len(), 1);
// Verify the message structure is preserved
let msg = &messages[0];
assert_eq!(msg["role"], "user");
assert!(msg["content"].is_array());
}
#[test]
fn test_chat_to_pd_request_logprobs_boolean() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Test".to_string()),
name: None,
}];
let req = ChatCompletionRequest {
messages,
model: "test".to_string(),
logprobs: true, // Boolean logprobs flag
top_logprobs: Some(3),
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
stop: None,
max_tokens: None,
max_completion_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
seed: None,
response_format: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
functions: None,
function_call: None,
..default_chat_completion_request()
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
assert_eq!(other.get("logprobs"), Some(&json!(true)));
assert_eq!(other.get("top_logprobs"), Some(&json!(3)));
}
#[test]
fn test_chat_to_pd_request_minimal_fields() {
let messages = vec![ChatMessage::Assistant {
role: "assistant".to_string(),
content: Some("I can help with that.".to_string()),
name: None,
tool_calls: None,
function_call: None,
reasoning_content: None,
}];
let req = ChatCompletionRequest {
messages,
model: "gpt-3.5-turbo".to_string(),
..default_chat_completion_request()
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Should only have required fields
assert!(other.contains_key("messages"));
assert!(other.contains_key("model"));
assert!(other.contains_key("stream"));
// Optional fields should not be present
assert!(!other.contains_key("temperature"));
assert!(!other.contains_key("top_p"));
assert!(!other.contains_key("max_tokens"));
assert!(!other.contains_key("stop"));
}
#[test]
fn test_routeable_request_to_json() {
let req = GenerateRequest {
text: Some("test".to_string()),
..default_generate_request()
};
let json = req.to_json().unwrap();
assert_eq!(json["text"], "test");
assert_eq!(json["stream"], false);
}
// ============= Macro Tests =============
#[test]
fn test_insert_if_some_macro() {
let mut map = serde_json::Map::new();
let some_value: Option<i32> = Some(42);
let none_value: Option<i32> = None;
insert_if_some!(map,
some_value => "present",
none_value => "absent"
);
assert_eq!(map.get("present"), Some(&json!(42)));
assert!(!map.contains_key("absent"));
}
#[test]
fn test_insert_value_macro() {
let mut map = serde_json::Map::new();
let value1 = "test";
let value2 = 42;
insert_value!(map,
value1 => "string_field",
value2 => "int_field"
);
assert_eq!(map.get("string_field"), Some(&json!("test")));
assert_eq!(map.get("int_field"), Some(&json!(42)));
}
// ============= Edge Cases and Error Handling =============
#[test]
fn test_null_value_handling() {
let params = GenerateParameters {
max_new_tokens: None,
temperature: None,
..Default::default()
};
let req = GenerateRequest {
text: Some("test".to_string()),
prompt: None,
input_ids: None,
stream: false,
parameters: Some(params),
sampling_params: None,
return_logprob: false,
..default_generate_request()
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Should not have parameters field if all fields are None
assert!(!other.contains_key("parameters"));
}
#[test]
fn test_large_batch_conversion() {
let large_batch: Vec<String> = (0..1000).map(|i| format!("item_{}", i)).collect();
let req = GenerateRequest {
text: None,
prompt: Some(StringOrArray::Array(large_batch.clone())),
input_ids: None,
stream: false,
parameters: None,
sampling_params: None,
return_logprob: false,
..default_generate_request()
};
let pd_req = req.to_pd_request();
if let Some(SingleOrBatch::Batch(batch)) = pd_req.text {
assert_eq!(batch.len(), 1000);
assert_eq!(batch[0], "item_0");
assert_eq!(batch[999], "item_999");
} else {
panic!("Expected batch text");
}
}
#[test]
fn test_unicode_string_handling() {
let unicode_text = "Hello 世界 🌍 नमस्ते мир".to_string();
let req = GenerateRequest {
text: Some(unicode_text.clone()),
..default_generate_request()
};
let pd_req = req.to_pd_request();
if let Some(SingleOrBatch::Single(text)) = pd_req.text {
assert_eq!(text, unicode_text);
} else {
panic!("Expected single text");
}
}
#[test]
fn test_deeply_nested_parameters() {
let mut nested_params = serde_json::Map::new();
nested_params.insert(
"nested".to_string(),
json!({
"level1": {
"level2": {
"level3": "value"
}
}
}),
);
let params = GenerateParameters {
max_new_tokens: Some(100),
..Default::default()
};
let req = GenerateRequest {
text: Some("test".to_string()),
prompt: None,
input_ids: None,
stream: false,
parameters: Some(params),
sampling_params: None,
return_logprob: false,
..default_generate_request()
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Parameters should be preserved even with nested structures
assert!(other.contains_key("max_new_tokens"));
}
// ============= Bootstrap Field Tests =============
#[test]
fn test_bootstrap_fields_none() {
let req = GenerateRequest {
text: Some("test".to_string()),
..default_generate_request()
};
let pd_req = req.to_pd_request();
assert_eq!(pd_req.bootstrap_host, None);
assert_eq!(pd_req.bootstrap_port, None);
assert_eq!(pd_req.bootstrap_room, None);
}
// ============= SGLang Extension Field Pass-Through Tests =============
#[test]
fn test_chat_completion_sglang_extensions_passed_through() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Test".to_string()),
name: None,
}];
let mut session_params = std::collections::HashMap::new();
session_params.insert(
"key".to_string(),
serde_json::Value::String("value".to_string()),
);
let req = ChatCompletionRequest {
messages,
model: "test-model".to_string(),
// SGLang Extensions - Priority 1
top_k: Some(40),
min_p: Some(0.05),
min_tokens: Some(10),
repetition_penalty: Some(1.1),
regex: Some("test_regex".to_string()),
ebnf: Some("test_ebnf".to_string()),
stop_token_ids: Some(vec![1, 2, 3]),
// SGLang Extensions - Priority 2
lora_path: Some(LoRAPath::Single(Some("test_lora.bin".to_string()))),
session_params: Some(session_params.clone()),
// Boolean extensions (ChatCompletionRequest has these as bool, not Option<bool>)
no_stop_trim: true,
ignore_eos: false,
continue_final_message: true,
skip_special_tokens: false,
separate_reasoning: true,
stream_reasoning: false,
return_hidden_states: true,
..default_chat_completion_request()
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Verify SGLang extensions are passed through
assert_eq!(other.get("top_k"), Some(&json!(40)));
assert!((other.get("min_p").unwrap().as_f64().unwrap() - 0.05).abs() < 0.0001);
assert_eq!(other.get("min_tokens"), Some(&json!(10)));
assert!((other.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.1).abs() < 0.0001);
assert_eq!(other.get("regex"), Some(&json!("test_regex")));
assert_eq!(other.get("ebnf"), Some(&json!("test_ebnf")));
assert_eq!(other.get("stop_token_ids"), Some(&json!(vec![1, 2, 3])));
assert_eq!(other.get("lora_path"), Some(&json!("test_lora.bin")));
assert_eq!(
other.get("session_params"),
Some(&serde_json::to_value(&session_params).unwrap())
);
// Verify boolean extensions
assert_eq!(other.get("no_stop_trim"), Some(&json!(true)));
assert_eq!(other.get("ignore_eos"), Some(&json!(false)));
assert_eq!(other.get("continue_final_message"), Some(&json!(true)));
assert_eq!(other.get("skip_special_tokens"), Some(&json!(false)));
assert_eq!(other.get("separate_reasoning"), Some(&json!(true)));
assert_eq!(other.get("stream_reasoning"), Some(&json!(false)));
assert_eq!(other.get("return_hidden_states"), Some(&json!(true)));
}
#[test]
fn test_completion_request_sglang_extensions_passed_through() {
let mut session_params = std::collections::HashMap::new();
session_params.insert(
"key".to_string(),
serde_json::Value::String("value".to_string()),
);
let req = CompletionRequest {
prompt: StringOrArray::String("Test prompt".to_string()),
model: "test-model".to_string(),
// SGLang Extensions - Priority 1
top_k: Some(40),
min_p: Some(0.05),
min_tokens: Some(10),
repetition_penalty: Some(1.1),
regex: Some("test_regex".to_string()),
ebnf: Some("test_ebnf".to_string()),
stop_token_ids: Some(vec![1, 2, 3]),
// SGLang Extensions - Priority 2
lora_path: Some(LoRAPath::Single(Some("test_lora.bin".to_string()))),
session_params: Some(session_params.clone()),
// Boolean extensions (CompletionRequest only has these 4 boolean fields)
no_stop_trim: true,
ignore_eos: false,
skip_special_tokens: false,
return_hidden_states: true,
..default_completion_request()
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Verify SGLang extensions are passed through
assert_eq!(other.get("top_k"), Some(&json!(40)));
assert!((other.get("min_p").unwrap().as_f64().unwrap() - 0.05).abs() < 0.0001);
assert_eq!(other.get("min_tokens"), Some(&json!(10)));
assert!((other.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.1).abs() < 0.0001);
assert_eq!(other.get("regex"), Some(&json!("test_regex")));
assert_eq!(other.get("ebnf"), Some(&json!("test_ebnf")));
assert_eq!(other.get("stop_token_ids"), Some(&json!(vec![1, 2, 3])));
assert_eq!(other.get("lora_path"), Some(&json!("test_lora.bin")));
assert_eq!(
other.get("session_params"),
Some(&serde_json::to_value(&session_params).unwrap())
);
// Verify boolean extensions (only the ones CompletionRequest has)
assert_eq!(other.get("no_stop_trim"), Some(&json!(true)));
assert_eq!(other.get("ignore_eos"), Some(&json!(false)));
assert_eq!(other.get("skip_special_tokens"), Some(&json!(false)));
assert_eq!(other.get("return_hidden_states"), Some(&json!(true)));
}
#[test]
fn test_sglang_extensions_none_values_not_passed_through() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Test".to_string()),
name: None,
}];
let req = ChatCompletionRequest {
messages,
model: "test-model".to_string(),
// All SGLang extensions as None/default - Optional fields won't appear, bools will use defaults
top_k: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
stop_token_ids: None,
lora_path: None,
session_params: None,
// Boolean fields use defaults (false for most, true for some with default_true)
no_stop_trim: false,
ignore_eos: false,
continue_final_message: false,
skip_special_tokens: true, // This has default_true
separate_reasoning: true, // This has default_true
stream_reasoning: true, // This has default_true
return_hidden_states: false,
..default_chat_completion_request()
};
let pd_req = req.to_pd_request();
let other = pd_req.other.as_object().unwrap();
// Verify None values are not included
assert!(!other.contains_key("top_k"));
assert!(!other.contains_key("min_p"));
assert!(!other.contains_key("min_tokens"));
assert!(!other.contains_key("repetition_penalty"));
assert!(!other.contains_key("regex"));
assert!(!other.contains_key("ebnf"));
assert!(!other.contains_key("stop_token_ids"));
assert!(!other.contains_key("lora_path"));
assert!(!other.contains_key("session_params"));
// Boolean fields are always present with their values (can't be None)
assert_eq!(other.get("no_stop_trim"), Some(&json!(false)));
assert_eq!(other.get("ignore_eos"), Some(&json!(false)));
assert_eq!(other.get("continue_final_message"), Some(&json!(false)));
assert_eq!(other.get("skip_special_tokens"), Some(&json!(true))); // default_true
assert_eq!(other.get("separate_reasoning"), Some(&json!(true))); // default_true
assert_eq!(other.get("stream_reasoning"), Some(&json!(true))); // default_true
assert_eq!(other.get("return_hidden_states"), Some(&json!(false)));
}
}
// Integration test to ensure benchmarks compile and basic functionality works
// This prevents benchmarks from breaking in CI
//
// UPDATED: Removed deprecated ToPdRequest usage, now uses direct JSON serialization
use serde_json::{from_str, to_string};
use serde_json::{from_str, to_string, to_value};
use sglang_router_rs::core::{BasicWorker, WorkerType};
use sglang_router_rs::openai_api_types::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
SamplingParams, StringOrArray, UserMessageContent,
};
use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest};
use sglang_router_rs::routers::bootstrap_injector::inject_bootstrap_fields;
/// Create a default GenerateRequest for benchmarks with minimal fields set
fn default_generate_request() -> GenerateRequest {
......@@ -114,6 +117,15 @@ fn default_completion_request() -> CompletionRequest {
}
}
fn create_test_worker() -> BasicWorker {
BasicWorker::new(
"http://test-server:8000".to_string(),
WorkerType::Prefill {
bootstrap_port: Some(5678),
},
)
}
#[test]
fn test_benchmark_request_creation() {
// Ensure all benchmark request types can be created without panicking
......@@ -197,8 +209,8 @@ fn test_benchmark_serialization_roundtrip() {
}
#[test]
fn test_benchmark_request_adaptation() {
// Test that PD request adaptation works for benchmark types
fn test_benchmark_bootstrap_injection() {
// Test that bootstrap injection works for benchmark types (replaces PD request adaptation)
let generate_req = GenerateRequest {
text: Some("Test prompt".to_string()),
......@@ -236,24 +248,40 @@ fn test_benchmark_request_adaptation() {
..default_completion_request()
};
// Test PD adaptation (should not panic)
let _pd_generate = generate_req.to_pd_request();
let _pd_chat = chat_req.to_pd_request();
let _pd_completion = completion_req.to_pd_request();
let worker = create_test_worker();
// Test bootstrap injection (should not panic)
let mut generate_json = to_value(&generate_req).unwrap();
let mut chat_json = to_value(&chat_req).unwrap();
let mut completion_json = to_value(&completion_req).unwrap();
assert!(inject_bootstrap_fields(&mut generate_json, &worker).is_ok());
assert!(inject_bootstrap_fields(&mut chat_json, &worker).is_ok());
assert!(inject_bootstrap_fields(&mut completion_json, &worker).is_ok());
// Verify bootstrap fields were added
assert!(generate_json.get("bootstrap_host").is_some());
assert!(generate_json.get("bootstrap_port").is_some());
assert!(generate_json.get("bootstrap_room").is_some());
}
#[test]
fn test_benchmark_regular_routing() {
// Test regular routing functionality for benchmark types
fn test_benchmark_direct_json_routing() {
// Test direct JSON routing functionality for benchmark types (replaces regular routing)
let generate_req = GenerateRequest {
text: Some("Test prompt".to_string()),
..default_generate_request()
};
// Test regular routing methods (should not panic)
let _json = generate_req.to_json();
let _bytes = generate_req.to_bytes();
// Test direct JSON conversion (replaces regular routing methods)
let json = to_value(&generate_req).unwrap();
let json_string = to_string(&json).unwrap();
let bytes = json_string.as_bytes();
// Verify conversions work
assert!(!json_string.is_empty());
assert!(!bytes.is_empty());
}
#[test]
......@@ -266,23 +294,36 @@ fn test_benchmark_performance_baseline() {
..default_generate_request()
};
// Serialization should be fast (< 1ms for simple requests)
// Test the actual simplified pipeline: to_value + bootstrap injection
let start = Instant::now();
let _json = to_string(&generate_req).unwrap();
let serialize_duration = start.elapsed();
let worker = create_test_worker();
// This mirrors the actual router pipeline
let mut json = to_value(&generate_req).unwrap();
let _ = inject_bootstrap_fields(&mut json, &worker);
let total_duration = start.elapsed();
assert!(
serialize_duration.as_millis() < 1,
"Serialization took too long: {:?}",
serialize_duration
total_duration.as_millis() < 5,
"Simplified pipeline took too long: {:?} (should be faster than old adapter approach)",
total_duration
);
// PD adaptation should be very fast (< 1ms)
// Individual components should also be fast
let start = Instant::now();
let _pd_req = generate_req.to_pd_request();
let adapt_duration = start.elapsed();
let _json = to_value(&generate_req).unwrap();
let to_value_duration = start.elapsed();
let start = Instant::now();
let mut json = to_value(&generate_req).unwrap();
let _ = inject_bootstrap_fields(&mut json, &worker);
let inject_duration = start.elapsed();
// Bootstrap injection should be faster than the JSON conversion
assert!(
adapt_duration.as_millis() < 1,
"PD adaptation took too long: {:?}",
adapt_duration
inject_duration <= to_value_duration * 3,
"Bootstrap injection ({:?}) should not be much slower than JSON conversion ({:?})",
inject_duration,
to_value_duration
);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment