"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d76bc437205a865435a529d0c791d6eba729cde6"
Unverified Commit 828a4fe9 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] Implement HTTP Dependency Injection Pattern for Router System (#8714)

parent 8ada1ab6
//! Factory for creating router instances //! Factory for creating router instances
use super::{pd_router::PDRouter, router::Router, RouterTrait}; use super::{pd_router::PDRouter, router::Router, RouterTrait};
use crate::config::{PolicyConfig, RouterConfig, RoutingMode}; use crate::config::{PolicyConfig, RoutingMode};
use crate::policies::PolicyFactory; use crate::policies::PolicyFactory;
use crate::server::AppContext;
use std::sync::Arc;
/// Factory for creating router instances based on configuration /// Factory for creating router instances based on configuration
pub struct RouterFactory; pub struct RouterFactory;
impl RouterFactory { impl RouterFactory {
/// Create a router instance from configuration /// Create a router instance from application context
pub fn create_router(config: &RouterConfig) -> Result<Box<dyn RouterTrait>, String> { pub fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
match &config.mode { match &ctx.router_config.mode {
RoutingMode::Regular { worker_urls } => { RoutingMode::Regular { worker_urls } => {
Self::create_regular_router(worker_urls, &config.policy, config) Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx)
} }
RoutingMode::PrefillDecode { RoutingMode::PrefillDecode {
prefill_urls, prefill_urls,
...@@ -24,8 +26,8 @@ impl RouterFactory { ...@@ -24,8 +26,8 @@ impl RouterFactory {
decode_urls, decode_urls,
prefill_policy.as_ref(), prefill_policy.as_ref(),
decode_policy.as_ref(), decode_policy.as_ref(),
&config.policy, &ctx.router_config.policy,
config, ctx,
), ),
} }
} }
...@@ -34,19 +36,20 @@ impl RouterFactory { ...@@ -34,19 +36,20 @@ impl RouterFactory {
fn create_regular_router( fn create_regular_router(
worker_urls: &[String], worker_urls: &[String],
policy_config: &PolicyConfig, policy_config: &PolicyConfig,
router_config: &RouterConfig, ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> { ) -> Result<Box<dyn RouterTrait>, String> {
// Create policy // Create policy
let policy = PolicyFactory::create_from_config(policy_config); let policy = PolicyFactory::create_from_config(policy_config);
// Create regular router with injected policy // Create regular router with injected policy and client
let router = Router::new( let router = Router::new(
worker_urls.to_vec(), worker_urls.to_vec(),
policy, policy,
router_config.worker_startup_timeout_secs, ctx.client.clone(),
router_config.worker_startup_check_interval_secs, ctx.router_config.worker_startup_timeout_secs,
router_config.dp_aware, ctx.router_config.worker_startup_check_interval_secs,
router_config.api_key.clone(), ctx.router_config.dp_aware,
ctx.router_config.api_key.clone(),
)?; )?;
Ok(Box::new(router)) Ok(Box::new(router))
...@@ -59,7 +62,7 @@ impl RouterFactory { ...@@ -59,7 +62,7 @@ impl RouterFactory {
prefill_policy_config: Option<&PolicyConfig>, prefill_policy_config: Option<&PolicyConfig>,
decode_policy_config: Option<&PolicyConfig>, decode_policy_config: Option<&PolicyConfig>,
main_policy_config: &PolicyConfig, main_policy_config: &PolicyConfig,
router_config: &RouterConfig, ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> { ) -> Result<Box<dyn RouterTrait>, String> {
// Create policies - use specific policies if provided, otherwise fall back to main policy // Create policies - use specific policies if provided, otherwise fall back to main policy
let prefill_policy = let prefill_policy =
...@@ -67,14 +70,15 @@ impl RouterFactory { ...@@ -67,14 +70,15 @@ impl RouterFactory {
let decode_policy = let decode_policy =
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config)); PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
// Create PD router with separate policies // Create PD router with separate policies and client
let router = PDRouter::new( let router = PDRouter::new(
prefill_urls.to_vec(), prefill_urls.to_vec(),
decode_urls.to_vec(), decode_urls.to_vec(),
prefill_policy, prefill_policy,
decode_policy, decode_policy,
router_config.worker_startup_timeout_secs, ctx.client.clone(),
router_config.worker_startup_check_interval_secs, ctx.router_config.worker_startup_timeout_secs,
ctx.router_config.worker_startup_check_interval_secs,
)?; )?;
Ok(Box::new(router)) Ok(Box::new(router))
......
...@@ -7,7 +7,6 @@ use axum::{ ...@@ -7,7 +7,6 @@ use axum::{
http::{HeaderMap, StatusCode}, http::{HeaderMap, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use reqwest::Client;
use std::fmt::Debug; use std::fmt::Debug;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
...@@ -46,32 +45,27 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { ...@@ -46,32 +45,27 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
fn as_any(&self) -> &dyn std::any::Any; fn as_any(&self) -> &dyn std::any::Any;
/// Route a health check request /// Route a health check request
async fn health(&self, client: &Client, req: Request<Body>) -> Response; async fn health(&self, req: Request<Body>) -> Response;
/// Route a health generate request /// Route a health generate request
async fn health_generate(&self, client: &Client, req: Request<Body>) -> Response; async fn health_generate(&self, req: Request<Body>) -> Response;
/// Get server information /// Get server information
async fn get_server_info(&self, client: &Client, req: Request<Body>) -> Response; async fn get_server_info(&self, req: Request<Body>) -> Response;
/// Get available models /// Get available models
async fn get_models(&self, client: &Client, req: Request<Body>) -> Response; async fn get_models(&self, req: Request<Body>) -> Response;
/// Get model information /// Get model information
async fn get_model_info(&self, client: &Client, req: Request<Body>) -> Response; async fn get_model_info(&self, req: Request<Body>) -> Response;
/// Route a generate request /// Route a generate request
async fn route_generate( async fn route_generate(&self, headers: Option<&HeaderMap>, body: &GenerateRequest)
&self, -> Response;
client: &Client,
headers: Option<&HeaderMap>,
body: &GenerateRequest,
) -> Response;
/// Route a chat completion request /// Route a chat completion request
async fn route_chat( async fn route_chat(
&self, &self,
client: &Client,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &ChatCompletionRequest, body: &ChatCompletionRequest,
) -> Response; ) -> Response;
...@@ -79,16 +73,15 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { ...@@ -79,16 +73,15 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
/// Route a completion request /// Route a completion request
async fn route_completion( async fn route_completion(
&self, &self,
client: &Client,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &CompletionRequest, body: &CompletionRequest,
) -> Response; ) -> Response;
/// Flush cache on all workers /// Flush cache on all workers
async fn flush_cache(&self, client: &Client) -> Response; async fn flush_cache(&self) -> Response;
/// Get worker loads (for monitoring) /// Get worker loads (for monitoring)
async fn get_worker_loads(&self, client: &Client) -> Response; async fn get_worker_loads(&self) -> Response;
/// Get router type name /// Get router type name
fn router_type(&self) -> &'static str; fn router_type(&self) -> &'static str;
......
...@@ -35,7 +35,7 @@ pub struct PDRouter { ...@@ -35,7 +35,7 @@ pub struct PDRouter {
pub interval_secs: u64, pub interval_secs: u64,
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>, pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>, pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
pub http_client: Client, pub client: Client,
_prefill_health_checker: Option<HealthChecker>, _prefill_health_checker: Option<HealthChecker>,
_decode_health_checker: Option<HealthChecker>, _decode_health_checker: Option<HealthChecker>,
} }
...@@ -177,6 +177,7 @@ impl PDRouter { ...@@ -177,6 +177,7 @@ impl PDRouter {
decode_urls: Vec<String>, decode_urls: Vec<String>,
prefill_policy: Arc<dyn LoadBalancingPolicy>, prefill_policy: Arc<dyn LoadBalancingPolicy>,
decode_policy: Arc<dyn LoadBalancingPolicy>, decode_policy: Arc<dyn LoadBalancingPolicy>,
client: Client,
timeout_secs: u64, timeout_secs: u64,
interval_secs: u64, interval_secs: u64,
) -> Result<Self, String> { ) -> Result<Self, String> {
...@@ -215,17 +216,11 @@ impl PDRouter { ...@@ -215,17 +216,11 @@ impl PDRouter {
let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx); let worker_loads = Arc::new(rx);
// Create a shared HTTP client for all operations
let http_client = Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
let load_monitor_handle = let load_monitor_handle =
if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" { if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" {
let monitor_urls = all_urls.clone(); let monitor_urls = all_urls.clone();
let monitor_interval = interval_secs; let monitor_interval = interval_secs;
let monitor_client = http_client.clone(); let monitor_client = client.clone();
let prefill_policy_clone = Arc::clone(&prefill_policy); let prefill_policy_clone = Arc::clone(&prefill_policy);
let decode_policy_clone = Arc::clone(&decode_policy); let decode_policy_clone = Arc::clone(&decode_policy);
...@@ -264,7 +259,7 @@ impl PDRouter { ...@@ -264,7 +259,7 @@ impl PDRouter {
interval_secs, interval_secs,
worker_loads, worker_loads,
load_monitor_handle, load_monitor_handle,
http_client, client,
_prefill_health_checker: Some(prefill_health_checker), _prefill_health_checker: Some(prefill_health_checker),
_decode_health_checker: Some(decode_health_checker), _decode_health_checker: Some(decode_health_checker),
}) })
...@@ -302,7 +297,6 @@ impl PDRouter { ...@@ -302,7 +297,6 @@ impl PDRouter {
// Route a typed generate request // Route a typed generate request
pub async fn route_generate( pub async fn route_generate(
&self, &self,
client: &Client,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
mut typed_req: GenerateReqInput, mut typed_req: GenerateReqInput,
route: &str, route: &str,
...@@ -371,7 +365,6 @@ impl PDRouter { ...@@ -371,7 +365,6 @@ impl PDRouter {
// Execute dual dispatch // Execute dual dispatch
self.execute_dual_dispatch( self.execute_dual_dispatch(
client,
headers, headers,
json_with_bootstrap, json_with_bootstrap,
route, route,
...@@ -387,7 +380,6 @@ impl PDRouter { ...@@ -387,7 +380,6 @@ impl PDRouter {
// Route a typed chat request // Route a typed chat request
pub async fn route_chat( pub async fn route_chat(
&self, &self,
client: &Client,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
mut typed_req: ChatReqInput, mut typed_req: ChatReqInput,
route: &str, route: &str,
...@@ -459,7 +451,6 @@ impl PDRouter { ...@@ -459,7 +451,6 @@ impl PDRouter {
// Execute dual dispatch // Execute dual dispatch
self.execute_dual_dispatch( self.execute_dual_dispatch(
client,
headers, headers,
json_with_bootstrap, json_with_bootstrap,
route, route,
...@@ -475,7 +466,6 @@ impl PDRouter { ...@@ -475,7 +466,6 @@ impl PDRouter {
// Route a completion request while preserving OpenAI format // Route a completion request while preserving OpenAI format
pub async fn route_completion( pub async fn route_completion(
&self, &self,
client: &Client,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
mut typed_req: CompletionRequest, mut typed_req: CompletionRequest,
route: &str, route: &str,
...@@ -540,7 +530,6 @@ impl PDRouter { ...@@ -540,7 +530,6 @@ impl PDRouter {
// Execute dual dispatch // Execute dual dispatch
self.execute_dual_dispatch( self.execute_dual_dispatch(
client,
headers, headers,
json_with_bootstrap, json_with_bootstrap,
route, route,
...@@ -554,10 +543,8 @@ impl PDRouter { ...@@ -554,10 +543,8 @@ impl PDRouter {
} }
// Execute the dual dispatch to prefill and decode servers // Execute the dual dispatch to prefill and decode servers
#[allow(clippy::too_many_arguments)]
async fn execute_dual_dispatch( async fn execute_dual_dispatch(
&self, &self,
client: &Client,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
json_request: Value, json_request: Value,
route: &str, route: &str,
...@@ -571,11 +558,13 @@ impl PDRouter { ...@@ -571,11 +558,13 @@ impl PDRouter {
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]); let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
// Build requests using .json() method // Build requests using .json() method
let mut prefill_request = client let mut prefill_request = self
.client
.post(api_path(prefill.url(), route)) .post(api_path(prefill.url(), route))
.json(&json_request); .json(&json_request);
let mut decode_request = client let mut decode_request = self
.client
.post(api_path(decode.url(), route)) .post(api_path(decode.url(), route))
.json(&json_request); .json(&json_request);
...@@ -987,7 +976,7 @@ async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option<i ...@@ -987,7 +976,7 @@ async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option<i
// PD-specific endpoints // PD-specific endpoints
impl PDRouter { impl PDRouter {
pub async fn health_generate(&self, client: &reqwest::Client) -> Response { pub async fn health_generate(&self) -> Response {
// Test model generation capability by selecting a random pair and testing them // Test model generation capability by selecting a random pair and testing them
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair // Note: This endpoint actually causes the model to generate tokens, so we only test one pair
...@@ -1005,11 +994,11 @@ impl PDRouter { ...@@ -1005,11 +994,11 @@ impl PDRouter {
// Test prefill server's health_generate // Test prefill server's health_generate
let prefill_url = format!("{}/health_generate", prefill.url()); let prefill_url = format!("{}/health_generate", prefill.url());
let prefill_result = client.get(&prefill_url).send().await; let prefill_result = self.client.get(&prefill_url).send().await;
// Test decode server's health_generate // Test decode server's health_generate
let decode_url = format!("{}/health_generate", decode.url()); let decode_url = format!("{}/health_generate", decode.url());
let decode_result = client.get(&decode_url).send().await; let decode_result = self.client.get(&decode_url).send().await;
// Check results // Check results
let mut errors = Vec::new(); let mut errors = Vec::new();
...@@ -1068,7 +1057,7 @@ impl PDRouter { ...@@ -1068,7 +1057,7 @@ impl PDRouter {
} }
} }
pub async fn get_server_info(&self, client: &reqwest::Client) -> Response { pub async fn get_server_info(&self) -> Response {
// Get info from the first decode server to match sglang's server info format // Get info from the first decode server to match sglang's server info format
let first_decode_url = if let Ok(workers) = self.decode_workers.read() { let first_decode_url = if let Ok(workers) = self.decode_workers.read() {
workers.first().map(|w| w.url().to_string()) workers.first().map(|w| w.url().to_string())
...@@ -1081,7 +1070,8 @@ impl PDRouter { ...@@ -1081,7 +1070,8 @@ impl PDRouter {
}; };
if let Some(worker_url) = first_decode_url { if let Some(worker_url) = first_decode_url {
match client match self
.client
.get(format!("{}/get_server_info", worker_url)) .get(format!("{}/get_server_info", worker_url))
.send() .send()
.await .await
...@@ -1130,7 +1120,7 @@ impl PDRouter { ...@@ -1130,7 +1120,7 @@ impl PDRouter {
} }
} }
pub async fn get_models(&self, client: &reqwest::Client, req: Request<Body>) -> Response { pub async fn get_models(&self, req: Request<Body>) -> Response {
// Extract headers first to avoid Send issues // Extract headers first to avoid Send issues
let headers = crate::routers::router::copy_request_headers(&req); let headers = crate::routers::router::copy_request_headers(&req);
...@@ -1147,7 +1137,7 @@ impl PDRouter { ...@@ -1147,7 +1137,7 @@ impl PDRouter {
if let Some(worker_url) = first_worker_url { if let Some(worker_url) = first_worker_url {
// Send request directly without going through Router // Send request directly without going through Router
let mut request_builder = client.get(format!("{}/v1/models", worker_url)); let mut request_builder = self.client.get(format!("{}/v1/models", worker_url));
for (name, value) in headers { for (name, value) in headers {
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
{ {
...@@ -1224,7 +1214,7 @@ impl PDRouter { ...@@ -1224,7 +1214,7 @@ impl PDRouter {
.into_response() .into_response()
} }
pub async fn get_model_info(&self, client: &reqwest::Client, req: Request<Body>) -> Response { pub async fn get_model_info(&self, req: Request<Body>) -> Response {
// Extract headers first to avoid Send issues // Extract headers first to avoid Send issues
let headers = crate::routers::router::copy_request_headers(&req); let headers = crate::routers::router::copy_request_headers(&req);
...@@ -1241,7 +1231,7 @@ impl PDRouter { ...@@ -1241,7 +1231,7 @@ impl PDRouter {
}; };
if let Some(worker_url) = first_worker_url { if let Some(worker_url) = first_worker_url {
let mut request_builder = client.get(format!("{}/get_model_info", worker_url)); let mut request_builder = self.client.get(format!("{}/get_model_info", worker_url));
for (name, value) in headers { for (name, value) in headers {
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
{ {
...@@ -1384,7 +1374,7 @@ impl RouterTrait for PDRouter { ...@@ -1384,7 +1374,7 @@ impl RouterTrait for PDRouter {
self self
} }
async fn health(&self, _client: &Client, _req: Request<Body>) -> Response { async fn health(&self, _req: Request<Body>) -> Response {
// This is a server readiness check - checking if we have healthy workers // This is a server readiness check - checking if we have healthy workers
// Workers handle their own health checks in the background // Workers handle their own health checks in the background
let mut all_healthy = true; let mut all_healthy = true;
...@@ -1417,68 +1407,65 @@ impl RouterTrait for PDRouter { ...@@ -1417,68 +1407,65 @@ impl RouterTrait for PDRouter {
} }
} }
async fn health_generate(&self, client: &Client, _req: Request<Body>) -> Response { async fn health_generate(&self, _req: Request<Body>) -> Response {
// Use the existing PDRouter health_generate method // Use the existing PDRouter health_generate method
PDRouter::health_generate(self, client).await PDRouter::health_generate(self).await
} }
async fn get_server_info(&self, client: &Client, _req: Request<Body>) -> Response { async fn get_server_info(&self, _req: Request<Body>) -> Response {
// Use the existing PDRouter get_server_info method // Use the existing PDRouter get_server_info method
PDRouter::get_server_info(self, client).await PDRouter::get_server_info(self).await
} }
async fn get_models(&self, client: &Client, req: Request<Body>) -> Response { async fn get_models(&self, req: Request<Body>) -> Response {
// Use the existing PDRouter get_models method // Use the existing PDRouter get_models method
PDRouter::get_models(self, client, req).await PDRouter::get_models(self, req).await
} }
async fn get_model_info(&self, client: &Client, req: Request<Body>) -> Response { async fn get_model_info(&self, req: Request<Body>) -> Response {
// Use the existing PDRouter get_model_info method // Use the existing PDRouter get_model_info method
PDRouter::get_model_info(self, client, req).await PDRouter::get_model_info(self, req).await
} }
async fn route_generate( async fn route_generate(
&self, &self,
client: &Client,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &GenerateRequest, body: &GenerateRequest,
) -> Response { ) -> Response {
// Convert OpenAI format to PD format // Convert OpenAI format to PD format
let pd_req = body.clone().to_pd_request(); let pd_req = body.clone().to_pd_request();
PDRouter::route_generate(self, client, headers, pd_req, "/generate").await PDRouter::route_generate(self, headers, pd_req, "/generate").await
} }
async fn route_chat( async fn route_chat(
&self, &self,
client: &Client,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &ChatCompletionRequest, body: &ChatCompletionRequest,
) -> Response { ) -> Response {
// Convert OpenAI format to PD format // Convert OpenAI format to PD format
let pd_req = body.clone().to_pd_request(); let pd_req = body.clone().to_pd_request();
PDRouter::route_chat(self, client, headers, pd_req, "/v1/chat/completions").await PDRouter::route_chat(self, headers, pd_req, "/v1/chat/completions").await
} }
async fn route_completion( async fn route_completion(
&self, &self,
client: &Client,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &CompletionRequest, body: &CompletionRequest,
) -> Response { ) -> Response {
// Use the new method that preserves OpenAI format // Use the new method that preserves OpenAI format
PDRouter::route_completion(self, client, headers, body.clone(), "/v1/completions").await PDRouter::route_completion(self, headers, body.clone(), "/v1/completions").await
} }
async fn flush_cache(&self, client: &Client) -> Response { async fn flush_cache(&self) -> Response {
// Use the existing PDRouter flush_cache method // Use the existing PDRouter flush_cache method
PDRouter::flush_cache(self, client).await PDRouter::flush_cache(self, &self.client).await
} }
async fn get_worker_loads(&self, client: &Client) -> Response { async fn get_worker_loads(&self) -> Response {
// Use the existing PDRouter get_loads method // Use the existing PDRouter get_loads method
PDRouter::get_loads(self, client).await PDRouter::get_loads(self, &self.client).await
} }
fn router_type(&self) -> &'static str { fn router_type(&self) -> &'static str {
...@@ -1570,7 +1557,7 @@ mod tests { ...@@ -1570,7 +1557,7 @@ mod tests {
interval_secs: 1, interval_secs: 1,
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1), worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
load_monitor_handle: None, load_monitor_handle: None,
http_client: reqwest::Client::new(), client: Client::new(),
_prefill_health_checker: None, _prefill_health_checker: None,
_decode_health_checker: None, _decode_health_checker: None,
} }
...@@ -1959,11 +1946,10 @@ mod tests { ...@@ -1959,11 +1946,10 @@ mod tests {
router.decode_workers.write().unwrap().push(decode_worker); router.decode_workers.write().unwrap().push(decode_worker);
// Test health endpoint // Test health endpoint
let client = reqwest::Client::new();
let http_req = axum::http::Request::builder() let http_req = axum::http::Request::builder()
.body(axum::body::Body::empty()) .body(axum::body::Body::empty())
.unwrap(); .unwrap();
let response = router.health(&client, http_req).await; let response = router.health(http_req).await;
assert_eq!(response.status(), 200); assert_eq!(response.status(), 200);
......
...@@ -34,6 +34,7 @@ pub fn copy_request_headers(req: &Request<Body>) -> Vec<(String, String)> { ...@@ -34,6 +34,7 @@ pub fn copy_request_headers(req: &Request<Body>) -> Vec<(String, String)> {
pub struct Router { pub struct Router {
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>, workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
policy: Arc<dyn LoadBalancingPolicy>, policy: Arc<dyn LoadBalancingPolicy>,
client: Client,
timeout_secs: u64, timeout_secs: u64,
interval_secs: u64, interval_secs: u64,
dp_aware: bool, dp_aware: bool,
...@@ -44,10 +45,11 @@ pub struct Router { ...@@ -44,10 +45,11 @@ pub struct Router {
} }
impl Router { impl Router {
/// Create a new router with injected policy /// Create a new router with injected policy and client
pub fn new( pub fn new(
worker_urls: Vec<String>, worker_urls: Vec<String>,
policy: Arc<dyn LoadBalancingPolicy>, policy: Arc<dyn LoadBalancingPolicy>,
client: Client,
timeout_secs: u64, timeout_secs: u64,
interval_secs: u64, interval_secs: u64,
dp_aware: bool, dp_aware: bool,
...@@ -94,9 +96,17 @@ impl Router { ...@@ -94,9 +96,17 @@ impl Router {
let monitor_urls = worker_urls.clone(); let monitor_urls = worker_urls.clone();
let monitor_interval = interval_secs; let monitor_interval = interval_secs;
let policy_clone = Arc::clone(&policy); let policy_clone = Arc::clone(&policy);
let client_clone = client.clone();
Some(Arc::new(tokio::spawn(async move { Some(Arc::new(tokio::spawn(async move {
Self::monitor_worker_loads(monitor_urls, tx, monitor_interval, policy_clone).await; Self::monitor_worker_loads(
monitor_urls,
tx,
monitor_interval,
policy_clone,
client_clone,
)
.await;
}))) })))
} else { } else {
None None
...@@ -105,6 +115,7 @@ impl Router { ...@@ -105,6 +115,7 @@ impl Router {
Ok(Router { Ok(Router {
workers, workers,
policy, policy,
client,
timeout_secs, timeout_secs,
interval_secs, interval_secs,
dp_aware, dp_aware,
...@@ -245,7 +256,7 @@ impl Router { ...@@ -245,7 +256,7 @@ impl Router {
} }
} }
pub async fn send_health_check(&self, client: &Client, worker_url: &str) -> Response { pub async fn send_health_check(&self, worker_url: &str) -> Response {
let health_url = if self.dp_aware { let health_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank" // Need to extract the URL from "http://host:port@dp_rank"
match Self::extract_dp_rank(worker_url) { match Self::extract_dp_rank(worker_url) {
...@@ -263,7 +274,7 @@ impl Router { ...@@ -263,7 +274,7 @@ impl Router {
worker_url worker_url
}; };
let request_builder = client.get(format!("{}/health", health_url)); let request_builder = self.client.get(format!("{}/health", health_url));
let response = match request_builder.send().await { let response = match request_builder.send().await {
Ok(res) => { Ok(res) => {
...@@ -305,17 +316,12 @@ impl Router { ...@@ -305,17 +316,12 @@ impl Router {
} }
// Helper method to proxy GET requests to the first available worker // Helper method to proxy GET requests to the first available worker
async fn proxy_get_request( async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
&self,
client: &Client,
req: Request<Body>,
endpoint: &str,
) -> Response {
let headers = copy_request_headers(&req); let headers = copy_request_headers(&req);
match self.select_first_worker() { match self.select_first_worker() {
Ok(worker_url) => { Ok(worker_url) => {
let mut request_builder = client.get(format!("{}/{}", worker_url, endpoint)); let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint));
for (name, value) in headers { for (name, value) in headers {
if name.to_lowercase() != "content-type" if name.to_lowercase() != "content-type"
&& name.to_lowercase() != "content-length" && name.to_lowercase() != "content-length"
...@@ -353,7 +359,6 @@ impl Router { ...@@ -353,7 +359,6 @@ impl Router {
T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone, T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone,
>( >(
&self, &self,
client: &reqwest::Client,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
typed_req: &T, typed_req: &T,
route: &str, route: &str,
...@@ -397,7 +402,6 @@ impl Router { ...@@ -397,7 +402,6 @@ impl Router {
// Send typed request directly // Send typed request directly
let response = self let response = self
.send_typed_request( .send_typed_request(
client,
headers, headers,
typed_req, typed_req,
route, route,
...@@ -413,7 +417,7 @@ impl Router { ...@@ -413,7 +417,7 @@ impl Router {
return response; return response;
} else { } else {
// if the worker is healthy, it means the request is bad, so return the error response // if the worker is healthy, it means the request is bad, so return the error response
let health_response = self.send_health_check(client, &worker_url).await; let health_response = self.send_health_check(&worker_url).await;
if health_response.status().is_success() { if health_response.status().is_success() {
RouterMetrics::record_request_error(route, "request_failed"); RouterMetrics::record_request_error(route, "request_failed");
return response; return response;
...@@ -483,7 +487,6 @@ impl Router { ...@@ -483,7 +487,6 @@ impl Router {
// Send typed request directly without conversion // Send typed request directly without conversion
async fn send_typed_request<T: serde::Serialize>( async fn send_typed_request<T: serde::Serialize>(
&self, &self,
client: &reqwest::Client,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
typed_req: &T, typed_req: &T,
route: &str, route: &str,
...@@ -536,11 +539,11 @@ impl Router { ...@@ -536,11 +539,11 @@ impl Router {
.into_response(); .into_response();
} }
client self.client
.post(format!("{}{}", worker_url_prefix, route)) .post(format!("{}{}", worker_url_prefix, route))
.json(&json_val) .json(&json_val)
} else { } else {
client self.client
.post(format!("{}{}", worker_url, route)) .post(format!("{}{}", worker_url, route))
.json(typed_req) // Use json() directly with typed request .json(typed_req) // Use json() directly with typed request
}; };
...@@ -866,7 +869,7 @@ impl Router { ...@@ -866,7 +869,7 @@ impl Router {
} }
} }
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> { async fn get_worker_load(&self, worker_url: &str) -> Option<isize> {
let worker_url = if self.dp_aware { let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank" // Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) { let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
...@@ -881,7 +884,12 @@ impl Router { ...@@ -881,7 +884,12 @@ impl Router {
worker_url worker_url
}; };
match client.get(&format!("{}/get_load", worker_url)).send().await { match self
.client
.get(&format!("{}/get_load", worker_url))
.send()
.await
{
Ok(res) if res.status().is_success() => match res.bytes().await { Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) { Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
Ok(data) => data Ok(data) => data
...@@ -919,18 +927,8 @@ impl Router { ...@@ -919,18 +927,8 @@ impl Router {
tx: tokio::sync::watch::Sender<HashMap<String, isize>>, tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
interval_secs: u64, interval_secs: u64,
policy: Arc<dyn LoadBalancingPolicy>, policy: Arc<dyn LoadBalancingPolicy>,
client: Client,
) { ) {
let client = match reqwest::Client::builder()
.timeout(Duration::from_secs(5))
.build()
{
Ok(c) => c,
Err(e) => {
error!("Failed to create HTTP client for load monitoring: {}", e);
return;
}
};
let mut interval = tokio::time::interval(Duration::from_secs(interval_secs)); let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
loop { loop {
...@@ -1028,7 +1026,7 @@ impl RouterTrait for Router { ...@@ -1028,7 +1026,7 @@ impl RouterTrait for Router {
self self
} }
async fn health(&self, _client: &Client, _req: Request<Body>) -> Response { async fn health(&self, _req: Request<Body>) -> Response {
let workers = self.workers.read().unwrap(); let workers = self.workers.read().unwrap();
let unhealthy_servers: Vec<_> = workers let unhealthy_servers: Vec<_> = workers
.iter() .iter()
...@@ -1047,53 +1045,49 @@ impl RouterTrait for Router { ...@@ -1047,53 +1045,49 @@ impl RouterTrait for Router {
} }
} }
async fn health_generate(&self, client: &Client, req: Request<Body>) -> Response { async fn health_generate(&self, req: Request<Body>) -> Response {
self.proxy_get_request(client, req, "health_generate").await self.proxy_get_request(req, "health_generate").await
} }
async fn get_server_info(&self, client: &Client, req: Request<Body>) -> Response { async fn get_server_info(&self, req: Request<Body>) -> Response {
self.proxy_get_request(client, req, "get_server_info").await self.proxy_get_request(req, "get_server_info").await
} }
async fn get_models(&self, client: &Client, req: Request<Body>) -> Response { async fn get_models(&self, req: Request<Body>) -> Response {
self.proxy_get_request(client, req, "v1/models").await self.proxy_get_request(req, "v1/models").await
} }
async fn get_model_info(&self, client: &Client, req: Request<Body>) -> Response { async fn get_model_info(&self, req: Request<Body>) -> Response {
self.proxy_get_request(client, req, "get_model_info").await self.proxy_get_request(req, "get_model_info").await
} }
async fn route_generate( async fn route_generate(
&self, &self,
client: &Client,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &GenerateRequest, body: &GenerateRequest,
) -> Response { ) -> Response {
self.route_typed_request(client, headers, body, "/generate") self.route_typed_request(headers, body, "/generate").await
.await
} }
async fn route_chat( async fn route_chat(
&self, &self,
client: &Client,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &ChatCompletionRequest, body: &ChatCompletionRequest,
) -> Response { ) -> Response {
self.route_typed_request(client, headers, body, "/v1/chat/completions") self.route_typed_request(headers, body, "/v1/chat/completions")
.await .await
} }
async fn route_completion( async fn route_completion(
&self, &self,
client: &Client,
headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &CompletionRequest, body: &CompletionRequest,
) -> Response { ) -> Response {
self.route_typed_request(client, headers, body, "/v1/completions") self.route_typed_request(headers, body, "/v1/completions")
.await .await
} }
async fn flush_cache(&self, client: &Client) -> Response { async fn flush_cache(&self) -> Response {
// Get all worker URLs // Get all worker URLs
let worker_urls = self.get_worker_urls(); let worker_urls = self.get_worker_urls();
...@@ -1117,7 +1111,7 @@ impl RouterTrait for Router { ...@@ -1117,7 +1111,7 @@ impl RouterTrait for Router {
} else { } else {
worker_url worker_url
}; };
let request_builder = client.post(format!("{}/flush_cache", worker_url)); let request_builder = self.client.post(format!("{}/flush_cache", worker_url));
tasks.push(request_builder.send()); tasks.push(request_builder.send());
} }
...@@ -1142,13 +1136,13 @@ impl RouterTrait for Router { ...@@ -1142,13 +1136,13 @@ impl RouterTrait for Router {
} }
} }
async fn get_worker_loads(&self, client: &Client) -> Response { async fn get_worker_loads(&self) -> Response {
let urls = self.get_worker_urls(); let urls = self.get_worker_urls();
let mut loads = Vec::new(); let mut loads = Vec::new();
// Get loads from all workers // Get loads from all workers
for url in &urls { for url in &urls {
let load = self.get_worker_load(client, url).await.unwrap_or(-1); let load = self.get_worker_load(url).await.unwrap_or(-1);
loads.push(serde_json::json!({ loads.push(serde_json::json!({
"worker": url, "worker": url,
"load": load "load": load
...@@ -1215,6 +1209,7 @@ mod tests { ...@@ -1215,6 +1209,7 @@ mod tests {
interval_secs: 1, interval_secs: 1,
dp_aware: false, dp_aware: false,
api_key: None, api_key: None,
client: Client::new(),
_worker_loads: Arc::new(rx), _worker_loads: Arc::new(rx),
_load_monitor_handle: None, _load_monitor_handle: None,
_health_checker: None, _health_checker: None,
......
...@@ -22,29 +22,34 @@ use tokio::spawn; ...@@ -22,29 +22,34 @@ use tokio::spawn;
use tracing::{error, info, warn, Level}; use tracing::{error, info, warn, Level};
#[derive(Clone)] #[derive(Clone)]
pub struct AppState { pub struct AppContext {
pub router: Arc<dyn RouterTrait>,
pub client: Client, pub client: Client,
pub _concurrency_limiter: Arc<tokio::sync::Semaphore>, pub router_config: RouterConfig,
pub concurrency_limiter: Arc<tokio::sync::Semaphore>,
// Future dependencies can be added here
} }
impl AppState { impl AppContext {
pub fn new( pub fn new(
router_config: RouterConfig, router_config: RouterConfig,
client: Client, client: Client,
max_concurrent_requests: usize, max_concurrent_requests: usize,
) -> Result<Self, String> { ) -> Self {
let router = RouterFactory::create_router(&router_config)?;
let router = Arc::from(router);
let concurrency_limiter = Arc::new(tokio::sync::Semaphore::new(max_concurrent_requests)); let concurrency_limiter = Arc::new(tokio::sync::Semaphore::new(max_concurrent_requests));
Ok(Self { Self {
router,
client, client,
_concurrency_limiter: concurrency_limiter, router_config,
}) concurrency_limiter,
}
} }
} }
#[derive(Clone)]
pub struct AppState {
pub router: Arc<dyn RouterTrait>,
pub context: Arc<AppContext>,
}
// Fallback handler for unmatched routes // Fallback handler for unmatched routes
async fn sink_handler() -> Response { async fn sink_handler() -> Response {
StatusCode::NOT_FOUND.into_response() StatusCode::NOT_FOUND.into_response()
...@@ -60,23 +65,23 @@ async fn readiness(State(state): State<Arc<AppState>>) -> Response { ...@@ -60,23 +65,23 @@ async fn readiness(State(state): State<Arc<AppState>>) -> Response {
} }
async fn health(State(state): State<Arc<AppState>>, req: Request) -> Response { async fn health(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.health(&state.client, req).await state.router.health(req).await
} }
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response { async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.health_generate(&state.client, req).await state.router.health_generate(req).await
} }
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response { async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.get_server_info(&state.client, req).await state.router.get_server_info(req).await
} }
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response { async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.get_models(&state.client, req).await state.router.get_models(req).await
} }
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response { async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
state.router.get_model_info(&state.client, req).await state.router.get_model_info(req).await
} }
// Generation endpoints // Generation endpoints
...@@ -86,10 +91,7 @@ async fn generate( ...@@ -86,10 +91,7 @@ async fn generate(
headers: http::HeaderMap, headers: http::HeaderMap,
Json(body): Json<GenerateRequest>, Json(body): Json<GenerateRequest>,
) -> Response { ) -> Response {
state state.router.route_generate(Some(&headers), &body).await
.router
.route_generate(&state.client, Some(&headers), &body)
.await
} }
async fn v1_chat_completions( async fn v1_chat_completions(
...@@ -97,10 +99,7 @@ async fn v1_chat_completions( ...@@ -97,10 +99,7 @@ async fn v1_chat_completions(
headers: http::HeaderMap, headers: http::HeaderMap,
Json(body): Json<ChatCompletionRequest>, Json(body): Json<ChatCompletionRequest>,
) -> Response { ) -> Response {
state state.router.route_chat(Some(&headers), &body).await
.router
.route_chat(&state.client, Some(&headers), &body)
.await
} }
async fn v1_completions( async fn v1_completions(
...@@ -108,10 +107,7 @@ async fn v1_completions( ...@@ -108,10 +107,7 @@ async fn v1_completions(
headers: http::HeaderMap, headers: http::HeaderMap,
Json(body): Json<CompletionRequest>, Json(body): Json<CompletionRequest>,
) -> Response { ) -> Response {
state state.router.route_completion(Some(&headers), &body).await
.router
.route_completion(&state.client, Some(&headers), &body)
.await
} }
// Worker management endpoints // Worker management endpoints
...@@ -159,11 +155,11 @@ async fn remove_worker( ...@@ -159,11 +155,11 @@ async fn remove_worker(
} }
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response { async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
state.router.flush_cache(&state.client).await state.router.flush_cache().await
} }
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response { async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
state.router.get_worker_loads(&state.client).await state.router.get_worker_loads().await
} }
pub struct ServerConfig { pub struct ServerConfig {
...@@ -281,11 +277,21 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -281,11 +277,21 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
.build() .build()
.expect("Failed to create HTTP client"); .expect("Failed to create HTTP client");
let app_state = Arc::new(AppState::new( // Create the application context with all dependencies
let app_context = Arc::new(AppContext::new(
config.router_config.clone(), config.router_config.clone(),
client.clone(), client.clone(),
config.router_config.max_concurrent_requests, config.router_config.max_concurrent_requests,
)?); ));
// Create router with the context
let router = RouterFactory::create_router(&app_context)?;
// Create app state with router and context
let app_state = Arc::new(AppState {
router: Arc::from(router),
context: app_context.clone(),
});
let router_arc = Arc::clone(&app_state.router); let router_arc = Arc::clone(&app_state.router);
// Start the service discovery if enabled // Start the service discovery if enabled
......
...@@ -40,7 +40,6 @@ impl Default for ServiceDiscoveryConfig { ...@@ -40,7 +40,6 @@ impl Default for ServiceDiscoveryConfig {
check_interval: Duration::from_secs(60), check_interval: Duration::from_secs(60),
port: 8000, // Standard port for modern services port: 8000, // Standard port for modern services
namespace: None, // None means watch all namespaces namespace: None, // None means watch all namespaces
// PD mode defaults
pd_mode: false, pd_mode: false,
prefill_selector: HashMap::new(), prefill_selector: HashMap::new(),
decode_selector: HashMap::new(), decode_selector: HashMap::new(),
...@@ -581,7 +580,8 @@ mod tests { ...@@ -581,7 +580,8 @@ mod tests {
use crate::routers::router::Router; use crate::routers::router::Router;
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random); let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
let router = Router::new(vec![], policy, 5, 1, false, None).unwrap(); let router =
Router::new(vec![], policy, reqwest::Client::new(), 5, 1, false, None).unwrap();
Arc::new(router) as Arc<dyn RouterTrait> Arc::new(router) as Arc<dyn RouterTrait>
} }
......
...@@ -83,12 +83,12 @@ impl TestContext { ...@@ -83,12 +83,12 @@ impl TestContext {
.build() .build()
.unwrap(); .unwrap();
// Clone config for the closure // Create app context
let config_clone = config.clone(); let app_context = common::create_test_context(config.clone());
// Create router using sync factory in a blocking context // Create router using sync factory in a blocking context
let router = let router =
tokio::task::spawn_blocking(move || RouterFactory::create_router(&config_clone)) tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context))
.await .await
.unwrap() .unwrap()
.unwrap(); .unwrap();
...@@ -1433,9 +1433,12 @@ mod pd_mode_tests { ...@@ -1433,9 +1433,12 @@ mod pd_mode_tests {
cors_allowed_origins: vec![], cors_allowed_origins: vec![],
}; };
// Create app context
let app_context = common::create_test_context(config);
// Create router - this might fail due to health check issues // Create router - this might fail due to health check issues
let router_result = let router_result =
tokio::task::spawn_blocking(move || RouterFactory::create_router(&config)) tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context))
.await .await
.unwrap(); .unwrap();
......
pub mod mock_worker; pub mod mock_worker;
pub mod test_app; pub mod test_app;
use sglang_router_rs::config::RouterConfig;
use sglang_router_rs::server::AppContext;
use std::sync::Arc;
/// Helper function to create AppContext for tests
pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
Arc::new(AppContext::new(
config.clone(),
reqwest::Client::new(),
config.max_concurrent_requests,
))
}
...@@ -3,7 +3,7 @@ use reqwest::Client; ...@@ -3,7 +3,7 @@ use reqwest::Client;
use sglang_router_rs::{ use sglang_router_rs::{
config::RouterConfig, config::RouterConfig,
routers::RouterTrait, routers::RouterTrait,
server::{build_app, AppState}, server::{build_app, AppContext, AppState},
}; };
use std::sync::Arc; use std::sync::Arc;
...@@ -13,13 +13,17 @@ pub fn create_test_app( ...@@ -13,13 +13,17 @@ pub fn create_test_app(
client: Client, client: Client,
router_config: &RouterConfig, router_config: &RouterConfig,
) -> Router { ) -> Router {
// Create AppState with the test router // Create AppContext
let app_context = Arc::new(AppContext::new(
router_config.clone(),
client,
router_config.max_concurrent_requests,
));
// Create AppState with the test router and context
let app_state = Arc::new(AppState { let app_state = Arc::new(AppState {
router, router,
client, context: app_context,
_concurrency_limiter: Arc::new(tokio::sync::Semaphore::new(
router_config.max_concurrent_requests,
)),
}); });
// Configure request ID headers (use defaults if not specified) // Configure request ID headers (use defaults if not specified)
......
...@@ -53,10 +53,12 @@ impl TestContext { ...@@ -53,10 +53,12 @@ impl TestContext {
config.mode = RoutingMode::Regular { worker_urls }; config.mode = RoutingMode::Regular { worker_urls };
let router = tokio::task::spawn_blocking(move || RouterFactory::create_router(&config)) let app_context = common::create_test_context(config);
.await let router =
.unwrap() tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context))
.unwrap(); .await
.unwrap()
.unwrap();
let router = Arc::from(router); let router = Arc::from(router);
if !workers.is_empty() { if !workers.is_empty() {
......
...@@ -54,10 +54,12 @@ impl TestContext { ...@@ -54,10 +54,12 @@ impl TestContext {
config.mode = RoutingMode::Regular { worker_urls }; config.mode = RoutingMode::Regular { worker_urls };
let router = tokio::task::spawn_blocking(move || RouterFactory::create_router(&config)) let app_context = common::create_test_context(config);
.await let router =
.unwrap() tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context))
.unwrap(); .await
.unwrap()
.unwrap();
let router = Arc::from(router); let router = Arc::from(router);
if !workers.is_empty() { if !workers.is_empty() {
......
...@@ -181,7 +181,10 @@ mod test_pd_routing { ...@@ -181,7 +181,10 @@ mod test_pd_routing {
}; };
// Router creation will fail due to health checks, but config should be valid // Router creation will fail due to health checks, but config should be valid
let result = RouterFactory::create_router(&config); let app_context =
sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64);
let app_context = std::sync::Arc::new(app_context);
let result = RouterFactory::create_router(&app_context);
assert!(result.is_err()); assert!(result.is_err());
let error_msg = result.unwrap_err(); let error_msg = result.unwrap_err();
// Error should be about health/timeout, not configuration // Error should be about health/timeout, not configuration
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment