Unverified Commit 53430588 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] grpc router bootstraps (#9759)

parent fce7ae33
......@@ -7,7 +7,9 @@ use sglang_router_rs::protocols::spec::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
SamplingParams, StringOrArray, UserMessageContent,
};
use sglang_router_rs::routers::pd_types::{generate_room_id, get_hostname, RequestWithBootstrap};
use sglang_router_rs::routers::http::pd_types::{
generate_room_id, get_hostname, RequestWithBootstrap,
};
fn create_test_worker() -> BasicWorker {
BasicWorker::new(
......
......@@ -19,6 +19,6 @@ pub use circuit_breaker::{
pub use error::{WorkerError, WorkerResult};
pub use retry::{is_retryable_status, BackoffCalculator, RetryError, RetryExecutor};
pub use worker::{
start_health_checker, BasicWorker, DPAwareWorker, HealthChecker, HealthConfig, Worker,
WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType,
start_health_checker, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig,
Worker, WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType,
};
......@@ -24,6 +24,9 @@ pub trait Worker: Send + Sync + fmt::Debug {
/// Get the worker's type (Regular, Prefill, or Decode)
fn worker_type(&self) -> WorkerType;
/// Get the worker's connection mode (HTTP or gRPC)
fn connection_mode(&self) -> ConnectionMode;
/// Check if the worker is currently healthy
fn is_healthy(&self) -> bool;
......@@ -152,6 +155,30 @@ pub trait Worker: Send + Sync + fmt::Debug {
}
}
/// Connection mode for worker communication
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ConnectionMode {
/// HTTP/REST connection
Http,
/// gRPC connection
Grpc {
/// Optional port for gRPC endpoint (if different from URL)
port: Option<u16>,
},
}
impl fmt::Display for ConnectionMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ConnectionMode::Http => write!(f, "HTTP"),
ConnectionMode::Grpc { port } => match port {
Some(p) => write!(f, "gRPC(port:{})", p),
None => write!(f, "gRPC"),
},
}
}
}
/// Worker type classification
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum WorkerType {
......@@ -213,6 +240,8 @@ pub struct WorkerMetadata {
pub url: String,
/// Worker type
pub worker_type: WorkerType,
/// Connection mode
pub connection_mode: ConnectionMode,
/// Additional labels/tags
pub labels: std::collections::HashMap<String, String>,
/// Health check configuration
......@@ -233,9 +262,18 @@ pub struct BasicWorker {
impl BasicWorker {
pub fn new(url: String, worker_type: WorkerType) -> Self {
Self::with_connection_mode(url, worker_type, ConnectionMode::Http)
}
pub fn with_connection_mode(
url: String,
worker_type: WorkerType,
connection_mode: ConnectionMode,
) -> Self {
let metadata = WorkerMetadata {
url: url.clone(),
worker_type,
connection_mode,
labels: std::collections::HashMap::new(),
health_config: HealthConfig::default(),
};
......@@ -298,6 +336,10 @@ impl Worker for BasicWorker {
self.metadata.worker_type.clone()
}
fn connection_mode(&self) -> ConnectionMode {
self.metadata.connection_mode.clone()
}
fn is_healthy(&self) -> bool {
self.healthy.load(Ordering::Acquire)
}
......@@ -434,6 +476,10 @@ impl Worker for DPAwareWorker {
self.base_worker.worker_type()
}
fn connection_mode(&self) -> ConnectionMode {
self.base_worker.connection_mode()
}
fn is_healthy(&self) -> bool {
self.base_worker.is_healthy()
}
......@@ -603,6 +649,28 @@ impl WorkerFactory {
(regular_workers, prefill_workers, decode_workers)
}
/// Create a gRPC worker
pub fn create_grpc(url: String, worker_type: WorkerType, port: Option<u16>) -> Box<dyn Worker> {
Box::new(BasicWorker::with_connection_mode(
url,
worker_type,
ConnectionMode::Grpc { port },
))
}
/// Create a gRPC worker with custom circuit breaker configuration
pub fn create_grpc_with_config(
url: String,
worker_type: WorkerType,
port: Option<u16>,
circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> {
Box::new(
BasicWorker::with_connection_mode(url, worker_type, ConnectionMode::Grpc { port })
.with_circuit_breaker_config(circuit_breaker_config),
)
}
/// Create a DP-aware worker of specified type
pub fn create_dp_aware(
base_url: String,
......
//! Factory for creating router instances
use super::{pd_router::PDRouter, router::Router, RouterTrait};
use super::{
http::{pd_router::PDRouter, router::Router},
RouterTrait,
};
use crate::config::{PolicyConfig, RoutingMode};
use crate::policies::PolicyFactory;
use crate::server::AppContext;
......@@ -17,7 +20,9 @@ impl RouterFactory {
return Self::create_igw_router(ctx).await;
}
// Default to proxy mode
// TODO: Add gRPC mode check here when implementing gRPC support
// Default to HTTP proxy mode
match &ctx.router_config.mode {
RoutingMode::Regular { worker_urls } => {
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx).await
......@@ -101,6 +106,29 @@ impl RouterFactory {
Ok(Box::new(router))
}
/// Create a gRPC router with injected policy
pub async fn create_grpc_router(
_worker_urls: &[String],
_policy_config: &PolicyConfig,
_ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> {
// For now, return an error as gRPC router is not yet implemented
Err("gRPC router is not yet implemented".to_string())
}
/// Create a gRPC PD router (placeholder for now)
pub async fn create_grpc_pd_router(
_prefill_urls: &[(String, Option<u16>)],
_decode_urls: &[String],
_prefill_policy_config: Option<&PolicyConfig>,
_decode_policy_config: Option<&PolicyConfig>,
_main_policy_config: &PolicyConfig,
_ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> {
// For now, return an error as gRPC PD router is not yet implemented
Err("gRPC PD router is not yet implemented".to_string())
}
/// Create an IGW router (placeholder for future implementation)
async fn create_igw_router(_ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
// For now, return an error indicating IGW is not yet implemented
......
//! gRPC router implementations
pub mod pd_router;
pub mod router;
// PD (Prefill-Decode) gRPC Router Implementation
// TODO: Implement gRPC-based PD router for disaggregated prefill-decode systems
use crate::routers::{RouterTrait, WorkerManagement};
use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
/// Placeholder for gRPC PD router
#[derive(Debug)]
pub struct GrpcPDRouter;
impl GrpcPDRouter {
pub async fn new() -> Result<Self, String> {
// TODO: Implement gRPC PD router initialization
Err("gRPC PD router not yet implemented".to_string())
}
}
#[async_trait]
impl RouterTrait for GrpcPDRouter {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn health(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn health_generate(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_server_info(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_models(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_model_info(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_generate(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::GenerateRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_chat(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::ChatCompletionRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_completion(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::CompletionRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn flush_cache(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_worker_loads(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
fn router_type(&self) -> &'static str {
"grpc_pd"
}
fn readiness(&self) -> Response {
(StatusCode::SERVICE_UNAVAILABLE).into_response()
}
}
#[async_trait]
impl WorkerManagement for GrpcPDRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
Err("Not implemented".to_string())
}
fn remove_worker(&self, _worker_url: &str) {}
fn get_worker_urls(&self) -> Vec<String> {
vec![]
}
}
// gRPC Router Implementation
// TODO: Implement gRPC-based router
use crate::routers::{RouterTrait, WorkerManagement};
use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
/// Placeholder for gRPC router
#[derive(Debug)]
pub struct GrpcRouter;
impl GrpcRouter {
pub async fn new() -> Result<Self, String> {
// TODO: Implement gRPC router initialization
Err("gRPC router not yet implemented".to_string())
}
}
#[async_trait]
impl RouterTrait for GrpcRouter {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn health(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn health_generate(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_server_info(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_models(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_model_info(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_generate(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::GenerateRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_chat(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::ChatCompletionRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_completion(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::CompletionRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn flush_cache(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_worker_loads(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
fn router_type(&self) -> &'static str {
"grpc"
}
fn readiness(&self) -> Response {
(StatusCode::SERVICE_UNAVAILABLE).into_response()
}
}
#[async_trait]
impl WorkerManagement for GrpcRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
Err("Not implemented".to_string())
}
fn remove_worker(&self, _worker_url: &str) {}
fn get_worker_urls(&self) -> Vec<String> {
vec![]
}
}
//! HTTP router implementations
pub mod pd_router;
pub mod pd_types;
pub mod router;
// PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems
use super::header_utils;
use super::pd_types::{api_path, PDRouterError};
use crate::config::types::{
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
......@@ -16,6 +15,7 @@ use crate::protocols::spec::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, StringOrArray,
UserMessageContent,
};
use crate::routers::header_utils;
use crate::routers::{RouterTrait, WorkerManagement};
use async_trait::async_trait;
use axum::{
......@@ -72,7 +72,7 @@ impl PDRouter {
// Private helper method to perform health check on a new server
async fn wait_for_server_health(&self, url: &str) -> Result<(), PDRouterError> {
crate::routers::router::Router::wait_for_healthy_workers(
crate::routers::http::router::Router::wait_for_healthy_workers(
&[url.to_string()],
self.timeout_secs,
self.interval_secs,
......@@ -435,7 +435,7 @@ impl PDRouter {
.map(|worker| worker.url().to_string())
.collect();
if !all_urls.is_empty() {
crate::routers::router::Router::wait_for_healthy_workers(
crate::routers::http::router::Router::wait_for_healthy_workers(
&all_urls,
timeout_secs,
interval_secs,
......@@ -1935,6 +1935,14 @@ impl RouterTrait for PDRouter {
self.execute_dual_dispatch(headers, body, context).await
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
todo!()
}
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
todo!()
}
async fn flush_cache(&self) -> Response {
// Process both prefill and decode workers
let (prefill_results, prefill_errors) = self
......@@ -2040,7 +2048,7 @@ impl RouterTrait for PDRouter {
let total_decode = self.decode_workers.read().unwrap().len();
if healthy_prefill_count > 0 && healthy_decode_count > 0 {
Json(serde_json::json!({
Json(json!({
"status": "ready",
"prefill": {
"healthy": healthy_prefill_count,
......
use super::header_utils;
use crate::config::types::{
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
......@@ -12,6 +11,7 @@ use crate::policies::LoadBalancingPolicy;
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest,
};
use crate::routers::header_utils;
use crate::routers::{RouterTrait, WorkerManagement};
use axum::{
body::Body,
......@@ -393,7 +393,7 @@ impl Router {
// Helper method to proxy GET requests to the first available worker
async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
let headers = super::header_utils::copy_request_headers(&req);
let headers = header_utils::copy_request_headers(&req);
match self.select_first_worker() {
Ok(worker_url) => {
......@@ -667,7 +667,7 @@ impl Router {
if !is_stream {
// For non-streaming requests, preserve headers
let response_headers = super::header_utils::preserve_response_headers(res.headers());
let response_headers = header_utils::preserve_response_headers(res.headers());
let response = match res.bytes().await {
Ok(body) => {
......@@ -1198,6 +1198,14 @@ impl RouterTrait for Router {
.await
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
todo!()
}
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
todo!()
}
async fn flush_cache(&self) -> Response {
// Get all worker URLs
let worker_urls = self.get_worker_urls();
......
......@@ -12,10 +12,9 @@ use std::fmt::Debug;
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
pub mod factory;
pub mod grpc;
pub mod header_utils;
pub mod pd_router;
pub mod pd_types;
pub mod router;
pub mod http;
pub use factory::RouterFactory;
......@@ -77,6 +76,10 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
body: &CompletionRequest,
) -> Response;
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
async fn route_rerank(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
/// Flush cache on all workers
async fn flush_cache(&self) -> Response;
......
......@@ -383,7 +383,7 @@ async fn handle_pod_event(
// Handle PD mode with specific pod types
let result = if pd_mode && pod_info.pod_type.is_some() {
// Need to import PDRouter type
use crate::routers::pd_router::PDRouter;
use crate::routers::http::pd_router::PDRouter;
// Try to downcast to PDRouter
if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
......@@ -453,7 +453,7 @@ async fn handle_pod_deletion(
// Handle PD mode removal
if pd_mode && pod_info.pod_type.is_some() {
use crate::routers::pd_router::PDRouter;
use crate::routers::http::pd_router::PDRouter;
// Try to downcast to PDRouter for PD-specific removal
if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
......@@ -581,7 +581,7 @@ mod tests {
async fn create_test_router() -> Arc<dyn RouterTrait> {
use crate::config::PolicyConfig;
use crate::policies::PolicyFactory;
use crate::routers::router::Router;
use crate::routers::http::router::Router;
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
let router = Router::new(
......
......@@ -5,8 +5,8 @@ mod test_pd_routing {
CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
};
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::routers::pd_types::get_hostname;
use sglang_router_rs::routers::pd_types::PDSelectionPolicy;
use sglang_router_rs::routers::http::pd_types::get_hostname;
use sglang_router_rs::routers::http::pd_types::PDSelectionPolicy;
use sglang_router_rs::routers::RouterFactory;
// Test-only struct to help validate PD request parsing
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment