Unverified Commit fe6a445d authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] improve router logs and request id header (#8415)

parent dd487e55
...@@ -93,6 +93,19 @@ python -m sglang_router.launch_router \ ...@@ -93,6 +93,19 @@ python -m sglang_router.launch_router \
--prometheus-port 9000 --prometheus-port 9000
``` ```
### Request ID Tracking
Track requests across distributed systems with configurable headers:
```bash
# Use custom request ID headers
python -m sglang_router.launch_router \
--worker-urls http://localhost:8080 \
--request-id-headers x-trace-id x-request-id
```
Default headers: `x-request-id`, `x-correlation-id`, `x-trace-id`, `request-id`
## Advanced Features ## Advanced Features
### Kubernetes Service Discovery ### Kubernetes Service Discovery
......
...@@ -64,6 +64,8 @@ class RouterArgs: ...@@ -64,6 +64,8 @@ class RouterArgs:
# Prometheus configuration # Prometheus configuration
prometheus_port: Optional[int] = None prometheus_port: Optional[int] = None
prometheus_host: Optional[str] = None prometheus_host: Optional[str] = None
# Request ID headers configuration
request_id_headers: Optional[List[str]] = None
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(
...@@ -255,6 +257,12 @@ class RouterArgs: ...@@ -255,6 +257,12 @@ class RouterArgs:
default="127.0.0.1", default="127.0.0.1",
help="Host address to bind the Prometheus metrics server", help="Host address to bind the Prometheus metrics server",
) )
parser.add_argument(
f"--{prefix}request-id-headers",
type=str,
nargs="*",
help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.",
)
@classmethod @classmethod
def from_cli_args( def from_cli_args(
...@@ -313,6 +321,7 @@ class RouterArgs: ...@@ -313,6 +321,7 @@ class RouterArgs:
bootstrap_port_annotation="sglang.ai/bootstrap-port", # Mooncake-specific annotation bootstrap_port_annotation="sglang.ai/bootstrap-port", # Mooncake-specific annotation
prometheus_port=getattr(args, f"{prefix}prometheus_port", None), prometheus_port=getattr(args, f"{prefix}prometheus_port", None),
prometheus_host=getattr(args, f"{prefix}prometheus_host", None), prometheus_host=getattr(args, f"{prefix}prometheus_host", None),
request_id_headers=getattr(args, f"{prefix}request_id_headers", None),
) )
@staticmethod @staticmethod
...@@ -481,6 +490,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: ...@@ -481,6 +490,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
if router_args.decode_policy if router_args.decode_policy
else None else None
), ),
request_id_headers=router_args.request_id_headers,
) )
router.start() router.start()
......
...@@ -54,6 +54,9 @@ class Router: ...@@ -54,6 +54,9 @@ class Router:
If not specified, uses the main policy. Default: None If not specified, uses the main policy. Default: None
decode_policy: Specific load balancing policy for decode nodes (PD mode only). decode_policy: Specific load balancing policy for decode nodes (PD mode only).
If not specified, uses the main policy. Default: None If not specified, uses the main policy. Default: None
request_id_headers: List of HTTP headers to check for request IDs. If not specified,
uses common defaults: ['x-request-id', 'x-correlation-id', 'x-trace-id', 'request-id'].
Example: ['x-my-request-id', 'x-custom-trace-id']. Default: None
""" """
def __init__( def __init__(
...@@ -85,6 +88,7 @@ class Router: ...@@ -85,6 +88,7 @@ class Router:
decode_urls: Optional[List[str]] = None, decode_urls: Optional[List[str]] = None,
prefill_policy: Optional[PolicyType] = None, prefill_policy: Optional[PolicyType] = None,
decode_policy: Optional[PolicyType] = None, decode_policy: Optional[PolicyType] = None,
request_id_headers: Optional[List[str]] = None,
): ):
if selector is None: if selector is None:
selector = {} selector = {}
...@@ -121,6 +125,7 @@ class Router: ...@@ -121,6 +125,7 @@ class Router:
decode_urls=decode_urls, decode_urls=decode_urls,
prefill_policy=prefill_policy, prefill_policy=prefill_policy,
decode_policy=decode_policy, decode_policy=decode_policy,
request_id_headers=request_id_headers,
) )
def start(self) -> None: def start(self) -> None:
......
...@@ -29,6 +29,8 @@ pub struct RouterConfig { ...@@ -29,6 +29,8 @@ pub struct RouterConfig {
pub log_dir: Option<String>, pub log_dir: Option<String>,
/// Log level (None = info) /// Log level (None = info)
pub log_level: Option<String>, pub log_level: Option<String>,
/// Custom request ID headers to check (defaults to common headers)
pub request_id_headers: Option<Vec<String>>,
} }
/// Routing mode configuration /// Routing mode configuration
...@@ -207,6 +209,7 @@ impl Default for RouterConfig { ...@@ -207,6 +209,7 @@ impl Default for RouterConfig {
metrics: None, metrics: None,
log_dir: None, log_dir: None,
log_level: None, log_level: None,
request_id_headers: None,
} }
} }
} }
...@@ -312,6 +315,7 @@ mod tests { ...@@ -312,6 +315,7 @@ mod tests {
metrics: Some(MetricsConfig::default()), metrics: Some(MetricsConfig::default()),
log_dir: Some("/var/log".to_string()), log_dir: Some("/var/log".to_string()),
log_level: Some("debug".to_string()), log_level: Some("debug".to_string()),
request_id_headers: None,
}; };
let json = serde_json::to_string(&config).unwrap(); let json = serde_json::to_string(&config).unwrap();
...@@ -734,6 +738,7 @@ mod tests { ...@@ -734,6 +738,7 @@ mod tests {
}), }),
log_dir: Some("/var/log/sglang".to_string()), log_dir: Some("/var/log/sglang".to_string()),
log_level: Some("info".to_string()), log_level: Some("info".to_string()),
request_id_headers: None,
}; };
assert!(config.mode.is_pd_mode()); assert!(config.mode.is_pd_mode());
...@@ -780,6 +785,7 @@ mod tests { ...@@ -780,6 +785,7 @@ mod tests {
metrics: Some(MetricsConfig::default()), metrics: Some(MetricsConfig::default()),
log_dir: None, log_dir: None,
log_level: Some("debug".to_string()), log_level: Some("debug".to_string()),
request_id_headers: None,
}; };
assert!(!config.mode.is_pd_mode()); assert!(!config.mode.is_pd_mode());
...@@ -822,6 +828,7 @@ mod tests { ...@@ -822,6 +828,7 @@ mod tests {
}), }),
log_dir: Some("/opt/logs/sglang".to_string()), log_dir: Some("/opt/logs/sglang".to_string()),
log_level: Some("trace".to_string()), log_level: Some("trace".to_string()),
request_id_headers: None,
}; };
assert!(config.has_service_discovery()); assert!(config.has_service_discovery());
......
...@@ -411,7 +411,7 @@ pub fn start_health_checker( ...@@ -411,7 +411,7 @@ pub fn start_health_checker(
// Check for shutdown signal // Check for shutdown signal
if shutdown_clone.load(Ordering::Acquire) { if shutdown_clone.load(Ordering::Acquire) {
tracing::info!("Health checker shutting down"); tracing::debug!("Health checker shutting down");
break; break;
} }
...@@ -439,6 +439,9 @@ pub fn start_health_checker( ...@@ -439,6 +439,9 @@ pub fn start_health_checker(
Err(e) => { Err(e) => {
if was_healthy { if was_healthy {
tracing::warn!("Worker {} health check failed: {}", worker_url, e); tracing::warn!("Worker {} health check failed: {}", worker_url, e);
} else {
// Worker was already unhealthy, log at debug level
tracing::debug!("Worker {} remains unhealthy: {}", worker_url, e);
} }
} }
} }
......
...@@ -4,6 +4,7 @@ pub mod logging; ...@@ -4,6 +4,7 @@ pub mod logging;
use std::collections::HashMap; use std::collections::HashMap;
pub mod core; pub mod core;
pub mod metrics; pub mod metrics;
pub mod middleware;
pub mod openai_api_types; pub mod openai_api_types;
pub mod policies; pub mod policies;
pub mod routers; pub mod routers;
...@@ -49,6 +50,7 @@ struct Router { ...@@ -49,6 +50,7 @@ struct Router {
prometheus_port: Option<u16>, prometheus_port: Option<u16>,
prometheus_host: Option<String>, prometheus_host: Option<String>,
request_timeout_secs: u64, request_timeout_secs: u64,
request_id_headers: Option<Vec<String>>,
// PD mode flag // PD mode flag
pd_disaggregation: bool, pd_disaggregation: bool,
// PD-specific fields (only used when pd_disaggregation is true) // PD-specific fields (only used when pd_disaggregation is true)
...@@ -138,6 +140,7 @@ impl Router { ...@@ -138,6 +140,7 @@ impl Router {
metrics, metrics,
log_dir: self.log_dir.clone(), log_dir: self.log_dir.clone(),
log_level: self.log_level.clone(), log_level: self.log_level.clone(),
request_id_headers: self.request_id_headers.clone(),
}) })
} }
} }
...@@ -170,6 +173,7 @@ impl Router { ...@@ -170,6 +173,7 @@ impl Router {
prometheus_port = None, prometheus_port = None,
prometheus_host = None, prometheus_host = None,
request_timeout_secs = 600, // Add configurable request timeout request_timeout_secs = 600, // Add configurable request timeout
request_id_headers = None, // Custom request ID headers
pd_disaggregation = false, // New flag for PD mode pd_disaggregation = false, // New flag for PD mode
prefill_urls = None, prefill_urls = None,
decode_urls = None, decode_urls = None,
...@@ -201,6 +205,7 @@ impl Router { ...@@ -201,6 +205,7 @@ impl Router {
prometheus_port: Option<u16>, prometheus_port: Option<u16>,
prometheus_host: Option<String>, prometheus_host: Option<String>,
request_timeout_secs: u64, request_timeout_secs: u64,
request_id_headers: Option<Vec<String>>,
pd_disaggregation: bool, pd_disaggregation: bool,
prefill_urls: Option<Vec<(String, Option<u16>)>>, prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>, decode_urls: Option<Vec<String>>,
...@@ -232,6 +237,7 @@ impl Router { ...@@ -232,6 +237,7 @@ impl Router {
prometheus_port, prometheus_port,
prometheus_host, prometheus_host,
request_timeout_secs, request_timeout_secs,
request_id_headers,
pd_disaggregation, pd_disaggregation,
prefill_urls, prefill_urls,
decode_urls, decode_urls,
...@@ -297,6 +303,7 @@ impl Router { ...@@ -297,6 +303,7 @@ impl Router {
service_discovery_config, service_discovery_config,
prometheus_config, prometheus_config,
request_timeout_secs: self.request_timeout_secs, request_timeout_secs: self.request_timeout_secs,
request_id_headers: self.request_id_headers.clone(),
}) })
.await .await
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string())) .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
......
use actix_web::{
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
Error, HttpMessage, HttpRequest,
};
use futures_util::future::LocalBoxFuture;
use std::future::{ready, Ready};
/// Generate OpenAI-compatible request ID based on endpoint
fn generate_request_id(path: &str) -> String {
let prefix = if path.contains("/chat/completions") {
"chatcmpl-"
} else if path.contains("/completions") {
"cmpl-"
} else if path.contains("/generate") {
"gnt-"
} else {
"req-"
};
// Generate a random string similar to OpenAI's format
let random_part: String = (0..24)
.map(|_| {
let chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
chars
.chars()
.nth(rand::random::<usize>() % chars.len())
.unwrap()
})
.collect();
format!("{}{}", prefix, random_part)
}
/// Extract request ID from request extensions or generate a new one
pub fn get_request_id(req: &HttpRequest) -> String {
req.extensions()
.get::<String>()
.cloned()
.unwrap_or_else(|| generate_request_id(req.path()))
}
/// Middleware for injecting request ID into request extensions
pub struct RequestIdMiddleware {
headers: Vec<String>,
}
impl RequestIdMiddleware {
pub fn new(headers: Vec<String>) -> Self {
Self { headers }
}
}
impl<S, B> Transform<S, ServiceRequest> for RequestIdMiddleware
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type InitError = ();
type Transform = RequestIdMiddlewareService<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(RequestIdMiddlewareService {
service,
headers: self.headers.clone(),
}))
}
}
pub struct RequestIdMiddlewareService<S> {
service: S,
headers: Vec<String>,
}
impl<S, B> Service<ServiceRequest> for RequestIdMiddlewareService<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
// Extract request ID from headers or generate new one
let mut request_id = None;
for header_name in &self.headers {
if let Some(header_value) = req.headers().get(header_name) {
if let Ok(value) = header_value.to_str() {
request_id = Some(value.to_string());
break;
}
}
}
let request_id = request_id.unwrap_or_else(|| generate_request_id(req.path()));
// Insert request ID into request extensions
req.extensions_mut().insert(request_id);
let fut = self.service.call(req);
Box::pin(async move { fut.await })
}
}
...@@ -66,7 +66,7 @@ use crate::tree::Tree; ...@@ -66,7 +66,7 @@ use crate::tree::Tree;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::thread; use std::thread;
use std::time::Duration; use std::time::Duration;
use tracing::{debug, info}; use tracing::debug;
/// Cache-aware routing policy /// Cache-aware routing policy
/// ///
...@@ -164,10 +164,8 @@ impl LoadBalancingPolicy for CacheAwarePolicy { ...@@ -164,10 +164,8 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
.map(|w| (w.url().to_string(), w.load())) .map(|w| (w.url().to_string(), w.load()))
.collect(); .collect();
info!( debug!(
"Load balancing triggered due to workload imbalance:\n\ "Load balancing triggered | max: {} | min: {} | workers: {:?}",
Max load: {}, Min load: {}\n\
Current worker loads: {:?}",
max_load, min_load, worker_loads max_load, min_load, worker_loads
); );
......
...@@ -5,6 +5,7 @@ use super::pd_types::{api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRou ...@@ -5,6 +5,7 @@ use super::pd_types::{api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRou
use super::request_adapter::ToPdRequest; use super::request_adapter::ToPdRequest;
use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard}; use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard};
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::middleware::get_request_id;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::policies::LoadBalancingPolicy; use crate::policies::LoadBalancingPolicy;
use crate::tree::Tree; use crate::tree::Tree;
...@@ -16,7 +17,6 @@ use std::collections::HashMap; ...@@ -16,7 +17,6 @@ use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock}; use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use uuid::Uuid;
#[derive(Debug)] #[derive(Debug)]
pub struct PDRouter { pub struct PDRouter {
...@@ -307,8 +307,8 @@ impl PDRouter { ...@@ -307,8 +307,8 @@ impl PDRouter {
mut typed_req: GenerateReqInput, mut typed_req: GenerateReqInput,
route: &str, route: &str,
) -> HttpResponse { ) -> HttpResponse {
let request_id = get_request_id(req);
let start = Instant::now(); let start = Instant::now();
let _request_id = Uuid::new_v4();
// Get stream flag and return_logprob flag before moving the request // Get stream flag and return_logprob flag before moving the request
let is_stream = typed_req.stream; let is_stream = typed_req.stream;
...@@ -328,7 +328,10 @@ impl PDRouter { ...@@ -328,7 +328,10 @@ impl PDRouter {
let (prefill, decode) = match self.select_pd_pair(client, request_text).await { let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
Ok(pair) => pair, Ok(pair) => pair,
Err(e) => { Err(e) => {
error!("Failed to select PD pair: {}", e); error!(
request_id = %request_id,
"Failed to select PD pair error={}", e
);
RouterMetrics::record_pd_error("server_selection"); RouterMetrics::record_pd_error("server_selection");
return HttpResponse::ServiceUnavailable() return HttpResponse::ServiceUnavailable()
.body(format!("No available servers: {}", e)); .body(format!("No available servers: {}", e));
...@@ -337,15 +340,17 @@ impl PDRouter { ...@@ -337,15 +340,17 @@ impl PDRouter {
// Log routing decision // Log routing decision
info!( info!(
"PD routing: {} -> prefill={}, decode={}", request_id = %request_id,
route, "PD routing decision route={} prefill_url={} decode_url={}",
prefill.url(), route, prefill.url(), decode.url()
decode.url()
); );
// Add bootstrap info using the trait method // Add bootstrap info using the trait method
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) { if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
error!("Failed to add bootstrap info: {}", e); error!(
request_id = %request_id,
"Failed to add bootstrap info error={}", e
);
RouterMetrics::record_pd_error("bootstrap_injection"); RouterMetrics::record_pd_error("bootstrap_injection");
return HttpResponse::InternalServerError() return HttpResponse::InternalServerError()
.body(format!("Bootstrap injection failed: {}", e)); .body(format!("Bootstrap injection failed: {}", e));
...@@ -355,7 +360,10 @@ impl PDRouter { ...@@ -355,7 +360,10 @@ impl PDRouter {
let json_with_bootstrap = match serde_json::to_value(&typed_req) { let json_with_bootstrap = match serde_json::to_value(&typed_req) {
Ok(json) => json, Ok(json) => json,
Err(e) => { Err(e) => {
error!("Failed to serialize request: {}", e); error!(
request_id = %request_id,
"Failed to serialize request error={}", e
);
return HttpResponse::InternalServerError().body("Failed to serialize request"); return HttpResponse::InternalServerError().body("Failed to serialize request");
} }
}; };
...@@ -383,6 +391,7 @@ impl PDRouter { ...@@ -383,6 +391,7 @@ impl PDRouter {
mut typed_req: ChatReqInput, mut typed_req: ChatReqInput,
route: &str, route: &str,
) -> HttpResponse { ) -> HttpResponse {
let request_id = get_request_id(req);
let start = Instant::now(); let start = Instant::now();
// Get stream flag and return_logprob flag before moving the request // Get stream flag and return_logprob flag before moving the request
...@@ -406,7 +415,10 @@ impl PDRouter { ...@@ -406,7 +415,10 @@ impl PDRouter {
let (prefill, decode) = match self.select_pd_pair(client, request_text).await { let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
Ok(pair) => pair, Ok(pair) => pair,
Err(e) => { Err(e) => {
error!("Failed to select PD pair: {}", e); error!(
request_id = %request_id,
"Failed to select PD pair error={}", e
);
RouterMetrics::record_pd_error("server_selection"); RouterMetrics::record_pd_error("server_selection");
return HttpResponse::ServiceUnavailable() return HttpResponse::ServiceUnavailable()
.body(format!("No available servers: {}", e)); .body(format!("No available servers: {}", e));
...@@ -415,15 +427,17 @@ impl PDRouter { ...@@ -415,15 +427,17 @@ impl PDRouter {
// Log routing decision // Log routing decision
info!( info!(
"PD routing: {} -> prefill={}, decode={}", request_id = %request_id,
route, "PD routing decision route={} prefill_url={} decode_url={}",
prefill.url(), route, prefill.url(), decode.url()
decode.url()
); );
// Add bootstrap info using the trait method // Add bootstrap info using the trait method
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) { if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
error!("Failed to add bootstrap info: {}", e); error!(
request_id = %request_id,
"Failed to add bootstrap info error={}", e
);
RouterMetrics::record_pd_error("bootstrap_injection"); RouterMetrics::record_pd_error("bootstrap_injection");
return HttpResponse::InternalServerError() return HttpResponse::InternalServerError()
.body(format!("Bootstrap injection failed: {}", e)); .body(format!("Bootstrap injection failed: {}", e));
...@@ -433,7 +447,10 @@ impl PDRouter { ...@@ -433,7 +447,10 @@ impl PDRouter {
let json_with_bootstrap = match serde_json::to_value(&typed_req) { let json_with_bootstrap = match serde_json::to_value(&typed_req) {
Ok(json) => json, Ok(json) => json,
Err(e) => { Err(e) => {
error!("Failed to serialize request: {}", e); error!(
request_id = %request_id,
"Failed to serialize request error={}", e
);
return HttpResponse::InternalServerError().body("Failed to serialize request"); return HttpResponse::InternalServerError().body("Failed to serialize request");
} }
}; };
...@@ -461,6 +478,7 @@ impl PDRouter { ...@@ -461,6 +478,7 @@ impl PDRouter {
mut typed_req: CompletionRequest, mut typed_req: CompletionRequest,
route: &str, route: &str,
) -> HttpResponse { ) -> HttpResponse {
let request_id = get_request_id(req);
let start = Instant::now(); let start = Instant::now();
// Get stream flag and return_logprob flag before moving the request // Get stream flag and return_logprob flag before moving the request
...@@ -477,7 +495,10 @@ impl PDRouter { ...@@ -477,7 +495,10 @@ impl PDRouter {
let (prefill, decode) = match self.select_pd_pair(client, request_text).await { let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
Ok(pair) => pair, Ok(pair) => pair,
Err(e) => { Err(e) => {
error!("Failed to select PD pair: {}", e); error!(
request_id = %request_id,
"Failed to select PD pair error={}", e
);
RouterMetrics::record_pd_error("server_selection"); RouterMetrics::record_pd_error("server_selection");
return HttpResponse::ServiceUnavailable() return HttpResponse::ServiceUnavailable()
.body(format!("No available servers: {}", e)); .body(format!("No available servers: {}", e));
...@@ -486,15 +507,17 @@ impl PDRouter { ...@@ -486,15 +507,17 @@ impl PDRouter {
// Log routing decision // Log routing decision
info!( info!(
"PD routing: {} -> prefill={}, decode={}", request_id = %request_id,
route, "PD routing decision route={} prefill_url={} decode_url={}",
prefill.url(), route, prefill.url(), decode.url()
decode.url()
); );
// Add bootstrap info using the trait method // Add bootstrap info using the trait method
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) { if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
error!("Failed to add bootstrap info: {}", e); error!(
request_id = %request_id,
"Failed to add bootstrap info error={}", e
);
RouterMetrics::record_pd_error("bootstrap_injection"); RouterMetrics::record_pd_error("bootstrap_injection");
return HttpResponse::InternalServerError() return HttpResponse::InternalServerError()
.body(format!("Bootstrap injection failed: {}", e)); .body(format!("Bootstrap injection failed: {}", e));
...@@ -504,7 +527,10 @@ impl PDRouter { ...@@ -504,7 +527,10 @@ impl PDRouter {
let json_with_bootstrap = match serde_json::to_value(&typed_req) { let json_with_bootstrap = match serde_json::to_value(&typed_req) {
Ok(json) => json, Ok(json) => json,
Err(e) => { Err(e) => {
error!("Failed to serialize request: {}", e); error!(
request_id = %request_id,
"Failed to serialize request error={}", e
);
return HttpResponse::InternalServerError().body("Failed to serialize request"); return HttpResponse::InternalServerError().body("Failed to serialize request");
} }
}; };
...@@ -538,6 +564,7 @@ impl PDRouter { ...@@ -538,6 +564,7 @@ impl PDRouter {
return_logprob: bool, return_logprob: bool,
start_time: Instant, start_time: Instant,
) -> HttpResponse { ) -> HttpResponse {
let request_id = get_request_id(req);
// Update load tracking for both workers // Update load tracking for both workers
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]); let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
...@@ -578,9 +605,9 @@ impl PDRouter { ...@@ -578,9 +605,9 @@ impl PDRouter {
if !status.is_success() { if !status.is_success() {
RouterMetrics::record_pd_decode_error(decode.url()); RouterMetrics::record_pd_decode_error(decode.url());
error!( error!(
"Decode server {} returned error status: {}", request_id = %request_id,
decode.url(), "Decode server returned error status decode_url={} status={}",
status decode.url(), status
); );
// Return the error response from decode server // Return the error response from decode server
...@@ -598,9 +625,9 @@ impl PDRouter { ...@@ -598,9 +625,9 @@ impl PDRouter {
// Log prefill errors for debugging // Log prefill errors for debugging
if let Err(e) = &prefill_result { if let Err(e) = &prefill_result {
error!( error!(
"Prefill server {} failed (non-critical): {}", request_id = %request_id,
prefill.url(), "Prefill server failed (non-critical) prefill_url={} error={}",
e prefill.url(), e
); );
RouterMetrics::record_pd_prefill_error(prefill.url()); RouterMetrics::record_pd_prefill_error(prefill.url());
} }
...@@ -684,7 +711,12 @@ impl PDRouter { ...@@ -684,7 +711,12 @@ impl PDRouter {
} }
} }
Err(e) => { Err(e) => {
error!("Decode request failed: {}", e); error!(
request_id = %request_id,
decode_url = %decode.url(),
error = %e,
"Decode request failed"
);
RouterMetrics::record_pd_decode_error(decode.url()); RouterMetrics::record_pd_decode_error(decode.url());
HttpResponse::BadGateway().body(format!("Decode server error: {}", e)) HttpResponse::BadGateway().body(format!("Decode server error: {}", e))
} }
......
use crate::core::{HealthChecker, Worker, WorkerFactory}; use crate::core::{HealthChecker, Worker, WorkerFactory};
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::middleware::get_request_id;
use crate::policies::LoadBalancingPolicy; use crate::policies::LoadBalancingPolicy;
use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse}; use actix_web::{HttpRequest, HttpResponse};
...@@ -134,32 +135,26 @@ impl Router { ...@@ -134,32 +135,26 @@ impl Router {
match sync_client.get(&format!("{}/health", url)).send() { match sync_client.get(&format!("{}/health", url)).send() {
Ok(res) => { Ok(res) => {
if !res.status().is_success() { if !res.status().is_success() {
let msg = format!(
"Worker heatlh check is pending with status {}",
res.status()
);
info!("{}", msg);
all_healthy = false; all_healthy = false;
unhealthy_workers.push((url, msg)); unhealthy_workers.push((url, format!("status: {}", res.status())));
} }
} }
Err(_) => { Err(_) => {
let msg = format!("Worker is not ready yet");
info!("{}", msg);
all_healthy = false; all_healthy = false;
unhealthy_workers.push((url, msg)); unhealthy_workers.push((url, "not ready".to_string()));
} }
} }
} }
if all_healthy { if all_healthy {
info!("All workers are healthy"); info!("All {} workers are healthy", worker_urls.len());
return Ok(()); return Ok(());
} else { } else {
info!("Initializing workers:"); debug!(
for (url, reason) in &unhealthy_workers { "Waiting for {} workers to become healthy ({} unhealthy)",
info!(" {} - {}", url, reason); worker_urls.len(),
} unhealthy_workers.len()
);
thread::sleep(Duration::from_secs(interval_secs)); thread::sleep(Duration::from_secs(interval_secs));
} }
} }
...@@ -181,6 +176,7 @@ impl Router { ...@@ -181,6 +176,7 @@ impl Router {
route: &str, route: &str,
req: &HttpRequest, req: &HttpRequest,
) -> HttpResponse { ) -> HttpResponse {
let request_id = get_request_id(req);
let start = Instant::now(); let start = Instant::now();
let mut request_builder = client.get(format!("{}{}", worker_url, route)); let mut request_builder = client.get(format!("{}{}", worker_url, route));
...@@ -202,14 +198,32 @@ impl Router { ...@@ -202,14 +198,32 @@ impl Router {
match res.bytes().await { match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()), Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(e) => HttpResponse::InternalServerError() Err(e) => {
.body(format!("Failed to read response body: {}", e)), error!(
request_id = %request_id,
worker_url = %worker_url,
route = %route,
error = %e,
"Failed to read response body"
);
HttpResponse::InternalServerError()
.body(format!("Failed to read response body: {}", e))
}
} }
} }
Err(e) => HttpResponse::InternalServerError().body(format!( Err(e) => {
"Failed to send request to worker {}: {}", error!(
worker_url, e request_id = %request_id,
)), worker_url = %worker_url,
route = %route,
error = %e,
"Failed to send request to worker"
);
HttpResponse::InternalServerError().body(format!(
"Failed to send request to worker {}: {}",
worker_url, e
))
}
}; };
// Record request metrics // Record request metrics
...@@ -231,6 +245,7 @@ impl Router { ...@@ -231,6 +245,7 @@ impl Router {
route: &str, route: &str,
req: &HttpRequest, req: &HttpRequest,
) -> HttpResponse { ) -> HttpResponse {
let request_id = get_request_id(req);
const MAX_REQUEST_RETRIES: u32 = 3; const MAX_REQUEST_RETRIES: u32 = 3;
const MAX_TOTAL_RETRIES: u32 = 6; const MAX_TOTAL_RETRIES: u32 = 6;
let mut total_retries = 0; let mut total_retries = 0;
...@@ -260,17 +275,23 @@ impl Router { ...@@ -260,17 +275,23 @@ impl Router {
} }
warn!( warn!(
"Request to {} failed (attempt {}/{})", request_id = %request_id,
worker_url, route = %route,
request_retries + 1, worker_url = %worker_url,
MAX_REQUEST_RETRIES attempt = request_retries + 1,
max_attempts = MAX_REQUEST_RETRIES,
"Request failed"
); );
request_retries += 1; request_retries += 1;
total_retries += 1; total_retries += 1;
if request_retries == MAX_REQUEST_RETRIES { if request_retries == MAX_REQUEST_RETRIES {
warn!("Removing failed worker: {}", worker_url); warn!(
request_id = %request_id,
worker_url = %worker_url,
"Removing failed worker"
);
self.remove_worker(&worker_url); self.remove_worker(&worker_url);
break; break;
} }
...@@ -293,6 +314,7 @@ impl Router { ...@@ -293,6 +314,7 @@ impl Router {
typed_req: &T, typed_req: &T,
route: &str, route: &str,
) -> HttpResponse { ) -> HttpResponse {
let request_id = get_request_id(req);
// Handle retries like the original implementation // Handle retries like the original implementation
let start = Instant::now(); let start = Instant::now();
const MAX_REQUEST_RETRIES: u32 = 3; const MAX_REQUEST_RETRIES: u32 = 3;
...@@ -357,17 +379,19 @@ impl Router { ...@@ -357,17 +379,19 @@ impl Router {
} }
warn!( warn!(
"Generate request to {} failed (attempt {}/{})", request_id = %request_id,
worker_url, "Generate request failed route={} worker_url={} attempt={} max_attempts={}",
request_retries + 1, route, worker_url, request_retries + 1, MAX_REQUEST_RETRIES
MAX_REQUEST_RETRIES
); );
request_retries += 1; request_retries += 1;
total_retries += 1; total_retries += 1;
if request_retries == MAX_REQUEST_RETRIES { if request_retries == MAX_REQUEST_RETRIES {
warn!("Removing failed worker: {}", worker_url); warn!(
request_id = %request_id,
"Removing failed worker after typed request failures worker_url={}", worker_url
);
self.remove_worker(&worker_url); self.remove_worker(&worker_url);
break; break;
} }
...@@ -402,13 +426,9 @@ impl Router { ...@@ -402,13 +426,9 @@ impl Router {
is_stream: bool, is_stream: bool,
load_incremented: bool, // Whether load was incremented for this request load_incremented: bool, // Whether load was incremented for this request
) -> HttpResponse { ) -> HttpResponse {
let request_id = get_request_id(req);
let start = Instant::now(); let start = Instant::now();
// Debug: Log what we're sending
if let Ok(json_str) = serde_json::to_string_pretty(typed_req) {
debug!("Sending request to {}: {}", route, json_str);
}
let mut request_builder = client let mut request_builder = client
.post(format!("{}{}", worker_url, route)) .post(format!("{}{}", worker_url, route))
.json(typed_req); // Use json() directly with typed request .json(typed_req); // Use json() directly with typed request
...@@ -424,7 +444,11 @@ impl Router { ...@@ -424,7 +444,11 @@ impl Router {
let res = match request_builder.send().await { let res = match request_builder.send().await {
Ok(res) => res, Ok(res) => res,
Err(e) => { Err(e) => {
error!("Failed to send request to {}: {}", worker_url, e); error!(
request_id = %request_id,
"Failed to send typed request worker_url={} route={} error={}",
worker_url, route, e
);
// Decrement load on error if it was incremented // Decrement load on error if it was incremented
if load_incremented { if load_incremented {
...@@ -497,7 +521,6 @@ impl Router { ...@@ -497,7 +521,6 @@ impl Router {
&worker_url, &worker_url,
worker.load(), worker.load(),
); );
debug!("Streaming is done!!")
} }
} }
} }
...@@ -536,7 +559,6 @@ impl Router { ...@@ -536,7 +559,6 @@ impl Router {
match client.get(&format!("{}/health", worker_url)).send().await { match client.get(&format!("{}/health", worker_url)).send().await {
Ok(res) => { Ok(res) => {
if res.status().is_success() { if res.status().is_success() {
info!("Worker {} health check passed", worker_url);
let mut workers_guard = self.workers.write().unwrap(); let mut workers_guard = self.workers.write().unwrap();
if workers_guard.iter().any(|w| w.url() == worker_url) { if workers_guard.iter().any(|w| w.url() == worker_url) {
return Err(format!("Worker {} already exists", worker_url)); return Err(format!("Worker {} already exists", worker_url));
...@@ -560,8 +582,8 @@ impl Router { ...@@ -560,8 +582,8 @@ impl Router {
return Ok(format!("Successfully added worker: {}", worker_url)); return Ok(format!("Successfully added worker: {}", worker_url));
} else { } else {
info!( debug!(
"Worker {} health check is pending with status: {}.", "Worker {} health check pending - status: {}",
worker_url, worker_url,
res.status() res.status()
); );
...@@ -576,10 +598,7 @@ impl Router { ...@@ -576,10 +598,7 @@ impl Router {
} }
} }
Err(e) => { Err(e) => {
info!( debug!("Worker {} health check pending - error: {}", worker_url, e);
"Worker {} health check is pending with error: {}",
worker_url, e
);
// if the url does not have http or https prefix, warn users // if the url does not have http or https prefix, warn users
if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") { if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") {
...@@ -611,7 +630,6 @@ impl Router { ...@@ -611,7 +630,6 @@ impl Router {
.downcast_ref::<crate::policies::CacheAwarePolicy>() .downcast_ref::<crate::policies::CacheAwarePolicy>()
{ {
cache_aware.remove_worker(worker_url); cache_aware.remove_worker(worker_url);
info!("Removed worker from tree: {}", worker_url);
} }
} }
...@@ -675,7 +693,6 @@ impl Router { ...@@ -675,7 +693,6 @@ impl Router {
for url in &worker_urls { for url in &worker_urls {
if let Some(load) = Self::get_worker_load_static(&client, url).await { if let Some(load) = Self::get_worker_load_static(&client, url).await {
loads.insert(url.clone(), load); loads.insert(url.clone(), load);
debug!("Worker {} load: {}", url, load);
} }
} }
......
use crate::config::RouterConfig; use crate::config::RouterConfig;
use crate::logging::{self, LoggingConfig}; use crate::logging::{self, LoggingConfig};
use crate::metrics::{self, PrometheusConfig}; use crate::metrics::{self, PrometheusConfig};
use crate::middleware::{get_request_id, RequestIdMiddleware};
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::routers::{RouterFactory, RouterTrait}; use crate::routers::{RouterFactory, RouterTrait};
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig}; use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
...@@ -46,13 +47,13 @@ async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<Ht ...@@ -46,13 +47,13 @@ async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<Ht
} }
// Custom error handler for JSON payload errors. // Custom error handler for JSON payload errors.
fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error { fn json_error_handler(err: error::JsonPayloadError, req: &HttpRequest) -> Error {
error!("JSON payload error: {:?}", err); let request_id = get_request_id(req);
match &err { match &err {
error::JsonPayloadError::OverflowKnownLength { length, limit } => { error::JsonPayloadError::OverflowKnownLength { length, limit } => {
error!( error!(
"Payload too large: {} bytes exceeds limit of {} bytes", request_id = %request_id,
length, limit "Payload too large length={} limit={}", length, limit
); );
error::ErrorPayloadTooLarge(format!( error::ErrorPayloadTooLarge(format!(
"Payload too large: {} bytes exceeds limit of {} bytes", "Payload too large: {} bytes exceeds limit of {} bytes",
...@@ -60,10 +61,19 @@ fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error ...@@ -60,10 +61,19 @@ fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error
)) ))
} }
error::JsonPayloadError::Overflow { limit } => { error::JsonPayloadError::Overflow { limit } => {
error!("Payload overflow: exceeds limit of {} bytes", limit); error!(
request_id = %request_id,
"Payload overflow limit={}", limit
);
error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit)) error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit))
} }
_ => error::ErrorBadRequest(format!("Invalid JSON payload: {}", err)), _ => {
error!(
request_id = %request_id,
"Invalid JSON payload error={}", err
);
error::ErrorBadRequest(format!("Invalid JSON payload: {}", err))
}
} }
} }
...@@ -108,8 +118,20 @@ async fn generate( ...@@ -108,8 +118,20 @@ async fn generate(
body: web::Json<GenerateRequest>, body: web::Json<GenerateRequest>,
state: web::Data<AppState>, state: web::Data<AppState>,
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
let json_body = serde_json::to_value(body.into_inner()) let request_id = get_request_id(&req);
.map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?; info!(
request_id = %request_id,
"Received generate request method=\"POST\" path=\"/generate\""
);
let json_body = serde_json::to_value(body.into_inner()).map_err(|e| {
error!(
request_id = %request_id,
"Failed to parse generate request body error={}", e
);
error::ErrorBadRequest(format!("Invalid JSON: {}", e))
})?;
Ok(state Ok(state
.router .router
.route_generate(&state.client, &req, json_body) .route_generate(&state.client, &req, json_body)
...@@ -122,8 +144,20 @@ async fn v1_chat_completions( ...@@ -122,8 +144,20 @@ async fn v1_chat_completions(
body: web::Json<ChatCompletionRequest>, body: web::Json<ChatCompletionRequest>,
state: web::Data<AppState>, state: web::Data<AppState>,
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
let json_body = serde_json::to_value(body.into_inner()) let request_id = get_request_id(&req);
.map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?; info!(
request_id = %request_id,
"Received chat completion request method=\"POST\" path=\"/v1/chat/completions\""
);
let json_body = serde_json::to_value(body.into_inner()).map_err(|e| {
error!(
request_id = %request_id,
"Failed to parse chat completion request body error={}", e
);
error::ErrorBadRequest(format!("Invalid JSON: {}", e))
})?;
Ok(state Ok(state
.router .router
.route_chat(&state.client, &req, json_body) .route_chat(&state.client, &req, json_body)
...@@ -136,8 +170,20 @@ async fn v1_completions( ...@@ -136,8 +170,20 @@ async fn v1_completions(
body: web::Json<CompletionRequest>, body: web::Json<CompletionRequest>,
state: web::Data<AppState>, state: web::Data<AppState>,
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
let json_body = serde_json::to_value(body.into_inner()) let request_id = get_request_id(&req);
.map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?; info!(
request_id = %request_id,
"Received completion request method=\"POST\" path=\"/v1/completions\""
);
let json_body = serde_json::to_value(body.into_inner()).map_err(|e| {
error!(
request_id = %request_id,
"Failed to parse completion request body error={}", e
);
error::ErrorBadRequest(format!("Invalid JSON: {}", e))
})?;
Ok(state Ok(state
.router .router
.route_completion(&state.client, &req, json_body) .route_completion(&state.client, &req, json_body)
...@@ -146,20 +192,48 @@ async fn v1_completions( ...@@ -146,20 +192,48 @@ async fn v1_completions(
#[post("/add_worker")] #[post("/add_worker")]
async fn add_worker( async fn add_worker(
req: HttpRequest,
query: web::Query<HashMap<String, String>>, query: web::Query<HashMap<String, String>>,
data: web::Data<AppState>, data: web::Data<AppState>,
) -> impl Responder { ) -> impl Responder {
let request_id = get_request_id(&req);
let worker_url = match query.get("url") { let worker_url = match query.get("url") {
Some(url) => url.to_string(), Some(url) => url.to_string(),
None => { None => {
warn!(
request_id = %request_id,
"Add worker request missing URL parameter"
);
return HttpResponse::BadRequest() return HttpResponse::BadRequest()
.body("Worker URL required. Provide 'url' query parameter") .body("Worker URL required. Provide 'url' query parameter");
} }
}; };
info!(
request_id = %request_id,
worker_url = %worker_url,
"Adding worker"
);
match data.router.add_worker(&worker_url).await { match data.router.add_worker(&worker_url).await {
Ok(message) => HttpResponse::Ok().body(message), Ok(message) => {
Err(error) => HttpResponse::BadRequest().body(error), info!(
request_id = %request_id,
worker_url = %worker_url,
"Successfully added worker"
);
HttpResponse::Ok().body(message)
}
Err(error) => {
error!(
request_id = %request_id,
worker_url = %worker_url,
error = %error,
"Failed to add worker"
);
HttpResponse::BadRequest().body(error)
}
} }
} }
...@@ -171,13 +245,29 @@ async fn list_workers(data: web::Data<AppState>) -> impl Responder { ...@@ -171,13 +245,29 @@ async fn list_workers(data: web::Data<AppState>) -> impl Responder {
#[post("/remove_worker")] #[post("/remove_worker")]
async fn remove_worker( async fn remove_worker(
req: HttpRequest,
query: web::Query<HashMap<String, String>>, query: web::Query<HashMap<String, String>>,
data: web::Data<AppState>, data: web::Data<AppState>,
) -> impl Responder { ) -> impl Responder {
let request_id = get_request_id(&req);
let worker_url = match query.get("url") { let worker_url = match query.get("url") {
Some(url) => url.to_string(), Some(url) => url.to_string(),
None => return HttpResponse::BadRequest().finish(), None => {
warn!(
request_id = %request_id,
"Remove worker request missing URL parameter"
);
return HttpResponse::BadRequest().finish();
}
}; };
info!(
request_id = %request_id,
worker_url = %worker_url,
"Removing worker"
);
data.router.remove_worker(&worker_url); data.router.remove_worker(&worker_url);
HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url)) HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
} }
...@@ -202,6 +292,7 @@ pub struct ServerConfig { ...@@ -202,6 +292,7 @@ pub struct ServerConfig {
pub service_discovery_config: Option<ServiceDiscoveryConfig>, pub service_discovery_config: Option<ServiceDiscoveryConfig>,
pub prometheus_config: Option<PrometheusConfig>, pub prometheus_config: Option<PrometheusConfig>,
pub request_timeout_secs: u64, pub request_timeout_secs: u64,
pub request_id_headers: Option<Vec<String>>,
} }
pub async fn startup(config: ServerConfig) -> std::io::Result<()> { pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
...@@ -233,31 +324,18 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { ...@@ -233,31 +324,18 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
// Initialize prometheus metrics exporter // Initialize prometheus metrics exporter
if let Some(prometheus_config) = config.prometheus_config { if let Some(prometheus_config) = config.prometheus_config {
info!(
"🚧 Initializing Prometheus metrics on {}:{}",
prometheus_config.host, prometheus_config.port
);
metrics::start_prometheus(prometheus_config); metrics::start_prometheus(prometheus_config);
} else {
info!("🚧 Prometheus metrics disabled");
} }
info!("🚧 Initializing router on {}:{}", config.host, config.port);
info!("🚧 Router mode: {:?}", config.router_config.mode);
info!("🚧 Policy: {:?}", config.router_config.policy);
info!( info!(
"🚧 Max payload size: {} MB", "Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB",
config.host,
config.port,
config.router_config.mode,
config.router_config.policy,
config.max_payload_size / (1024 * 1024) config.max_payload_size / (1024 * 1024)
); );
// Log service discovery status
if let Some(service_discovery_config) = &config.service_discovery_config {
info!("🚧 Service discovery enabled");
info!("🚧 Selector: {:?}", service_discovery_config.selector);
} else {
info!("🚧 Service discovery disabled");
}
let client = Client::builder() let client = Client::builder()
.pool_idle_timeout(Some(Duration::from_secs(50))) .pool_idle_timeout(Some(Duration::from_secs(50)))
.timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout .timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout
...@@ -272,11 +350,9 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { ...@@ -272,11 +350,9 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
// Start the service discovery if enabled // Start the service discovery if enabled
if let Some(service_discovery_config) = config.service_discovery_config { if let Some(service_discovery_config) = config.service_discovery_config {
if service_discovery_config.enabled { if service_discovery_config.enabled {
info!("🚧 Initializing Kubernetes service discovery");
// Pass the Arc<Router> directly
match start_service_discovery(service_discovery_config, router_arc).await { match start_service_discovery(service_discovery_config, router_arc).await {
Ok(handle) => { Ok(handle) => {
info!("Service discovery started successfully"); info!("Service discovery started");
// Spawn a task to handle the service discovery thread // Spawn a task to handle the service discovery thread
spawn(async move { spawn(async move {
if let Err(e) = handle.await { if let Err(e) = handle.await {
...@@ -292,14 +368,26 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { ...@@ -292,14 +368,26 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
} }
} }
info!("✅ Serving router on {}:{}", config.host, config.port);
info!( info!(
"✅ Serving workers on {:?}", "Router ready | workers: {:?}",
app_state.router.get_worker_urls() app_state.router.get_worker_urls()
); );
// Configure request ID headers
let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| {
vec![
"x-request-id".to_string(),
"x-correlation-id".to_string(),
"x-trace-id".to_string(),
"request-id".to_string(),
]
});
HttpServer::new(move || { HttpServer::new(move || {
let request_id_middleware = RequestIdMiddleware::new(request_id_headers.clone());
App::new() App::new()
.wrap(request_id_middleware)
.app_data(app_state.clone()) .app_data(app_state.clone())
.app_data( .app_data(
web::JsonConfig::default() web::JsonConfig::default()
......
...@@ -209,7 +209,7 @@ pub async fn start_service_discovery( ...@@ -209,7 +209,7 @@ pub async fn start_service_discovery(
.join(","); .join(",");
info!( info!(
"Starting Kubernetes service discovery in PD mode with prefill_selector: '{}', decode_selector: '{}'", "Starting K8s service discovery | PD mode | prefill: '{}' | decode: '{}'",
prefill_selector, decode_selector prefill_selector, decode_selector
); );
} else { } else {
...@@ -221,7 +221,7 @@ pub async fn start_service_discovery( ...@@ -221,7 +221,7 @@ pub async fn start_service_discovery(
.join(","); .join(",");
info!( info!(
"Starting Kubernetes service discovery with selector: '{}'", "Starting K8s service discovery | selector: '{}'",
label_selector label_selector
); );
} }
...@@ -238,7 +238,7 @@ pub async fn start_service_discovery( ...@@ -238,7 +238,7 @@ pub async fn start_service_discovery(
Api::all(client) Api::all(client)
}; };
info!("Kubernetes service discovery initialized successfully"); debug!("K8s service discovery initialized");
// Create Arcs for configuration data // Create Arcs for configuration data
let config_arc = Arc::new(config.clone()); let config_arc = Arc::new(config.clone());
...@@ -375,7 +375,7 @@ async fn handle_pod_event( ...@@ -375,7 +375,7 @@ async fn handle_pod_event(
if should_add { if should_add {
info!( info!(
"Healthy pod found: {} (type: {:?}). Adding worker: {}", "Adding pod: {} | type: {:?} | url: {}",
pod_info.name, pod_info.pod_type, worker_url pod_info.name, pod_info.pod_type, worker_url
); );
...@@ -409,8 +409,8 @@ async fn handle_pod_event( ...@@ -409,8 +409,8 @@ async fn handle_pod_event(
}; };
match result { match result {
Ok(msg) => { Ok(_) => {
info!("Successfully added worker: {}", msg); debug!("Worker added: {}", worker_url);
} }
Err(e) => { Err(e) => {
error!("Failed to add worker {} to router: {}", worker_url, e); error!("Failed to add worker {} to router: {}", worker_url, e);
...@@ -446,7 +446,7 @@ async fn handle_pod_deletion( ...@@ -446,7 +446,7 @@ async fn handle_pod_deletion(
if was_tracked { if was_tracked {
info!( info!(
"Pod deleted: {} (type: {:?}). Removing worker: {}", "Removing pod: {} | type: {:?} | url: {}",
pod_info.name, pod_info.pod_type, worker_url pod_info.name, pod_info.pod_type, worker_url
); );
......
...@@ -35,6 +35,7 @@ impl TestContext { ...@@ -35,6 +35,7 @@ impl TestContext {
metrics: None, metrics: None,
log_dir: None, log_dir: None,
log_level: None, log_level: None,
request_id_headers: None,
}; };
Self::new_with_config(config, worker_configs).await Self::new_with_config(config, worker_configs).await
...@@ -953,6 +954,7 @@ mod error_tests { ...@@ -953,6 +954,7 @@ mod error_tests {
metrics: None, metrics: None,
log_dir: None, log_dir: None,
log_level: None, log_level: None,
request_id_headers: None,
}; };
let ctx = TestContext::new_with_config( let ctx = TestContext::new_with_config(
......
...@@ -20,6 +20,7 @@ pub fn create_test_config(worker_urls: Vec<String>) -> RouterConfig { ...@@ -20,6 +20,7 @@ pub fn create_test_config(worker_urls: Vec<String>) -> RouterConfig {
metrics: None, metrics: None,
log_dir: None, log_dir: None,
log_level: None, log_level: None,
request_id_headers: None,
} }
} }
...@@ -40,6 +41,7 @@ pub fn create_test_config_no_workers() -> RouterConfig { ...@@ -40,6 +41,7 @@ pub fn create_test_config_no_workers() -> RouterConfig {
metrics: None, metrics: None,
log_dir: None, log_dir: None,
log_level: None, log_level: None,
request_id_headers: None,
} }
} }
......
...@@ -46,6 +46,7 @@ impl RequestTestContext { ...@@ -46,6 +46,7 @@ impl RequestTestContext {
metrics: None, metrics: None,
log_dir: None, log_dir: None,
log_level: None, log_level: None,
request_id_headers: None,
}; };
let client = Client::builder() let client = Client::builder()
......
...@@ -50,6 +50,7 @@ impl StreamingTestContext { ...@@ -50,6 +50,7 @@ impl StreamingTestContext {
metrics: None, metrics: None,
log_dir: None, log_dir: None,
log_level: None, log_level: None,
request_id_headers: None,
}; };
let client = Client::builder() let client = Client::builder()
......
...@@ -173,6 +173,7 @@ mod test_pd_routing { ...@@ -173,6 +173,7 @@ mod test_pd_routing {
metrics: None, metrics: None,
log_dir: None, log_dir: None,
log_level: None, log_level: None,
request_id_headers: None,
}; };
// Router creation will fail due to health checks, but config should be valid // Router creation will fail due to health checks, but config should be valid
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment