Unverified Commit 66a398f4 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] migrate router from actix to axum (#8479)

parent 29980334
...@@ -10,41 +10,41 @@ name = "sglang_router_rs" ...@@ -10,41 +10,41 @@ name = "sglang_router_rs"
crate-type = ["cdylib", "rlib"] crate-type = ["cdylib", "rlib"]
[dependencies] [dependencies]
actix-web = "4.0" axum = { version = "0.8.4", features = ["macros", "ws", "tracing"] }
tower = { version = "0.5", features = ["full"] }
tower-http = { version = "0.6", features = ["trace", "compression-gzip", "cors", "timeout", "limit", "request-id", "util"] }
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
clap = { version = "4.4", features = ["derive"] } serde_json = "1.0"
bytes = "1.8.0" bytes = "1.8.0"
rand = "0.8.5" rand = "0.8.5"
reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] } reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] }
futures-util = "0.3" futures-util = "0.3"
serde_json = "1.0" futures = "0.3"
pyo3 = { version = "0.22.5", features = ["extension-module"] } pyo3 = { version = "0.22.5", features = ["extension-module"] }
dashmap = "6.1.0" dashmap = "6.1.0"
http = "1.1.0" http = "1.1.0"
tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread"] } tokio = { version = "1.42.0", features = ["full"] }
# Added for enhanced logging system async-trait = "0.1"
once_cell = "1.21"
tracing = "0.1" tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "chrono"] } tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "chrono"] }
tracing-log = "0.2" tracing-log = "0.2"
tracing-appender = "0.2.3" tracing-appender = "0.2.3"
chrono = "0.4"
kube = { version = "0.88.1", features = ["runtime", "derive"] } kube = { version = "0.88.1", features = ["runtime", "derive"] }
k8s-openapi = { version = "0.21.0", features = ["v1_29"] } k8s-openapi = { version = "0.21.0", features = ["v1_29"] }
futures = "0.3"
async-trait = "0.1"
once_cell = "1.21"
# Added for metrics
metrics = "0.24.2" metrics = "0.24.2"
metrics-exporter-prometheus = "0.17.0" metrics-exporter-prometheus = "0.17.0"
# Added for request tracing
uuid = { version = "1.10", features = ["v4", "serde"] } uuid = { version = "1.10", features = ["v4", "serde"] }
thiserror = "2.0.12" thiserror = "2.0.12"
url = "2.5.4" url = "2.5.4"
tokio-stream = { version = "0.1", features = ["sync"] }
[dev-dependencies] [dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] } criterion = { version = "0.5", features = ["html_reports"] }
tokio-stream = "0.1" tower = { version = "0.5", features = ["util"] }
actix-http = "3.0" http-body-util = "0.1"
futures = "0.3" portpicker = "0.1"
[[bench]] [[bench]]
name = "request_processing" name = "request_processing"
......
...@@ -68,6 +68,12 @@ class RouterArgs: ...@@ -68,6 +68,12 @@ class RouterArgs:
prometheus_host: Optional[str] = None prometheus_host: Optional[str] = None
# Request ID headers configuration # Request ID headers configuration
request_id_headers: Optional[List[str]] = None request_id_headers: Optional[List[str]] = None
# Request timeout in seconds
request_timeout_secs: int = 600
# Max concurrent requests for rate limiting
max_concurrent_requests: int = 64
# CORS allowed origins
cors_allowed_origins: List[str] = dataclasses.field(default_factory=list)
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(
...@@ -276,6 +282,25 @@ class RouterArgs: ...@@ -276,6 +282,25 @@ class RouterArgs:
nargs="*", nargs="*",
help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.", help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.",
) )
parser.add_argument(
f"--{prefix}request-timeout-secs",
type=int,
default=RouterArgs.request_timeout_secs,
help="Request timeout in seconds",
)
parser.add_argument(
f"--{prefix}max-concurrent-requests",
type=int,
default=RouterArgs.max_concurrent_requests,
help="Maximum number of concurrent requests allowed (for rate limiting)",
)
parser.add_argument(
f"--{prefix}cors-allowed-origins",
type=str,
nargs="*",
default=[],
help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)",
)
@classmethod @classmethod
def from_cli_args( def from_cli_args(
...@@ -337,6 +362,15 @@ class RouterArgs: ...@@ -337,6 +362,15 @@ class RouterArgs:
prometheus_port=getattr(args, f"{prefix}prometheus_port", None), prometheus_port=getattr(args, f"{prefix}prometheus_port", None),
prometheus_host=getattr(args, f"{prefix}prometheus_host", None), prometheus_host=getattr(args, f"{prefix}prometheus_host", None),
request_id_headers=getattr(args, f"{prefix}request_id_headers", None), request_id_headers=getattr(args, f"{prefix}request_id_headers", None),
request_timeout_secs=getattr(
args, f"{prefix}request_timeout_secs", RouterArgs.request_timeout_secs
),
max_concurrent_requests=getattr(
args,
f"{prefix}max_concurrent_requests",
RouterArgs.max_concurrent_requests,
),
cors_allowed_origins=getattr(args, f"{prefix}cors_allowed_origins", []),
) )
@staticmethod @staticmethod
...@@ -490,6 +524,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: ...@@ -490,6 +524,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
decode_selector=router_args.decode_selector, decode_selector=router_args.decode_selector,
prometheus_port=router_args.prometheus_port, prometheus_port=router_args.prometheus_port,
prometheus_host=router_args.prometheus_host, prometheus_host=router_args.prometheus_host,
request_timeout_secs=router_args.request_timeout_secs,
pd_disaggregation=router_args.pd_disaggregation, pd_disaggregation=router_args.pd_disaggregation,
prefill_urls=( prefill_urls=(
router_args.prefill_urls if router_args.pd_disaggregation else None router_args.prefill_urls if router_args.pd_disaggregation else None
...@@ -508,6 +543,8 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: ...@@ -508,6 +543,8 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
else None else None
), ),
request_id_headers=router_args.request_id_headers, request_id_headers=router_args.request_id_headers,
max_concurrent_requests=router_args.max_concurrent_requests,
cors_allowed_origins=router_args.cors_allowed_origins,
) )
router.start() router.start()
......
...@@ -61,6 +61,11 @@ class Router: ...@@ -61,6 +61,11 @@ class Router:
request_id_headers: List of HTTP headers to check for request IDs. If not specified, request_id_headers: List of HTTP headers to check for request IDs. If not specified,
uses common defaults: ['x-request-id', 'x-correlation-id', 'x-trace-id', 'request-id']. uses common defaults: ['x-request-id', 'x-correlation-id', 'x-trace-id', 'request-id'].
Example: ['x-my-request-id', 'x-custom-trace-id']. Default: None Example: ['x-my-request-id', 'x-custom-trace-id']. Default: None
bootstrap_port_annotation: Kubernetes annotation name for bootstrap port (PD mode).
Default: 'sglang.ai/bootstrap-port'
request_timeout_secs: Request timeout in seconds. Default: 600
max_concurrent_requests: Maximum number of concurrent requests allowed for rate limiting. Default: 64
cors_allowed_origins: List of allowed origins for CORS. Empty list allows all origins. Default: []
""" """
def __init__( def __init__(
...@@ -87,14 +92,18 @@ class Router: ...@@ -87,14 +92,18 @@ class Router:
service_discovery_namespace: Optional[str] = None, service_discovery_namespace: Optional[str] = None,
prefill_selector: Dict[str, str] = None, prefill_selector: Dict[str, str] = None,
decode_selector: Dict[str, str] = None, decode_selector: Dict[str, str] = None,
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port",
prometheus_port: Optional[int] = None, prometheus_port: Optional[int] = None,
prometheus_host: Optional[str] = None, prometheus_host: Optional[str] = None,
request_timeout_secs: int = 600,
request_id_headers: Optional[List[str]] = None,
pd_disaggregation: bool = False, pd_disaggregation: bool = False,
prefill_urls: Optional[List[tuple]] = None, prefill_urls: Optional[List[tuple]] = None,
decode_urls: Optional[List[str]] = None, decode_urls: Optional[List[str]] = None,
prefill_policy: Optional[PolicyType] = None, prefill_policy: Optional[PolicyType] = None,
decode_policy: Optional[PolicyType] = None, decode_policy: Optional[PolicyType] = None,
request_id_headers: Optional[List[str]] = None, max_concurrent_requests: int = 64,
cors_allowed_origins: List[str] = None,
): ):
if selector is None: if selector is None:
selector = {} selector = {}
...@@ -102,6 +111,8 @@ class Router: ...@@ -102,6 +111,8 @@ class Router:
prefill_selector = {} prefill_selector = {}
if decode_selector is None: if decode_selector is None:
decode_selector = {} decode_selector = {}
if cors_allowed_origins is None:
cors_allowed_origins = []
self._router = _Router( self._router = _Router(
worker_urls=worker_urls, worker_urls=worker_urls,
...@@ -126,14 +137,18 @@ class Router: ...@@ -126,14 +137,18 @@ class Router:
service_discovery_namespace=service_discovery_namespace, service_discovery_namespace=service_discovery_namespace,
prefill_selector=prefill_selector, prefill_selector=prefill_selector,
decode_selector=decode_selector, decode_selector=decode_selector,
bootstrap_port_annotation=bootstrap_port_annotation,
prometheus_port=prometheus_port, prometheus_port=prometheus_port,
prometheus_host=prometheus_host, prometheus_host=prometheus_host,
request_timeout_secs=request_timeout_secs,
request_id_headers=request_id_headers,
pd_disaggregation=pd_disaggregation, pd_disaggregation=pd_disaggregation,
prefill_urls=prefill_urls, prefill_urls=prefill_urls,
decode_urls=decode_urls, decode_urls=decode_urls,
prefill_policy=prefill_policy, prefill_policy=prefill_policy,
decode_policy=decode_policy, decode_policy=decode_policy,
request_id_headers=request_id_headers, max_concurrent_requests=max_concurrent_requests,
cors_allowed_origins=cors_allowed_origins,
) )
def start(self) -> None: def start(self) -> None:
......
...@@ -46,11 +46,12 @@ class TestLaunchRouter(unittest.TestCase): ...@@ -46,11 +46,12 @@ class TestLaunchRouter(unittest.TestCase):
dp_aware=False, dp_aware=False,
prometheus_port=None, prometheus_port=None,
prometheus_host=None, prometheus_host=None,
# PD-specific attributes request_timeout_secs=60,
max_concurrent_requests=64,
cors_allowed_origins=[],
pd_disaggregation=False, pd_disaggregation=False,
prefill=None, prefill=None,
decode=None, decode=None,
# Keep worker_urls for regular mode
worker_urls=[], worker_urls=[],
) )
......
...@@ -35,6 +35,10 @@ pub struct RouterConfig { ...@@ -35,6 +35,10 @@ pub struct RouterConfig {
pub log_level: Option<String>, pub log_level: Option<String>,
/// Custom request ID headers to check (defaults to common headers) /// Custom request ID headers to check (defaults to common headers)
pub request_id_headers: Option<Vec<String>>, pub request_id_headers: Option<Vec<String>>,
/// Maximum concurrent requests allowed (for rate limiting)
pub max_concurrent_requests: usize,
/// CORS allowed origins
pub cors_allowed_origins: Vec<String>,
} }
/// Routing mode configuration /// Routing mode configuration
...@@ -216,6 +220,8 @@ impl Default for RouterConfig { ...@@ -216,6 +220,8 @@ impl Default for RouterConfig {
log_dir: None, log_dir: None,
log_level: None, log_level: None,
request_id_headers: None, request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
} }
} }
} }
...@@ -324,6 +330,8 @@ mod tests { ...@@ -324,6 +330,8 @@ mod tests {
log_dir: Some("/var/log".to_string()), log_dir: Some("/var/log".to_string()),
log_level: Some("debug".to_string()), log_level: Some("debug".to_string()),
request_id_headers: None, request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
}; };
let json = serde_json::to_string(&config).unwrap(); let json = serde_json::to_string(&config).unwrap();
...@@ -749,6 +757,8 @@ mod tests { ...@@ -749,6 +757,8 @@ mod tests {
log_dir: Some("/var/log/sglang".to_string()), log_dir: Some("/var/log/sglang".to_string()),
log_level: Some("info".to_string()), log_level: Some("info".to_string()),
request_id_headers: None, request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
}; };
assert!(config.mode.is_pd_mode()); assert!(config.mode.is_pd_mode());
...@@ -798,6 +808,8 @@ mod tests { ...@@ -798,6 +808,8 @@ mod tests {
log_dir: None, log_dir: None,
log_level: Some("debug".to_string()), log_level: Some("debug".to_string()),
request_id_headers: None, request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
}; };
assert!(!config.mode.is_pd_mode()); assert!(!config.mode.is_pd_mode());
...@@ -843,6 +855,8 @@ mod tests { ...@@ -843,6 +855,8 @@ mod tests {
log_dir: Some("/opt/logs/sglang".to_string()), log_dir: Some("/opt/logs/sglang".to_string()),
log_level: Some("trace".to_string()), log_level: Some("trace".to_string()),
request_id_headers: None, request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
}; };
assert!(config.has_service_discovery()); assert!(config.has_service_discovery());
......
...@@ -60,6 +60,9 @@ struct Router { ...@@ -60,6 +60,9 @@ struct Router {
decode_urls: Option<Vec<String>>, decode_urls: Option<Vec<String>>,
prefill_policy: Option<PolicyType>, prefill_policy: Option<PolicyType>,
decode_policy: Option<PolicyType>, decode_policy: Option<PolicyType>,
// Additional server config fields
max_concurrent_requests: usize,
cors_allowed_origins: Vec<String>,
} }
impl Router { impl Router {
...@@ -145,6 +148,8 @@ impl Router { ...@@ -145,6 +148,8 @@ impl Router {
log_dir: self.log_dir.clone(), log_dir: self.log_dir.clone(),
log_level: self.log_level.clone(), log_level: self.log_level.clone(),
request_id_headers: self.request_id_headers.clone(), request_id_headers: self.request_id_headers.clone(),
max_concurrent_requests: self.max_concurrent_requests,
cors_allowed_origins: self.cors_allowed_origins.clone(),
}) })
} }
} }
...@@ -184,7 +189,9 @@ impl Router { ...@@ -184,7 +189,9 @@ impl Router {
prefill_urls = None, prefill_urls = None,
decode_urls = None, decode_urls = None,
prefill_policy = None, prefill_policy = None,
decode_policy = None decode_policy = None,
max_concurrent_requests = 64,
cors_allowed_origins = vec![]
))] ))]
fn new( fn new(
worker_urls: Vec<String>, worker_urls: Vec<String>,
...@@ -219,6 +226,8 @@ impl Router { ...@@ -219,6 +226,8 @@ impl Router {
decode_urls: Option<Vec<String>>, decode_urls: Option<Vec<String>>,
prefill_policy: Option<PolicyType>, prefill_policy: Option<PolicyType>,
decode_policy: Option<PolicyType>, decode_policy: Option<PolicyType>,
max_concurrent_requests: usize,
cors_allowed_origins: Vec<String>,
) -> PyResult<Self> { ) -> PyResult<Self> {
Ok(Router { Ok(Router {
host, host,
...@@ -253,6 +262,8 @@ impl Router { ...@@ -253,6 +262,8 @@ impl Router {
decode_urls, decode_urls,
prefill_policy, prefill_policy,
decode_policy, decode_policy,
max_concurrent_requests,
cors_allowed_origins,
}) })
} }
......
use actix_web::{ use axum::{extract::Request, http::HeaderValue, response::Response};
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}, use std::sync::Arc;
Error, HttpMessage, HttpRequest, use std::time::Instant;
}; use tower::{Layer, Service};
use futures_util::future::LocalBoxFuture; use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer};
use std::future::{ready, Ready}; use tracing::{field::Empty, info_span, Span};
/// Generate OpenAI-compatible request ID based on endpoint /// Generate OpenAI-compatible request ID based on endpoint
fn generate_request_id(path: &str) -> String { fn generate_request_id(path: &str) -> String {
...@@ -31,67 +31,67 @@ fn generate_request_id(path: &str) -> String { ...@@ -31,67 +31,67 @@ fn generate_request_id(path: &str) -> String {
format!("{}{}", prefix, random_part) format!("{}{}", prefix, random_part)
} }
/// Extract request ID from request extensions or generate a new one /// Extension type for storing request ID
pub fn get_request_id(req: &HttpRequest) -> String { #[derive(Clone, Debug)]
req.extensions() pub struct RequestId(pub String);
.get::<String>()
.cloned()
.unwrap_or_else(|| generate_request_id(req.path()))
}
/// Middleware for injecting request ID into request extensions /// Tower Layer for request ID middleware
pub struct RequestIdMiddleware { #[derive(Clone)]
headers: Vec<String>, pub struct RequestIdLayer {
headers: Arc<Vec<String>>,
} }
impl RequestIdMiddleware { impl RequestIdLayer {
pub fn new(headers: Vec<String>) -> Self { pub fn new(headers: Vec<String>) -> Self {
Self { headers } Self {
headers: Arc::new(headers),
}
} }
} }
impl<S, B> Transform<S, ServiceRequest> for RequestIdMiddleware impl<S> Layer<S> for RequestIdLayer {
where type Service = RequestIdMiddleware<S>;
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static, fn layer(&self, inner: S) -> Self::Service {
B: 'static, RequestIdMiddleware {
{ inner,
type Response = ServiceResponse<B>;
type Error = Error;
type InitError = ();
type Transform = RequestIdMiddlewareService<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(RequestIdMiddlewareService {
service,
headers: self.headers.clone(), headers: self.headers.clone(),
})) }
} }
} }
pub struct RequestIdMiddlewareService<S> { /// Tower Service for request ID middleware
service: S, #[derive(Clone)]
headers: Vec<String>, pub struct RequestIdMiddleware<S> {
inner: S,
headers: Arc<Vec<String>>,
} }
impl<S, B> Service<ServiceRequest> for RequestIdMiddlewareService<S> impl<S> Service<Request> for RequestIdMiddleware<S>
where where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>, S: Service<Request, Response = Response> + Send + 'static,
S::Future: 'static, S::Future: Send + 'static,
B: 'static,
{ {
type Response = ServiceResponse<B>; type Response = S::Response;
type Error = Error; type Error = S::Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
forward_ready!(service); fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request) -> Self::Future {
let headers = self.headers.clone();
fn call(&self, req: ServiceRequest) -> Self::Future {
// Extract request ID from headers or generate new one // Extract request ID from headers or generate new one
let mut request_id = None; let mut request_id = None;
for header_name in &self.headers { for header_name in headers.iter() {
if let Some(header_value) = req.headers().get(header_name) { if let Some(header_value) = req.headers().get(header_name) {
if let Ok(value) = header_value.to_str() { if let Ok(value) = header_value.to_str() {
request_id = Some(value.to_string()); request_id = Some(value.to_string());
...@@ -100,12 +100,216 @@ where ...@@ -100,12 +100,216 @@ where
} }
} }
let request_id = request_id.unwrap_or_else(|| generate_request_id(req.path())); let request_id = request_id.unwrap_or_else(|| generate_request_id(req.uri().path()));
// Insert request ID into request extensions // Insert request ID into request extensions
req.extensions_mut().insert(request_id); req.extensions_mut().insert(RequestId(request_id.clone()));
// Create a span with the request ID for this request
let span = tracing::info_span!(
"http_request",
method = %req.method(),
uri = %req.uri(),
version = ?req.version(),
request_id = %request_id
);
// Log within the span
let _enter = span.enter();
tracing::info!(
target: "sglang_router_rs::request",
"started processing request"
);
drop(_enter);
// Capture values we need in the async block
let method = req.method().clone();
let uri = req.uri().clone();
let version = req.version();
// Call the inner service
let future = self.inner.call(req);
Box::pin(async move {
let start_time = Instant::now();
let mut response = future.await?;
let latency = start_time.elapsed();
let fut = self.service.call(req); // Add request ID to response headers
Box::pin(async move { fut.await }) response.headers_mut().insert(
"x-request-id",
HeaderValue::from_str(&request_id)
.unwrap_or_else(|_| HeaderValue::from_static("invalid-request-id")),
);
// Log the response with proper request ID in span
let status = response.status();
let span = tracing::info_span!(
"http_request",
method = %method,
uri = %uri,
version = ?version,
request_id = %request_id,
status = %status,
latency = ?latency
);
let _enter = span.enter();
if status.is_server_error() {
tracing::error!(
target: "sglang_router_rs::response",
"request failed with server error"
);
} else if status.is_client_error() {
tracing::warn!(
target: "sglang_router_rs::response",
"request failed with client error"
);
} else {
tracing::info!(
target: "sglang_router_rs::response",
"finished processing request"
);
}
Ok(response)
})
}
}
// ============= Logging Middleware =============
/// Custom span maker that includes request ID
#[derive(Clone, Debug)]
pub struct RequestSpan;
impl<B> MakeSpan<B> for RequestSpan {
fn make_span(&mut self, request: &Request<B>) -> Span {
// Don't try to extract request ID here - it won't be available yet
// The RequestIdLayer runs after TraceLayer creates the span
info_span!(
"http_request",
method = %request.method(),
uri = %request.uri(),
version = ?request.version(),
request_id = Empty, // Will be set later
status_code = Empty,
latency = Empty,
error = Empty,
)
}
}
/// Custom on_request handler
#[derive(Clone, Debug)]
pub struct RequestLogger;
impl<B> OnRequest<B> for RequestLogger {
fn on_request(&mut self, request: &Request<B>, span: &Span) {
let _enter = span.enter();
// Try to get the request ID from extensions
// This will work if RequestIdLayer has already run
if let Some(request_id) = request.extensions().get::<RequestId>() {
span.record("request_id", &request_id.0.as_str());
}
// Don't log here - we already log in RequestIdService with the proper request_id
}
}
/// Custom on_response handler
#[derive(Clone, Debug)]
pub struct ResponseLogger {
_start_time: Instant,
}
impl Default for ResponseLogger {
fn default() -> Self {
Self {
_start_time: Instant::now(),
}
}
}
impl<B> OnResponse<B> for ResponseLogger {
fn on_response(self, response: &Response<B>, latency: std::time::Duration, span: &Span) {
let status = response.status();
// Record these in the span for structured logging/observability tools
span.record("status_code", status.as_u16());
span.record("latency", format!("{:?}", latency));
// Don't log here - RequestIdService handles all logging with proper request IDs
}
}
/// Create a configured TraceLayer for HTTP logging
/// Note: Actual request/response logging with request IDs is done in RequestIdService
pub fn create_logging_layer() -> TraceLayer<
tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>,
RequestSpan,
RequestLogger,
ResponseLogger,
> {
TraceLayer::new_for_http()
.make_span_with(RequestSpan)
.on_request(RequestLogger)
.on_response(ResponseLogger::default())
}
/// Structured logging data for requests
#[derive(Debug, serde::Serialize)]
pub struct RequestLogEntry {
pub timestamp: String,
pub request_id: String,
pub method: String,
pub uri: String,
pub status: u16,
pub latency_ms: u64,
pub user_agent: Option<String>,
pub remote_addr: Option<String>,
pub error: Option<String>,
}
/// Log a request with structured data
pub fn log_request(entry: RequestLogEntry) {
if entry.status >= 500 {
tracing::error!(
target: "sglang_router_rs::http",
request_id = %entry.request_id,
method = %entry.method,
uri = %entry.uri,
status = entry.status,
latency_ms = entry.latency_ms,
user_agent = ?entry.user_agent,
remote_addr = ?entry.remote_addr,
error = ?entry.error,
"HTTP request failed"
);
} else if entry.status >= 400 {
tracing::warn!(
target: "sglang_router_rs::http",
request_id = %entry.request_id,
method = %entry.method,
uri = %entry.uri,
status = entry.status,
latency_ms = entry.latency_ms,
user_agent = ?entry.user_agent,
remote_addr = ?entry.remote_addr,
"HTTP request client error"
);
} else {
tracing::info!(
target: "sglang_router_rs::http",
request_id = %entry.request_id,
method = %entry.method,
uri = %entry.uri,
status = entry.status,
latency_ms = entry.latency_ms,
user_agent = ?entry.user_agent,
remote_addr = ?entry.remote_addr,
"HTTP request completed"
);
} }
} }
//! Router implementations //! Router implementations
use actix_web::{HttpRequest, HttpResponse};
use async_trait::async_trait; use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use reqwest::Client; use reqwest::Client;
use std::fmt::Debug; use std::fmt::Debug;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
pub mod factory; pub mod factory;
pub mod pd_router; pub mod pd_router;
pub mod pd_types; pub mod pd_types;
...@@ -33,54 +40,55 @@ pub trait WorkerManagement: Send + Sync { ...@@ -33,54 +40,55 @@ pub trait WorkerManagement: Send + Sync {
/// ///
/// This trait provides a unified interface for routing requests, /// This trait provides a unified interface for routing requests,
/// regardless of whether it's a regular router or PD router. /// regardless of whether it's a regular router or PD router.
#[async_trait(?Send)] #[async_trait]
pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
/// Get a reference to self as Any for downcasting /// Get a reference to self as Any for downcasting
fn as_any(&self) -> &dyn std::any::Any; fn as_any(&self) -> &dyn std::any::Any;
/// Route a health check request /// Route a health check request
async fn health(&self, client: &Client, req: &HttpRequest) -> HttpResponse; async fn health(&self, client: &Client, req: Request<Body>) -> Response;
/// Route a health generate request /// Route a health generate request
async fn health_generate(&self, client: &Client, req: &HttpRequest) -> HttpResponse; async fn health_generate(&self, client: &Client, req: Request<Body>) -> Response;
/// Get server information /// Get server information
async fn get_server_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse; async fn get_server_info(&self, client: &Client, req: Request<Body>) -> Response;
/// Get available models /// Get available models
async fn get_models(&self, client: &Client, req: &HttpRequest) -> HttpResponse; async fn get_models(&self, client: &Client, req: Request<Body>) -> Response;
/// Get model information /// Get model information
async fn get_model_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse; async fn get_model_info(&self, client: &Client, req: Request<Body>) -> Response;
/// Route a generate request /// Route a generate request
async fn route_generate( async fn route_generate(
&self, &self,
client: &Client, client: &Client,
req: &HttpRequest, headers: Option<&HeaderMap>,
body: serde_json::Value, body: &GenerateRequest,
) -> HttpResponse; ) -> Response;
/// Route a chat completion request /// Route a chat completion request
async fn route_chat( async fn route_chat(
&self, &self,
client: &Client, client: &Client,
req: &HttpRequest, headers: Option<&HeaderMap>,
body: serde_json::Value, body: &ChatCompletionRequest,
) -> HttpResponse; ) -> Response;
/// Route a completion request /// Route a completion request
async fn route_completion( async fn route_completion(
&self, &self,
client: &Client, client: &Client,
req: &HttpRequest, headers: Option<&HeaderMap>,
body: serde_json::Value, body: &CompletionRequest,
) -> HttpResponse; ) -> Response;
/// Flush cache on all workers /// Flush cache on all workers
async fn flush_cache(&self, client: &Client) -> HttpResponse; async fn flush_cache(&self, client: &Client) -> Response;
/// Get worker loads (for monitoring) /// Get worker loads (for monitoring)
async fn get_worker_loads(&self, client: &Client) -> HttpResponse; async fn get_worker_loads(&self, client: &Client) -> Response;
/// Get router type name /// Get router type name
fn router_type(&self) -> &'static str; fn router_type(&self) -> &'static str;
...@@ -91,11 +99,11 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { ...@@ -91,11 +99,11 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
} }
/// Server liveness check - is the server process running /// Server liveness check - is the server process running
fn liveness(&self) -> HttpResponse { fn liveness(&self) -> Response {
// Simple liveness check - if we can respond, we're alive // Simple liveness check - if we can respond, we're alive
HttpResponse::Ok().body("OK") (StatusCode::OK, "OK").into_response()
} }
/// Server readiness check - is the server ready to handle requests /// Server readiness check - is the server ready to handle requests
fn readiness(&self) -> HttpResponse; fn readiness(&self) -> Response;
} }
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
pub mod mock_worker; pub mod mock_worker;
pub mod test_app;
use actix_web::web;
use reqwest::Client;
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
use sglang_router_rs::server::AppState;
/// Helper function to create test router configuration
pub fn create_test_config(worker_urls: Vec<String>) -> RouterConfig {
RouterConfig {
mode: RoutingMode::Regular { worker_urls },
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 3001,
max_payload_size: 256 * 1024 * 1024, // 256MB
request_timeout_secs: 600,
worker_startup_timeout_secs: 300,
worker_startup_check_interval_secs: 10,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: None,
request_id_headers: None,
}
}
/// Helper function to create test router configuration with no health check
pub fn create_test_config_no_workers() -> RouterConfig {
RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec![],
}, // Empty to skip health check
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 3001,
max_payload_size: 256 * 1024 * 1024, // 256MB
request_timeout_secs: 600,
worker_startup_timeout_secs: 0, // No wait
worker_startup_check_interval_secs: 10,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: None,
request_id_headers: None,
}
}
/// Helper function to create test app state
pub async fn create_test_app_state(config: RouterConfig) -> Result<web::Data<AppState>, String> {
// Create a non-blocking client
let client = Client::builder()
.timeout(std::time::Duration::from_secs(config.request_timeout_secs))
.build()
.map_err(|e| e.to_string())?;
let app_state = AppState::new(config, client)?;
Ok(web::Data::new(app_state))
}
use axum::Router;
use reqwest::Client;
use sglang_router_rs::{
config::RouterConfig,
routers::RouterTrait,
server::{build_app, AppState},
};
use std::sync::Arc;
/// Create a test Axum application using the actual server's build_app function
pub fn create_test_app(
router: Arc<dyn RouterTrait>,
client: Client,
router_config: &RouterConfig,
) -> Router {
// Create AppState with the test router
let app_state = Arc::new(AppState {
router,
client,
_concurrency_limiter: Arc::new(tokio::sync::Semaphore::new(
router_config.max_concurrent_requests,
)),
});
// Configure request ID headers (use defaults if not specified)
let request_id_headers = router_config.request_id_headers.clone().unwrap_or_else(|| {
vec![
"x-request-id".to_string(),
"x-correlation-id".to_string(),
"x-trace-id".to_string(),
"request-id".to_string(),
]
});
// Use the actual server's build_app function
build_app(
app_state,
router_config.max_payload_size,
request_id_headers,
router_config.cors_allowed_origins.clone(),
)
}
This diff is collapsed.
This diff is collapsed.
...@@ -176,6 +176,8 @@ mod test_pd_routing { ...@@ -176,6 +176,8 @@ mod test_pd_routing {
log_dir: None, log_dir: None,
log_level: None, log_level: None,
request_id_headers: None, request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
}; };
// Router creation will fail due to health checks, but config should be valid // Router creation will fail due to health checks, but config should be valid
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment