Unverified Commit 53430588 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] grpc router bootstraps (#9759)

parent fce7ae33
...@@ -7,7 +7,9 @@ use sglang_router_rs::protocols::spec::{ ...@@ -7,7 +7,9 @@ use sglang_router_rs::protocols::spec::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
SamplingParams, StringOrArray, UserMessageContent, SamplingParams, StringOrArray, UserMessageContent,
}; };
use sglang_router_rs::routers::pd_types::{generate_room_id, get_hostname, RequestWithBootstrap}; use sglang_router_rs::routers::http::pd_types::{
generate_room_id, get_hostname, RequestWithBootstrap,
};
fn create_test_worker() -> BasicWorker { fn create_test_worker() -> BasicWorker {
BasicWorker::new( BasicWorker::new(
......
...@@ -19,6 +19,6 @@ pub use circuit_breaker::{ ...@@ -19,6 +19,6 @@ pub use circuit_breaker::{
pub use error::{WorkerError, WorkerResult}; pub use error::{WorkerError, WorkerResult};
pub use retry::{is_retryable_status, BackoffCalculator, RetryError, RetryExecutor}; pub use retry::{is_retryable_status, BackoffCalculator, RetryError, RetryExecutor};
pub use worker::{ pub use worker::{
start_health_checker, BasicWorker, DPAwareWorker, HealthChecker, HealthConfig, Worker, start_health_checker, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig,
WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType, Worker, WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType,
}; };
...@@ -24,6 +24,9 @@ pub trait Worker: Send + Sync + fmt::Debug { ...@@ -24,6 +24,9 @@ pub trait Worker: Send + Sync + fmt::Debug {
/// Get the worker's type (Regular, Prefill, or Decode) /// Get the worker's type (Regular, Prefill, or Decode)
fn worker_type(&self) -> WorkerType; fn worker_type(&self) -> WorkerType;
/// Get the worker's connection mode (HTTP or gRPC)
fn connection_mode(&self) -> ConnectionMode;
/// Check if the worker is currently healthy /// Check if the worker is currently healthy
fn is_healthy(&self) -> bool; fn is_healthy(&self) -> bool;
...@@ -152,6 +155,30 @@ pub trait Worker: Send + Sync + fmt::Debug { ...@@ -152,6 +155,30 @@ pub trait Worker: Send + Sync + fmt::Debug {
} }
} }
/// Connection mode for worker communication
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ConnectionMode {
/// HTTP/REST connection
Http,
/// gRPC connection
Grpc {
/// Optional port for gRPC endpoint (if different from URL)
port: Option<u16>,
},
}
impl fmt::Display for ConnectionMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ConnectionMode::Http => write!(f, "HTTP"),
ConnectionMode::Grpc { port } => match port {
Some(p) => write!(f, "gRPC(port:{})", p),
None => write!(f, "gRPC"),
},
}
}
}
/// Worker type classification /// Worker type classification
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum WorkerType { pub enum WorkerType {
...@@ -213,6 +240,8 @@ pub struct WorkerMetadata { ...@@ -213,6 +240,8 @@ pub struct WorkerMetadata {
pub url: String, pub url: String,
/// Worker type /// Worker type
pub worker_type: WorkerType, pub worker_type: WorkerType,
/// Connection mode
pub connection_mode: ConnectionMode,
/// Additional labels/tags /// Additional labels/tags
pub labels: std::collections::HashMap<String, String>, pub labels: std::collections::HashMap<String, String>,
/// Health check configuration /// Health check configuration
...@@ -233,9 +262,18 @@ pub struct BasicWorker { ...@@ -233,9 +262,18 @@ pub struct BasicWorker {
impl BasicWorker { impl BasicWorker {
pub fn new(url: String, worker_type: WorkerType) -> Self { pub fn new(url: String, worker_type: WorkerType) -> Self {
Self::with_connection_mode(url, worker_type, ConnectionMode::Http)
}
pub fn with_connection_mode(
url: String,
worker_type: WorkerType,
connection_mode: ConnectionMode,
) -> Self {
let metadata = WorkerMetadata { let metadata = WorkerMetadata {
url: url.clone(), url: url.clone(),
worker_type, worker_type,
connection_mode,
labels: std::collections::HashMap::new(), labels: std::collections::HashMap::new(),
health_config: HealthConfig::default(), health_config: HealthConfig::default(),
}; };
...@@ -298,6 +336,10 @@ impl Worker for BasicWorker { ...@@ -298,6 +336,10 @@ impl Worker for BasicWorker {
self.metadata.worker_type.clone() self.metadata.worker_type.clone()
} }
fn connection_mode(&self) -> ConnectionMode {
self.metadata.connection_mode.clone()
}
fn is_healthy(&self) -> bool { fn is_healthy(&self) -> bool {
self.healthy.load(Ordering::Acquire) self.healthy.load(Ordering::Acquire)
} }
...@@ -434,6 +476,10 @@ impl Worker for DPAwareWorker { ...@@ -434,6 +476,10 @@ impl Worker for DPAwareWorker {
self.base_worker.worker_type() self.base_worker.worker_type()
} }
fn connection_mode(&self) -> ConnectionMode {
self.base_worker.connection_mode()
}
fn is_healthy(&self) -> bool { fn is_healthy(&self) -> bool {
self.base_worker.is_healthy() self.base_worker.is_healthy()
} }
...@@ -603,6 +649,28 @@ impl WorkerFactory { ...@@ -603,6 +649,28 @@ impl WorkerFactory {
(regular_workers, prefill_workers, decode_workers) (regular_workers, prefill_workers, decode_workers)
} }
/// Create a gRPC worker
pub fn create_grpc(url: String, worker_type: WorkerType, port: Option<u16>) -> Box<dyn Worker> {
Box::new(BasicWorker::with_connection_mode(
url,
worker_type,
ConnectionMode::Grpc { port },
))
}
/// Create a gRPC worker with custom circuit breaker configuration
pub fn create_grpc_with_config(
url: String,
worker_type: WorkerType,
port: Option<u16>,
circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> {
Box::new(
BasicWorker::with_connection_mode(url, worker_type, ConnectionMode::Grpc { port })
.with_circuit_breaker_config(circuit_breaker_config),
)
}
/// Create a DP-aware worker of specified type /// Create a DP-aware worker of specified type
pub fn create_dp_aware( pub fn create_dp_aware(
base_url: String, base_url: String,
......
//! Factory for creating router instances //! Factory for creating router instances
use super::{pd_router::PDRouter, router::Router, RouterTrait}; use super::{
http::{pd_router::PDRouter, router::Router},
RouterTrait,
};
use crate::config::{PolicyConfig, RoutingMode}; use crate::config::{PolicyConfig, RoutingMode};
use crate::policies::PolicyFactory; use crate::policies::PolicyFactory;
use crate::server::AppContext; use crate::server::AppContext;
...@@ -17,7 +20,9 @@ impl RouterFactory { ...@@ -17,7 +20,9 @@ impl RouterFactory {
return Self::create_igw_router(ctx).await; return Self::create_igw_router(ctx).await;
} }
// Default to proxy mode // TODO: Add gRPC mode check here when implementing gRPC support
// Default to HTTP proxy mode
match &ctx.router_config.mode { match &ctx.router_config.mode {
RoutingMode::Regular { worker_urls } => { RoutingMode::Regular { worker_urls } => {
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx).await Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx).await
...@@ -101,6 +106,29 @@ impl RouterFactory { ...@@ -101,6 +106,29 @@ impl RouterFactory {
Ok(Box::new(router)) Ok(Box::new(router))
} }
/// Create a gRPC router with injected policy
pub async fn create_grpc_router(
_worker_urls: &[String],
_policy_config: &PolicyConfig,
_ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> {
// For now, return an error as gRPC router is not yet implemented
Err("gRPC router is not yet implemented".to_string())
}
/// Create a gRPC PD router (placeholder for now)
pub async fn create_grpc_pd_router(
_prefill_urls: &[(String, Option<u16>)],
_decode_urls: &[String],
_prefill_policy_config: Option<&PolicyConfig>,
_decode_policy_config: Option<&PolicyConfig>,
_main_policy_config: &PolicyConfig,
_ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> {
// For now, return an error as gRPC PD router is not yet implemented
Err("gRPC PD router is not yet implemented".to_string())
}
/// Create an IGW router (placeholder for future implementation) /// Create an IGW router (placeholder for future implementation)
async fn create_igw_router(_ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> { async fn create_igw_router(_ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
// For now, return an error indicating IGW is not yet implemented // For now, return an error indicating IGW is not yet implemented
......
//! gRPC router implementations
pub mod pd_router;
pub mod router;
// PD (Prefill-Decode) gRPC Router Implementation
// TODO: Implement gRPC-based PD router for disaggregated prefill-decode systems
use crate::routers::{RouterTrait, WorkerManagement};
use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
/// Placeholder for gRPC PD router
#[derive(Debug)]
pub struct GrpcPDRouter;
impl GrpcPDRouter {
pub async fn new() -> Result<Self, String> {
// TODO: Implement gRPC PD router initialization
Err("gRPC PD router not yet implemented".to_string())
}
}
#[async_trait]
impl RouterTrait for GrpcPDRouter {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn health(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn health_generate(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_server_info(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_models(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_model_info(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_generate(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::GenerateRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_chat(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::ChatCompletionRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_completion(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::CompletionRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn flush_cache(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_worker_loads(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
fn router_type(&self) -> &'static str {
"grpc_pd"
}
fn readiness(&self) -> Response {
(StatusCode::SERVICE_UNAVAILABLE).into_response()
}
}
#[async_trait]
impl WorkerManagement for GrpcPDRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
Err("Not implemented".to_string())
}
fn remove_worker(&self, _worker_url: &str) {}
fn get_worker_urls(&self) -> Vec<String> {
vec![]
}
}
// gRPC Router Implementation
// TODO: Implement gRPC-based router
use crate::routers::{RouterTrait, WorkerManagement};
use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
/// Placeholder for gRPC router
#[derive(Debug)]
pub struct GrpcRouter;
impl GrpcRouter {
pub async fn new() -> Result<Self, String> {
// TODO: Implement gRPC router initialization
Err("gRPC router not yet implemented".to_string())
}
}
#[async_trait]
impl RouterTrait for GrpcRouter {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn health(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn health_generate(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_server_info(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_models(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_model_info(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_generate(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::GenerateRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_chat(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::ChatCompletionRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_completion(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::CompletionRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn flush_cache(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_worker_loads(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
fn router_type(&self) -> &'static str {
"grpc"
}
fn readiness(&self) -> Response {
(StatusCode::SERVICE_UNAVAILABLE).into_response()
}
}
#[async_trait]
impl WorkerManagement for GrpcRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
Err("Not implemented".to_string())
}
fn remove_worker(&self, _worker_url: &str) {}
fn get_worker_urls(&self) -> Vec<String> {
vec![]
}
}
//! HTTP router implementations
pub mod pd_router;
pub mod pd_types;
pub mod router;
// PD (Prefill-Decode) Router Implementation // PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems // This module handles routing for disaggregated prefill-decode systems
use super::header_utils;
use super::pd_types::{api_path, PDRouterError}; use super::pd_types::{api_path, PDRouterError};
use crate::config::types::{ use crate::config::types::{
CircuitBreakerConfig as ConfigCircuitBreakerConfig, CircuitBreakerConfig as ConfigCircuitBreakerConfig,
...@@ -16,6 +15,7 @@ use crate::protocols::spec::{ ...@@ -16,6 +15,7 @@ use crate::protocols::spec::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, StringOrArray, ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, StringOrArray,
UserMessageContent, UserMessageContent,
}; };
use crate::routers::header_utils;
use crate::routers::{RouterTrait, WorkerManagement}; use crate::routers::{RouterTrait, WorkerManagement};
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
...@@ -72,7 +72,7 @@ impl PDRouter { ...@@ -72,7 +72,7 @@ impl PDRouter {
// Private helper method to perform health check on a new server // Private helper method to perform health check on a new server
async fn wait_for_server_health(&self, url: &str) -> Result<(), PDRouterError> { async fn wait_for_server_health(&self, url: &str) -> Result<(), PDRouterError> {
crate::routers::router::Router::wait_for_healthy_workers( crate::routers::http::router::Router::wait_for_healthy_workers(
&[url.to_string()], &[url.to_string()],
self.timeout_secs, self.timeout_secs,
self.interval_secs, self.interval_secs,
...@@ -435,7 +435,7 @@ impl PDRouter { ...@@ -435,7 +435,7 @@ impl PDRouter {
.map(|worker| worker.url().to_string()) .map(|worker| worker.url().to_string())
.collect(); .collect();
if !all_urls.is_empty() { if !all_urls.is_empty() {
crate::routers::router::Router::wait_for_healthy_workers( crate::routers::http::router::Router::wait_for_healthy_workers(
&all_urls, &all_urls,
timeout_secs, timeout_secs,
interval_secs, interval_secs,
...@@ -1935,6 +1935,14 @@ impl RouterTrait for PDRouter { ...@@ -1935,6 +1935,14 @@ impl RouterTrait for PDRouter {
self.execute_dual_dispatch(headers, body, context).await self.execute_dual_dispatch(headers, body, context).await
} }
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
todo!()
}
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
todo!()
}
async fn flush_cache(&self) -> Response { async fn flush_cache(&self) -> Response {
// Process both prefill and decode workers // Process both prefill and decode workers
let (prefill_results, prefill_errors) = self let (prefill_results, prefill_errors) = self
...@@ -2040,7 +2048,7 @@ impl RouterTrait for PDRouter { ...@@ -2040,7 +2048,7 @@ impl RouterTrait for PDRouter {
let total_decode = self.decode_workers.read().unwrap().len(); let total_decode = self.decode_workers.read().unwrap().len();
if healthy_prefill_count > 0 && healthy_decode_count > 0 { if healthy_prefill_count > 0 && healthy_decode_count > 0 {
Json(serde_json::json!({ Json(json!({
"status": "ready", "status": "ready",
"prefill": { "prefill": {
"healthy": healthy_prefill_count, "healthy": healthy_prefill_count,
......
use super::header_utils;
use crate::config::types::{ use crate::config::types::{
CircuitBreakerConfig as ConfigCircuitBreakerConfig, CircuitBreakerConfig as ConfigCircuitBreakerConfig,
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig, HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
...@@ -12,6 +11,7 @@ use crate::policies::LoadBalancingPolicy; ...@@ -12,6 +11,7 @@ use crate::policies::LoadBalancingPolicy;
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest,
}; };
use crate::routers::header_utils;
use crate::routers::{RouterTrait, WorkerManagement}; use crate::routers::{RouterTrait, WorkerManagement};
use axum::{ use axum::{
body::Body, body::Body,
...@@ -393,7 +393,7 @@ impl Router { ...@@ -393,7 +393,7 @@ impl Router {
// Helper method to proxy GET requests to the first available worker // Helper method to proxy GET requests to the first available worker
async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response { async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
let headers = super::header_utils::copy_request_headers(&req); let headers = header_utils::copy_request_headers(&req);
match self.select_first_worker() { match self.select_first_worker() {
Ok(worker_url) => { Ok(worker_url) => {
...@@ -667,7 +667,7 @@ impl Router { ...@@ -667,7 +667,7 @@ impl Router {
if !is_stream { if !is_stream {
// For non-streaming requests, preserve headers // For non-streaming requests, preserve headers
let response_headers = super::header_utils::preserve_response_headers(res.headers()); let response_headers = header_utils::preserve_response_headers(res.headers());
let response = match res.bytes().await { let response = match res.bytes().await {
Ok(body) => { Ok(body) => {
...@@ -1198,6 +1198,14 @@ impl RouterTrait for Router { ...@@ -1198,6 +1198,14 @@ impl RouterTrait for Router {
.await .await
} }
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
todo!()
}
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
todo!()
}
async fn flush_cache(&self) -> Response { async fn flush_cache(&self) -> Response {
// Get all worker URLs // Get all worker URLs
let worker_urls = self.get_worker_urls(); let worker_urls = self.get_worker_urls();
......
...@@ -12,10 +12,9 @@ use std::fmt::Debug; ...@@ -12,10 +12,9 @@ use std::fmt::Debug;
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
pub mod factory; pub mod factory;
pub mod grpc;
pub mod header_utils; pub mod header_utils;
pub mod pd_router; pub mod http;
pub mod pd_types;
pub mod router;
pub use factory::RouterFactory; pub use factory::RouterFactory;
...@@ -77,6 +76,10 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { ...@@ -77,6 +76,10 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
body: &CompletionRequest, body: &CompletionRequest,
) -> Response; ) -> Response;
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
async fn route_rerank(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
/// Flush cache on all workers /// Flush cache on all workers
async fn flush_cache(&self) -> Response; async fn flush_cache(&self) -> Response;
......
...@@ -383,7 +383,7 @@ async fn handle_pod_event( ...@@ -383,7 +383,7 @@ async fn handle_pod_event(
// Handle PD mode with specific pod types // Handle PD mode with specific pod types
let result = if pd_mode && pod_info.pod_type.is_some() { let result = if pd_mode && pod_info.pod_type.is_some() {
// Need to import PDRouter type // Need to import PDRouter type
use crate::routers::pd_router::PDRouter; use crate::routers::http::pd_router::PDRouter;
// Try to downcast to PDRouter // Try to downcast to PDRouter
if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() { if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
...@@ -453,7 +453,7 @@ async fn handle_pod_deletion( ...@@ -453,7 +453,7 @@ async fn handle_pod_deletion(
// Handle PD mode removal // Handle PD mode removal
if pd_mode && pod_info.pod_type.is_some() { if pd_mode && pod_info.pod_type.is_some() {
use crate::routers::pd_router::PDRouter; use crate::routers::http::pd_router::PDRouter;
// Try to downcast to PDRouter for PD-specific removal // Try to downcast to PDRouter for PD-specific removal
if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() { if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
...@@ -581,7 +581,7 @@ mod tests { ...@@ -581,7 +581,7 @@ mod tests {
async fn create_test_router() -> Arc<dyn RouterTrait> { async fn create_test_router() -> Arc<dyn RouterTrait> {
use crate::config::PolicyConfig; use crate::config::PolicyConfig;
use crate::policies::PolicyFactory; use crate::policies::PolicyFactory;
use crate::routers::router::Router; use crate::routers::http::router::Router;
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random); let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
let router = Router::new( let router = Router::new(
......
...@@ -5,8 +5,8 @@ mod test_pd_routing { ...@@ -5,8 +5,8 @@ mod test_pd_routing {
CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
}; };
use sglang_router_rs::core::{WorkerFactory, WorkerType}; use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::routers::pd_types::get_hostname; use sglang_router_rs::routers::http::pd_types::get_hostname;
use sglang_router_rs::routers::pd_types::PDSelectionPolicy; use sglang_router_rs::routers::http::pd_types::PDSelectionPolicy;
use sglang_router_rs::routers::RouterFactory; use sglang_router_rs::routers::RouterFactory;
// Test-only struct to help validate PD request parsing // Test-only struct to help validate PD request parsing
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment